1 | //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert scf.for, scf.if and loop.terminator |
10 | // ops into standard CFG ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
18 | #include "mlir/Dialect/SCF/IR/SCF.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/BuiltinOps.h" |
21 | #include "mlir/IR/IRMapping.h" |
22 | #include "mlir/IR/MLIRContext.h" |
23 | #include "mlir/IR/PatternMatch.h" |
24 | #include "mlir/Transforms/DialectConversion.h" |
25 | #include "mlir/Transforms/Passes.h" |
26 | |
27 | namespace mlir { |
28 | #define GEN_PASS_DEF_SCFTOCONTROLFLOW |
29 | #include "mlir/Conversion/Passes.h.inc" |
30 | } // namespace mlir |
31 | |
32 | using namespace mlir; |
33 | using namespace mlir::scf; |
34 | |
35 | namespace { |
36 | |
37 | struct SCFToControlFlowPass |
38 | : public impl::SCFToControlFlowBase<SCFToControlFlowPass> { |
39 | void runOnOperation() override; |
40 | }; |
41 | |
42 | // Create a CFG subgraph for the loop around its body blocks (if the body |
43 | // contained other loops, they have been already lowered to a flow of blocks). |
44 | // Maintain the invariants that a CFG subgraph created for any loop has a single |
45 | // entry and a single exit, and that the entry/exit blocks are respectively |
46 | // first/last blocks in the parent region. The original loop operation is |
47 | // replaced by the initialization operations that set up the initial value of |
48 | // the loop induction variable (%iv) and computes the loop bounds that are loop- |
49 | // invariant for affine loops. The operations following the original scf.for |
50 | // are split out into a separate continuation (exit) block. A condition block is |
51 | // created before the continuation block. It checks the exit condition of the |
52 | // loop and branches either to the continuation block, or to the first block of |
53 | // the body. The condition block takes as arguments the values of the induction |
54 | // variable followed by loop-carried values. Since it dominates both the body |
55 | // blocks and the continuation block, loop-carried values are visible in all of |
56 | // those blocks. Induction variable modification is appended to the last block |
57 | // of the body (which is the exit block from the body subgraph thanks to the |
58 | // invariant we maintain) along with a branch that loops back to the condition |
59 | // block. Loop-carried values are the loop terminator operands, which are |
60 | // forwarded to the branch. |
61 | // |
62 | // +---------------------------------+ |
63 | // | <code before the ForOp> | |
64 | // | <definitions of %init...> | |
65 | // | <compute initial %iv value> | |
66 | // | cf.br cond(%iv, %init...) | |
67 | // +---------------------------------+ |
68 | // | |
69 | // -------| | |
70 | // | v v |
71 | // | +--------------------------------+ |
72 | // | | cond(%iv, %init...): | |
73 | // | | <compare %iv to upper bound> | |
74 | // | | cf.cond_br %r, body, end | |
75 | // | +--------------------------------+ |
76 | // | | | |
77 | // | | -------------| |
78 | // | v | |
79 | // | +--------------------------------+ | |
80 | // | | body-first: | | |
81 | // | | <%init visible by dominance> | | |
82 | // | | <body contents> | | |
83 | // | +--------------------------------+ | |
84 | // | | | |
85 | // | ... | |
86 | // | | | |
87 | // | +--------------------------------+ | |
88 | // | | body-last: | | |
89 | // | | <body contents> | | |
90 | // | | <operands of yield = %yields>| | |
91 | // | | %new_iv =<add step to %iv> | | |
92 | // | | cf.br cond(%new_iv, %yields) | | |
93 | // | +--------------------------------+ | |
94 | // | | | |
95 | // |----------- |-------------------- |
96 | // v |
97 | // +--------------------------------+ |
98 | // | end: | |
99 | // | <code after the ForOp> | |
100 | // | <%init visible by dominance> | |
101 | // +--------------------------------+ |
102 | // |
103 | struct ForLowering : public OpRewritePattern<ForOp> { |
104 | using OpRewritePattern<ForOp>::OpRewritePattern; |
105 | |
106 | LogicalResult matchAndRewrite(ForOp forOp, |
107 | PatternRewriter &rewriter) const override; |
108 | }; |
109 | |
110 | // Create a CFG subgraph for the scf.if operation (including its "then" and |
111 | // optional "else" operation blocks). We maintain the invariants that the |
112 | // subgraph has a single entry and a single exit point, and that the entry/exit |
113 | // blocks are respectively the first/last block of the enclosing region. The |
114 | // operations following the scf.if are split into a continuation (subgraph |
115 | // exit) block. The condition is lowered to a chain of blocks that implement the |
116 | // short-circuit scheme. The "scf.if" operation is replaced with a conditional |
117 | // branch to either the first block of the "then" region, or to the first block |
118 | // of the "else" region. In these blocks, "scf.yield" is unconditional branches |
119 | // to the post-dominating block. When the "scf.if" does not return values, the |
120 | // post-dominating block is the same as the continuation block. When it returns |
121 | // values, the post-dominating block is a new block with arguments that |
122 | // correspond to the values returned by the "scf.if" that unconditionally |
123 | // branches to the continuation block. This allows block arguments to dominate |
124 | // any uses of the hitherto "scf.if" results that they replaced. (Inserting a |
125 | // new block allows us to avoid modifying the argument list of an existing |
126 | // block, which is illegal in a conversion pattern). When the "else" region is |
127 | // empty, which is only allowed for "scf.if"s that don't return values, the |
128 | // condition branches directly to the continuation block. |
129 | // |
130 | // CFG for a scf.if with else and without results. |
131 | // |
132 | // +--------------------------------+ |
133 | // | <code before the IfOp> | |
134 | // | cf.cond_br %cond, %then, %else | |
135 | // +--------------------------------+ |
136 | // | | |
137 | // | --------------| |
138 | // v | |
139 | // +--------------------------------+ | |
140 | // | then: | | |
141 | // | <then contents> | | |
142 | // | cf.br continue | | |
143 | // +--------------------------------+ | |
144 | // | | |
145 | // |---------- |------------- |
146 | // | V |
147 | // | +--------------------------------+ |
148 | // | | else: | |
149 | // | | <else contents> | |
150 | // | | cf.br continue | |
151 | // | +--------------------------------+ |
152 | // | | |
153 | // ------| | |
154 | // v v |
155 | // +--------------------------------+ |
156 | // | continue: | |
157 | // | <code after the IfOp> | |
158 | // +--------------------------------+ |
159 | // |
160 | // CFG for a scf.if with results. |
161 | // |
162 | // +--------------------------------+ |
163 | // | <code before the IfOp> | |
164 | // | cf.cond_br %cond, %then, %else | |
165 | // +--------------------------------+ |
166 | // | | |
167 | // | --------------| |
168 | // v | |
169 | // +--------------------------------+ | |
170 | // | then: | | |
171 | // | <then contents> | | |
172 | // | cf.br dom(%args...) | | |
173 | // +--------------------------------+ | |
174 | // | | |
175 | // |---------- |------------- |
176 | // | V |
177 | // | +--------------------------------+ |
178 | // | | else: | |
179 | // | | <else contents> | |
180 | // | | cf.br dom(%args...) | |
181 | // | +--------------------------------+ |
182 | // | | |
183 | // ------| | |
184 | // v v |
185 | // +--------------------------------+ |
186 | // | dom(%args...): | |
187 | // | cf.br continue | |
188 | // +--------------------------------+ |
189 | // | |
190 | // v |
191 | // +--------------------------------+ |
192 | // | continue: | |
193 | // | <code after the IfOp> | |
194 | // +--------------------------------+ |
195 | // |
196 | struct IfLowering : public OpRewritePattern<IfOp> { |
197 | using OpRewritePattern<IfOp>::OpRewritePattern; |
198 | |
199 | LogicalResult matchAndRewrite(IfOp ifOp, |
200 | PatternRewriter &rewriter) const override; |
201 | }; |
202 | |
203 | struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> { |
204 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
205 | |
206 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
207 | PatternRewriter &rewriter) const override; |
208 | }; |
209 | |
210 | struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> { |
211 | using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern; |
212 | |
213 | LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, |
214 | PatternRewriter &rewriter) const override; |
215 | }; |
216 | |
217 | /// Create a CFG subgraph for this loop construct. The regions of the loop need |
218 | /// not be a single block anymore (for example, if other SCF constructs that |
219 | /// they contain have been already converted to CFG), but need to be single-exit |
220 | /// from the last block of each region. The operations following the original |
221 | /// WhileOp are split into a new continuation block. Both regions of the WhileOp |
222 | /// are inlined, and their terminators are rewritten to organize the control |
223 | /// flow implementing the loop as follows. |
224 | /// |
225 | /// +---------------------------------+ |
226 | /// | <code before the WhileOp> | |
227 | /// | cf.br ^before(%operands...) | |
228 | /// +---------------------------------+ |
229 | /// | |
230 | /// -------| | |
231 | /// | v v |
232 | /// | +--------------------------------+ |
233 | /// | | ^before(%bargs...): | |
234 | /// | | %vals... = <some payload> | |
235 | /// | +--------------------------------+ |
236 | /// | | |
237 | /// | ... |
238 | /// | | |
239 | /// | +--------------------------------+ |
240 | /// | | ^before-last: |
241 | /// | | %cond = <compute condition> | |
242 | /// | | cf.cond_br %cond, | |
243 | /// | | ^after(%vals...), ^cont | |
244 | /// | +--------------------------------+ |
245 | /// | | | |
246 | /// | | -------------| |
247 | /// | v | |
248 | /// | +--------------------------------+ | |
249 | /// | | ^after(%aargs...): | | |
250 | /// | | <body contents> | | |
251 | /// | +--------------------------------+ | |
252 | /// | | | |
253 | /// | ... | |
254 | /// | | | |
255 | /// | +--------------------------------+ | |
256 | /// | | ^after-last: | | |
257 | /// | | %yields... = <some payload> | | |
258 | /// | | cf.br ^before(%yields...) | | |
259 | /// | +--------------------------------+ | |
260 | /// | | | |
261 | /// |----------- |-------------------- |
262 | /// v |
263 | /// +--------------------------------+ |
264 | /// | ^cont: | |
265 | /// | <code after the WhileOp> | |
266 | /// | <%vals from 'before' region | |
267 | /// | visible by dominance> | |
268 | /// +--------------------------------+ |
269 | /// |
270 | /// Values are communicated between ex-regions (the groups of blocks that used |
271 | /// to form a region before inlining) through block arguments of their |
272 | /// entry blocks, which are visible in all other dominated blocks. Similarly, |
273 | /// the results of the WhileOp are defined in the 'before' region, which is |
274 | /// required to have a single existing block, and are therefore accessible in |
275 | /// the continuation block due to dominance. |
276 | struct WhileLowering : public OpRewritePattern<WhileOp> { |
277 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
278 | |
279 | LogicalResult matchAndRewrite(WhileOp whileOp, |
280 | PatternRewriter &rewriter) const override; |
281 | }; |
282 | |
283 | /// Optimized version of the above for the case of the "after" region merely |
284 | /// forwarding its arguments back to the "before" region (i.e., a "do-while" |
285 | /// loop). This avoid inlining the "after" region completely and branches back |
286 | /// to the "before" entry instead. |
287 | struct DoWhileLowering : public OpRewritePattern<WhileOp> { |
288 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
289 | |
290 | LogicalResult matchAndRewrite(WhileOp whileOp, |
291 | PatternRewriter &rewriter) const override; |
292 | }; |
293 | |
294 | /// Lower an `scf.index_switch` operation to a `cf.switch` operation. |
295 | struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> { |
296 | using OpRewritePattern::OpRewritePattern; |
297 | |
298 | LogicalResult matchAndRewrite(IndexSwitchOp op, |
299 | PatternRewriter &rewriter) const override; |
300 | }; |
301 | |
302 | /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it |
303 | /// has no shared outputs. Ops with shared outputs should be bufferized first. |
304 | /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other |
305 | /// dialects/passes. |
306 | struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> { |
307 | using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern; |
308 | |
309 | LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp, |
310 | PatternRewriter &rewriter) const override; |
311 | }; |
312 | |
313 | } // namespace |
314 | |
315 | LogicalResult ForLowering::matchAndRewrite(ForOp forOp, |
316 | PatternRewriter &rewriter) const { |
317 | Location loc = forOp.getLoc(); |
318 | |
319 | // Start by splitting the block containing the 'scf.for' into two parts. |
320 | // The part before will get the init code, the part after will be the end |
321 | // point. |
322 | auto *initBlock = rewriter.getInsertionBlock(); |
323 | auto initPosition = rewriter.getInsertionPoint(); |
324 | auto *endBlock = rewriter.splitBlock(block: initBlock, before: initPosition); |
325 | |
326 | // Use the first block of the loop body as the condition block since it is the |
327 | // block that has the induction variable and loop-carried values as arguments. |
328 | // Split out all operations from the first block into a new block. Move all |
329 | // body blocks from the loop body region to the region containing the loop. |
330 | auto *conditionBlock = &forOp.getRegion().front(); |
331 | auto *firstBodyBlock = |
332 | rewriter.splitBlock(block: conditionBlock, before: conditionBlock->begin()); |
333 | auto *lastBodyBlock = &forOp.getRegion().back(); |
334 | rewriter.inlineRegionBefore(forOp.getRegion(), endBlock); |
335 | auto iv = conditionBlock->getArgument(0); |
336 | |
337 | // Append the induction variable stepping logic to the last body block and |
338 | // branch back to the condition block. Loop-carried values are taken from |
339 | // operands of the loop terminator. |
340 | Operation *terminator = lastBodyBlock->getTerminator(); |
341 | rewriter.setInsertionPointToEnd(lastBodyBlock); |
342 | auto step = forOp.getStep(); |
343 | auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult(); |
344 | if (!stepped) |
345 | return failure(); |
346 | |
347 | SmallVector<Value, 8> loopCarried; |
348 | loopCarried.push_back(Elt: stepped); |
349 | loopCarried.append(in_start: terminator->operand_begin(), in_end: terminator->operand_end()); |
350 | rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried); |
351 | rewriter.eraseOp(op: terminator); |
352 | |
353 | // Compute loop bounds before branching to the condition. |
354 | rewriter.setInsertionPointToEnd(initBlock); |
355 | Value lowerBound = forOp.getLowerBound(); |
356 | Value upperBound = forOp.getUpperBound(); |
357 | if (!lowerBound || !upperBound) |
358 | return failure(); |
359 | |
360 | // The initial values of loop-carried values is obtained from the operands |
361 | // of the loop operation. |
362 | SmallVector<Value, 8> destOperands; |
363 | destOperands.push_back(Elt: lowerBound); |
364 | llvm::append_range(destOperands, forOp.getInitArgs()); |
365 | rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands); |
366 | |
367 | // With the body block done, we can fill in the condition block. |
368 | rewriter.setInsertionPointToEnd(conditionBlock); |
369 | auto comparison = rewriter.create<arith::CmpIOp>( |
370 | loc, arith::CmpIPredicate::slt, iv, upperBound); |
371 | |
372 | rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock, |
373 | ArrayRef<Value>(), endBlock, |
374 | ArrayRef<Value>()); |
375 | // The result of the loop operation is the values of the condition block |
376 | // arguments except the induction variable on the last iteration. |
377 | rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); |
378 | return success(); |
379 | } |
380 | |
381 | LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, |
382 | PatternRewriter &rewriter) const { |
383 | auto loc = ifOp.getLoc(); |
384 | |
385 | // Start by splitting the block containing the 'scf.if' into two parts. |
386 | // The part before will contain the condition, the part after will be the |
387 | // continuation point. |
388 | auto *condBlock = rewriter.getInsertionBlock(); |
389 | auto opPosition = rewriter.getInsertionPoint(); |
390 | auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition); |
391 | Block *continueBlock; |
392 | if (ifOp.getNumResults() == 0) { |
393 | continueBlock = remainingOpsBlock; |
394 | } else { |
395 | continueBlock = |
396 | rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), |
397 | SmallVector<Location>(ifOp.getNumResults(), loc)); |
398 | rewriter.create<cf::BranchOp>(loc, remainingOpsBlock); |
399 | } |
400 | |
401 | // Move blocks from the "then" region to the region containing 'scf.if', |
402 | // place it before the continuation block, and branch to it. |
403 | auto &thenRegion = ifOp.getThenRegion(); |
404 | auto *thenBlock = &thenRegion.front(); |
405 | Operation *thenTerminator = thenRegion.back().getTerminator(); |
406 | ValueRange thenTerminatorOperands = thenTerminator->getOperands(); |
407 | rewriter.setInsertionPointToEnd(&thenRegion.back()); |
408 | rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands); |
409 | rewriter.eraseOp(op: thenTerminator); |
410 | rewriter.inlineRegionBefore(thenRegion, continueBlock); |
411 | |
412 | // Move blocks from the "else" region (if present) to the region containing |
413 | // 'scf.if', place it before the continuation block and branch to it. It |
414 | // will be placed after the "then" regions. |
415 | auto *elseBlock = continueBlock; |
416 | auto &elseRegion = ifOp.getElseRegion(); |
417 | if (!elseRegion.empty()) { |
418 | elseBlock = &elseRegion.front(); |
419 | Operation *elseTerminator = elseRegion.back().getTerminator(); |
420 | ValueRange elseTerminatorOperands = elseTerminator->getOperands(); |
421 | rewriter.setInsertionPointToEnd(&elseRegion.back()); |
422 | rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands); |
423 | rewriter.eraseOp(op: elseTerminator); |
424 | rewriter.inlineRegionBefore(elseRegion, continueBlock); |
425 | } |
426 | |
427 | rewriter.setInsertionPointToEnd(condBlock); |
428 | rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock, |
429 | /*trueArgs=*/ArrayRef<Value>(), elseBlock, |
430 | /*falseArgs=*/ArrayRef<Value>()); |
431 | |
432 | // Ok, we're done! |
433 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
434 | return success(); |
435 | } |
436 | |
437 | LogicalResult |
438 | ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, |
439 | PatternRewriter &rewriter) const { |
440 | auto loc = op.getLoc(); |
441 | |
442 | auto *condBlock = rewriter.getInsertionBlock(); |
443 | auto opPosition = rewriter.getInsertionPoint(); |
444 | auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition); |
445 | |
446 | auto ®ion = op.getRegion(); |
447 | rewriter.setInsertionPointToEnd(condBlock); |
448 | rewriter.create<cf::BranchOp>(loc, ®ion.front()); |
449 | |
450 | for (Block &block : region) { |
451 | if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { |
452 | ValueRange terminatorOperands = terminator->getOperands(); |
453 | rewriter.setInsertionPointToEnd(&block); |
454 | rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands); |
455 | rewriter.eraseOp(terminator); |
456 | } |
457 | } |
458 | |
459 | rewriter.inlineRegionBefore(region, remainingOpsBlock); |
460 | |
461 | SmallVector<Value> vals; |
462 | SmallVector<Location> argLocs(op.getNumResults(), op->getLoc()); |
463 | for (auto arg : |
464 | remainingOpsBlock->addArguments(op->getResultTypes(), argLocs)) |
465 | vals.push_back(arg); |
466 | rewriter.replaceOp(op, vals); |
467 | return success(); |
468 | } |
469 | |
470 | LogicalResult |
471 | ParallelLowering::matchAndRewrite(ParallelOp parallelOp, |
472 | PatternRewriter &rewriter) const { |
473 | Location loc = parallelOp.getLoc(); |
474 | auto reductionOp = cast<ReduceOp>(parallelOp.getBody()->getTerminator()); |
475 | |
476 | // For a parallel loop, we essentially need to create an n-dimensional loop |
477 | // nest. We do this by translating to scf.for ops and have those lowered in |
478 | // a further rewrite. If a parallel loop contains reductions (and thus returns |
479 | // values), forward the initial values for the reductions down the loop |
480 | // hierarchy and bubble up the results by modifying the "yield" terminator. |
481 | SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals()); |
482 | SmallVector<Value, 4> ivs; |
483 | ivs.reserve(N: parallelOp.getNumLoops()); |
484 | bool first = true; |
485 | SmallVector<Value, 4> loopResults(iterArgs); |
486 | for (auto [iv, lower, upper, step] : |
487 | llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), |
488 | parallelOp.getUpperBound(), parallelOp.getStep())) { |
489 | ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); |
490 | ivs.push_back(forOp.getInductionVar()); |
491 | auto iterRange = forOp.getRegionIterArgs(); |
492 | iterArgs.assign(iterRange.begin(), iterRange.end()); |
493 | |
494 | if (first) { |
495 | // Store the results of the outermost loop that will be used to replace |
496 | // the results of the parallel loop when it is fully rewritten. |
497 | loopResults.assign(forOp.result_begin(), forOp.result_end()); |
498 | first = false; |
499 | } else if (!forOp.getResults().empty()) { |
500 | // A loop is constructed with an empty "yield" terminator if there are |
501 | // no results. |
502 | rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); |
503 | rewriter.create<scf::YieldOp>(loc, forOp.getResults()); |
504 | } |
505 | |
506 | rewriter.setInsertionPointToStart(forOp.getBody()); |
507 | } |
508 | |
509 | // First, merge reduction blocks into the main region. |
510 | SmallVector<Value> yieldOperands; |
511 | yieldOperands.reserve(N: parallelOp.getNumResults()); |
512 | for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) { |
513 | Block &reductionBody = reductionOp.getReductions()[i].front(); |
514 | Value arg = iterArgs[yieldOperands.size()]; |
515 | yieldOperands.push_back( |
516 | cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult()); |
517 | rewriter.eraseOp(op: reductionBody.getTerminator()); |
518 | rewriter.inlineBlockBefore(&reductionBody, reductionOp, |
519 | {arg, reductionOp.getOperands()[i]}); |
520 | } |
521 | rewriter.eraseOp(op: reductionOp); |
522 | |
523 | // Then merge the loop body without the terminator. |
524 | Block *newBody = rewriter.getInsertionBlock(); |
525 | if (newBody->empty()) |
526 | rewriter.mergeBlocks(source: parallelOp.getBody(), dest: newBody, argValues: ivs); |
527 | else |
528 | rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(), |
529 | ivs); |
530 | |
531 | // Finally, create the terminator if required (for loops with no results, it |
532 | // has been already created in loop construction). |
533 | if (!yieldOperands.empty()) { |
534 | rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); |
535 | rewriter.create<scf::YieldOp>(loc, yieldOperands); |
536 | } |
537 | |
538 | rewriter.replaceOp(parallelOp, loopResults); |
539 | |
540 | return success(); |
541 | } |
542 | |
543 | LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, |
544 | PatternRewriter &rewriter) const { |
545 | OpBuilder::InsertionGuard guard(rewriter); |
546 | Location loc = whileOp.getLoc(); |
547 | |
548 | // Split the current block before the WhileOp to create the inlining point. |
549 | Block *currentBlock = rewriter.getInsertionBlock(); |
550 | Block *continuation = |
551 | rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint()); |
552 | |
553 | // Inline both regions. |
554 | Block *after = whileOp.getAfterBody(); |
555 | Block *before = whileOp.getBeforeBody(); |
556 | rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); |
557 | rewriter.inlineRegionBefore(whileOp.getBefore(), after); |
558 | |
559 | // Branch to the "before" region. |
560 | rewriter.setInsertionPointToEnd(currentBlock); |
561 | rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits()); |
562 | |
563 | // Replace terminators with branches. Assuming bodies are SESE, which holds |
564 | // given only the patterns from this file, we only need to look at the last |
565 | // block. This should be reconsidered if we allow break/continue in SCF. |
566 | rewriter.setInsertionPointToEnd(before); |
567 | auto condOp = cast<ConditionOp>(before->getTerminator()); |
568 | rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), |
569 | after, condOp.getArgs(), |
570 | continuation, ValueRange()); |
571 | |
572 | rewriter.setInsertionPointToEnd(after); |
573 | auto yieldOp = cast<scf::YieldOp>(after->getTerminator()); |
574 | rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, |
575 | yieldOp.getResults()); |
576 | |
577 | // Replace the op with values "yielded" from the "before" region, which are |
578 | // visible by dominance. |
579 | rewriter.replaceOp(whileOp, condOp.getArgs()); |
580 | |
581 | return success(); |
582 | } |
583 | |
584 | LogicalResult |
585 | DoWhileLowering::matchAndRewrite(WhileOp whileOp, |
586 | PatternRewriter &rewriter) const { |
587 | Block &afterBlock = *whileOp.getAfterBody(); |
588 | if (!llvm::hasSingleElement(C&: afterBlock)) |
589 | return rewriter.notifyMatchFailure(whileOp, |
590 | "do-while simplification applicable " |
591 | "only if 'after' region has no payload" ); |
592 | |
593 | auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front()); |
594 | if (!yield || yield.getResults() != afterBlock.getArguments()) |
595 | return rewriter.notifyMatchFailure(whileOp, |
596 | "do-while simplification applicable " |
597 | "only to forwarding 'after' regions" ); |
598 | |
599 | // Split the current block before the WhileOp to create the inlining point. |
600 | OpBuilder::InsertionGuard guard(rewriter); |
601 | Block *currentBlock = rewriter.getInsertionBlock(); |
602 | Block *continuation = |
603 | rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint()); |
604 | |
605 | // Only the "before" region should be inlined. |
606 | Block *before = whileOp.getBeforeBody(); |
607 | rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); |
608 | |
609 | // Branch to the "before" region. |
610 | rewriter.setInsertionPointToEnd(currentBlock); |
611 | rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); |
612 | |
613 | // Loop around the "before" region based on condition. |
614 | rewriter.setInsertionPointToEnd(before); |
615 | auto condOp = cast<ConditionOp>(before->getTerminator()); |
616 | rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), |
617 | before, condOp.getArgs(), |
618 | continuation, ValueRange()); |
619 | |
620 | // Replace the op with values "yielded" from the "before" region, which are |
621 | // visible by dominance. |
622 | rewriter.replaceOp(whileOp, condOp.getArgs()); |
623 | |
624 | return success(); |
625 | } |
626 | |
627 | LogicalResult |
628 | IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, |
629 | PatternRewriter &rewriter) const { |
630 | // Split the block at the op. |
631 | Block *condBlock = rewriter.getInsertionBlock(); |
632 | Block *continueBlock = rewriter.splitBlock(block: condBlock, before: Block::iterator(op)); |
633 | |
634 | // Create the arguments on the continue block with which to replace the |
635 | // results of the op. |
636 | SmallVector<Value> results; |
637 | results.reserve(N: op.getNumResults()); |
638 | for (Type resultType : op.getResultTypes()) |
639 | results.push_back(continueBlock->addArgument(resultType, op.getLoc())); |
640 | |
641 | // Handle the regions. |
642 | auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> { |
643 | Block *block = ®ion.front(); |
644 | |
645 | // Convert the yield terminator to a branch to the continue block. |
646 | auto yield = cast<scf::YieldOp>(block->getTerminator()); |
647 | rewriter.setInsertionPoint(yield); |
648 | rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock, |
649 | yield.getOperands()); |
650 | |
651 | // Inline the region. |
652 | rewriter.inlineRegionBefore(region, before: continueBlock); |
653 | return block; |
654 | }; |
655 | |
656 | // Convert the case regions. |
657 | SmallVector<Block *> caseSuccessors; |
658 | SmallVector<int32_t> caseValues; |
659 | caseSuccessors.reserve(N: op.getCases().size()); |
660 | caseValues.reserve(N: op.getCases().size()); |
661 | for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) { |
662 | FailureOr<Block *> block = convertRegion(region); |
663 | if (failed(block)) |
664 | return failure(); |
665 | caseSuccessors.push_back(*block); |
666 | caseValues.push_back(value); |
667 | } |
668 | |
669 | // Convert the default region. |
670 | FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion()); |
671 | if (failed(result: defaultBlock)) |
672 | return failure(); |
673 | |
674 | // Create the switch. |
675 | rewriter.setInsertionPointToEnd(condBlock); |
676 | SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {}); |
677 | |
678 | // Cast switch index to integer case value. |
679 | Value caseValue = rewriter.create<arith::IndexCastOp>( |
680 | op.getLoc(), rewriter.getI32Type(), op.getArg()); |
681 | |
682 | rewriter.create<cf::SwitchOp>( |
683 | op.getLoc(), caseValue, *defaultBlock, ValueRange(), |
684 | rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); |
685 | rewriter.replaceOp(op, continueBlock->getArguments()); |
686 | return success(); |
687 | } |
688 | |
689 | LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, |
690 | PatternRewriter &rewriter) const { |
691 | Location loc = forallOp.getLoc(); |
692 | if (!forallOp.getOutputs().empty()) |
693 | return rewriter.notifyMatchFailure( |
694 | forallOp, |
695 | "only fully bufferized scf.forall ops can be lowered to scf.parallel" ); |
696 | |
697 | // Convert mixed bounds and steps to SSA values. |
698 | SmallVector<Value> lbs = getValueOrCreateConstantIndexOp( |
699 | rewriter, loc, forallOp.getMixedLowerBound()); |
700 | SmallVector<Value> ubs = getValueOrCreateConstantIndexOp( |
701 | rewriter, loc, forallOp.getMixedUpperBound()); |
702 | SmallVector<Value> steps = |
703 | getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); |
704 | |
705 | // Create empty scf.parallel op. |
706 | auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps); |
707 | rewriter.eraseBlock(block: ¶llelOp.getRegion().front()); |
708 | rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), |
709 | parallelOp.getRegion().begin()); |
710 | // Replace the terminator. |
711 | rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); |
712 | rewriter.replaceOpWithNewOp<scf::ReduceOp>( |
713 | parallelOp.getRegion().front().getTerminator()); |
714 | |
715 | // Erase the scf.forall op. |
716 | rewriter.replaceOp(forallOp, parallelOp); |
717 | return success(); |
718 | } |
719 | |
720 | void mlir::populateSCFToControlFlowConversionPatterns( |
721 | RewritePatternSet &patterns) { |
722 | patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering, |
723 | WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>( |
724 | arg: patterns.getContext()); |
725 | patterns.add<DoWhileLowering>(arg: patterns.getContext(), /*benefit=*/args: 2); |
726 | } |
727 | |
728 | void SCFToControlFlowPass::runOnOperation() { |
729 | RewritePatternSet patterns(&getContext()); |
730 | populateSCFToControlFlowConversionPatterns(patterns); |
731 | |
732 | // Configure conversion to lower out SCF operations. |
733 | ConversionTarget target(getContext()); |
734 | target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp, |
735 | scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>(); |
736 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; }); |
737 | if (failed( |
738 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
739 | signalPassFailure(); |
740 | } |
741 | |
742 | std::unique_ptr<Pass> mlir::createConvertSCFToCFPass() { |
743 | return std::make_unique<SCFToControlFlowPass>(); |
744 | } |
745 | |