1//===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Specializes parallel loops and for loops for easier unrolling and
10// vectorization.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/SCF/Transforms/Passes.h"
15
16#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
17#include "mlir/Dialect/Affine/IR/AffineOps.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
22#include "mlir/Dialect/Utils/StaticValueUtils.h"
23#include "mlir/IR/AffineExpr.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_SCFFORLOOPPEELING
30#define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION
31#define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION
32#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::affine;
37using scf::ForOp;
38using scf::ParallelOp;
39
40/// Rewrite a parallel loop with bounds defined by an affine.min with a constant
41/// into 2 loops after checking if the bounds are equal to that constant. This
42/// is beneficial if the loop will almost always have the constant bound and
43/// that version can be fully unrolled and vectorized.
44static void specializeParallelLoopForUnrolling(ParallelOp op) {
45 SmallVector<int64_t, 2> constantIndices;
46 constantIndices.reserve(N: op.getUpperBound().size());
47 for (auto bound : op.getUpperBound()) {
48 auto minOp = bound.getDefiningOp<AffineMinOp>();
49 if (!minOp)
50 return;
51 int64_t minConstant = std::numeric_limits<int64_t>::max();
52 for (AffineExpr expr : minOp.getMap().getResults()) {
53 if (auto constantIndex = dyn_cast<AffineConstantExpr>(Val&: expr))
54 minConstant = std::min(a: minConstant, b: constantIndex.getValue());
55 }
56 if (minConstant == std::numeric_limits<int64_t>::max())
57 return;
58 constantIndices.push_back(Elt: minConstant);
59 }
60
61 OpBuilder b(op);
62 IRMapping map;
63 Value cond;
64 for (auto bound : llvm::zip(t: op.getUpperBound(), u&: constantIndices)) {
65 Value constant =
66 b.create<arith::ConstantIndexOp>(location: op.getLoc(), args&: std::get<1>(t&: bound));
67 Value cmp = b.create<arith::CmpIOp>(location: op.getLoc(), args: arith::CmpIPredicate::eq,
68 args&: std::get<0>(t&: bound), args&: constant);
69 cond = cond ? b.create<arith::AndIOp>(location: op.getLoc(), args&: cond, args&: cmp) : cmp;
70 map.map(from: std::get<0>(t&: bound), to: constant);
71 }
72 auto ifOp = b.create<scf::IfOp>(location: op.getLoc(), args&: cond, /*withElseRegion=*/args: true);
73 ifOp.getThenBodyBuilder().clone(op&: *op.getOperation(), mapper&: map);
74 ifOp.getElseBodyBuilder().clone(op&: *op.getOperation());
75 op.erase();
76}
77
78/// Rewrite a for loop with bounds defined by an affine.min with a constant into
79/// 2 loops after checking if the bounds are equal to that constant. This is
80/// beneficial if the loop will almost always have the constant bound and that
81/// version can be fully unrolled and vectorized.
82static void specializeForLoopForUnrolling(ForOp op) {
83 auto bound = op.getUpperBound();
84 auto minOp = bound.getDefiningOp<AffineMinOp>();
85 if (!minOp)
86 return;
87 int64_t minConstant = std::numeric_limits<int64_t>::max();
88 for (AffineExpr expr : minOp.getMap().getResults()) {
89 if (auto constantIndex = dyn_cast<AffineConstantExpr>(Val&: expr))
90 minConstant = std::min(a: minConstant, b: constantIndex.getValue());
91 }
92 if (minConstant == std::numeric_limits<int64_t>::max())
93 return;
94
95 OpBuilder b(op);
96 IRMapping map;
97 Value constant = b.create<arith::ConstantIndexOp>(location: op.getLoc(), args&: minConstant);
98 Value cond = b.create<arith::CmpIOp>(location: op.getLoc(), args: arith::CmpIPredicate::eq,
99 args&: bound, args&: constant);
100 map.map(from: bound, to: constant);
101 auto ifOp = b.create<scf::IfOp>(location: op.getLoc(), args&: cond, /*withElseRegion=*/args: true);
102 ifOp.getThenBodyBuilder().clone(op&: *op.getOperation(), mapper&: map);
103 ifOp.getElseBodyBuilder().clone(op&: *op.getOperation());
104 op.erase();
105}
106
107/// Rewrite a for loop with bounds/step that potentially do not divide evenly
108/// into a for loop where the step divides the iteration space evenly, followed
109/// by an scf.if for the last (partial) iteration (if any).
110///
111/// This function rewrites the given scf.for loop in-place and creates a new
112/// scf.if operation for the last iteration. It replaces all uses of the
113/// unpeeled loop with the results of the newly generated scf.if.
114///
115/// The newly generated scf.if operation is returned via `ifOp`. The boundary
116/// at which the loop is split (new upper bound) is returned via `splitBound`.
117/// The return value indicates whether the loop was rewritten or not.
118static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp,
119 ForOp &partialIteration, Value &splitBound) {
120 RewriterBase::InsertionGuard guard(b);
121 auto lbInt = getConstantIntValue(ofr: forOp.getLowerBound());
122 auto ubInt = getConstantIntValue(ofr: forOp.getUpperBound());
123 auto stepInt = getConstantIntValue(ofr: forOp.getStep());
124
125 // No specialization necessary if step size is 1. Also bail out in case of an
126 // invalid zero or negative step which might have happened during folding.
127 if (stepInt && *stepInt <= 1)
128 return failure();
129
130 // No specialization necessary if step already divides upper bound evenly.
131 // Fast path: lb, ub and step are constants.
132 if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0)
133 return failure();
134 // Slow path: Examine the ops that define lb, ub and step.
135 AffineExpr sym0, sym1, sym2;
136 bindSymbols(ctx: b.getContext(), exprs&: sym0, exprs&: sym1, exprs&: sym2);
137 SmallVector<Value> operands{forOp.getLowerBound(), forOp.getUpperBound(),
138 forOp.getStep()};
139 AffineMap map = AffineMap::get(dimCount: 0, symbolCount: 3, result: {(sym1 - sym0) % sym2});
140 affine::fullyComposeAffineMapAndOperands(map: &map, operands: &operands);
141 if (auto constExpr = dyn_cast<AffineConstantExpr>(Val: map.getResult(idx: 0)))
142 if (constExpr.getValue() == 0)
143 return failure();
144
145 // New upper bound: %ub - (%ub - %lb) mod %step
146 auto modMap = AffineMap::get(dimCount: 0, symbolCount: 3, result: {sym1 - ((sym1 - sym0) % sym2)});
147 b.setInsertionPoint(forOp);
148 auto loc = forOp.getLoc();
149 splitBound = b.createOrFold<AffineApplyOp>(location: loc, args&: modMap,
150 args: ValueRange{forOp.getLowerBound(),
151 forOp.getUpperBound(),
152 forOp.getStep()});
153
154 // Create ForOp for partial iteration.
155 b.setInsertionPointAfter(forOp);
156 partialIteration = cast<ForOp>(Val: b.clone(op&: *forOp.getOperation()));
157 partialIteration.getLowerBoundMutable().assign(value: splitBound);
158 b.replaceAllUsesWith(from: forOp.getResults(), to: partialIteration->getResults());
159 partialIteration.getInitArgsMutable().assign(values: forOp->getResults());
160
161 // Set new upper loop bound.
162 b.modifyOpInPlace(root: forOp,
163 callable: [&]() { forOp.getUpperBoundMutable().assign(value: splitBound); });
164
165 return success();
166}
167
168static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp,
169 ForOp partialIteration,
170 Value previousUb) {
171 Value mainIv = forOp.getInductionVar();
172 Value partialIv = partialIteration.getInductionVar();
173 assert(forOp.getStep() == partialIteration.getStep() &&
174 "expected same step in main and partial loop");
175 Value step = forOp.getStep();
176
177 forOp.walk(callback: [&](Operation *affineOp) {
178 if (!isa<AffineMinOp, AffineMaxOp>(Val: affineOp))
179 return WalkResult::advance();
180 (void)scf::rewritePeeledMinMaxOp(rewriter, op: affineOp, iv: mainIv, ub: previousUb,
181 step,
182 /*insideLoop=*/true);
183 return WalkResult::advance();
184 });
185 partialIteration.walk(callback: [&](Operation *affineOp) {
186 if (!isa<AffineMinOp, AffineMaxOp>(Val: affineOp))
187 return WalkResult::advance();
188 (void)scf::rewritePeeledMinMaxOp(rewriter, op: affineOp, iv: partialIv, ub: previousUb,
189 step, /*insideLoop=*/false);
190 return WalkResult::advance();
191 });
192}
193
194LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter,
195 ForOp forOp,
196 ForOp &partialIteration) {
197 Value previousUb = forOp.getUpperBound();
198 Value splitBound;
199 if (failed(Result: peelForLoop(b&: rewriter, forOp, partialIteration, splitBound)))
200 return failure();
201
202 // Rewrite affine.min and affine.max ops.
203 rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb);
204
205 return success();
206}
207
208/// Rewrites the original scf::ForOp as two scf::ForOp Ops, the first
209/// scf::ForOp corresponds to the first iteration of the loop which can be
210/// canonicalized away in the following optimizations. The second loop Op
211/// contains the remaining iterations, with a lower bound updated as the
212/// original lower bound plus the step (i.e. skips the first iteration).
213LogicalResult mlir::scf::peelForLoopFirstIteration(RewriterBase &b, ForOp forOp,
214 ForOp &firstIteration) {
215 RewriterBase::InsertionGuard guard(b);
216 auto lbInt = getConstantIntValue(ofr: forOp.getLowerBound());
217 auto ubInt = getConstantIntValue(ofr: forOp.getUpperBound());
218 auto stepInt = getConstantIntValue(ofr: forOp.getStep());
219
220 // Peeling is not needed if there is one or less iteration.
221 if (lbInt && ubInt && stepInt && ceil(x: float(*ubInt - *lbInt) / *stepInt) <= 1)
222 return failure();
223
224 AffineExpr lbSymbol, stepSymbol;
225 bindSymbols(ctx: b.getContext(), exprs&: lbSymbol, exprs&: stepSymbol);
226
227 // New lower bound for main loop: %lb + %step
228 auto ubMap = AffineMap::get(dimCount: 0, symbolCount: 2, result: {lbSymbol + stepSymbol});
229 b.setInsertionPoint(forOp);
230 auto loc = forOp.getLoc();
231 Value splitBound = b.createOrFold<AffineApplyOp>(
232 location: loc, args&: ubMap, args: ValueRange{forOp.getLowerBound(), forOp.getStep()});
233
234 // Peel the first iteration.
235 IRMapping map;
236 map.map(from: forOp.getUpperBound(), to: splitBound);
237 firstIteration = cast<ForOp>(Val: b.clone(op&: *forOp.getOperation(), mapper&: map));
238
239 // Update main loop with new lower bound.
240 b.modifyOpInPlace(root: forOp, callable: [&]() {
241 forOp.getInitArgsMutable().assign(values: firstIteration->getResults());
242 forOp.getLowerBoundMutable().assign(value: splitBound);
243 });
244
245 return success();
246}
247
248static constexpr char kPeeledLoopLabel[] = "__peeled_loop__";
249static constexpr char kPartialIterationLabel[] = "__partial_iteration__";
250
251namespace {
252struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
253 ForLoopPeelingPattern(MLIRContext *ctx, bool peelFront, bool skipPartial)
254 : OpRewritePattern<ForOp>(ctx), peelFront(peelFront),
255 skipPartial(skipPartial) {}
256
257 LogicalResult matchAndRewrite(ForOp forOp,
258 PatternRewriter &rewriter) const override {
259 // Do not peel already peeled loops.
260 if (forOp->hasAttr(name: kPeeledLoopLabel))
261 return failure();
262
263 scf::ForOp partialIteration;
264 // The case for peeling the first iteration of the loop.
265 if (peelFront) {
266 if (failed(
267 Result: peelForLoopFirstIteration(b&: rewriter, forOp, firstIteration&: partialIteration))) {
268 return failure();
269 }
270 } else {
271 if (skipPartial) {
272 // No peeling of loops inside the partial iteration of another peeled
273 // loop.
274 Operation *op = forOp.getOperation();
275 while ((op = op->getParentOfType<scf::ForOp>())) {
276 if (op->hasAttr(name: kPartialIterationLabel))
277 return failure();
278 }
279 }
280 // Apply loop peeling.
281 if (failed(
282 Result: peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration)))
283 return failure();
284 }
285
286 // Apply label, so that the same loop is not rewritten a second time.
287 rewriter.modifyOpInPlace(root: partialIteration, callable: [&]() {
288 partialIteration->setAttr(name: kPeeledLoopLabel, value: rewriter.getUnitAttr());
289 partialIteration->setAttr(name: kPartialIterationLabel, value: rewriter.getUnitAttr());
290 });
291 rewriter.modifyOpInPlace(root: forOp, callable: [&]() {
292 forOp->setAttr(name: kPeeledLoopLabel, value: rewriter.getUnitAttr());
293 });
294 return success();
295 }
296
297 // If set to true, the first iteration of the loop will be peeled. Otherwise,
298 // the unevenly divisible loop will be peeled at the end.
299 bool peelFront;
300
301 /// If set to true, loops inside partial iterations of another peeled loop
302 /// are not peeled. This reduces the size of the generated code. Partial
303 /// iterations are not usually performance critical.
304 /// Note: Takes into account the entire chain of parent operations, not just
305 /// the direct parent.
306 bool skipPartial;
307};
308} // namespace
309
310namespace {
311struct ParallelLoopSpecialization
312 : public impl::SCFParallelLoopSpecializationBase<
313 ParallelLoopSpecialization> {
314 void runOnOperation() override {
315 getOperation()->walk(
316 callback: [](ParallelOp op) { specializeParallelLoopForUnrolling(op); });
317 }
318};
319
320struct ForLoopSpecialization
321 : public impl::SCFForLoopSpecializationBase<ForLoopSpecialization> {
322 void runOnOperation() override {
323 getOperation()->walk(callback: [](ForOp op) { specializeForLoopForUnrolling(op); });
324 }
325};
326
327struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
328 void runOnOperation() override {
329 auto *parentOp = getOperation();
330 MLIRContext *ctx = parentOp->getContext();
331 RewritePatternSet patterns(ctx);
332 patterns.add<ForLoopPeelingPattern>(arg&: ctx, args&: peelFront, args&: skipPartial);
333 (void)applyPatternsGreedily(op: parentOp, patterns: std::move(patterns));
334
335 // Drop the markers.
336 parentOp->walk(callback: [](Operation *op) {
337 op->removeAttr(name: kPeeledLoopLabel);
338 op->removeAttr(name: kPartialIterationLabel);
339 });
340 }
341};
342} // namespace
343
344std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() {
345 return std::make_unique<ParallelLoopSpecialization>();
346}
347
348std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
349 return std::make_unique<ForLoopSpecialization>();
350}
351
352std::unique_ptr<Pass> mlir::createForLoopPeelingPass() {
353 return std::make_unique<ForLoopPeeling>();
354}
355

source code of mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp