1//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Affine/LoopUtils.h"
14#include "mlir/Dialect/Affine/Passes.h"
15
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Affine/Transforms/Transforms.h"
18#include "mlir/Dialect/Affine/Utils.h"
19#include "mlir/Dialect/Arith/Utils/Utils.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22namespace mlir {
23namespace affine {
24#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
25#include "mlir/Dialect/Affine/Passes.h.inc"
26} // namespace affine
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::affine;
31
32/// Given a basis (in static and dynamic components), return the sequence of
33/// suffix products of the basis, including the product of the entire basis,
34/// which must **not** contain an outer bound.
35///
36/// If excess dynamic values are provided, the values at the beginning
37/// will be ignored. This allows for dropping the outer bound without
38/// needing to manipulate the dynamic value array. `knownPositive`
39/// indicases that the values being used to compute the strides are known
40/// to be non-negative.
41static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
42 ValueRange dynamicBasis,
43 ArrayRef<int64_t> staticBasis,
44 bool knownNonNegative) {
45 if (staticBasis.empty())
46 return {};
47
48 SmallVector<Value> result;
49 result.reserve(staticBasis.size());
50 size_t dynamicIndex = dynamicBasis.size();
51 Value dynamicPart = nullptr;
52 int64_t staticPart = 1;
53 // The products of the strides can't have overflow by definition of
54 // affine.*_index.
55 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
56 if (knownNonNegative)
57 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
58 for (int64_t elem : llvm::reverse(staticBasis)) {
59 if (ShapedType::isDynamic(elem)) {
60 // Note: basis elements and their products are, definitionally,
61 // non-negative, so `nuw` is justified.
62 if (dynamicPart)
63 dynamicPart = rewriter.create<arith::MulIOp>(
64 loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
65 else
66 dynamicPart = dynamicBasis[dynamicIndex - 1];
67 --dynamicIndex;
68 } else {
69 staticPart *= elem;
70 }
71
72 if (dynamicPart && staticPart == 1) {
73 result.push_back(dynamicPart);
74 } else {
75 Value stride =
76 rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
77 if (dynamicPart)
78 stride =
79 rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
80 result.push_back(stride);
81 }
82 }
83 std::reverse(result.begin(), result.end());
84 return result;
85}
86
87namespace {
88/// Lowers `affine.delinearize_index` into a sequence of division and remainder
89/// operations.
90struct LowerDelinearizeIndexOps
91 : public OpRewritePattern<AffineDelinearizeIndexOp> {
92 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
93 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
94 PatternRewriter &rewriter) const override {
95 Location loc = op.getLoc();
96 Value linearIdx = op.getLinearIndex();
97 unsigned numResults = op.getNumResults();
98 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
99 if (numResults == staticBasis.size())
100 staticBasis = staticBasis.drop_front();
101
102 if (numResults == 1) {
103 rewriter.replaceOp(op, linearIdx);
104 return success();
105 }
106
107 SmallVector<Value> results;
108 results.reserve(numResults);
109 SmallVector<Value> strides =
110 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
111 /*knownNonNegative=*/true);
112
113 Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
114
115 Value initialPart =
116 rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
117 results.push_back(initialPart);
118
119 auto emitModTerm = [&](Value stride) -> Value {
120 Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
121 Value remainderNegative = rewriter.create<arith::CmpIOp>(
122 loc, arith::CmpIPredicate::slt, remainder, zero);
123 // If the correction is relevant, this term is <= stride, which is known
124 // to be positive in `index`. Otherwise, while 2 * stride might overflow,
125 // this branch won't be taken, so the risk of `poison` is fine.
126 Value corrected = rewriter.create<arith::AddIOp>(
127 loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
128 Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
129 corrected, remainder);
130 return mod;
131 };
132
133 // Generate all the intermediate parts
134 for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
135 Value thisStride = strides[i];
136 Value nextStride = strides[i + 1];
137 Value modulus = emitModTerm(thisStride);
138 // We know both inputs are positive, so floorDiv == div.
139 // This could potentially be a divui, but it's not clear if that would
140 // cause issues.
141 Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
142 results.push_back(divided);
143 }
144
145 results.push_back(emitModTerm(strides.back()));
146
147 rewriter.replaceOp(op, results);
148 return success();
149 }
150};
151
152/// Lowers `affine.linearize_index` into a sequence of multiplications and
153/// additions. Make a best effort to sort the input indices so that
154/// the most loop-invariant terms are at the left of the additions
155/// to enable loop-invariant code motion.
156struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
157 using OpRewritePattern::OpRewritePattern;
158 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
159 PatternRewriter &rewriter) const override {
160 // Should be folded away, included here for safety.
161 if (op.getMultiIndex().empty()) {
162 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
163 return success();
164 }
165
166 Location loc = op.getLoc();
167 ValueRange multiIndex = op.getMultiIndex();
168 size_t numIndexes = multiIndex.size();
169 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
170 if (numIndexes == staticBasis.size())
171 staticBasis = staticBasis.drop_front();
172
173 SmallVector<Value> strides =
174 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
175 /*knownNonNegative=*/op.getDisjoint());
176 SmallVector<std::pair<Value, int64_t>> scaledValues;
177 scaledValues.reserve(N: numIndexes);
178
179 // Note: strides doesn't contain a value for the final element (stride 1)
180 // and everything else lines up. We use the "mutable" accessor so we can get
181 // our hands on an `OpOperand&` for the loop invariant counting function.
182 for (auto [stride, idxOp] :
183 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
184 Value scaledIdx = rewriter.create<arith::MulIOp>(
185 loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
186 int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
187 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
188 }
189 scaledValues.emplace_back(
190 multiIndex.back(),
191 numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
192
193 // Sort by how many enclosing loops there are, ties implicitly broken by
194 // size of the stride.
195 llvm::stable_sort(Range&: scaledValues,
196 C: [&](auto l, auto r) { return l.second > r.second; });
197
198 Value result = scaledValues.front().first;
199 for (auto [scaledValue, numHoistableLoops] :
200 llvm::drop_begin(RangeOrContainer&: scaledValues)) {
201 std::ignore = numHoistableLoops;
202 result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
203 arith::IntegerOverflowFlags::nsw);
204 }
205 rewriter.replaceOp(op, result);
206 return success();
207 }
208};
209
210class ExpandAffineIndexOpsPass
211 : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
212public:
213 ExpandAffineIndexOpsPass() = default;
214
215 void runOnOperation() override {
216 MLIRContext *context = &getContext();
217 RewritePatternSet patterns(context);
218 populateAffineExpandIndexOpsPatterns(patterns);
219 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
220 return signalPassFailure();
221 }
222};
223
224} // namespace
225
226void mlir::affine::populateAffineExpandIndexOpsPatterns(
227 RewritePatternSet &patterns) {
228 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
229 arg: patterns.getContext());
230}
231
232std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
233 return std::make_unique<ExpandAffineIndexOpsPass>();
234}
235

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp