1//===- ForToWhile.cpp - scf.for to scf.while loop conversion --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.ForOp's into SCF.WhileOp's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SCF/Transforms/Transforms.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_SCFFORTOWHILELOOP
23#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
24} // namespace mlir
25
26using namespace llvm;
27using namespace mlir;
28using scf::ForOp;
29using scf::WhileOp;
30
31namespace {
32
33struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
34 using OpRewritePattern<ForOp>::OpRewritePattern;
35
36 LogicalResult matchAndRewrite(ForOp forOp,
37 PatternRewriter &rewriter) const override {
38 // Generate type signature for the loop-carried values. The induction
39 // variable is placed first, followed by the forOp.iterArgs.
40 SmallVector<Type> lcvTypes;
41 SmallVector<Location> lcvLocs;
42 lcvTypes.push_back(forOp.getInductionVar().getType());
43 lcvLocs.push_back(forOp.getInductionVar().getLoc());
44 for (Value value : forOp.getInitArgs()) {
45 lcvTypes.push_back(value.getType());
46 lcvLocs.push_back(value.getLoc());
47 }
48
49 // Build scf.WhileOp
50 SmallVector<Value> initArgs;
51 initArgs.push_back(forOp.getLowerBound());
52 llvm::append_range(initArgs, forOp.getInitArgs());
53 auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
54 forOp->getAttrs());
55
56 // 'before' region contains the loop condition and forwarding of iteration
57 // arguments to the 'after' region.
58 auto *beforeBlock = rewriter.createBlock(
59 &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
60 rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61 auto cmpOp = rewriter.create<arith::CmpIOp>(
62 whileOp.getLoc(), arith::CmpIPredicate::slt,
63 beforeBlock->getArgument(0), forOp.getUpperBound());
64 rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
65 beforeBlock->getArguments());
66
67 // Inline for-loop body into an executeRegion operation in the "after"
68 // region. The return type of the execRegionOp does not contain the
69 // iv - yields in the source for-loop contain only iterArgs.
70 auto *afterBlock = rewriter.createBlock(
71 &whileOp.getAfter(), whileOp.getAfter().begin(), lcvTypes, lcvLocs);
72
73 // Add induction variable incrementation
74 rewriter.setInsertionPointToEnd(afterBlock);
75 auto ivIncOp = rewriter.create<arith::AddIOp>(
76 whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
77
78 // Rewrite uses of the for-loop block arguments to the new while-loop
79 // "after" arguments
80 for (const auto &barg : enumerate(forOp.getBody(0)->getArguments()))
81 rewriter.replaceAllUsesWith(barg.value(),
82 afterBlock->getArgument(barg.index()));
83
84 // Inline for-loop body operations into 'after' region.
85 for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
86 rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
87
88 // Add incremented IV to yield operations
89 for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
90 SmallVector<Value> yieldOperands = yieldOp.getOperands();
91 yieldOperands.insert(yieldOperands.begin(), ivIncOp.getResult());
92 rewriter.modifyOpInPlace(yieldOp,
93 [&]() { yieldOp->setOperands(yieldOperands); });
94 }
95
96 // We cannot do a direct replacement of the forOp since the while op returns
97 // an extra value (the induction variable escapes the loop through being
98 // carried in the set of iterargs). Instead, rewrite uses of the forOp
99 // results.
100 for (const auto &arg : llvm::enumerate(forOp.getResults()))
101 rewriter.replaceAllUsesWith(arg.value(),
102 whileOp.getResult(arg.index() + 1));
103
104 rewriter.eraseOp(op: forOp);
105 return success();
106 }
107};
108
109struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
110 void runOnOperation() override {
111 auto *parentOp = getOperation();
112 MLIRContext *ctx = parentOp->getContext();
113 RewritePatternSet patterns(ctx);
114 patterns.add<ForLoopLoweringPattern>(ctx);
115 (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
116 }
117};
118} // namespace
119
120std::unique_ptr<Pass> mlir::createForToWhileLoopPass() {
121 return std::make_unique<ForToWhileLoop>();
122}
123

source code of mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp