1 | //===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements functionality to progressively decompose coarse-grained |
10 | // affine ops into finer-grained ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
15 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
18 | #include "llvm/Support/Debug.h" |
19 | |
20 | using namespace mlir; |
21 | using namespace mlir::affine; |
22 | |
23 | #define DEBUG_TYPE "decompose-affine-ops" |
24 | #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") |
25 | #define DBGSNL() (llvm::dbgs() << "\n") |
26 | |
27 | /// Count the number of loops surrounding `operand` such that operand could be |
28 | /// hoisted above. |
29 | /// Stop counting at the first loop over which the operand cannot be hoisted. |
30 | static int64_t numEnclosingInvariantLoops(OpOperand &operand) { |
31 | int64_t count = 0; |
32 | Operation *currentOp = operand.getOwner(); |
33 | while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) { |
34 | if (!loopOp.isDefinedOutsideOfLoop(operand.get())) |
35 | break; |
36 | currentOp = loopOp; |
37 | count++; |
38 | } |
39 | return count; |
40 | } |
41 | |
42 | void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter, |
43 | AffineApplyOp op) { |
44 | SmallVector<int64_t> numInvariant = llvm::to_vector( |
45 | llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) { |
46 | return numEnclosingInvariantLoops(operand); |
47 | })); |
48 | |
49 | int64_t numOperands = op.getNumOperands(); |
50 | SmallVector<int64_t> operandPositions = |
51 | llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOperands)); |
52 | llvm::stable_sort(Range&: operandPositions, C: [&numInvariant](size_t i1, size_t i2) { |
53 | return numInvariant[i1] > numInvariant[i2]; |
54 | }); |
55 | |
56 | SmallVector<AffineExpr> replacements(numOperands); |
57 | SmallVector<Value> operands(numOperands); |
58 | for (int64_t i = 0; i < numOperands; ++i) { |
59 | operands[i] = op.getOperand(operandPositions[i]); |
60 | replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext()); |
61 | } |
62 | |
63 | AffineMap map = op.getAffineMap(); |
64 | ArrayRef<AffineExpr> repls{replacements}; |
65 | map = map.replaceDimsAndSymbols(dimReplacements: repls.take_front(N: map.getNumDims()), |
66 | symReplacements: repls.drop_front(N: map.getNumDims()), |
67 | /*numResultDims=*/0, |
68 | /*numResultSyms=*/numOperands); |
69 | map = AffineMap::get(0, numOperands, |
70 | simplifyAffineExpr(expr: map.getResult(idx: 0), numDims: 0, numSymbols: numOperands), |
71 | op->getContext()); |
72 | canonicalizeMapAndOperands(map: &map, operands: &operands); |
73 | |
74 | rewriter.startOpModification(op: op); |
75 | op.setMap(map); |
76 | op->setOperands(operands); |
77 | rewriter.finalizeOpModification(op: op); |
78 | } |
79 | |
80 | /// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine |
81 | /// map and with the same operands. |
82 | /// Canonicalize the map and operands to deduplicate and drop dead operands |
83 | /// before returning but do not perform maximal composition of AffineApplyOp |
84 | /// which would defeat the purpose. |
85 | static AffineApplyOp createSubApply(RewriterBase &rewriter, |
86 | AffineApplyOp originalOp, AffineExpr expr) { |
87 | MLIRContext *ctx = originalOp->getContext(); |
88 | AffineMap m = originalOp.getAffineMap(); |
89 | auto rhsMap = AffineMap::get(dimCount: m.getNumDims(), symbolCount: m.getNumSymbols(), results: expr, context: ctx); |
90 | SmallVector<Value> rhsOperands = originalOp->getOperands(); |
91 | canonicalizeMapAndOperands(&rhsMap, &rhsOperands); |
92 | return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap, |
93 | rhsOperands); |
94 | } |
95 | |
96 | FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, |
97 | AffineApplyOp op) { |
98 | // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a |
99 | // top-level binary expression that we can reassociate (i.e. add or mul). |
100 | AffineMap m = op.getAffineMap(); |
101 | if (m.getNumDims() > 0) |
102 | return rewriter.notifyMatchFailure(op, "expected no dims" ); |
103 | |
104 | AffineExpr remainingExp = m.getResult(idx: 0); |
105 | auto binExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp); |
106 | if (!binExpr) |
107 | return rewriter.notifyMatchFailure(op, "terminal affine.apply" ); |
108 | |
109 | if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) && |
110 | !isa<AffineBinaryOpExpr>(binExpr.getRHS())) |
111 | return rewriter.notifyMatchFailure(op, "terminal affine.apply" ); |
112 | |
113 | bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) || |
114 | (binExpr.getKind() == AffineExprKind::Mul)); |
115 | if (!supportedKind) |
116 | return rewriter.notifyMatchFailure( |
117 | op, "only add or mul binary expr can be reassociated" ); |
118 | |
119 | LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n" ); |
120 | |
121 | // 2. Iteratively extract the RHS subexpressions while the top-level binary |
122 | // expr kind remains the same. |
123 | MLIRContext *ctx = op->getContext(); |
124 | SmallVector<AffineExpr> subExpressions; |
125 | while (true) { |
126 | auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(Val&: remainingExp); |
127 | if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { |
128 | subExpressions.push_back(Elt: remainingExp); |
129 | LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n" ); |
130 | break; |
131 | } |
132 | subExpressions.push_back(Elt: currentBinExpr.getRHS()); |
133 | LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n" ); |
134 | remainingExp = currentBinExpr.getLHS(); |
135 | } |
136 | |
137 | // 3. Reorder subExpressions by the min symbol they are a function of. |
138 | // This also takes care of properly reordering local variables. |
139 | // This however won't be able to split expression that cannot be reassociated |
140 | // such as ones that involve divs and multiple symbols. |
141 | auto getMaxSymbol = [&](AffineExpr e) -> int64_t { |
142 | for (int64_t i = m.getNumSymbols(); i >= 0; --i) |
143 | if (e.isFunctionOfSymbol(position: i)) |
144 | return i; |
145 | return -1; |
146 | }; |
147 | llvm::stable_sort(Range&: subExpressions, C: [&](AffineExpr e1, AffineExpr e2) { |
148 | return getMaxSymbol(e1) < getMaxSymbol(e2); |
149 | }); |
150 | LLVM_DEBUG( |
151 | llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: " ); |
152 | llvm::dbgs() << "\n" ); |
153 | |
154 | // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. |
155 | auto s0 = getAffineSymbolExpr(position: 0, context: ctx); |
156 | auto s1 = getAffineSymbolExpr(position: 1, context: ctx); |
157 | AffineMap binMap = AffineMap::get( |
158 | /*dimCount=*/0, /*symbolCount=*/2, |
159 | getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx); |
160 | |
161 | auto current = createSubApply(rewriter, op, subExpressions[0]); |
162 | for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { |
163 | Value tmp = createSubApply(rewriter, op, subExpressions[i]); |
164 | current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap, |
165 | ValueRange{current, tmp}); |
166 | LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n" ); |
167 | } |
168 | |
169 | // 5. Replace original op. |
170 | rewriter.replaceOp(op, current.getResult()); |
171 | return current; |
172 | } |
173 | |