1//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.WhileOp's into SCF.ForOp's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SCF/Transforms/Patterns.h"
18#include "mlir/IR/Dominance.h"
19#include "mlir/IR/PatternMatch.h"
20
21using namespace mlir;
22
23namespace {
24struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
25 using OpRewritePattern::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(scf::WhileOp loop,
28 PatternRewriter &rewriter) const override {
29 return upliftWhileToForLoop(rewriter, loop);
30 }
31};
32} // namespace
33
34FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
35 scf::WhileOp loop) {
36 Block *beforeBody = loop.getBeforeBody();
37 if (!llvm::hasSingleElement(C: beforeBody->without_terminator()))
38 return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
39
40 auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
41 if (!cmp)
42 return rewriter.notifyMatchFailure(loop,
43 "Loop body must have single cmp op");
44
45 scf::ConditionOp beforeTerm = loop.getConditionOp();
46 if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
47 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
48 diag << "Expected single condition use: " << *cmp;
49 });
50
51 // If all 'before' arguments are forwarded but the order is different from
52 // 'after' arguments, here is the mapping from the 'after' argument index to
53 // the 'before' argument index.
54 std::optional<SmallVector<unsigned>> argReorder;
55 // All `before` block args must be directly forwarded to ConditionOp.
56 // They will be converted to `scf.for` `iter_vars` except induction var.
57 if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) {
58 auto getArgReordering =
59 [](Block *beforeBody,
60 scf::ConditionOp cond) -> std::optional<SmallVector<unsigned>> {
61 // Skip further checking if their sizes mismatch.
62 if (beforeBody->getNumArguments() != cond.getArgs().size())
63 return std::nullopt;
64 // Bitset on which 'before' argument is forwarded.
65 llvm::SmallBitVector forwarded(beforeBody->getNumArguments(), false);
66 // The forwarding order of 'before' arguments.
67 SmallVector<unsigned> order;
68 for (Value a : cond.getArgs()) {
69 BlockArgument arg = dyn_cast<BlockArgument>(a);
70 // Skip if 'arg' is not a 'before' argument.
71 if (!arg || arg.getOwner() != beforeBody)
72 return std::nullopt;
73 unsigned idx = arg.getArgNumber();
74 // Skip if 'arg' is already forwarded in another place.
75 if (forwarded[idx])
76 return std::nullopt;
77 // Record the presence of 'arg' and its order.
78 forwarded[idx] = true;
79 order.push_back(idx);
80 }
81 // Skip if not all 'before' arguments are forwarded.
82 if (!forwarded.all())
83 return std::nullopt;
84 return order;
85 };
86 // Check if 'before' arguments are all forwarded but just reordered.
87 argReorder = getArgReordering(beforeBody, beforeTerm);
88 if (!argReorder)
89 return rewriter.notifyMatchFailure(loop, "Invalid args order");
90 }
91
92 using Pred = arith::CmpIPredicate;
93 Pred predicate = cmp.getPredicate();
94 if (predicate != Pred::slt && predicate != Pred::sgt)
95 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
96 diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
97 });
98
99 BlockArgument inductionVar;
100 Value ub;
101 DominanceInfo dom;
102
103 // Check if cmp has a suitable form. One of the arguments must be a `before`
104 // block arg, other must be defined outside `scf.while` and will be treated
105 // as upper bound.
106 for (bool reverse : {false, true}) {
107 auto expectedPred = reverse ? Pred::sgt : Pred::slt;
108 if (cmp.getPredicate() != expectedPred)
109 continue;
110
111 auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
112 auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
113
114 auto blockArg = dyn_cast<BlockArgument>(arg1);
115 if (!blockArg || blockArg.getOwner() != beforeBody)
116 continue;
117
118 if (!dom.properlyDominates(arg2, loop))
119 continue;
120
121 inductionVar = blockArg;
122 ub = arg2;
123 break;
124 }
125
126 if (!inductionVar)
127 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
128 diag << "Unrecognized cmp form: " << *cmp;
129 });
130
131 // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
132 // arg.
133 if (!llvm::hasNItems(C: inductionVar.getUses(), N: 2))
134 return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
135 diag << "Unrecognized induction var: " << inductionVar;
136 });
137
138 Block *afterBody = loop.getAfterBody();
139 scf::YieldOp afterTerm = loop.getYieldOp();
140 unsigned argNumber = inductionVar.getArgNumber();
141 Value afterTermIndArg = afterTerm.getResults()[argNumber];
142
143 auto findAfterArgNo = [](ArrayRef<unsigned> indices, unsigned beforeArgNo) {
144 return std::distance(first: indices.begin(),
145 last: llvm::find_if(Range&: indices, P: [beforeArgNo](unsigned n) {
146 return n == beforeArgNo;
147 }));
148 };
149 Value inductionVarAfter = afterBody->getArgument(
150 i: argReorder ? findAfterArgNo(*argReorder, argNumber) : argNumber);
151
152 // Find suitable `addi` op inside `after` block, one of the args must be an
153 // Induction var passed from `before` block and second arg must be defined
154 // outside of the loop and will be considered step value.
155 // TODO: Add `subi` support?
156 auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
157 if (!addOp)
158 return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
159
160 Value step;
161 if (addOp.getLhs() == inductionVarAfter) {
162 step = addOp.getRhs();
163 } else if (addOp.getRhs() == inductionVarAfter) {
164 step = addOp.getLhs();
165 }
166
167 if (!step || !dom.properlyDominates(step, loop))
168 return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
169
170 Value lb = loop.getInits()[argNumber];
171
172 assert(lb.getType().isIntOrIndex());
173 assert(lb.getType() == ub.getType());
174 assert(lb.getType() == step.getType());
175
176 SmallVector<Value> newArgs;
177
178 // Populate inits for new `scf.for`, skip induction var.
179 newArgs.reserve(N: loop.getInits().size());
180 for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
181 if (i == argNumber)
182 continue;
183
184 newArgs.emplace_back(init);
185 }
186
187 Location loc = loop.getLoc();
188
189 // With `builder == nullptr`, ForOp::build will try to insert terminator at
190 // the end of newly created block and we don't want it. Provide empty
191 // dummy builder instead.
192 auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
193 auto newLoop =
194 rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
195
196 Block *newBody = newLoop.getBody();
197
198 // Populate block args for `scf.for` body, move induction var to the front.
199 newArgs.clear();
200 ValueRange newBodyArgs = newBody->getArguments();
201 for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
202 if (i < argNumber) {
203 newArgs.emplace_back(newBodyArgs[i + 1]);
204 } else if (i == argNumber) {
205 newArgs.emplace_back(newBodyArgs.front());
206 } else {
207 newArgs.emplace_back(newBodyArgs[i]);
208 }
209 }
210 if (argReorder) {
211 // Reorder arguments following the 'after' argument order from the original
212 // 'while' loop.
213 SmallVector<Value> args;
214 for (unsigned order : *argReorder)
215 args.push_back(Elt: newArgs[order]);
216 newArgs = args;
217 }
218
219 rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
220 newArgs);
221
222 auto term = cast<scf::YieldOp>(newBody->getTerminator());
223
224 // Populate new yield args, skipping the induction var.
225 newArgs.clear();
226 for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
227 if (i == argNumber)
228 continue;
229
230 newArgs.emplace_back(arg);
231 }
232
233 OpBuilder::InsertionGuard g(rewriter);
234 rewriter.setInsertionPoint(term);
235 rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
236
237 // Compute induction var value after loop execution.
238 rewriter.setInsertionPointAfter(newLoop);
239 Value one;
240 if (isa<IndexType>(Val: step.getType())) {
241 one = rewriter.create<arith::ConstantIndexOp>(location: loc, args: 1);
242 } else {
243 one = rewriter.create<arith::ConstantIntOp>(location: loc, args: 1, args: step.getType());
244 }
245
246 Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
247 Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
248 len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
249 len = rewriter.create<arith::DivSIOp>(loc, len, step);
250 len = rewriter.create<arith::SubIOp>(loc, len, one);
251 Value res = rewriter.create<arith::MulIOp>(loc, len, step);
252 res = rewriter.create<arith::AddIOp>(loc, lb, res);
253
254 // Reconstruct `scf.while` results, inserting final induction var value
255 // into proper place.
256 newArgs.clear();
257 llvm::append_range(newArgs, newLoop.getResults());
258 newArgs.insert(I: newArgs.begin() + argNumber, Elt: res);
259 if (argReorder) {
260 // Reorder arguments following the 'after' argument order from the original
261 // 'while' loop.
262 SmallVector<Value> results;
263 for (unsigned order : *argReorder)
264 results.push_back(Elt: newArgs[order]);
265 newArgs = results;
266 }
267 rewriter.replaceOp(loop, newArgs);
268 return newLoop;
269}
270
271void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
272 patterns.add<UpliftWhileOp>(arg: patterns.getContext());
273}
274

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp