1//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.for, scf.if and loop.terminator
10// ops into standard CFG ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Transforms/DialectConversion.h"
25#include "mlir/Transforms/Passes.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::scf;
34
35namespace {
36
37struct SCFToControlFlowPass
38 : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
39 void runOnOperation() override;
40};
41
42// Create a CFG subgraph for the loop around its body blocks (if the body
43// contained other loops, they have been already lowered to a flow of blocks).
44// Maintain the invariants that a CFG subgraph created for any loop has a single
45// entry and a single exit, and that the entry/exit blocks are respectively
46// first/last blocks in the parent region. The original loop operation is
47// replaced by the initialization operations that set up the initial value of
48// the loop induction variable (%iv) and computes the loop bounds that are loop-
49// invariant for affine loops. The operations following the original scf.for
50// are split out into a separate continuation (exit) block. A condition block is
51// created before the continuation block. It checks the exit condition of the
52// loop and branches either to the continuation block, or to the first block of
53// the body. The condition block takes as arguments the values of the induction
54// variable followed by loop-carried values. Since it dominates both the body
55// blocks and the continuation block, loop-carried values are visible in all of
56// those blocks. Induction variable modification is appended to the last block
57// of the body (which is the exit block from the body subgraph thanks to the
58// invariant we maintain) along with a branch that loops back to the condition
59// block. Loop-carried values are the loop terminator operands, which are
60// forwarded to the branch.
61//
62// +---------------------------------+
63// | <code before the ForOp> |
64// | <definitions of %init...> |
65// | <compute initial %iv value> |
66// | cf.br cond(%iv, %init...) |
67// +---------------------------------+
68// |
69// -------| |
70// | v v
71// | +--------------------------------+
72// | | cond(%iv, %init...): |
73// | | <compare %iv to upper bound> |
74// | | cf.cond_br %r, body, end |
75// | +--------------------------------+
76// | | |
77// | | -------------|
78// | v |
79// | +--------------------------------+ |
80// | | body-first: | |
81// | | <%init visible by dominance> | |
82// | | <body contents> | |
83// | +--------------------------------+ |
84// | | |
85// | ... |
86// | | |
87// | +--------------------------------+ |
88// | | body-last: | |
89// | | <body contents> | |
90// | | <operands of yield = %yields>| |
91// | | %new_iv =<add step to %iv> | |
92// | | cf.br cond(%new_iv, %yields) | |
93// | +--------------------------------+ |
94// | | |
95// |----------- |--------------------
96// v
97// +--------------------------------+
98// | end: |
99// | <code after the ForOp> |
100// | <%init visible by dominance> |
101// +--------------------------------+
102//
103struct ForLowering : public OpRewritePattern<ForOp> {
104 using OpRewritePattern<ForOp>::OpRewritePattern;
105
106 LogicalResult matchAndRewrite(ForOp forOp,
107 PatternRewriter &rewriter) const override;
108};
109
110// Create a CFG subgraph for the scf.if operation (including its "then" and
111// optional "else" operation blocks). We maintain the invariants that the
112// subgraph has a single entry and a single exit point, and that the entry/exit
113// blocks are respectively the first/last block of the enclosing region. The
114// operations following the scf.if are split into a continuation (subgraph
115// exit) block. The condition is lowered to a chain of blocks that implement the
116// short-circuit scheme. The "scf.if" operation is replaced with a conditional
117// branch to either the first block of the "then" region, or to the first block
118// of the "else" region. In these blocks, "scf.yield" is unconditional branches
119// to the post-dominating block. When the "scf.if" does not return values, the
120// post-dominating block is the same as the continuation block. When it returns
121// values, the post-dominating block is a new block with arguments that
122// correspond to the values returned by the "scf.if" that unconditionally
123// branches to the continuation block. This allows block arguments to dominate
124// any uses of the hitherto "scf.if" results that they replaced. (Inserting a
125// new block allows us to avoid modifying the argument list of an existing
126// block, which is illegal in a conversion pattern). When the "else" region is
127// empty, which is only allowed for "scf.if"s that don't return values, the
128// condition branches directly to the continuation block.
129//
130// CFG for a scf.if with else and without results.
131//
132// +--------------------------------+
133// | <code before the IfOp> |
134// | cf.cond_br %cond, %then, %else |
135// +--------------------------------+
136// | |
137// | --------------|
138// v |
139// +--------------------------------+ |
140// | then: | |
141// | <then contents> | |
142// | cf.br continue | |
143// +--------------------------------+ |
144// | |
145// |---------- |-------------
146// | V
147// | +--------------------------------+
148// | | else: |
149// | | <else contents> |
150// | | cf.br continue |
151// | +--------------------------------+
152// | |
153// ------| |
154// v v
155// +--------------------------------+
156// | continue: |
157// | <code after the IfOp> |
158// +--------------------------------+
159//
160// CFG for a scf.if with results.
161//
162// +--------------------------------+
163// | <code before the IfOp> |
164// | cf.cond_br %cond, %then, %else |
165// +--------------------------------+
166// | |
167// | --------------|
168// v |
169// +--------------------------------+ |
170// | then: | |
171// | <then contents> | |
172// | cf.br dom(%args...) | |
173// +--------------------------------+ |
174// | |
175// |---------- |-------------
176// | V
177// | +--------------------------------+
178// | | else: |
179// | | <else contents> |
180// | | cf.br dom(%args...) |
181// | +--------------------------------+
182// | |
183// ------| |
184// v v
185// +--------------------------------+
186// | dom(%args...): |
187// | cf.br continue |
188// +--------------------------------+
189// |
190// v
191// +--------------------------------+
192// | continue: |
193// | <code after the IfOp> |
194// +--------------------------------+
195//
196struct IfLowering : public OpRewritePattern<IfOp> {
197 using OpRewritePattern<IfOp>::OpRewritePattern;
198
199 LogicalResult matchAndRewrite(IfOp ifOp,
200 PatternRewriter &rewriter) const override;
201};
202
203struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
204 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
205
206 LogicalResult matchAndRewrite(ExecuteRegionOp op,
207 PatternRewriter &rewriter) const override;
208};
209
210struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
211 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
212
213 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
214 PatternRewriter &rewriter) const override;
215};
216
217/// Create a CFG subgraph for this loop construct. The regions of the loop need
218/// not be a single block anymore (for example, if other SCF constructs that
219/// they contain have been already converted to CFG), but need to be single-exit
220/// from the last block of each region. The operations following the original
221/// WhileOp are split into a new continuation block. Both regions of the WhileOp
222/// are inlined, and their terminators are rewritten to organize the control
223/// flow implementing the loop as follows.
224///
225/// +---------------------------------+
226/// | <code before the WhileOp> |
227/// | cf.br ^before(%operands...) |
228/// +---------------------------------+
229/// |
230/// -------| |
231/// | v v
232/// | +--------------------------------+
233/// | | ^before(%bargs...): |
234/// | | %vals... = <some payload> |
235/// | +--------------------------------+
236/// | |
237/// | ...
238/// | |
239/// | +--------------------------------+
240/// | | ^before-last:
241/// | | %cond = <compute condition> |
242/// | | cf.cond_br %cond, |
243/// | | ^after(%vals...), ^cont |
244/// | +--------------------------------+
245/// | | |
246/// | | -------------|
247/// | v |
248/// | +--------------------------------+ |
249/// | | ^after(%aargs...): | |
250/// | | <body contents> | |
251/// | +--------------------------------+ |
252/// | | |
253/// | ... |
254/// | | |
255/// | +--------------------------------+ |
256/// | | ^after-last: | |
257/// | | %yields... = <some payload> | |
258/// | | cf.br ^before(%yields...) | |
259/// | +--------------------------------+ |
260/// | | |
261/// |----------- |--------------------
262/// v
263/// +--------------------------------+
264/// | ^cont: |
265/// | <code after the WhileOp> |
266/// | <%vals from 'before' region |
267/// | visible by dominance> |
268/// +--------------------------------+
269///
270/// Values are communicated between ex-regions (the groups of blocks that used
271/// to form a region before inlining) through block arguments of their
272/// entry blocks, which are visible in all other dominated blocks. Similarly,
273/// the results of the WhileOp are defined in the 'before' region, which is
274/// required to have a single existing block, and are therefore accessible in
275/// the continuation block due to dominance.
276struct WhileLowering : public OpRewritePattern<WhileOp> {
277 using OpRewritePattern<WhileOp>::OpRewritePattern;
278
279 LogicalResult matchAndRewrite(WhileOp whileOp,
280 PatternRewriter &rewriter) const override;
281};
282
283/// Optimized version of the above for the case of the "after" region merely
284/// forwarding its arguments back to the "before" region (i.e., a "do-while"
285/// loop). This avoid inlining the "after" region completely and branches back
286/// to the "before" entry instead.
287struct DoWhileLowering : public OpRewritePattern<WhileOp> {
288 using OpRewritePattern<WhileOp>::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(WhileOp whileOp,
291 PatternRewriter &rewriter) const override;
292};
293
294/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
295struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
296 using OpRewritePattern::OpRewritePattern;
297
298 LogicalResult matchAndRewrite(IndexSwitchOp op,
299 PatternRewriter &rewriter) const override;
300};
301
302/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
303/// has no shared outputs. Ops with shared outputs should be bufferized first.
304/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
305/// dialects/passes.
306struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
307 using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
308
309 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
310 PatternRewriter &rewriter) const override;
311};
312
313} // namespace
314
315LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
316 PatternRewriter &rewriter) const {
317 Location loc = forOp.getLoc();
318
319 // Start by splitting the block containing the 'scf.for' into two parts.
320 // The part before will get the init code, the part after will be the end
321 // point.
322 auto *initBlock = rewriter.getInsertionBlock();
323 auto initPosition = rewriter.getInsertionPoint();
324 auto *endBlock = rewriter.splitBlock(block: initBlock, before: initPosition);
325
326 // Use the first block of the loop body as the condition block since it is the
327 // block that has the induction variable and loop-carried values as arguments.
328 // Split out all operations from the first block into a new block. Move all
329 // body blocks from the loop body region to the region containing the loop.
330 auto *conditionBlock = &forOp.getRegion().front();
331 auto *firstBodyBlock =
332 rewriter.splitBlock(block: conditionBlock, before: conditionBlock->begin());
333 auto *lastBodyBlock = &forOp.getRegion().back();
334 rewriter.inlineRegionBefore(region&: forOp.getRegion(), before: endBlock);
335 auto iv = conditionBlock->getArgument(i: 0);
336
337 // Append the induction variable stepping logic to the last body block and
338 // branch back to the condition block. Loop-carried values are taken from
339 // operands of the loop terminator.
340 Operation *terminator = lastBodyBlock->getTerminator();
341 rewriter.setInsertionPointToEnd(lastBodyBlock);
342 auto step = forOp.getStep();
343 auto stepped = rewriter.create<arith::AddIOp>(location: loc, args&: iv, args&: step).getResult();
344 if (!stepped)
345 return failure();
346
347 SmallVector<Value, 8> loopCarried;
348 loopCarried.push_back(Elt: stepped);
349 loopCarried.append(in_start: terminator->operand_begin(), in_end: terminator->operand_end());
350 auto branchOp =
351 rewriter.create<cf::BranchOp>(location: loc, args&: conditionBlock, args&: loopCarried);
352
353 // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
354 // llvm.loop_annotation attribute.
355 // LLVM requires the loop metadata to be attached on the "latch" block. Which
356 // is the back-edge to the header block (conditionBlock)
357 SmallVector<NamedAttribute> llvmAttrs;
358 llvm::copy_if(Range: forOp->getAttrs(), Out: std::back_inserter(x&: llvmAttrs),
359 P: [](auto attr) {
360 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
361 });
362 branchOp->setDiscardableAttrs(llvmAttrs);
363
364 rewriter.eraseOp(op: terminator);
365
366 // Compute loop bounds before branching to the condition.
367 rewriter.setInsertionPointToEnd(initBlock);
368 Value lowerBound = forOp.getLowerBound();
369 Value upperBound = forOp.getUpperBound();
370 if (!lowerBound || !upperBound)
371 return failure();
372
373 // The initial values of loop-carried values is obtained from the operands
374 // of the loop operation.
375 SmallVector<Value, 8> destOperands;
376 destOperands.push_back(Elt: lowerBound);
377 llvm::append_range(C&: destOperands, R: forOp.getInitArgs());
378 rewriter.create<cf::BranchOp>(location: loc, args&: conditionBlock, args&: destOperands);
379
380 // With the body block done, we can fill in the condition block.
381 rewriter.setInsertionPointToEnd(conditionBlock);
382 auto comparison = rewriter.create<arith::CmpIOp>(
383 location: loc, args: arith::CmpIPredicate::slt, args&: iv, args&: upperBound);
384
385 rewriter.create<cf::CondBranchOp>(location: loc, args&: comparison, args&: firstBodyBlock,
386 args: ArrayRef<Value>(), args&: endBlock,
387 args: ArrayRef<Value>());
388
389 // The result of the loop operation is the values of the condition block
390 // arguments except the induction variable on the last iteration.
391 rewriter.replaceOp(op: forOp, newValues: conditionBlock->getArguments().drop_front());
392 return success();
393}
394
395LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
396 PatternRewriter &rewriter) const {
397 auto loc = ifOp.getLoc();
398
399 // Start by splitting the block containing the 'scf.if' into two parts.
400 // The part before will contain the condition, the part after will be the
401 // continuation point.
402 auto *condBlock = rewriter.getInsertionBlock();
403 auto opPosition = rewriter.getInsertionPoint();
404 auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition);
405 Block *continueBlock;
406 if (ifOp.getNumResults() == 0) {
407 continueBlock = remainingOpsBlock;
408 } else {
409 continueBlock =
410 rewriter.createBlock(insertBefore: remainingOpsBlock, argTypes: ifOp.getResultTypes(),
411 locs: SmallVector<Location>(ifOp.getNumResults(), loc));
412 rewriter.create<cf::BranchOp>(location: loc, args&: remainingOpsBlock);
413 }
414
415 // Move blocks from the "then" region to the region containing 'scf.if',
416 // place it before the continuation block, and branch to it.
417 auto &thenRegion = ifOp.getThenRegion();
418 auto *thenBlock = &thenRegion.front();
419 Operation *thenTerminator = thenRegion.back().getTerminator();
420 ValueRange thenTerminatorOperands = thenTerminator->getOperands();
421 rewriter.setInsertionPointToEnd(&thenRegion.back());
422 rewriter.create<cf::BranchOp>(location: loc, args&: continueBlock, args&: thenTerminatorOperands);
423 rewriter.eraseOp(op: thenTerminator);
424 rewriter.inlineRegionBefore(region&: thenRegion, before: continueBlock);
425
426 // Move blocks from the "else" region (if present) to the region containing
427 // 'scf.if', place it before the continuation block and branch to it. It
428 // will be placed after the "then" regions.
429 auto *elseBlock = continueBlock;
430 auto &elseRegion = ifOp.getElseRegion();
431 if (!elseRegion.empty()) {
432 elseBlock = &elseRegion.front();
433 Operation *elseTerminator = elseRegion.back().getTerminator();
434 ValueRange elseTerminatorOperands = elseTerminator->getOperands();
435 rewriter.setInsertionPointToEnd(&elseRegion.back());
436 rewriter.create<cf::BranchOp>(location: loc, args&: continueBlock, args&: elseTerminatorOperands);
437 rewriter.eraseOp(op: elseTerminator);
438 rewriter.inlineRegionBefore(region&: elseRegion, before: continueBlock);
439 }
440
441 rewriter.setInsertionPointToEnd(condBlock);
442 rewriter.create<cf::CondBranchOp>(location: loc, args: ifOp.getCondition(), args&: thenBlock,
443 /*trueArgs=*/args: ArrayRef<Value>(), args&: elseBlock,
444 /*falseArgs=*/args: ArrayRef<Value>());
445
446 // Ok, we're done!
447 rewriter.replaceOp(op: ifOp, newValues: continueBlock->getArguments());
448 return success();
449}
450
451LogicalResult
452ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
453 PatternRewriter &rewriter) const {
454 auto loc = op.getLoc();
455
456 auto *condBlock = rewriter.getInsertionBlock();
457 auto opPosition = rewriter.getInsertionPoint();
458 auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition);
459
460 auto &region = op.getRegion();
461 rewriter.setInsertionPointToEnd(condBlock);
462 rewriter.create<cf::BranchOp>(location: loc, args: &region.front());
463
464 for (Block &block : region) {
465 if (auto terminator = dyn_cast<scf::YieldOp>(Val: block.getTerminator())) {
466 ValueRange terminatorOperands = terminator->getOperands();
467 rewriter.setInsertionPointToEnd(&block);
468 rewriter.create<cf::BranchOp>(location: loc, args&: remainingOpsBlock, args&: terminatorOperands);
469 rewriter.eraseOp(op: terminator);
470 }
471 }
472
473 rewriter.inlineRegionBefore(region, before: remainingOpsBlock);
474
475 SmallVector<Value> vals;
476 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
477 for (auto arg :
478 remainingOpsBlock->addArguments(types: op->getResultTypes(), locs: argLocs))
479 vals.push_back(Elt: arg);
480 rewriter.replaceOp(op, newValues: vals);
481 return success();
482}
483
484LogicalResult
485ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
486 PatternRewriter &rewriter) const {
487 Location loc = parallelOp.getLoc();
488 auto reductionOp = dyn_cast<ReduceOp>(Val: parallelOp.getBody()->getTerminator());
489 if (!reductionOp) {
490 return failure();
491 }
492
493 // For a parallel loop, we essentially need to create an n-dimensional loop
494 // nest. We do this by translating to scf.for ops and have those lowered in
495 // a further rewrite. If a parallel loop contains reductions (and thus returns
496 // values), forward the initial values for the reductions down the loop
497 // hierarchy and bubble up the results by modifying the "yield" terminator.
498 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(Range: parallelOp.getInitVals());
499 SmallVector<Value, 4> ivs;
500 ivs.reserve(N: parallelOp.getNumLoops());
501 bool first = true;
502 SmallVector<Value, 4> loopResults(iterArgs);
503 for (auto [iv, lower, upper, step] :
504 llvm::zip(t: parallelOp.getInductionVars(), u: parallelOp.getLowerBound(),
505 args: parallelOp.getUpperBound(), args: parallelOp.getStep())) {
506 ForOp forOp = rewriter.create<ForOp>(location: loc, args&: lower, args&: upper, args&: step, args&: iterArgs);
507 ivs.push_back(Elt: forOp.getInductionVar());
508 auto iterRange = forOp.getRegionIterArgs();
509 iterArgs.assign(in_start: iterRange.begin(), in_end: iterRange.end());
510
511 if (first) {
512 // Store the results of the outermost loop that will be used to replace
513 // the results of the parallel loop when it is fully rewritten.
514 loopResults.assign(in_start: forOp.result_begin(), in_end: forOp.result_end());
515 first = false;
516 } else if (!forOp.getResults().empty()) {
517 // A loop is constructed with an empty "yield" terminator if there are
518 // no results.
519 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
520 rewriter.create<scf::YieldOp>(location: loc, args: forOp.getResults());
521 }
522
523 rewriter.setInsertionPointToStart(forOp.getBody());
524 }
525
526 // First, merge reduction blocks into the main region.
527 SmallVector<Value> yieldOperands;
528 yieldOperands.reserve(N: parallelOp.getNumResults());
529 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
530 Block &reductionBody = reductionOp.getReductions()[i].front();
531 Value arg = iterArgs[yieldOperands.size()];
532 yieldOperands.push_back(
533 Elt: cast<ReduceReturnOp>(Val: reductionBody.getTerminator()).getResult());
534 rewriter.eraseOp(op: reductionBody.getTerminator());
535 rewriter.inlineBlockBefore(source: &reductionBody, op: reductionOp,
536 argValues: {arg, reductionOp.getOperands()[i]});
537 }
538 rewriter.eraseOp(op: reductionOp);
539
540 // Then merge the loop body without the terminator.
541 Block *newBody = rewriter.getInsertionBlock();
542 if (newBody->empty())
543 rewriter.mergeBlocks(source: parallelOp.getBody(), dest: newBody, argValues: ivs);
544 else
545 rewriter.inlineBlockBefore(source: parallelOp.getBody(), op: newBody->getTerminator(),
546 argValues: ivs);
547
548 // Finally, create the terminator if required (for loops with no results, it
549 // has been already created in loop construction).
550 if (!yieldOperands.empty()) {
551 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
552 rewriter.create<scf::YieldOp>(location: loc, args&: yieldOperands);
553 }
554
555 rewriter.replaceOp(op: parallelOp, newValues: loopResults);
556
557 return success();
558}
559
560LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
561 PatternRewriter &rewriter) const {
562 OpBuilder::InsertionGuard guard(rewriter);
563 Location loc = whileOp.getLoc();
564
565 // Split the current block before the WhileOp to create the inlining point.
566 Block *currentBlock = rewriter.getInsertionBlock();
567 Block *continuation =
568 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
569
570 // Inline both regions.
571 Block *after = whileOp.getAfterBody();
572 Block *before = whileOp.getBeforeBody();
573 rewriter.inlineRegionBefore(region&: whileOp.getAfter(), before: continuation);
574 rewriter.inlineRegionBefore(region&: whileOp.getBefore(), before: after);
575
576 // Branch to the "before" region.
577 rewriter.setInsertionPointToEnd(currentBlock);
578 rewriter.create<cf::BranchOp>(location: loc, args&: before, args: whileOp.getInits());
579
580 // Replace terminators with branches. Assuming bodies are SESE, which holds
581 // given only the patterns from this file, we only need to look at the last
582 // block. This should be reconsidered if we allow break/continue in SCF.
583 rewriter.setInsertionPointToEnd(before);
584 auto condOp = cast<ConditionOp>(Val: before->getTerminator());
585 SmallVector<Value> args = llvm::to_vector(Range: condOp.getArgs());
586 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(op: condOp, args: condOp.getCondition(),
587 args&: after, args: condOp.getArgs(),
588 args&: continuation, args: ValueRange());
589
590 rewriter.setInsertionPointToEnd(after);
591 auto yieldOp = cast<scf::YieldOp>(Val: after->getTerminator());
592 rewriter.replaceOpWithNewOp<cf::BranchOp>(op: yieldOp, args&: before,
593 args: yieldOp.getResults());
594
595 // Replace the op with values "yielded" from the "before" region, which are
596 // visible by dominance.
597 rewriter.replaceOp(op: whileOp, newValues: args);
598
599 return success();
600}
601
602LogicalResult
603DoWhileLowering::matchAndRewrite(WhileOp whileOp,
604 PatternRewriter &rewriter) const {
605 Block &afterBlock = *whileOp.getAfterBody();
606 if (!llvm::hasSingleElement(C&: afterBlock))
607 return rewriter.notifyMatchFailure(arg&: whileOp,
608 msg: "do-while simplification applicable "
609 "only if 'after' region has no payload");
610
611 auto yield = dyn_cast<scf::YieldOp>(Val: &afterBlock.front());
612 if (!yield || yield.getResults() != afterBlock.getArguments())
613 return rewriter.notifyMatchFailure(arg&: whileOp,
614 msg: "do-while simplification applicable "
615 "only to forwarding 'after' regions");
616
617 // Split the current block before the WhileOp to create the inlining point.
618 OpBuilder::InsertionGuard guard(rewriter);
619 Block *currentBlock = rewriter.getInsertionBlock();
620 Block *continuation =
621 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
622
623 // Only the "before" region should be inlined.
624 Block *before = whileOp.getBeforeBody();
625 rewriter.inlineRegionBefore(region&: whileOp.getBefore(), before: continuation);
626
627 // Branch to the "before" region.
628 rewriter.setInsertionPointToEnd(currentBlock);
629 rewriter.create<cf::BranchOp>(location: whileOp.getLoc(), args&: before, args: whileOp.getInits());
630
631 // Loop around the "before" region based on condition.
632 rewriter.setInsertionPointToEnd(before);
633 auto condOp = cast<ConditionOp>(Val: before->getTerminator());
634 rewriter.create<cf::CondBranchOp>(location: condOp.getLoc(), args: condOp.getCondition(),
635 args&: before, args: condOp.getArgs(), args&: continuation,
636 args: ValueRange());
637
638 // Replace the op with values "yielded" from the "before" region, which are
639 // visible by dominance.
640 rewriter.replaceOp(op: whileOp, newValues: condOp.getArgs());
641
642 // Erase the condition op.
643 rewriter.eraseOp(op: condOp);
644 return success();
645}
646
647LogicalResult
648IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
649 PatternRewriter &rewriter) const {
650 // Split the block at the op.
651 Block *condBlock = rewriter.getInsertionBlock();
652 Block *continueBlock = rewriter.splitBlock(block: condBlock, before: Block::iterator(op));
653
654 // Create the arguments on the continue block with which to replace the
655 // results of the op.
656 SmallVector<Value> results;
657 results.reserve(N: op.getNumResults());
658 for (Type resultType : op.getResultTypes())
659 results.push_back(Elt: continueBlock->addArgument(type: resultType, loc: op.getLoc()));
660
661 // Handle the regions.
662 auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
663 Block *block = &region.front();
664
665 // Convert the yield terminator to a branch to the continue block.
666 auto yield = cast<scf::YieldOp>(Val: block->getTerminator());
667 rewriter.setInsertionPoint(yield);
668 rewriter.replaceOpWithNewOp<cf::BranchOp>(op: yield, args&: continueBlock,
669 args: yield.getOperands());
670
671 // Inline the region.
672 rewriter.inlineRegionBefore(region, before: continueBlock);
673 return block;
674 };
675
676 // Convert the case regions.
677 SmallVector<Block *> caseSuccessors;
678 SmallVector<int32_t> caseValues;
679 caseSuccessors.reserve(N: op.getCases().size());
680 caseValues.reserve(N: op.getCases().size());
681 for (auto [region, value] : llvm::zip(t: op.getCaseRegions(), u: op.getCases())) {
682 FailureOr<Block *> block = convertRegion(region);
683 if (failed(Result: block))
684 return failure();
685 caseSuccessors.push_back(Elt: *block);
686 caseValues.push_back(Elt: value);
687 }
688
689 // Convert the default region.
690 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
691 if (failed(Result: defaultBlock))
692 return failure();
693
694 // Create the switch.
695 rewriter.setInsertionPointToEnd(condBlock);
696 SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
697
698 // Cast switch index to integer case value.
699 Value caseValue = rewriter.create<arith::IndexCastOp>(
700 location: op.getLoc(), args: rewriter.getI32Type(), args: op.getArg());
701
702 rewriter.create<cf::SwitchOp>(
703 location: op.getLoc(), args&: caseValue, args&: *defaultBlock, args: ValueRange(),
704 args: rewriter.getDenseI32ArrayAttr(values: caseValues), args&: caseSuccessors, args&: caseOperands);
705 rewriter.replaceOp(op, newValues: continueBlock->getArguments());
706 return success();
707}
708
709LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
710 PatternRewriter &rewriter) const {
711 return scf::forallToParallelLoop(rewriter, forallOp);
712}
713
714void mlir::populateSCFToControlFlowConversionPatterns(
715 RewritePatternSet &patterns) {
716 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
717 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
718 arg: patterns.getContext());
719 patterns.add<DoWhileLowering>(arg: patterns.getContext(), /*benefit=*/args: 2);
720}
721
722void SCFToControlFlowPass::runOnOperation() {
723 RewritePatternSet patterns(&getContext());
724 populateSCFToControlFlowConversionPatterns(patterns);
725
726 // Configure conversion to lower out SCF operations.
727 ConversionTarget target(getContext());
728 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
729 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
730 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
731 if (failed(
732 Result: applyPartialConversion(op: getOperation(), target, patterns: std::move(patterns))))
733 signalPassFailure();
734}
735

source code of mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp