1//===-- FIRToSCF.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Dialect/FIRDialect.h"
10#include "flang/Optimizer/Transforms/Passes.h"
11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "mlir/Transforms/DialectConversion.h"
13
14namespace fir {
15#define GEN_PASS_DEF_FIRTOSCFPASS
16#include "flang/Optimizer/Transforms/Passes.h.inc"
17} // namespace fir
18
19using namespace fir;
20using namespace mlir;
21
22namespace {
23class FIRToSCFPass : public fir::impl::FIRToSCFPassBase<FIRToSCFPass> {
24public:
25 void runOnOperation() override;
26};
27
28struct DoLoopConversion : public OpRewritePattern<fir::DoLoopOp> {
29 using OpRewritePattern<fir::DoLoopOp>::OpRewritePattern;
30
31 LogicalResult matchAndRewrite(fir::DoLoopOp doLoopOp,
32 PatternRewriter &rewriter) const override {
33 auto loc = doLoopOp.getLoc();
34 bool hasFinalValue = doLoopOp.getFinalValue().has_value();
35
36 // Get loop values from the DoLoopOp
37 auto low = doLoopOp.getLowerBound();
38 auto high = doLoopOp.getUpperBound();
39 assert(low && high && "must be a Value");
40 auto step = doLoopOp.getStep();
41 llvm::SmallVector<Value> iterArgs;
42 if (hasFinalValue)
43 iterArgs.push_back(low);
44 iterArgs.append(doLoopOp.getIterOperands().begin(),
45 doLoopOp.getIterOperands().end());
46
47 // fir.do_loop iterates over the interval [%l, %u], and the step may be
48 // negative. But scf.for iterates over the interval [%l, %u), and the step
49 // must be a positive value.
50 // For easier conversion, we calculate the trip count and use a canonical
51 // induction variable.
52 auto diff = rewriter.create<arith::SubIOp>(loc, high, low);
53 auto distance = rewriter.create<arith::AddIOp>(loc, diff, step);
54 auto tripCount = rewriter.create<arith::DivSIOp>(loc, distance, step);
55 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
56 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
57 auto scfForOp =
58 rewriter.create<scf::ForOp>(loc, zero, tripCount, one, iterArgs);
59
60 auto &loopOps = doLoopOp.getBody()->getOperations();
61 auto resultOp = cast<fir::ResultOp>(doLoopOp.getBody()->getTerminator());
62 auto results = resultOp.getOperands();
63 Block *loweredBody = scfForOp.getBody();
64
65 loweredBody->getOperations().splice(loweredBody->begin(), loopOps,
66 loopOps.begin(),
67 std::prev(loopOps.end()));
68
69 rewriter.setInsertionPointToStart(loweredBody);
70 Value iv =
71 rewriter.create<arith::MulIOp>(loc, scfForOp.getInductionVar(), step);
72 iv = rewriter.create<arith::AddIOp>(loc, low, iv);
73
74 if (!results.empty()) {
75 rewriter.setInsertionPointToEnd(loweredBody);
76 rewriter.create<scf::YieldOp>(resultOp->getLoc(), results);
77 }
78 doLoopOp.getInductionVar().replaceAllUsesWith(iv);
79 rewriter.replaceAllUsesWith(doLoopOp.getRegionIterArgs(),
80 hasFinalValue
81 ? scfForOp.getRegionIterArgs().drop_front()
82 : scfForOp.getRegionIterArgs());
83
84 // Copy all the attributes from the old to new op.
85 scfForOp->setAttrs(doLoopOp->getAttrs());
86 rewriter.replaceOp(doLoopOp, scfForOp);
87 return success();
88 }
89};
90
91struct IfConversion : public OpRewritePattern<fir::IfOp> {
92 using OpRewritePattern<fir::IfOp>::OpRewritePattern;
93 LogicalResult matchAndRewrite(fir::IfOp ifOp,
94 PatternRewriter &rewriter) const override {
95 mlir::Location loc = ifOp.getLoc();
96 mlir::detail::TypedValue<mlir::IntegerType> condition = ifOp.getCondition();
97 ValueTypeRange<ResultRange> resultTypes = ifOp.getResultTypes();
98 mlir::scf::IfOp scfIfOp = rewriter.create<scf::IfOp>(
99 loc, resultTypes, condition, !ifOp.getElseRegion().empty());
100 // then region
101 scfIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
102 Block &scfThenBlock = scfIfOp.getThenRegion().front();
103 Operation *scfThenTerminator = scfThenBlock.getTerminator();
104 // fir.result->scf.yield
105 rewriter.setInsertionPointToEnd(&scfThenBlock);
106 rewriter.replaceOpWithNewOp<scf::YieldOp>(scfThenTerminator,
107 scfThenTerminator->getOperands());
108
109 // else region
110 if (!ifOp.getElseRegion().empty()) {
111 scfIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
112 mlir::Block &elseBlock = scfIfOp.getElseRegion().front();
113 mlir::Operation *elseTerminator = elseBlock.getTerminator();
114
115 rewriter.setInsertionPointToEnd(&elseBlock);
116 rewriter.replaceOpWithNewOp<scf::YieldOp>(elseTerminator,
117 elseTerminator->getOperands());
118 }
119
120 scfIfOp->setAttrs(ifOp->getAttrs());
121 rewriter.replaceOp(ifOp, scfIfOp);
122 return success();
123 }
124};
125} // namespace
126
127void FIRToSCFPass::runOnOperation() {
128 RewritePatternSet patterns(&getContext());
129 patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
130 ConversionTarget target(getContext());
131 target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
132 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
133 if (failed(
134 applyPartialConversion(getOperation(), target, std::move(patterns))))
135 signalPassFailure();
136}
137
138std::unique_ptr<Pass> fir::createFIRToSCFPass() {
139 return std::make_unique<FIRToSCFPass>();
140}
141

source code of flang/lib/Optimizer/Transforms/FIRToSCF.cpp