1 | //===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Transforms SCF.WhileOp's into SCF.ForOp's. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
18 | #include "mlir/IR/Dominance.h" |
19 | #include "mlir/IR/PatternMatch.h" |
20 | |
21 | using namespace mlir; |
22 | |
23 | namespace { |
24 | struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { |
25 | using OpRewritePattern::OpRewritePattern; |
26 | |
27 | LogicalResult matchAndRewrite(scf::WhileOp loop, |
28 | PatternRewriter &rewriter) const override { |
29 | return upliftWhileToForLoop(rewriter, loop); |
30 | } |
31 | }; |
32 | } // namespace |
33 | |
34 | FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, |
35 | scf::WhileOp loop) { |
36 | Block *beforeBody = loop.getBeforeBody(); |
37 | if (!llvm::hasSingleElement(C: beforeBody->without_terminator())) |
38 | return rewriter.notifyMatchFailure(loop, "Loop body must have single op" ); |
39 | |
40 | auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front()); |
41 | if (!cmp) |
42 | return rewriter.notifyMatchFailure(loop, |
43 | "Loop body must have single cmp op" ); |
44 | |
45 | scf::ConditionOp beforeTerm = loop.getConditionOp(); |
46 | if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult()) |
47 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
48 | diag << "Expected single condition use: " << *cmp; |
49 | }); |
50 | |
51 | // If all 'before' arguments are forwarded but the order is different from |
52 | // 'after' arguments, here is the mapping from the 'after' argument index to |
53 | // the 'before' argument index. |
54 | std::optional<SmallVector<unsigned>> argReorder; |
55 | // All `before` block args must be directly forwarded to ConditionOp. |
56 | // They will be converted to `scf.for` `iter_vars` except induction var. |
57 | if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) { |
58 | auto getArgReordering = |
59 | [](Block *beforeBody, |
60 | scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> { |
61 | // Skip further checking if their sizes mismatch. |
62 | if (beforeBody->getNumArguments() != cond.getArgs().size()) |
63 | return std::nullopt; |
64 | // Bitset on which 'before' argument is forwarded. |
65 | llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false); |
66 | // The forwarding order of 'before' arguments. |
67 | SmallVector<unsigned> order; |
68 | for (Value a : cond.getArgs()) { |
69 | BlockArgument arg = dyn_cast<BlockArgument>(a); |
70 | // Skip if 'arg' is not a 'before' argument. |
71 | if (!arg || arg.getOwner() != beforeBody) |
72 | return std::nullopt; |
73 | unsigned idx = arg.getArgNumber(); |
74 | // Skip if 'arg' is already forwarded in another place. |
75 | if (forwarded[idx]) |
76 | return std::nullopt; |
77 | // Record the presence of 'arg' and its order. |
78 | forwarded[idx] = true; |
79 | order.push_back(idx); |
80 | } |
81 | // Skip if not all 'before' arguments are forwarded. |
82 | if (!forwarded.all()) |
83 | return std::nullopt; |
84 | return order; |
85 | }; |
86 | // Check if 'before' arguments are all forwarded but just reordered. |
87 | argReorder = getArgReordering(beforeBody, beforeTerm); |
88 | if (!argReorder) |
89 | return rewriter.notifyMatchFailure(loop, "Invalid args order" ); |
90 | } |
91 | |
92 | using Pred = arith::CmpIPredicate; |
93 | Pred predicate = cmp.getPredicate(); |
94 | if (predicate != Pred::slt && predicate != Pred::sgt) |
95 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
96 | diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; |
97 | }); |
98 | |
99 | BlockArgument inductionVar; |
100 | Value ub; |
101 | DominanceInfo dom; |
102 | |
103 | // Check if cmp has a suitable form. One of the arguments must be a `before` |
104 | // block arg, other must be defined outside `scf.while` and will be treated |
105 | // as upper bound. |
106 | for (bool reverse : {false, true}) { |
107 | auto expectedPred = reverse ? Pred::sgt : Pred::slt; |
108 | if (cmp.getPredicate() != expectedPred) |
109 | continue; |
110 | |
111 | auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); |
112 | auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); |
113 | |
114 | auto blockArg = dyn_cast<BlockArgument>(arg1); |
115 | if (!blockArg || blockArg.getOwner() != beforeBody) |
116 | continue; |
117 | |
118 | if (!dom.properlyDominates(arg2, loop)) |
119 | continue; |
120 | |
121 | inductionVar = blockArg; |
122 | ub = arg2; |
123 | break; |
124 | } |
125 | |
126 | if (!inductionVar) |
127 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
128 | diag << "Unrecognized cmp form: " << *cmp; |
129 | }); |
130 | |
131 | // inductionVar must have 2 uses: one is in `cmp` and other is `condition` |
132 | // arg. |
133 | if (!llvm::hasNItems(C: inductionVar.getUses(), N: 2)) |
134 | return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { |
135 | diag << "Unrecognized induction var: " << inductionVar; |
136 | }); |
137 | |
138 | Block *afterBody = loop.getAfterBody(); |
139 | scf::YieldOp afterTerm = loop.getYieldOp(); |
140 | unsigned argNumber = inductionVar.getArgNumber(); |
141 | Value afterTermIndArg = afterTerm.getResults()[argNumber]; |
142 | |
143 | auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) { |
144 | return std::distance(first: indices.begin(), |
145 | last: llvm::find_if(Range&: indices, P: [beforeArgNo](unsigned n) { |
146 | return n == beforeArgNo; |
147 | })); |
148 | }; |
149 | Value inductionVarAfter = afterBody->getArgument( |
150 | i: argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber); |
151 | |
152 | // Find suitable `addi` op inside `after` block, one of the args must be an |
153 | // Induction var passed from `before` block and second arg must be defined |
154 | // outside of the loop and will be considered step value. |
155 | // TODO: Add `subi` support? |
156 | auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>(); |
157 | if (!addOp) |
158 | return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op" ); |
159 | |
160 | Value step; |
161 | if (addOp.getLhs() == inductionVarAfter) { |
162 | step = addOp.getRhs(); |
163 | } else if (addOp.getRhs() == inductionVarAfter) { |
164 | step = addOp.getLhs(); |
165 | } |
166 | |
167 | if (!step || !dom.properlyDominates(step, loop)) |
168 | return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form" ); |
169 | |
170 | Value lb = loop.getInits()[argNumber]; |
171 | |
172 | assert(lb.getType().isIntOrIndex()); |
173 | assert(lb.getType() == ub.getType()); |
174 | assert(lb.getType() == step.getType()); |
175 | |
176 | SmallVector<Value> newArgs; |
177 | |
178 | // Populate inits for new `scf.for`, skip induction var. |
179 | newArgs.reserve(N: loop.getInits().size()); |
180 | for (auto &&[i, init] : llvm::enumerate(loop.getInits())) { |
181 | if (i == argNumber) |
182 | continue; |
183 | |
184 | newArgs.emplace_back(init); |
185 | } |
186 | |
187 | Location loc = loop.getLoc(); |
188 | |
189 | // With `builder == nullptr`, ForOp::build will try to insert terminator at |
190 | // the end of newly created block and we don't want it. Provide empty |
191 | // dummy builder instead. |
192 | auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; |
193 | auto newLoop = |
194 | rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder); |
195 | |
196 | Block *newBody = newLoop.getBody(); |
197 | |
198 | // Populate block args for `scf.for` body, move induction var to the front. |
199 | newArgs.clear(); |
200 | ValueRange newBodyArgs = newBody->getArguments(); |
201 | for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) { |
202 | if (i < argNumber) { |
203 | newArgs.emplace_back(newBodyArgs[i + 1]); |
204 | } else if (i == argNumber) { |
205 | newArgs.emplace_back(newBodyArgs.front()); |
206 | } else { |
207 | newArgs.emplace_back(newBodyArgs[i]); |
208 | } |
209 | } |
210 | if (argReorder) { |
211 | // Reorder arguments following the 'after' argument order from the original |
212 | // 'while' loop. |
213 | SmallVector<Value> args; |
214 | for (unsigned order : *argReorder) |
215 | args.push_back(Elt: newArgs[order]); |
216 | newArgs = args; |
217 | } |
218 | |
219 | rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), |
220 | newArgs); |
221 | |
222 | auto term = cast<scf::YieldOp>(newBody->getTerminator()); |
223 | |
224 | // Populate new yield args, skipping the induction var. |
225 | newArgs.clear(); |
226 | for (auto &&[i, arg] : llvm::enumerate(term.getResults())) { |
227 | if (i == argNumber) |
228 | continue; |
229 | |
230 | newArgs.emplace_back(arg); |
231 | } |
232 | |
233 | OpBuilder::InsertionGuard g(rewriter); |
234 | rewriter.setInsertionPoint(term); |
235 | rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs); |
236 | |
237 | // Compute induction var value after loop execution. |
238 | rewriter.setInsertionPointAfter(newLoop); |
239 | Value one; |
240 | if (isa<IndexType>(Val: step.getType())) { |
241 | one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1); |
242 | } else { |
243 | one = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: step.getType()); |
244 | } |
245 | |
246 | Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one); |
247 | Value len = rewriter.create<arith::SubIOp>(loc, ub, lb); |
248 | len = rewriter.create<arith::AddIOp>(loc, len, stepDec); |
249 | len = rewriter.create<arith::DivSIOp>(loc, len, step); |
250 | len = rewriter.create<arith::SubIOp>(loc, len, one); |
251 | Value res = rewriter.create<arith::MulIOp>(loc, len, step); |
252 | res = rewriter.create<arith::AddIOp>(loc, lb, res); |
253 | |
254 | // Reconstruct `scf.while` results, inserting final induction var value |
255 | // into proper place. |
256 | newArgs.clear(); |
257 | llvm::append_range(newArgs, newLoop.getResults()); |
258 | newArgs.insert(I: newArgs.begin() + argNumber, Elt: res); |
259 | if (argReorder) { |
260 | // Reorder arguments following the 'after' argument order from the original |
261 | // 'while' loop. |
262 | SmallVector<Value> results; |
263 | for (unsigned order : *argReorder) |
264 | results.push_back(Elt: newArgs[order]); |
265 | newArgs = results; |
266 | } |
267 | rewriter.replaceOp(loop, newArgs); |
268 | return newLoop; |
269 | } |
270 | |
271 | void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { |
272 | patterns.add<UpliftWhileOp>(arg: patterns.getContext()); |
273 | } |
274 | |