1//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.WhileOp's into SCF.ForOp's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SCF/Transforms/Patterns.h"
18#include "mlir/IR/Dominance.h"
19#include "mlir/IR/PatternMatch.h"
20
21using namespace mlir;
22
23namespace {
24struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
25 using OpRewritePattern::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(scf::WhileOp loop,
28 PatternRewriter &rewriter) const override {
29 return upliftWhileToForLoop(rewriter, loop);
30 }
31};
32} // namespace
33
34FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
35 scf::WhileOp loop) {
36 Block *beforeBody = loop.getBeforeBody();
37 if (!llvm::hasSingleElement(C: beforeBody->without_terminator()))
38 return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
39
40 auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
41 if (!cmp)
42 return rewriter.notifyMatchFailure(loop,
43 "Loop body must have single cmp op");
44
45 scf::ConditionOp beforeTerm = loop.getConditionOp();
46 if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
47 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
48 diag << "Expected single condition use: " << *cmp;
49 });
50
51 // All `before` block args must be directly forwarded to ConditionOp.
52 // They will be converted to `scf.for` `iter_vars` except induction var.
53 if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
54 return rewriter.notifyMatchFailure(loop, "Invalid args order");
55
56 using Pred = arith::CmpIPredicate;
57 Pred predicate = cmp.getPredicate();
58 if (predicate != Pred::slt && predicate != Pred::sgt)
59 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
60 diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
61 });
62
63 BlockArgument inductionVar;
64 Value ub;
65 DominanceInfo dom;
66
67 // Check if cmp has a suitable form. One of the arguments must be a `before`
68 // block arg, other must be defined outside `scf.while` and will be treated
69 // as upper bound.
70 for (bool reverse : {false, true}) {
71 auto expectedPred = reverse ? Pred::sgt : Pred::slt;
72 if (cmp.getPredicate() != expectedPred)
73 continue;
74
75 auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
76 auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
77
78 auto blockArg = dyn_cast<BlockArgument>(arg1);
79 if (!blockArg || blockArg.getOwner() != beforeBody)
80 continue;
81
82 if (!dom.properlyDominates(arg2, loop))
83 continue;
84
85 inductionVar = blockArg;
86 ub = arg2;
87 break;
88 }
89
90 if (!inductionVar)
91 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
92 diag << "Unrecognized cmp form: " << *cmp;
93 });
94
95 // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
96 // arg.
97 if (!llvm::hasNItems(C: inductionVar.getUses(), N: 2))
98 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
99 diag << "Unrecognized induction var: " << inductionVar;
100 });
101
102 Block *afterBody = loop.getAfterBody();
103 scf::YieldOp afterTerm = loop.getYieldOp();
104 unsigned argNumber = inductionVar.getArgNumber();
105 Value afterTermIndArg = afterTerm.getResults()[argNumber];
106
107 Value inductionVarAfter = afterBody->getArgument(i: argNumber);
108
109 // Find suitable `addi` op inside `after` block, one of the args must be an
110 // Induction var passed from `before` block and second arg must be defined
111 // outside of the loop and will be considered step value.
112 // TODO: Add `subi` support?
113 auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
114 if (!addOp)
115 return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
116
117 Value step;
118 if (addOp.getLhs() == inductionVarAfter) {
119 step = addOp.getRhs();
120 } else if (addOp.getRhs() == inductionVarAfter) {
121 step = addOp.getLhs();
122 }
123
124 if (!step || !dom.properlyDominates(step, loop))
125 return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
126
127 Value lb = loop.getInits()[argNumber];
128
129 assert(lb.getType().isIntOrIndex());
130 assert(lb.getType() == ub.getType());
131 assert(lb.getType() == step.getType());
132
133 llvm::SmallVector<Value> newArgs;
134
135 // Populate inits for new `scf.for`, skip induction var.
136 newArgs.reserve(N: loop.getInits().size());
137 for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
138 if (i == argNumber)
139 continue;
140
141 newArgs.emplace_back(init);
142 }
143
144 Location loc = loop.getLoc();
145
146 // With `builder == nullptr`, ForOp::build will try to insert terminator at
147 // the end of newly created block and we don't want it. Provide empty
148 // dummy builder instead.
149 auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
150 auto newLoop =
151 rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
152
153 Block *newBody = newLoop.getBody();
154
155 // Populate block args for `scf.for` body, move induction var to the front.
156 newArgs.clear();
157 ValueRange newBodyArgs = newBody->getArguments();
158 for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
159 if (i < argNumber) {
160 newArgs.emplace_back(newBodyArgs[i + 1]);
161 } else if (i == argNumber) {
162 newArgs.emplace_back(newBodyArgs.front());
163 } else {
164 newArgs.emplace_back(newBodyArgs[i]);
165 }
166 }
167
168 rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
169 newArgs);
170
171 auto term = cast<scf::YieldOp>(newBody->getTerminator());
172
173 // Populate new yield args, skipping the induction var.
174 newArgs.clear();
175 for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
176 if (i == argNumber)
177 continue;
178
179 newArgs.emplace_back(arg);
180 }
181
182 OpBuilder::InsertionGuard g(rewriter);
183 rewriter.setInsertionPoint(term);
184 rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
185
186 // Compute induction var value after loop execution.
187 rewriter.setInsertionPointAfter(newLoop);
188 Value one;
189 if (isa<IndexType>(Val: step.getType())) {
190 one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
191 } else {
192 one = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: step.getType());
193 }
194
195 Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
196 Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
197 len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
198 len = rewriter.create<arith::DivSIOp>(loc, len, step);
199 len = rewriter.create<arith::SubIOp>(loc, len, one);
200 Value res = rewriter.create<arith::MulIOp>(loc, len, step);
201 res = rewriter.create<arith::AddIOp>(loc, lb, res);
202
203 // Reconstruct `scf.while` results, inserting final induction var value
204 // into proper place.
205 newArgs.clear();
206 llvm::append_range(newArgs, newLoop.getResults());
207 newArgs.insert(I: newArgs.begin() + argNumber, Elt: res);
208 rewriter.replaceOp(loop, newArgs);
209 return newLoop;
210}
211
212void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
213 patterns.add<UpliftWhileOp>(arg: patterns.getContext());
214}
215

source code of mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp