1//===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.for, scf.if and loop.terminator
10// ops into standard CFG ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/SCF/Transforms/Transforms.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/BuiltinOps.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/MLIRContext.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "mlir/Transforms/Passes.h"
28
29namespace mlir {
30#define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
31#include "mlir/Conversion/Passes.h.inc"
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::scf;
36
37namespace {
38
39struct SCFToControlFlowPass
40 : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> {
41 void runOnOperation() override;
42};
43
44// Create a CFG subgraph for the loop around its body blocks (if the body
45// contained other loops, they have been already lowered to a flow of blocks).
46// Maintain the invariants that a CFG subgraph created for any loop has a single
47// entry and a single exit, and that the entry/exit blocks are respectively
48// first/last blocks in the parent region. The original loop operation is
49// replaced by the initialization operations that set up the initial value of
50// the loop induction variable (%iv) and computes the loop bounds that are loop-
51// invariant for affine loops. The operations following the original scf.for
52// are split out into a separate continuation (exit) block. A condition block is
53// created before the continuation block. It checks the exit condition of the
54// loop and branches either to the continuation block, or to the first block of
55// the body. The condition block takes as arguments the values of the induction
56// variable followed by loop-carried values. Since it dominates both the body
57// blocks and the continuation block, loop-carried values are visible in all of
58// those blocks. Induction variable modification is appended to the last block
59// of the body (which is the exit block from the body subgraph thanks to the
60// invariant we maintain) along with a branch that loops back to the condition
61// block. Loop-carried values are the loop terminator operands, which are
62// forwarded to the branch.
63//
64// +---------------------------------+
65// | <code before the ForOp> |
66// | <definitions of %init...> |
67// | <compute initial %iv value> |
68// | cf.br cond(%iv, %init...) |
69// +---------------------------------+
70// |
71// -------| |
72// | v v
73// | +--------------------------------+
74// | | cond(%iv, %init...): |
75// | | <compare %iv to upper bound> |
76// | | cf.cond_br %r, body, end |
77// | +--------------------------------+
78// | | |
79// | | -------------|
80// | v |
81// | +--------------------------------+ |
82// | | body-first: | |
83// | | <%init visible by dominance> | |
84// | | <body contents> | |
85// | +--------------------------------+ |
86// | | |
87// | ... |
88// | | |
89// | +--------------------------------+ |
90// | | body-last: | |
91// | | <body contents> | |
92// | | <operands of yield = %yields>| |
93// | | %new_iv =<add step to %iv> | |
94// | | cf.br cond(%new_iv, %yields) | |
95// | +--------------------------------+ |
96// | | |
97// |----------- |--------------------
98// v
99// +--------------------------------+
100// | end: |
101// | <code after the ForOp> |
102// | <%init visible by dominance> |
103// +--------------------------------+
104//
105struct ForLowering : public OpRewritePattern<ForOp> {
106 using OpRewritePattern<ForOp>::OpRewritePattern;
107
108 LogicalResult matchAndRewrite(ForOp forOp,
109 PatternRewriter &rewriter) const override;
110};
111
112// Create a CFG subgraph for the scf.if operation (including its "then" and
113// optional "else" operation blocks). We maintain the invariants that the
114// subgraph has a single entry and a single exit point, and that the entry/exit
115// blocks are respectively the first/last block of the enclosing region. The
116// operations following the scf.if are split into a continuation (subgraph
117// exit) block. The condition is lowered to a chain of blocks that implement the
118// short-circuit scheme. The "scf.if" operation is replaced with a conditional
119// branch to either the first block of the "then" region, or to the first block
120// of the "else" region. In these blocks, "scf.yield" is unconditional branches
121// to the post-dominating block. When the "scf.if" does not return values, the
122// post-dominating block is the same as the continuation block. When it returns
123// values, the post-dominating block is a new block with arguments that
124// correspond to the values returned by the "scf.if" that unconditionally
125// branches to the continuation block. This allows block arguments to dominate
126// any uses of the hitherto "scf.if" results that they replaced. (Inserting a
127// new block allows us to avoid modifying the argument list of an existing
128// block, which is illegal in a conversion pattern). When the "else" region is
129// empty, which is only allowed for "scf.if"s that don't return values, the
130// condition branches directly to the continuation block.
131//
132// CFG for a scf.if with else and without results.
133//
134// +--------------------------------+
135// | <code before the IfOp> |
136// | cf.cond_br %cond, %then, %else |
137// +--------------------------------+
138// | |
139// | --------------|
140// v |
141// +--------------------------------+ |
142// | then: | |
143// | <then contents> | |
144// | cf.br continue | |
145// +--------------------------------+ |
146// | |
147// |---------- |-------------
148// | V
149// | +--------------------------------+
150// | | else: |
151// | | <else contents> |
152// | | cf.br continue |
153// | +--------------------------------+
154// | |
155// ------| |
156// v v
157// +--------------------------------+
158// | continue: |
159// | <code after the IfOp> |
160// +--------------------------------+
161//
162// CFG for a scf.if with results.
163//
164// +--------------------------------+
165// | <code before the IfOp> |
166// | cf.cond_br %cond, %then, %else |
167// +--------------------------------+
168// | |
169// | --------------|
170// v |
171// +--------------------------------+ |
172// | then: | |
173// | <then contents> | |
174// | cf.br dom(%args...) | |
175// +--------------------------------+ |
176// | |
177// |---------- |-------------
178// | V
179// | +--------------------------------+
180// | | else: |
181// | | <else contents> |
182// | | cf.br dom(%args...) |
183// | +--------------------------------+
184// | |
185// ------| |
186// v v
187// +--------------------------------+
188// | dom(%args...): |
189// | cf.br continue |
190// +--------------------------------+
191// |
192// v
193// +--------------------------------+
194// | continue: |
195// | <code after the IfOp> |
196// +--------------------------------+
197//
198struct IfLowering : public OpRewritePattern<IfOp> {
199 using OpRewritePattern<IfOp>::OpRewritePattern;
200
201 LogicalResult matchAndRewrite(IfOp ifOp,
202 PatternRewriter &rewriter) const override;
203};
204
205struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
206 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
207
208 LogicalResult matchAndRewrite(ExecuteRegionOp op,
209 PatternRewriter &rewriter) const override;
210};
211
212struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
213 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
214
215 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
216 PatternRewriter &rewriter) const override;
217};
218
219/// Create a CFG subgraph for this loop construct. The regions of the loop need
220/// not be a single block anymore (for example, if other SCF constructs that
221/// they contain have been already converted to CFG), but need to be single-exit
222/// from the last block of each region. The operations following the original
223/// WhileOp are split into a new continuation block. Both regions of the WhileOp
224/// are inlined, and their terminators are rewritten to organize the control
225/// flow implementing the loop as follows.
226///
227/// +---------------------------------+
228/// | <code before the WhileOp> |
229/// | cf.br ^before(%operands...) |
230/// +---------------------------------+
231/// |
232/// -------| |
233/// | v v
234/// | +--------------------------------+
235/// | | ^before(%bargs...): |
236/// | | %vals... = <some payload> |
237/// | +--------------------------------+
238/// | |
239/// | ...
240/// | |
241/// | +--------------------------------+
242/// | | ^before-last:
243/// | | %cond = <compute condition> |
244/// | | cf.cond_br %cond, |
245/// | | ^after(%vals...), ^cont |
246/// | +--------------------------------+
247/// | | |
248/// | | -------------|
249/// | v |
250/// | +--------------------------------+ |
251/// | | ^after(%aargs...): | |
252/// | | <body contents> | |
253/// | +--------------------------------+ |
254/// | | |
255/// | ... |
256/// | | |
257/// | +--------------------------------+ |
258/// | | ^after-last: | |
259/// | | %yields... = <some payload> | |
260/// | | cf.br ^before(%yields...) | |
261/// | +--------------------------------+ |
262/// | | |
263/// |----------- |--------------------
264/// v
265/// +--------------------------------+
266/// | ^cont: |
267/// | <code after the WhileOp> |
268/// | <%vals from 'before' region |
269/// | visible by dominance> |
270/// +--------------------------------+
271///
272/// Values are communicated between ex-regions (the groups of blocks that used
273/// to form a region before inlining) through block arguments of their
274/// entry blocks, which are visible in all other dominated blocks. Similarly,
275/// the results of the WhileOp are defined in the 'before' region, which is
276/// required to have a single existing block, and are therefore accessible in
277/// the continuation block due to dominance.
278struct WhileLowering : public OpRewritePattern<WhileOp> {
279 using OpRewritePattern<WhileOp>::OpRewritePattern;
280
281 LogicalResult matchAndRewrite(WhileOp whileOp,
282 PatternRewriter &rewriter) const override;
283};
284
285/// Optimized version of the above for the case of the "after" region merely
286/// forwarding its arguments back to the "before" region (i.e., a "do-while"
287/// loop). This avoid inlining the "after" region completely and branches back
288/// to the "before" entry instead.
289struct DoWhileLowering : public OpRewritePattern<WhileOp> {
290 using OpRewritePattern<WhileOp>::OpRewritePattern;
291
292 LogicalResult matchAndRewrite(WhileOp whileOp,
293 PatternRewriter &rewriter) const override;
294};
295
296/// Lower an `scf.index_switch` operation to a `cf.switch` operation.
297struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> {
298 using OpRewritePattern::OpRewritePattern;
299
300 LogicalResult matchAndRewrite(IndexSwitchOp op,
301 PatternRewriter &rewriter) const override;
302};
303
304/// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
305/// has no shared outputs. Ops with shared outputs should be bufferized first.
306/// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
307/// dialects/passes.
308struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
309 using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern;
310
311 LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp,
312 PatternRewriter &rewriter) const override;
313};
314
315} // namespace
316
317LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
318 PatternRewriter &rewriter) const {
319 Location loc = forOp.getLoc();
320
321 // Start by splitting the block containing the 'scf.for' into two parts.
322 // The part before will get the init code, the part after will be the end
323 // point.
324 auto *initBlock = rewriter.getInsertionBlock();
325 auto initPosition = rewriter.getInsertionPoint();
326 auto *endBlock = rewriter.splitBlock(block: initBlock, before: initPosition);
327
328 // Use the first block of the loop body as the condition block since it is the
329 // block that has the induction variable and loop-carried values as arguments.
330 // Split out all operations from the first block into a new block. Move all
331 // body blocks from the loop body region to the region containing the loop.
332 auto *conditionBlock = &forOp.getRegion().front();
333 auto *firstBodyBlock =
334 rewriter.splitBlock(block: conditionBlock, before: conditionBlock->begin());
335 auto *lastBodyBlock = &forOp.getRegion().back();
336 rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
337 auto iv = conditionBlock->getArgument(0);
338
339 // Append the induction variable stepping logic to the last body block and
340 // branch back to the condition block. Loop-carried values are taken from
341 // operands of the loop terminator.
342 Operation *terminator = lastBodyBlock->getTerminator();
343 rewriter.setInsertionPointToEnd(lastBodyBlock);
344 auto step = forOp.getStep();
345 auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult();
346 if (!stepped)
347 return failure();
348
349 SmallVector<Value, 8> loopCarried;
350 loopCarried.push_back(Elt: stepped);
351 loopCarried.append(in_start: terminator->operand_begin(), in_end: terminator->operand_end());
352 rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
353 rewriter.eraseOp(op: terminator);
354
355 // Compute loop bounds before branching to the condition.
356 rewriter.setInsertionPointToEnd(initBlock);
357 Value lowerBound = forOp.getLowerBound();
358 Value upperBound = forOp.getUpperBound();
359 if (!lowerBound || !upperBound)
360 return failure();
361
362 // The initial values of loop-carried values is obtained from the operands
363 // of the loop operation.
364 SmallVector<Value, 8> destOperands;
365 destOperands.push_back(Elt: lowerBound);
366 llvm::append_range(destOperands, forOp.getInitArgs());
367 rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands);
368
369 // With the body block done, we can fill in the condition block.
370 rewriter.setInsertionPointToEnd(conditionBlock);
371 auto comparison = rewriter.create<arith::CmpIOp>(
372 loc, arith::CmpIPredicate::slt, iv, upperBound);
373
374 auto condBranchOp = rewriter.create<cf::CondBranchOp>(
375 loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock,
376 ArrayRef<Value>());
377
378 // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
379 // llvm.loop_annotation attribute.
380 SmallVector<NamedAttribute> llvmAttrs;
381 llvm::copy_if(forOp->getAttrs(), std::back_inserter(x&: llvmAttrs),
382 [](auto attr) {
383 return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
384 });
385 condBranchOp->setDiscardableAttrs(llvmAttrs);
386 // The result of the loop operation is the values of the condition block
387 // arguments except the induction variable on the last iteration.
388 rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
389 return success();
390}
391
392LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
393 PatternRewriter &rewriter) const {
394 auto loc = ifOp.getLoc();
395
396 // Start by splitting the block containing the 'scf.if' into two parts.
397 // The part before will contain the condition, the part after will be the
398 // continuation point.
399 auto *condBlock = rewriter.getInsertionBlock();
400 auto opPosition = rewriter.getInsertionPoint();
401 auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition);
402 Block *continueBlock;
403 if (ifOp.getNumResults() == 0) {
404 continueBlock = remainingOpsBlock;
405 } else {
406 continueBlock =
407 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(),
408 SmallVector<Location>(ifOp.getNumResults(), loc));
409 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock);
410 }
411
412 // Move blocks from the "then" region to the region containing 'scf.if',
413 // place it before the continuation block, and branch to it.
414 auto &thenRegion = ifOp.getThenRegion();
415 auto *thenBlock = &thenRegion.front();
416 Operation *thenTerminator = thenRegion.back().getTerminator();
417 ValueRange thenTerminatorOperands = thenTerminator->getOperands();
418 rewriter.setInsertionPointToEnd(&thenRegion.back());
419 rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands);
420 rewriter.eraseOp(op: thenTerminator);
421 rewriter.inlineRegionBefore(thenRegion, continueBlock);
422
423 // Move blocks from the "else" region (if present) to the region containing
424 // 'scf.if', place it before the continuation block and branch to it. It
425 // will be placed after the "then" regions.
426 auto *elseBlock = continueBlock;
427 auto &elseRegion = ifOp.getElseRegion();
428 if (!elseRegion.empty()) {
429 elseBlock = &elseRegion.front();
430 Operation *elseTerminator = elseRegion.back().getTerminator();
431 ValueRange elseTerminatorOperands = elseTerminator->getOperands();
432 rewriter.setInsertionPointToEnd(&elseRegion.back());
433 rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands);
434 rewriter.eraseOp(op: elseTerminator);
435 rewriter.inlineRegionBefore(elseRegion, continueBlock);
436 }
437
438 rewriter.setInsertionPointToEnd(condBlock);
439 rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock,
440 /*trueArgs=*/ArrayRef<Value>(), elseBlock,
441 /*falseArgs=*/ArrayRef<Value>());
442
443 // Ok, we're done!
444 rewriter.replaceOp(ifOp, continueBlock->getArguments());
445 return success();
446}
447
448LogicalResult
449ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
450 PatternRewriter &rewriter) const {
451 auto loc = op.getLoc();
452
453 auto *condBlock = rewriter.getInsertionBlock();
454 auto opPosition = rewriter.getInsertionPoint();
455 auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition);
456
457 auto &region = op.getRegion();
458 rewriter.setInsertionPointToEnd(condBlock);
459 rewriter.create<cf::BranchOp>(loc, &region.front());
460
461 for (Block &block : region) {
462 if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
463 ValueRange terminatorOperands = terminator->getOperands();
464 rewriter.setInsertionPointToEnd(&block);
465 rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands);
466 rewriter.eraseOp(terminator);
467 }
468 }
469
470 rewriter.inlineRegionBefore(region, remainingOpsBlock);
471
472 SmallVector<Value> vals;
473 SmallVector<Location> argLocs(op.getNumResults(), op->getLoc());
474 for (auto arg :
475 remainingOpsBlock->addArguments(op->getResultTypes(), argLocs))
476 vals.push_back(arg);
477 rewriter.replaceOp(op, vals);
478 return success();
479}
480
481LogicalResult
482ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
483 PatternRewriter &rewriter) const {
484 Location loc = parallelOp.getLoc();
485 auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator());
486 if (!reductionOp) {
487 return failure();
488 }
489
490 // For a parallel loop, we essentially need to create an n-dimensional loop
491 // nest. We do this by translating to scf.for ops and have those lowered in
492 // a further rewrite. If a parallel loop contains reductions (and thus returns
493 // values), forward the initial values for the reductions down the loop
494 // hierarchy and bubble up the results by modifying the "yield" terminator.
495 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals());
496 SmallVector<Value, 4> ivs;
497 ivs.reserve(N: parallelOp.getNumLoops());
498 bool first = true;
499 SmallVector<Value, 4> loopResults(iterArgs);
500 for (auto [iv, lower, upper, step] :
501 llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(),
502 parallelOp.getUpperBound(), parallelOp.getStep())) {
503 ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
504 ivs.push_back(forOp.getInductionVar());
505 auto iterRange = forOp.getRegionIterArgs();
506 iterArgs.assign(iterRange.begin(), iterRange.end());
507
508 if (first) {
509 // Store the results of the outermost loop that will be used to replace
510 // the results of the parallel loop when it is fully rewritten.
511 loopResults.assign(forOp.result_begin(), forOp.result_end());
512 first = false;
513 } else if (!forOp.getResults().empty()) {
514 // A loop is constructed with an empty "yield" terminator if there are
515 // no results.
516 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
517 rewriter.create<scf::YieldOp>(loc, forOp.getResults());
518 }
519
520 rewriter.setInsertionPointToStart(forOp.getBody());
521 }
522
523 // First, merge reduction blocks into the main region.
524 SmallVector<Value> yieldOperands;
525 yieldOperands.reserve(N: parallelOp.getNumResults());
526 for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) {
527 Block &reductionBody = reductionOp.getReductions()[i].front();
528 Value arg = iterArgs[yieldOperands.size()];
529 yieldOperands.push_back(
530 cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult());
531 rewriter.eraseOp(op: reductionBody.getTerminator());
532 rewriter.inlineBlockBefore(&reductionBody, reductionOp,
533 {arg, reductionOp.getOperands()[i]});
534 }
535 rewriter.eraseOp(op: reductionOp);
536
537 // Then merge the loop body without the terminator.
538 Block *newBody = rewriter.getInsertionBlock();
539 if (newBody->empty())
540 rewriter.mergeBlocks(source: parallelOp.getBody(), dest: newBody, argValues: ivs);
541 else
542 rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
543 ivs);
544
545 // Finally, create the terminator if required (for loops with no results, it
546 // has been already created in loop construction).
547 if (!yieldOperands.empty()) {
548 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
549 rewriter.create<scf::YieldOp>(loc, yieldOperands);
550 }
551
552 rewriter.replaceOp(parallelOp, loopResults);
553
554 return success();
555}
556
557LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
558 PatternRewriter &rewriter) const {
559 OpBuilder::InsertionGuard guard(rewriter);
560 Location loc = whileOp.getLoc();
561
562 // Split the current block before the WhileOp to create the inlining point.
563 Block *currentBlock = rewriter.getInsertionBlock();
564 Block *continuation =
565 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
566
567 // Inline both regions.
568 Block *after = whileOp.getAfterBody();
569 Block *before = whileOp.getBeforeBody();
570 rewriter.inlineRegionBefore(whileOp.getAfter(), continuation);
571 rewriter.inlineRegionBefore(whileOp.getBefore(), after);
572
573 // Branch to the "before" region.
574 rewriter.setInsertionPointToEnd(currentBlock);
575 rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits());
576
577 // Replace terminators with branches. Assuming bodies are SESE, which holds
578 // given only the patterns from this file, we only need to look at the last
579 // block. This should be reconsidered if we allow break/continue in SCF.
580 rewriter.setInsertionPointToEnd(before);
581 auto condOp = cast<ConditionOp>(before->getTerminator());
582 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
583 after, condOp.getArgs(),
584 continuation, ValueRange());
585
586 rewriter.setInsertionPointToEnd(after);
587 auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
588 rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
589 yieldOp.getResults());
590
591 // Replace the op with values "yielded" from the "before" region, which are
592 // visible by dominance.
593 rewriter.replaceOp(whileOp, condOp.getArgs());
594
595 return success();
596}
597
598LogicalResult
599DoWhileLowering::matchAndRewrite(WhileOp whileOp,
600 PatternRewriter &rewriter) const {
601 Block &afterBlock = *whileOp.getAfterBody();
602 if (!llvm::hasSingleElement(C&: afterBlock))
603 return rewriter.notifyMatchFailure(whileOp,
604 "do-while simplification applicable "
605 "only if 'after' region has no payload");
606
607 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
608 if (!yield || yield.getResults() != afterBlock.getArguments())
609 return rewriter.notifyMatchFailure(whileOp,
610 "do-while simplification applicable "
611 "only to forwarding 'after' regions");
612
613 // Split the current block before the WhileOp to create the inlining point.
614 OpBuilder::InsertionGuard guard(rewriter);
615 Block *currentBlock = rewriter.getInsertionBlock();
616 Block *continuation =
617 rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint());
618
619 // Only the "before" region should be inlined.
620 Block *before = whileOp.getBeforeBody();
621 rewriter.inlineRegionBefore(whileOp.getBefore(), continuation);
622
623 // Branch to the "before" region.
624 rewriter.setInsertionPointToEnd(currentBlock);
625 rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits());
626
627 // Loop around the "before" region based on condition.
628 rewriter.setInsertionPointToEnd(before);
629 auto condOp = cast<ConditionOp>(before->getTerminator());
630 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(),
631 before, condOp.getArgs(),
632 continuation, ValueRange());
633
634 // Replace the op with values "yielded" from the "before" region, which are
635 // visible by dominance.
636 rewriter.replaceOp(whileOp, condOp.getArgs());
637
638 return success();
639}
640
641LogicalResult
642IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,
643 PatternRewriter &rewriter) const {
644 // Split the block at the op.
645 Block *condBlock = rewriter.getInsertionBlock();
646 Block *continueBlock = rewriter.splitBlock(block: condBlock, before: Block::iterator(op));
647
648 // Create the arguments on the continue block with which to replace the
649 // results of the op.
650 SmallVector<Value> results;
651 results.reserve(N: op.getNumResults());
652 for (Type resultType : op.getResultTypes())
653 results.push_back(continueBlock->addArgument(resultType, op.getLoc()));
654
655 // Handle the regions.
656 auto convertRegion = [&](Region &region) -> FailureOr<Block *> {
657 Block *block = &region.front();
658
659 // Convert the yield terminator to a branch to the continue block.
660 auto yield = cast<scf::YieldOp>(block->getTerminator());
661 rewriter.setInsertionPoint(yield);
662 rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock,
663 yield.getOperands());
664
665 // Inline the region.
666 rewriter.inlineRegionBefore(region, before: continueBlock);
667 return block;
668 };
669
670 // Convert the case regions.
671 SmallVector<Block *> caseSuccessors;
672 SmallVector<int32_t> caseValues;
673 caseSuccessors.reserve(N: op.getCases().size());
674 caseValues.reserve(N: op.getCases().size());
675 for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) {
676 FailureOr<Block *> block = convertRegion(region);
677 if (failed(block))
678 return failure();
679 caseSuccessors.push_back(*block);
680 caseValues.push_back(value);
681 }
682
683 // Convert the default region.
684 FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion());
685 if (failed(Result: defaultBlock))
686 return failure();
687
688 // Create the switch.
689 rewriter.setInsertionPointToEnd(condBlock);
690 SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {});
691
692 // Cast switch index to integer case value.
693 Value caseValue = rewriter.create<arith::IndexCastOp>(
694 op.getLoc(), rewriter.getI32Type(), op.getArg());
695
696 rewriter.create<cf::SwitchOp>(
697 op.getLoc(), caseValue, *defaultBlock, ValueRange(),
698 rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands);
699 rewriter.replaceOp(op, continueBlock->getArguments());
700 return success();
701}
702
703LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
704 PatternRewriter &rewriter) const {
705 return scf::forallToParallelLoop(rewriter, forallOp: forallOp);
706}
707
708void mlir::populateSCFToControlFlowConversionPatterns(
709 RewritePatternSet &patterns) {
710 patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
711 WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
712 arg: patterns.getContext());
713 patterns.add<DoWhileLowering>(arg: patterns.getContext(), /*benefit=*/args: 2);
714}
715
716void SCFToControlFlowPass::runOnOperation() {
717 RewritePatternSet patterns(&getContext());
718 populateSCFToControlFlowConversionPatterns(patterns);
719
720 // Configure conversion to lower out SCF operations.
721 ConversionTarget target(getContext());
722 target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
723 scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
724 target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; });
725 if (failed(
726 applyPartialConversion(getOperation(), target, std::move(patterns))))
727 signalPassFailure();
728}
729

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp