1//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/LoopUtils.h"
14#include "mlir/Dialect/Affine/Passes.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Affine/Transforms/Transforms.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19
20namespace mlir {
21namespace affine {
22#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
23#include "mlir/Dialect/Affine/Passes.h.inc"
24} // namespace affine
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::affine;
29
30/// Given a basis (in static and dynamic components), return the sequence of
31/// suffix products of the basis, including the product of the entire basis,
32/// which must **not** contain an outer bound.
33///
34/// If excess dynamic values are provided, the values at the beginning
35/// will be ignored. This allows for dropping the outer bound without
36/// needing to manipulate the dynamic value array. `knownPositive`
37/// indicases that the values being used to compute the strides are known
38/// to be non-negative.
39static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
40 ValueRange dynamicBasis,
41 ArrayRef<int64_t> staticBasis,
42 bool knownNonNegative) {
43 if (staticBasis.empty())
44 return {};
45
46 SmallVector<Value> result;
47 result.reserve(N: staticBasis.size());
48 size_t dynamicIndex = dynamicBasis.size();
49 Value dynamicPart = nullptr;
50 int64_t staticPart = 1;
51 // The products of the strides can't have overflow by definition of
52 // affine.*_index.
53 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
54 if (knownNonNegative)
55 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
56 for (int64_t elem : llvm::reverse(C&: staticBasis)) {
57 if (ShapedType::isDynamic(dValue: elem)) {
58 // Note: basis elements and their products are, definitionally,
59 // non-negative, so `nuw` is justified.
60 if (dynamicPart)
61 dynamicPart = rewriter.create<arith::MulIOp>(
62 location: loc, args&: dynamicPart, args: dynamicBasis[dynamicIndex - 1], args&: ovflags);
63 else
64 dynamicPart = dynamicBasis[dynamicIndex - 1];
65 --dynamicIndex;
66 } else {
67 staticPart *= elem;
68 }
69
70 if (dynamicPart && staticPart == 1) {
71 result.push_back(Elt: dynamicPart);
72 } else {
73 Value stride =
74 rewriter.createOrFold<arith::ConstantIndexOp>(location: loc, args&: staticPart);
75 if (dynamicPart)
76 stride =
77 rewriter.create<arith::MulIOp>(location: loc, args&: dynamicPart, args&: stride, args&: ovflags);
78 result.push_back(Elt: stride);
79 }
80 }
81 std::reverse(first: result.begin(), last: result.end());
82 return result;
83}
84
85LogicalResult
86affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
87 AffineDelinearizeIndexOp op) {
88 Location loc = op.getLoc();
89 Value linearIdx = op.getLinearIndex();
90 unsigned numResults = op.getNumResults();
91 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
92 if (numResults == staticBasis.size())
93 staticBasis = staticBasis.drop_front();
94
95 if (numResults == 1) {
96 rewriter.replaceOp(op, newValues: linearIdx);
97 return success();
98 }
99
100 SmallVector<Value> results;
101 results.reserve(N: numResults);
102 SmallVector<Value> strides =
103 computeStrides(loc, rewriter, dynamicBasis: op.getDynamicBasis(), staticBasis,
104 /*knownNonNegative=*/true);
105
106 Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(location: loc, args: 0);
107
108 Value initialPart =
109 rewriter.create<arith::FloorDivSIOp>(location: loc, args&: linearIdx, args&: strides.front());
110 results.push_back(Elt: initialPart);
111
112 auto emitModTerm = [&](Value stride) -> Value {
113 Value remainder = rewriter.create<arith::RemSIOp>(location: loc, args&: linearIdx, args&: stride);
114 Value remainderNegative = rewriter.create<arith::CmpIOp>(
115 location: loc, args: arith::CmpIPredicate::slt, args&: remainder, args&: zero);
116 // If the correction is relevant, this term is <= stride, which is known
117 // to be positive in `index`. Otherwise, while 2 * stride might overflow,
118 // this branch won't be taken, so the risk of `poison` is fine.
119 Value corrected = rewriter.create<arith::AddIOp>(
120 location: loc, args&: remainder, args&: stride, args: arith::IntegerOverflowFlags::nsw);
121 Value mod = rewriter.create<arith::SelectOp>(location: loc, args&: remainderNegative,
122 args&: corrected, args&: remainder);
123 return mod;
124 };
125
126 // Generate all the intermediate parts
127 for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
128 Value thisStride = strides[i];
129 Value nextStride = strides[i + 1];
130 Value modulus = emitModTerm(thisStride);
131 // We know both inputs are positive, so floorDiv == div.
132 // This could potentially be a divui, but it's not clear if that would
133 // cause issues.
134 Value divided = rewriter.create<arith::DivSIOp>(location: loc, args&: modulus, args&: nextStride);
135 results.push_back(Elt: divided);
136 }
137
138 results.push_back(Elt: emitModTerm(strides.back()));
139
140 rewriter.replaceOp(op, newValues: results);
141 return success();
142}
143
144LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
145 AffineLinearizeIndexOp op) {
146 // Should be folded away, included here for safety.
147 if (op.getMultiIndex().empty()) {
148 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, args: 0);
149 return success();
150 }
151
152 Location loc = op.getLoc();
153 ValueRange multiIndex = op.getMultiIndex();
154 size_t numIndexes = multiIndex.size();
155 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
156 if (numIndexes == staticBasis.size())
157 staticBasis = staticBasis.drop_front();
158
159 SmallVector<Value> strides =
160 computeStrides(loc, rewriter, dynamicBasis: op.getDynamicBasis(), staticBasis,
161 /*knownNonNegative=*/op.getDisjoint());
162 SmallVector<std::pair<Value, int64_t>> scaledValues;
163 scaledValues.reserve(N: numIndexes);
164
165 // Note: strides doesn't contain a value for the final element (stride 1)
166 // and everything else lines up. We use the "mutable" accessor so we can get
167 // our hands on an `OpOperand&` for the loop invariant counting function.
168 for (auto [stride, idxOp] :
169 llvm::zip_equal(t&: strides, u: llvm::drop_end(RangeOrContainer: op.getMultiIndexMutable()))) {
170 Value scaledIdx = rewriter.create<arith::MulIOp>(
171 location: loc, args: idxOp.get(), args&: stride, args: arith::IntegerOverflowFlags::nsw);
172 int64_t numHoistableLoops = numEnclosingInvariantLoops(operand&: idxOp);
173 scaledValues.emplace_back(Args&: scaledIdx, Args&: numHoistableLoops);
174 }
175 scaledValues.emplace_back(
176 Args: multiIndex.back(),
177 Args: numEnclosingInvariantLoops(operand&: op.getMultiIndexMutable()[numIndexes - 1]));
178
179 // Sort by how many enclosing loops there are, ties implicitly broken by
180 // size of the stride.
181 llvm::stable_sort(Range&: scaledValues,
182 C: [&](auto l, auto r) { return l.second > r.second; });
183
184 Value result = scaledValues.front().first;
185 for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(RangeOrContainer&: scaledValues)) {
186 std::ignore = numHoistableLoops;
187 result = rewriter.create<arith::AddIOp>(location: loc, args&: result, args&: scaledValue,
188 args: arith::IntegerOverflowFlags::nsw);
189 }
190 rewriter.replaceOp(op, newValues: result);
191 return success();
192}
193
194namespace {
195struct LowerDelinearizeIndexOps
196 : public OpRewritePattern<AffineDelinearizeIndexOp> {
197 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
198 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
199 PatternRewriter &rewriter) const override {
200 return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
201 }
202};
203
204struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
205 using OpRewritePattern::OpRewritePattern;
206 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
207 PatternRewriter &rewriter) const override {
208 return affine::lowerAffineLinearizeIndexOp(rewriter, op);
209 }
210};
211
212class ExpandAffineIndexOpsPass
213 : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
214public:
215 ExpandAffineIndexOpsPass() = default;
216
217 void runOnOperation() override {
218 MLIRContext *context = &getContext();
219 RewritePatternSet patterns(context);
220 populateAffineExpandIndexOpsPatterns(patterns);
221 if (failed(Result: applyPatternsGreedily(op: getOperation(), patterns: std::move(patterns))))
222 return signalPassFailure();
223 }
224};
225
226} // namespace
227
228void mlir::affine::populateAffineExpandIndexOpsPatterns(
229 RewritePatternSet &patterns) {
230 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
231 arg: patterns.getContext());
232}
233
234std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
235 return std::make_unique<ExpandAffineIndexOpsPass>();
236}
237

source code of mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp