1 | //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to expand affine index ops into one or more more |
10 | // fundamental operations. |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Affine/LoopUtils.h" |
14 | #include "mlir/Dialect/Affine/Passes.h" |
15 | |
16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
17 | #include "mlir/Dialect/Affine/Transforms/Transforms.h" |
18 | #include "mlir/Dialect/Affine/Utils.h" |
19 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
20 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | |
22 | namespace mlir { |
23 | namespace affine { |
24 | #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS |
25 | #include "mlir/Dialect/Affine/Passes.h.inc" |
26 | } // namespace affine |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | using namespace mlir::affine; |
31 | |
32 | /// Given a basis (in static and dynamic components), return the sequence of |
33 | /// suffix products of the basis, including the product of the entire basis, |
34 | /// which must **not** contain an outer bound. |
35 | /// |
36 | /// If excess dynamic values are provided, the values at the beginning |
37 | /// will be ignored. This allows for dropping the outer bound without |
38 | /// needing to manipulate the dynamic value array. `knownPositive` |
39 | /// indicases that the values being used to compute the strides are known |
40 | /// to be non-negative. |
41 | static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, |
42 | ValueRange dynamicBasis, |
43 | ArrayRef<int64_t> staticBasis, |
44 | bool knownNonNegative) { |
45 | if (staticBasis.empty()) |
46 | return {}; |
47 | |
48 | SmallVector<Value> result; |
49 | result.reserve(staticBasis.size()); |
50 | size_t dynamicIndex = dynamicBasis.size(); |
51 | Value dynamicPart = nullptr; |
52 | int64_t staticPart = 1; |
53 | // The products of the strides can't have overflow by definition of |
54 | // affine.*_index. |
55 | arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw; |
56 | if (knownNonNegative) |
57 | ovflags = ovflags | arith::IntegerOverflowFlags::nuw; |
58 | for (int64_t elem : llvm::reverse(staticBasis)) { |
59 | if (ShapedType::isDynamic(elem)) { |
60 | // Note: basis elements and their products are, definitionally, |
61 | // non-negative, so `nuw` is justified. |
62 | if (dynamicPart) |
63 | dynamicPart = rewriter.create<arith::MulIOp>( |
64 | loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags); |
65 | else |
66 | dynamicPart = dynamicBasis[dynamicIndex - 1]; |
67 | --dynamicIndex; |
68 | } else { |
69 | staticPart *= elem; |
70 | } |
71 | |
72 | if (dynamicPart && staticPart == 1) { |
73 | result.push_back(dynamicPart); |
74 | } else { |
75 | Value stride = |
76 | rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart); |
77 | if (dynamicPart) |
78 | stride = |
79 | rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags); |
80 | result.push_back(stride); |
81 | } |
82 | } |
83 | std::reverse(result.begin(), result.end()); |
84 | return result; |
85 | } |
86 | |
87 | namespace { |
88 | /// Lowers `affine.delinearize_index` into a sequence of division and remainder |
89 | /// operations. |
90 | struct LowerDelinearizeIndexOps |
91 | : public OpRewritePattern<AffineDelinearizeIndexOp> { |
92 | using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; |
93 | LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, |
94 | PatternRewriter &rewriter) const override { |
95 | Location loc = op.getLoc(); |
96 | Value linearIdx = op.getLinearIndex(); |
97 | unsigned numResults = op.getNumResults(); |
98 | ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
99 | if (numResults == staticBasis.size()) |
100 | staticBasis = staticBasis.drop_front(); |
101 | |
102 | if (numResults == 1) { |
103 | rewriter.replaceOp(op, linearIdx); |
104 | return success(); |
105 | } |
106 | |
107 | SmallVector<Value> results; |
108 | results.reserve(numResults); |
109 | SmallVector<Value> strides = |
110 | computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, |
111 | /*knownNonNegative=*/true); |
112 | |
113 | Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); |
114 | |
115 | Value initialPart = |
116 | rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front()); |
117 | results.push_back(initialPart); |
118 | |
119 | auto emitModTerm = [&](Value stride) -> Value { |
120 | Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride); |
121 | Value remainderNegative = rewriter.create<arith::CmpIOp>( |
122 | loc, arith::CmpIPredicate::slt, remainder, zero); |
123 | // If the correction is relevant, this term is <= stride, which is known |
124 | // to be positive in `index`. Otherwise, while 2 * stride might overflow, |
125 | // this branch won't be taken, so the risk of `poison` is fine. |
126 | Value corrected = rewriter.create<arith::AddIOp>( |
127 | loc, remainder, stride, arith::IntegerOverflowFlags::nsw); |
128 | Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative, |
129 | corrected, remainder); |
130 | return mod; |
131 | }; |
132 | |
133 | // Generate all the intermediate parts |
134 | for (size_t i = 0, e = strides.size() - 1; i < e; ++i) { |
135 | Value thisStride = strides[i]; |
136 | Value nextStride = strides[i + 1]; |
137 | Value modulus = emitModTerm(thisStride); |
138 | // We know both inputs are positive, so floorDiv == div. |
139 | // This could potentially be a divui, but it's not clear if that would |
140 | // cause issues. |
141 | Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride); |
142 | results.push_back(divided); |
143 | } |
144 | |
145 | results.push_back(emitModTerm(strides.back())); |
146 | |
147 | rewriter.replaceOp(op, results); |
148 | return success(); |
149 | } |
150 | }; |
151 | |
152 | /// Lowers `affine.linearize_index` into a sequence of multiplications and |
153 | /// additions. Make a best effort to sort the input indices so that |
154 | /// the most loop-invariant terms are at the left of the additions |
155 | /// to enable loop-invariant code motion. |
156 | struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { |
157 | using OpRewritePattern::OpRewritePattern; |
158 | LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, |
159 | PatternRewriter &rewriter) const override { |
160 | // Should be folded away, included here for safety. |
161 | if (op.getMultiIndex().empty()) { |
162 | rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); |
163 | return success(); |
164 | } |
165 | |
166 | Location loc = op.getLoc(); |
167 | ValueRange multiIndex = op.getMultiIndex(); |
168 | size_t numIndexes = multiIndex.size(); |
169 | ArrayRef<int64_t> staticBasis = op.getStaticBasis(); |
170 | if (numIndexes == staticBasis.size()) |
171 | staticBasis = staticBasis.drop_front(); |
172 | |
173 | SmallVector<Value> strides = |
174 | computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis, |
175 | /*knownNonNegative=*/op.getDisjoint()); |
176 | SmallVector<std::pair<Value, int64_t>> scaledValues; |
177 | scaledValues.reserve(N: numIndexes); |
178 | |
179 | // Note: strides doesn't contain a value for the final element (stride 1) |
180 | // and everything else lines up. We use the "mutable" accessor so we can get |
181 | // our hands on an `OpOperand&` for the loop invariant counting function. |
182 | for (auto [stride, idxOp] : |
183 | llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { |
184 | Value scaledIdx = rewriter.create<arith::MulIOp>( |
185 | loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw); |
186 | int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); |
187 | scaledValues.emplace_back(scaledIdx, numHoistableLoops); |
188 | } |
189 | scaledValues.emplace_back( |
190 | multiIndex.back(), |
191 | numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1])); |
192 | |
193 | // Sort by how many enclosing loops there are, ties implicitly broken by |
194 | // size of the stride. |
195 | llvm::stable_sort(Range&: scaledValues, |
196 | C: [&](auto l, auto r) { return l.second > r.second; }); |
197 | |
198 | Value result = scaledValues.front().first; |
199 | for (auto [scaledValue, numHoistableLoops] : |
200 | llvm::drop_begin(RangeOrContainer&: scaledValues)) { |
201 | std::ignore = numHoistableLoops; |
202 | result = rewriter.create<arith::AddIOp>(loc, result, scaledValue, |
203 | arith::IntegerOverflowFlags::nsw); |
204 | } |
205 | rewriter.replaceOp(op, result); |
206 | return success(); |
207 | } |
208 | }; |
209 | |
210 | class ExpandAffineIndexOpsPass |
211 | : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> { |
212 | public: |
213 | ExpandAffineIndexOpsPass() = default; |
214 | |
215 | void runOnOperation() override { |
216 | MLIRContext *context = &getContext(); |
217 | RewritePatternSet patterns(context); |
218 | populateAffineExpandIndexOpsPatterns(patterns); |
219 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
220 | return signalPassFailure(); |
221 | } |
222 | }; |
223 | |
224 | } // namespace |
225 | |
226 | void mlir::affine::populateAffineExpandIndexOpsPatterns( |
227 | RewritePatternSet &patterns) { |
228 | patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( |
229 | arg: patterns.getContext()); |
230 | } |
231 | |
232 | std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() { |
233 | return std::make_unique<ExpandAffineIndexOpsPass>(); |
234 | } |
235 | |