1//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.WhileOp's into SCF.ForOp's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SCF/Transforms/Patterns.h"
16#include "mlir/IR/Dominance.h"
17#include "mlir/IR/PatternMatch.h"
18
19using namespace mlir;
20
21namespace {
22struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
23 using OpRewritePattern::OpRewritePattern;
24
25 LogicalResult matchAndRewrite(scf::WhileOp loop,
26 PatternRewriter &rewriter) const override {
27 return upliftWhileToForLoop(rewriter, loop);
28 }
29};
30} // namespace
31
32FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
33 scf::WhileOp loop) {
34 Block *beforeBody = loop.getBeforeBody();
35 if (!llvm::hasSingleElement(C: beforeBody->without_terminator()))
36 return rewriter.notifyMatchFailure(arg&: loop, msg: "Loop body must have single op");
37
38 auto cmp = dyn_cast<arith::CmpIOp>(Val&: beforeBody->front());
39 if (!cmp)
40 return rewriter.notifyMatchFailure(arg&: loop,
41 msg: "Loop body must have single cmp op");
42
43 scf::ConditionOp beforeTerm = loop.getConditionOp();
44 if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
45 return rewriter.notifyMatchFailure(op: loop, reasonCallback: [&](Diagnostic &diag) {
46 diag << "Expected single condition use: " << *cmp;
47 });
48
49 // If all 'before' arguments are forwarded but the order is different from
50 // 'after' arguments, here is the mapping from the 'after' argument index to
51 // the 'before' argument index.
52 std::optional<SmallVector<unsigned>> argReorder;
53 // All `before` block args must be directly forwarded to ConditionOp.
54 // They will be converted to `scf.for` `iter_vars` except induction var.
55 if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
56 auto getArgReordering =
57 [](Block *beforeBody,
58 scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
59 // Skip further checking if their sizes mismatch.
60 if (beforeBody->getNumArguments() != cond.getArgs().size())
61 return std::nullopt;
62 // Bitset on which 'before' argument is forwarded.
63 llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false);
64 // The forwarding order of 'before' arguments.
65 SmallVector<unsigned> order;
66 for (Value a : cond.getArgs()) {
67 BlockArgument arg = dyn_cast<BlockArgument>(Val&: a);
68 // Skip if 'arg' is not a 'before' argument.
69 if (!arg || arg.getOwner() != beforeBody)
70 return std::nullopt;
71 unsigned idx = arg.getArgNumber();
72 // Skip if 'arg' is already forwarded in another place.
73 if (forwarded[idx])
74 return std::nullopt;
75 // Record the presence of 'arg' and its order.
76 forwarded[idx] = true;
77 order.push_back(Elt: idx);
78 }
79 // Skip if not all 'before' arguments are forwarded.
80 if (!forwarded.all())
81 return std::nullopt;
82 return order;
83 };
84 // Check if 'before' arguments are all forwarded but just reordered.
85 argReorder = getArgReordering(beforeBody, beforeTerm);
86 if (!argReorder)
87 return rewriter.notifyMatchFailure(arg&: loop, msg: "Invalid args order");
88 }
89
90 using Pred = arith::CmpIPredicate;
91 Pred predicate = cmp.getPredicate();
92 if (predicate != Pred::slt && predicate != Pred::sgt)
93 return rewriter.notifyMatchFailure(op: loop, reasonCallback: [&](Diagnostic &diag) {
94 diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
95 });
96
97 BlockArgument inductionVar;
98 Value ub;
99 DominanceInfo dom;
100
101 // Check if cmp has a suitable form. One of the arguments must be a `before`
102 // block arg, other must be defined outside `scf.while` and will be treated
103 // as upper bound.
104 for (bool reverse : {false, true}) {
105 auto expectedPred = reverse ? Pred::sgt : Pred::slt;
106 if (cmp.getPredicate() != expectedPred)
107 continue;
108
109 auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
110 auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
111
112 auto blockArg = dyn_cast<BlockArgument>(Val&: arg1);
113 if (!blockArg || blockArg.getOwner() != beforeBody)
114 continue;
115
116 if (!dom.properlyDominates(a: arg2, b: loop))
117 continue;
118
119 inductionVar = blockArg;
120 ub = arg2;
121 break;
122 }
123
124 if (!inductionVar)
125 return rewriter.notifyMatchFailure(op: loop, reasonCallback: [&](Diagnostic &diag) {
126 diag << "Unrecognized cmp form: " << *cmp;
127 });
128
129 // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
130 // arg.
131 if (!llvm::hasNItems(C: inductionVar.getUses(), N: 2))
132 return rewriter.notifyMatchFailure(op: loop, reasonCallback: [&](Diagnostic &diag) {
133 diag << "Unrecognized induction var: " << inductionVar;
134 });
135
136 Block *afterBody = loop.getAfterBody();
137 scf::YieldOp afterTerm = loop.getYieldOp();
138 unsigned argNumber = inductionVar.getArgNumber();
139 Value afterTermIndArg = afterTerm.getResults()[argNumber];
140
141 auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) {
142 return std::distance(first: indices.begin(),
143 last: llvm::find_if(Range&: indices, P: [beforeArgNo](unsigned n) {
144 return n == beforeArgNo;
145 }));
146 };
147 Value inductionVarAfter = afterBody->getArgument(
148 i: argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);
149
150 // Find suitable `addi` op inside `after` block, one of the args must be an
151 // Induction var passed from `before` block and second arg must be defined
152 // outside of the loop and will be considered step value.
153 // TODO: Add `subi` support?
154 auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
155 if (!addOp)
156 return rewriter.notifyMatchFailure(arg&: loop, msg: "Didn't found suitable 'addi' op");
157
158 Value step;
159 if (addOp.getLhs() == inductionVarAfter) {
160 step = addOp.getRhs();
161 } else if (addOp.getRhs() == inductionVarAfter) {
162 step = addOp.getLhs();
163 }
164
165 if (!step || !dom.properlyDominates(a: step, b: loop))
166 return rewriter.notifyMatchFailure(arg&: loop, msg: "Invalid 'addi' form");
167
168 Value lb = loop.getInits()[argNumber];
169
170 assert(lb.getType().isIntOrIndex());
171 assert(lb.getType() == ub.getType());
172 assert(lb.getType() == step.getType());
173
174 SmallVector<Value> newArgs;
175
176 // Populate inits for new `scf.for`, skip induction var.
177 newArgs.reserve(N: loop.getInits().size());
178 for (auto &&[i, init] : llvm::enumerate(First: loop.getInits())) {
179 if (i == argNumber)
180 continue;
181
182 newArgs.emplace_back(Args&: init);
183 }
184
185 Location loc = loop.getLoc();
186
187 // With `builder == nullptr`, ForOp::build will try to insert terminator at
188 // the end of newly created block and we don't want it. Provide empty
189 // dummy builder instead.
190 auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
191 auto newLoop =
192 rewriter.create<scf::ForOp>(location: loc, args&: lb, args&: ub, args&: step, args&: newArgs, args&: emptyBuilder);
193
194 Block *newBody = newLoop.getBody();
195
196 // Populate block args for `scf.for` body, move induction var to the front.
197 newArgs.clear();
198 ValueRange newBodyArgs = newBody->getArguments();
199 for (auto i : llvm::seq<size_t>(Begin: 0, End: newBodyArgs.size())) {
200 if (i < argNumber) {
201 newArgs.emplace_back(Args: newBodyArgs[i + 1]);
202 } else if (i == argNumber) {
203 newArgs.emplace_back(Args: newBodyArgs.front());
204 } else {
205 newArgs.emplace_back(Args: newBodyArgs[i]);
206 }
207 }
208 if (argReorder) {
209 // Reorder arguments following the 'after' argument order from the original
210 // 'while' loop.
211 SmallVector<Value> args;
212 for (unsigned order : *argReorder)
213 args.push_back(Elt: newArgs[order]);
214 newArgs = args;
215 }
216
217 rewriter.inlineBlockBefore(source: loop.getAfterBody(), dest: newBody, before: newBody->end(),
218 argValues: newArgs);
219
220 auto term = cast<scf::YieldOp>(Val: newBody->getTerminator());
221
222 // Populate new yield args, skipping the induction var.
223 newArgs.clear();
224 for (auto &&[i, arg] : llvm::enumerate(First: term.getResults())) {
225 if (i == argNumber)
226 continue;
227
228 newArgs.emplace_back(Args&: arg);
229 }
230
231 OpBuilder::InsertionGuard g(rewriter);
232 rewriter.setInsertionPoint(term);
233 rewriter.replaceOpWithNewOp<scf::YieldOp>(op: term, args&: newArgs);
234
235 // Compute induction var value after loop execution.
236 rewriter.setInsertionPointAfter(newLoop);
237 Value one;
238 if (isa<IndexType>(Val: step.getType())) {
239 one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
240 } else {
241 one = rewriter.create<arith::ConstantIntOp>(location: loc, args: step.getType(), args: 1);
242 }
243
244 Value stepDec = rewriter.create<arith::SubIOp>(location: loc, args&: step, args&: one);
245 Value len = rewriter.create<arith::SubIOp>(location: loc, args&: ub, args&: lb);
246 len = rewriter.create<arith::AddIOp>(location: loc, args&: len, args&: stepDec);
247 len = rewriter.create<arith::DivSIOp>(location: loc, args&: len, args&: step);
248 len = rewriter.create<arith::SubIOp>(location: loc, args&: len, args&: one);
249 Value res = rewriter.create<arith::MulIOp>(location: loc, args&: len, args&: step);
250 res = rewriter.create<arith::AddIOp>(location: loc, args&: lb, args&: res);
251
252 // Reconstruct `scf.while` results, inserting final induction var value
253 // into proper place.
254 newArgs.clear();
255 llvm::append_range(C&: newArgs, R: newLoop.getResults());
256 newArgs.insert(I: newArgs.begin() + argNumber, Elt: res);
257 if (argReorder) {
258 // Reorder arguments following the 'after' argument order from the original
259 // 'while' loop.
260 SmallVector<Value> results;
261 for (unsigned order : *argReorder)
262 results.push_back(Elt: newArgs[order]);
263 newArgs = results;
264 }
265 rewriter.replaceOp(op: loop, newValues: newArgs);
266 return newLoop;
267}
268
269void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
270 patterns.add<UpliftWhileOp>(arg: patterns.getContext());
271}
272

source code of mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp