| 1 | //===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements functionality to progressively decompose coarse-grained |
| 10 | // affine ops into finer-grained ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| 16 | #include "mlir/IR/PatternMatch.h" |
| 17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 18 | #include "llvm/Support/Debug.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | using namespace mlir::affine; |
| 22 | |
| 23 | #define DEBUG_TYPE "decompose-affine-ops" |
| 24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 25 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 26 | |
| 27 | /// Count the number of loops surrounding `operand` such that operand could be |
| 28 | /// hoisted above. |
| 29 | /// Stop counting at the first loop over which the operand cannot be hoisted. |
| 30 | static int64_t numEnclosingInvariantLoops(OpOperand &operand) { |
| 31 | int64_t count = 0; |
| 32 | Operation *currentOp = operand.getOwner(); |
| 33 | while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) { |
| 34 | if (!loopOp.isDefinedOutsideOfLoop(operand.get())) |
| 35 | break; |
| 36 | currentOp = loopOp; |
| 37 | count++; |
| 38 | } |
| 39 | return count; |
| 40 | } |
| 41 | |
| 42 | void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter, |
| 43 | AffineApplyOp op) { |
| 44 | SmallVector<int64_t> numInvariant = llvm::to_vector( |
| 45 | llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) { |
| 46 | return numEnclosingInvariantLoops(operand); |
| 47 | })); |
| 48 | |
| 49 | int64_t numOperands = op.getNumOperands(); |
| 50 | SmallVector<int64_t> operandPositions = |
| 51 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOperands)); |
| 52 | llvm::stable_sort(Range&: operandPositions, C: [&numInvariant](size_t i1, size_t i2) { |
| 53 | return numInvariant[i1] > numInvariant[i2]; |
| 54 | }); |
| 55 | |
| 56 | SmallVector<AffineExpr> replacements(numOperands); |
| 57 | SmallVector<Value> operands(numOperands); |
| 58 | for (int64_t i = 0; i < numOperands; ++i) { |
| 59 | operands[i] = op.getOperand(operandPositions[i]); |
| 60 | replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext()); |
| 61 | } |
| 62 | |
| 63 | AffineMap map = op.getAffineMap(); |
| 64 | ArrayRef<AffineExpr> repls{replacements}; |
| 65 | map = map.replaceDimsAndSymbols(dimReplacements: repls.take_front(N: map.getNumDims()), |
| 66 | symReplacements: repls.drop_front(N: map.getNumDims()), |
| 67 | /*numResultDims=*/0, |
| 68 | /*numResultSyms=*/numOperands); |
| 69 | map = AffineMap::get(0, numOperands, |
| 70 | simplifyAffineExpr(expr: map.getResult(idx: 0), numDims: 0, numSymbols: numOperands), |
| 71 | op->getContext()); |
| 72 | canonicalizeMapAndOperands(map: &map, operands: &operands); |
| 73 | |
| 74 | rewriter.startOpModification(op: op); |
| 75 | op.setMap(map); |
| 76 | op->setOperands(operands); |
| 77 | rewriter.finalizeOpModification(op: op); |
| 78 | } |
| 79 | |
| 80 | /// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine |
| 81 | /// map and with the same operands. |
| 82 | /// Canonicalize the map and operands to deduplicate and drop dead operands |
| 83 | /// before returning but do not perform maximal composition of AffineApplyOp |
| 84 | /// which would defeat the purpose. |
| 85 | static AffineApplyOp createSubApply(RewriterBase &rewriter, |
| 86 | AffineApplyOp originalOp, AffineExpr expr) { |
| 87 | MLIRContext *ctx = originalOp->getContext(); |
| 88 | AffineMap m = originalOp.getAffineMap(); |
| 89 | auto rhsMap = AffineMap::get(dimCount: m.getNumDims(), symbolCount: m.getNumSymbols(), results: expr, context: ctx); |
| 90 | SmallVector<Value> rhsOperands = originalOp->getOperands(); |
| 91 | canonicalizeMapAndOperands(&rhsMap, &rhsOperands); |
| 92 | return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap, |
| 93 | rhsOperands); |
| 94 | } |
| 95 | |
| 96 | FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, |
| 97 | AffineApplyOp op) { |
| 98 | // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a |
| 99 | // top-level binary expression that we can reassociate (i.e. add or mul). |
| 100 | AffineMap m = op.getAffineMap(); |
| 101 | if (m.getNumDims() > 0) |
| 102 | return rewriter.notifyMatchFailure(op, "expected no dims" ); |
| 103 | |
| 104 | AffineExpr remainingExp = m.getResult(idx: 0); |
| 105 | auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp); |
| 106 | if (!binExpr) |
| 107 | return rewriter.notifyMatchFailure(op, "terminal affine.apply" ); |
| 108 | |
| 109 | if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) && |
| 110 | !isa<AffineBinaryOpExpr>(binExpr.getRHS())) |
| 111 | return rewriter.notifyMatchFailure(op, "terminal affine.apply" ); |
| 112 | |
| 113 | bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) || |
| 114 | (binExpr.getKind() == AffineExprKind::Mul)); |
| 115 | if (!supportedKind) |
| 116 | return rewriter.notifyMatchFailure( |
| 117 | op, "only add or mul binary expr can be reassociated" ); |
| 118 | |
| 119 | LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n" ); |
| 120 | |
| 121 | // 2. Iteratively extract the RHS subexpressions while the top-level binary |
| 122 | // expr kind remains the same. |
| 123 | MLIRContext *ctx = op->getContext(); |
| 124 | SmallVector<AffineExpr> subExpressions; |
| 125 | while (true) { |
| 126 | auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp); |
| 127 | if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { |
| 128 | subExpressions.push_back(Elt: remainingExp); |
| 129 | LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n" ); |
| 130 | break; |
| 131 | } |
| 132 | subExpressions.push_back(Elt: currentBinExpr.getRHS()); |
| 133 | LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n" ); |
| 134 | remainingExp = currentBinExpr.getLHS(); |
| 135 | } |
| 136 | |
| 137 | // 3. Reorder subExpressions by the min symbol they are a function of. |
| 138 | // This also takes care of properly reordering local variables. |
| 139 | // This however won't be able to split expression that cannot be reassociated |
| 140 | // such as ones that involve divs and multiple symbols. |
| 141 | auto getMaxSymbol = [&](AffineExpr e) -> int64_t { |
| 142 | for (int64_t i = m.getNumSymbols(); i >= 0; --i) |
| 143 | if (e.isFunctionOfSymbol(position: i)) |
| 144 | return i; |
| 145 | return -1; |
| 146 | }; |
| 147 | llvm::stable_sort(Range&: subExpressions, C: [&](AffineExpr e1, AffineExpr e2) { |
| 148 | return getMaxSymbol(e1) < getMaxSymbol(e2); |
| 149 | }); |
| 150 | LLVM_DEBUG( |
| 151 | llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: " ); |
| 152 | llvm::dbgs() << "\n" ); |
| 153 | |
| 154 | // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. |
| 155 | auto s0 = getAffineSymbolExpr(position: 0, context: ctx); |
| 156 | auto s1 = getAffineSymbolExpr(position: 1, context: ctx); |
| 157 | AffineMap binMap = AffineMap::get( |
| 158 | /*dimCount=*/0, /*symbolCount=*/2, |
| 159 | getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx); |
| 160 | |
| 161 | auto current = createSubApply(rewriter, op, subExpressions[0]); |
| 162 | for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { |
| 163 | Value tmp = createSubApply(rewriter, op, subExpressions[i]); |
| 164 | current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap, |
| 165 | ValueRange{current, tmp}); |
| 166 | LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n" ); |
| 167 | } |
| 168 | |
| 169 | // 5. Replace original op. |
| 170 | rewriter.replaceOp(op, current.getResult()); |
| 171 | return current; |
| 172 | } |
| 173 | |