1 | //===-- FIRToSCF.cpp ------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
10 | #include "flang/Optimizer/Transforms/Passes.h" |
11 | #include "mlir/Dialect/SCF/IR/SCF.h" |
12 | #include "mlir/Transforms/DialectConversion.h" |
13 | |
14 | namespace fir { |
15 | #define GEN_PASS_DEF_FIRTOSCFPASS |
16 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
17 | } // namespace fir |
18 | |
19 | using namespace fir; |
20 | using namespace mlir; |
21 | |
22 | namespace { |
23 | class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> { |
24 | public: |
25 | void runOnOperation() override; |
26 | }; |
27 | |
28 | struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> { |
29 | using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern; |
30 | |
31 | LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp, |
32 | PatternRewriter &rewriter) const override { |
33 | auto loc = doLoopOp.getLoc(); |
34 | bool hasFinalValue = doLoopOp.getFinalValue().has_value(); |
35 | |
36 | // Get loop values from the DoLoopOp |
37 | auto low = doLoopOp.getLowerBound(); |
38 | auto high = doLoopOp.getUpperBound(); |
39 | assert(low && high && "must be a Value" ); |
40 | auto step = doLoopOp.getStep(); |
41 | llvm::SmallVector<Value> iterArgs; |
42 | if (hasFinalValue) |
43 | iterArgs.push_back(low); |
44 | iterArgs.append(doLoopOp.getIterOperands().begin(), |
45 | doLoopOp.getIterOperands().end()); |
46 | |
47 | // fir.do_loop iterates over the interval [%l, %u], and the step may be |
48 | // negative. But scf.for iterates over the interval [%l, %u), and the step |
49 | // must be a positive value. |
50 | // For easier conversion, we calculate the trip count and use a canonical |
51 | // induction variable. |
52 | auto diff = rewriter.create<arith::SubIOp>(loc, high, low); |
53 | auto distance = rewriter.create<arith::AddIOp>(loc, diff, step); |
54 | auto tripCount = rewriter.create<arith::DivSIOp>(loc, distance, step); |
55 | auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); |
56 | auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); |
57 | auto scfForOp = |
58 | rewriter.create<scf::ForOp>(loc, zero, tripCount, one, iterArgs); |
59 | |
60 | auto &loopOps = doLoopOp.getBody()->getOperations(); |
61 | auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator()); |
62 | auto results = resultOp.getOperands(); |
63 | Block *loweredBody = scfForOp.getBody(); |
64 | |
65 | loweredBody->getOperations().splice(loweredBody->begin(), loopOps, |
66 | loopOps.begin(), |
67 | std::prev(loopOps.end())); |
68 | |
69 | rewriter.setInsertionPointToStart(loweredBody); |
70 | Value iv = |
71 | rewriter.create<arith::MulIOp>(loc, scfForOp.getInductionVar(), step); |
72 | iv = rewriter.create<arith::AddIOp>(loc, low, iv); |
73 | |
74 | if (!results.empty()) { |
75 | rewriter.setInsertionPointToEnd(loweredBody); |
76 | rewriter.create<scf::YieldOp>(resultOp->getLoc(), results); |
77 | } |
78 | doLoopOp.getInductionVar().replaceAllUsesWith(iv); |
79 | rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(), |
80 | hasFinalValue |
81 | ? scfForOp.getRegionIterArgs().drop_front() |
82 | : scfForOp.getRegionIterArgs()); |
83 | |
84 | // Copy all the attributes from the old to new op. |
85 | scfForOp->setAttrs(doLoopOp->getAttrs()); |
86 | rewriter.replaceOp(doLoopOp, scfForOp); |
87 | return success(); |
88 | } |
89 | }; |
90 | |
91 | struct IfConversion : public OpRewritePattern<fir::IfOp> { |
92 | using OpRewritePattern<fir::IfOp>::OpRewritePattern; |
93 | LogicalResult matchAndRewrite(fir::IfOp ifOp, |
94 | PatternRewriter &rewriter) const override { |
95 | mlir::Location loc = ifOp.getLoc(); |
96 | mlir::detail::TypedValue<mlir::IntegerType> condition = ifOp.getCondition(); |
97 | ValueTypeRange<ResultRange> resultTypes = ifOp.getResultTypes(); |
98 | mlir::scf::IfOp scfIfOp = rewriter.create<scf::IfOp>( |
99 | loc, resultTypes, condition, !ifOp.getElseRegion().empty()); |
100 | // then region |
101 | scfIfOp.getThenRegion().takeBody(ifOp.getThenRegion()); |
102 | Block &scfThenBlock = scfIfOp.getThenRegion().front(); |
103 | Operation *scfThenTerminator = scfThenBlock.getTerminator(); |
104 | // fir.result->scf.yield |
105 | rewriter.setInsertionPointToEnd(&scfThenBlock); |
106 | rewriter.replaceOpWithNewOp<scf::YieldOp>(scfThenTerminator, |
107 | scfThenTerminator->getOperands()); |
108 | |
109 | // else region |
110 | if (!ifOp.getElseRegion().empty()) { |
111 | scfIfOp.getElseRegion().takeBody(ifOp.getElseRegion()); |
112 | mlir::Block &elseBlock = scfIfOp.getElseRegion().front(); |
113 | mlir::Operation *elseTerminator = elseBlock.getTerminator(); |
114 | |
115 | rewriter.setInsertionPointToEnd(&elseBlock); |
116 | rewriter.replaceOpWithNewOp<scf::YieldOp>(elseTerminator, |
117 | elseTerminator->getOperands()); |
118 | } |
119 | |
120 | scfIfOp->setAttrs(ifOp->getAttrs()); |
121 | rewriter.replaceOp(ifOp, scfIfOp); |
122 | return success(); |
123 | } |
124 | }; |
125 | } // namespace |
126 | |
127 | void FIRToSCFPass::runOnOperation() { |
128 | RewritePatternSet patterns(&getContext()); |
129 | patterns.add<DoLoopConversion, IfConversion>(patterns.getContext()); |
130 | ConversionTarget target(getContext()); |
131 | target.addIllegalOp<fir::DoLoopOp, fir::IfOp>(); |
132 | target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
133 | if (failed( |
134 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
135 | signalPassFailure(); |
136 | } |
137 | |
138 | std::unique_ptr<Pass> fir::createFIRToSCFPass() { |
139 | return std::make_unique<FIRToSCFPass>(); |
140 | } |
141 | |