1//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.ForOp's into SCF.WhileOp's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SCF/Transforms/Transforms.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_SCFFORTOWHILELOOP
23#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24} // namespace mlir
25
26using namespace llvm;
27using namespace mlir;
28using scf::ForOp;
29using scf::WhileOp;
30
31namespace {
32
33struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
34 using OpRewritePattern<ForOp>::OpRewritePattern;
35
36 LogicalResult matchAndRewrite(ForOp forOp,
37 PatternRewriter &rewriter) const override {
38 // Generate type signature for the loop-carried values. The induction
39 // variable is placed first, followed by the forOp.iterArgs.
40 SmallVector<Type> lcvTypes;
41 SmallVector<Location> lcvLocs;
42 lcvTypes.push_back(Elt: forOp.getInductionVar().getType());
43 lcvLocs.push_back(Elt: forOp.getInductionVar().getLoc());
44 for (Value value : forOp.getInitArgs()) {
45 lcvTypes.push_back(Elt: value.getType());
46 lcvLocs.push_back(Elt: value.getLoc());
47 }
48
49 // Build scf.WhileOp
50 SmallVector<Value> initArgs;
51 initArgs.push_back(Elt: forOp.getLowerBound());
52 llvm::append_range(C&: initArgs, R: forOp.getInitArgs());
53 auto whileOp = rewriter.create<WhileOp>(location: forOp.getLoc(), args&: lcvTypes, args&: initArgs,
54 args: forOp->getAttrs());
55
56 // 'before' region contains the loop condition and forwarding of iteration
57 // arguments to the 'after' region.
58 auto *beforeBlock = rewriter.createBlock(
59 parent: &whileOp.getBefore(), insertPt: whileOp.getBefore().begin(), argTypes: lcvTypes, locs: lcvLocs);
60 rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61 auto cmpOp = rewriter.create<arith::CmpIOp>(
62 location: whileOp.getLoc(), args: arith::CmpIPredicate::slt,
63 args: beforeBlock->getArgument(i: 0), args: forOp.getUpperBound());
64 rewriter.create<scf::ConditionOp>(location: whileOp.getLoc(), args: cmpOp.getResult(),
65 args: beforeBlock->getArguments());
66
67 // Inline for-loop body into an executeRegion operation in the "after"
68 // region. The return type of the execRegionOp does not contain the
69 // iv - yields in the source for-loop contain only iterArgs.
70 auto *afterBlock = rewriter.createBlock(
71 parent: &whileOp.getAfter(), insertPt: whileOp.getAfter().begin(), argTypes: lcvTypes, locs: lcvLocs);
72
73 // Add induction variable incrementation
74 rewriter.setInsertionPointToEnd(afterBlock);
75 auto ivIncOp = rewriter.create<arith::AddIOp>(
76 location: whileOp.getLoc(), args: afterBlock->getArgument(i: 0), args: forOp.getStep());
77
78 // Rewrite uses of the for-loop block arguments to the new while-loop
79 // "after" arguments
80 for (const auto &barg : enumerate(First: forOp.getBody(idx: 0)->getArguments()))
81 rewriter.replaceAllUsesWith(from: barg.value(),
82 to: afterBlock->getArgument(i: barg.index()));
83
84 // Inline for-loop body operations into 'after' region.
85 for (auto &arg : llvm::make_early_inc_range(Range&: *forOp.getBody()))
86 rewriter.moveOpBefore(op: &arg, block: afterBlock, iterator: afterBlock->end());
87
88 // Add incremented IV to yield operations
89 for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
90 SmallVector<Value> yieldOperands = yieldOp.getOperands();
91 yieldOperands.insert(I: yieldOperands.begin(), Elt: ivIncOp.getResult());
92 rewriter.modifyOpInPlace(root: yieldOp,
93 callable: [&]() { yieldOp->setOperands(yieldOperands); });
94 }
95
96 // We cannot do a direct replacement of the forOp since the while op returns
97 // an extra value (the induction variable escapes the loop through being
98 // carried in the set of iterargs). Instead, rewrite uses of the forOp
99 // results.
100 for (const auto &arg : llvm::enumerate(First: forOp.getResults()))
101 rewriter.replaceAllUsesWith(from: arg.value(),
102 to: whileOp.getResult(i: arg.index() + 1));
103
104 rewriter.eraseOp(op: forOp);
105 return success();
106 }
107};
108
109struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
110 void runOnOperation() override {
111 auto *parentOp = getOperation();
112 MLIRContext *ctx = parentOp->getContext();
113 RewritePatternSet patterns(ctx);
114 patterns.add<ForLoopLoweringPattern>(arg&: ctx);
115 (void)applyPatternsGreedily(op: parentOp, patterns: std::move(patterns));
116 }
117};
118} // namespace
119
120std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
121 return std::make_unique<ForToWhileLoop>();
122}
123

source code of mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp