| 1 | //===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements functionality to progressively decompose coarse-grained |
| 10 | // affine ops into finer-grained ops. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 15 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
| 16 | #include "mlir/IR/PatternMatch.h" |
| 17 | #include "llvm/Support/Debug.h" |
| 18 | |
| 19 | using namespace mlir; |
| 20 | using namespace mlir::affine; |
| 21 | |
| 22 | #define DEBUG_TYPE "decompose-affine-ops" |
| 23 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
| 24 | #define DBGSNL() (llvm::dbgs() << "\n") |
| 25 | |
| 26 | /// Count the number of loops surrounding `operand` such that operand could be |
| 27 | /// hoisted above. |
| 28 | /// Stop counting at the first loop over which the operand cannot be hoisted. |
| 29 | static int64_t numEnclosingInvariantLoops(OpOperand &operand) { |
| 30 | int64_t count = 0; |
| 31 | Operation *currentOp = operand.getOwner(); |
| 32 | while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) { |
| 33 | if (!loopOp.isDefinedOutsideOfLoop(value: operand.get())) |
| 34 | break; |
| 35 | currentOp = loopOp; |
| 36 | count++; |
| 37 | } |
| 38 | return count; |
| 39 | } |
| 40 | |
| 41 | void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter, |
| 42 | AffineApplyOp op) { |
| 43 | SmallVector<int64_t> numInvariant = llvm::to_vector( |
| 44 | Range: llvm::map_range(C: op->getOpOperands(), F: [&](OpOperand &operand) { |
| 45 | return numEnclosingInvariantLoops(operand); |
| 46 | })); |
| 47 | |
| 48 | int64_t numOperands = op.getNumOperands(); |
| 49 | SmallVector<int64_t> operandPositions = |
| 50 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOperands)); |
| 51 | llvm::stable_sort(Range&: operandPositions, C: [&numInvariant](size_t i1, size_t i2) { |
| 52 | return numInvariant[i1] > numInvariant[i2]; |
| 53 | }); |
| 54 | |
| 55 | SmallVector<AffineExpr> replacements(numOperands); |
| 56 | SmallVector<Value> operands(numOperands); |
| 57 | for (int64_t i = 0; i < numOperands; ++i) { |
| 58 | operands[i] = op.getOperand(i: operandPositions[i]); |
| 59 | replacements[operandPositions[i]] = getAffineSymbolExpr(position: i, context: op.getContext()); |
| 60 | } |
| 61 | |
| 62 | AffineMap map = op.getAffineMap(); |
| 63 | ArrayRef<AffineExpr> repls{replacements}; |
| 64 | map = map.replaceDimsAndSymbols(dimReplacements: repls.take_front(N: map.getNumDims()), |
| 65 | symReplacements: repls.drop_front(N: map.getNumDims()), |
| 66 | /*numResultDims=*/0, |
| 67 | /*numResultSyms=*/numOperands); |
| 68 | map = AffineMap::get(dimCount: 0, symbolCount: numOperands, |
| 69 | results: simplifyAffineExpr(expr: map.getResult(idx: 0), numDims: 0, numSymbols: numOperands), |
| 70 | context: op->getContext()); |
| 71 | canonicalizeMapAndOperands(map: &map, operands: &operands); |
| 72 | |
| 73 | rewriter.startOpModification(op); |
| 74 | op.setMap(map); |
| 75 | op->setOperands(operands); |
| 76 | rewriter.finalizeOpModification(op); |
| 77 | } |
| 78 | |
| 79 | /// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine |
| 80 | /// map and with the same operands. |
| 81 | /// Canonicalize the map and operands to deduplicate and drop dead operands |
| 82 | /// before returning but do not perform maximal composition of AffineApplyOp |
| 83 | /// which would defeat the purpose. |
| 84 | static AffineApplyOp createSubApply(RewriterBase &rewriter, |
| 85 | AffineApplyOp originalOp, AffineExpr expr) { |
| 86 | MLIRContext *ctx = originalOp->getContext(); |
| 87 | AffineMap m = originalOp.getAffineMap(); |
| 88 | auto rhsMap = AffineMap::get(dimCount: m.getNumDims(), symbolCount: m.getNumSymbols(), results: expr, context: ctx); |
| 89 | SmallVector<Value> rhsOperands = originalOp->getOperands(); |
| 90 | canonicalizeMapAndOperands(map: &rhsMap, operands: &rhsOperands); |
| 91 | return rewriter.create<AffineApplyOp>(location: originalOp.getLoc(), args&: rhsMap, |
| 92 | args&: rhsOperands); |
| 93 | } |
| 94 | |
| 95 | FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, |
| 96 | AffineApplyOp op) { |
| 97 | // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a |
| 98 | // top-level binary expression that we can reassociate (i.e. add or mul). |
| 99 | AffineMap m = op.getAffineMap(); |
| 100 | if (m.getNumDims() > 0) |
| 101 | return rewriter.notifyMatchFailure(arg&: op, msg: "expected no dims" ); |
| 102 | |
| 103 | AffineExpr remainingExp = m.getResult(idx: 0); |
| 104 | auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp); |
| 105 | if (!binExpr) |
| 106 | return rewriter.notifyMatchFailure(arg&: op, msg: "terminal affine.apply" ); |
| 107 | |
| 108 | if (!isa<AffineBinaryOpExpr>(Val: binExpr.getLHS()) && |
| 109 | !isa<AffineBinaryOpExpr>(Val: binExpr.getRHS())) |
| 110 | return rewriter.notifyMatchFailure(arg&: op, msg: "terminal affine.apply" ); |
| 111 | |
| 112 | bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) || |
| 113 | (binExpr.getKind() == AffineExprKind::Mul)); |
| 114 | if (!supportedKind) |
| 115 | return rewriter.notifyMatchFailure( |
| 116 | arg&: op, msg: "only add or mul binary expr can be reassociated" ); |
| 117 | |
| 118 | LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n" ); |
| 119 | |
| 120 | // 2. Iteratively extract the RHS subexpressions while the top-level binary |
| 121 | // expr kind remains the same. |
| 122 | MLIRContext *ctx = op->getContext(); |
| 123 | SmallVector<AffineExpr> subExpressions; |
| 124 | while (true) { |
| 125 | auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp); |
| 126 | if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { |
| 127 | subExpressions.push_back(Elt: remainingExp); |
| 128 | LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n" ); |
| 129 | break; |
| 130 | } |
| 131 | subExpressions.push_back(Elt: currentBinExpr.getRHS()); |
| 132 | LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n" ); |
| 133 | remainingExp = currentBinExpr.getLHS(); |
| 134 | } |
| 135 | |
| 136 | // 3. Reorder subExpressions by the min symbol they are a function of. |
| 137 | // This also takes care of properly reordering local variables. |
| 138 | // This however won't be able to split expression that cannot be reassociated |
| 139 | // such as ones that involve divs and multiple symbols. |
| 140 | auto getMaxSymbol = [&](AffineExpr e) -> int64_t { |
| 141 | for (int64_t i = m.getNumSymbols(); i >= 0; --i) |
| 142 | if (e.isFunctionOfSymbol(position: i)) |
| 143 | return i; |
| 144 | return -1; |
| 145 | }; |
| 146 | llvm::stable_sort(Range&: subExpressions, C: [&](AffineExpr e1, AffineExpr e2) { |
| 147 | return getMaxSymbol(e1) < getMaxSymbol(e2); |
| 148 | }); |
| 149 | LLVM_DEBUG( |
| 150 | llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: " ); |
| 151 | llvm::dbgs() << "\n" ); |
| 152 | |
| 153 | // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. |
| 154 | auto s0 = getAffineSymbolExpr(position: 0, context: ctx); |
| 155 | auto s1 = getAffineSymbolExpr(position: 1, context: ctx); |
| 156 | AffineMap binMap = AffineMap::get( |
| 157 | /*dimCount=*/0, /*symbolCount=*/2, |
| 158 | results: getAffineBinaryOpExpr(kind: binExpr.getKind(), lhs: s0, rhs: s1), context: ctx); |
| 159 | |
| 160 | auto current = createSubApply(rewriter, originalOp: op, expr: subExpressions[0]); |
| 161 | for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { |
| 162 | Value tmp = createSubApply(rewriter, originalOp: op, expr: subExpressions[i]); |
| 163 | current = rewriter.create<AffineApplyOp>(location: op.getLoc(), args&: binMap, |
| 164 | args: ValueRange{current, tmp}); |
| 165 | LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n" ); |
| 166 | } |
| 167 | |
| 168 | // 5. Replace original op. |
| 169 | rewriter.replaceOp(op, newValues: current.getResult()); |
| 170 | return current; |
| 171 | } |
| 172 | |