1//===-- ControlFlowConverter.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Dialect/FIRDialect.h"
10#include "flang/Optimizer/Dialect/FIROps.h"
11#include "flang/Optimizer/Dialect/FIROpsSupport.h"
12#include "flang/Optimizer/Dialect/Support/FIRContext.h"
13#include "flang/Optimizer/Dialect/Support/KindMapping.h"
14#include "flang/Optimizer/Support/InternalNames.h"
15#include "flang/Optimizer/Support/TypeCode.h"
16#include "flang/Optimizer/Transforms/Passes.h"
17#include "flang/Runtime/derived-api.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/Support/CommandLine.h"
25
26namespace fir {
27#define GEN_PASS_DEF_CFGCONVERSION
28#include "flang/Optimizer/Transforms/Passes.h.inc"
29} // namespace fir
30
31using namespace fir;
32using namespace mlir;
33
34namespace {
35
36// Conversion of fir control ops to more primitive control-flow.
37//
38// FIR loops that cannot be converted to the affine dialect will remain as
39// `fir.do_loop` operations. These can be converted to control-flow operations.
40
41/// Convert `fir.do_loop` to CFG
42class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
43public:
44 using OpRewritePattern::OpRewritePattern;
45
46 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW)
47 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
48 forceLoopToExecuteOnce(forceLoopToExecuteOnce), setNSW(setNSW) {}
49
50 llvm::LogicalResult
51 matchAndRewrite(DoLoopOp loop,
52 mlir::PatternRewriter &rewriter) const override {
53 auto loc = loop.getLoc();
54 mlir::arith::IntegerOverflowFlags flags{};
55 if (setNSW)
56 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw);
57 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
58 rewriter.getContext(), flags);
59
60 // Create the start and end blocks that will wrap the DoLoopOp with an
61 // initalizer and an end point
62 auto *initBlock = rewriter.getInsertionBlock();
63 auto initPos = rewriter.getInsertionPoint();
64 auto *endBlock = rewriter.splitBlock(initBlock, initPos);
65
66 // Split the first DoLoopOp block in two parts. The part before will be the
67 // conditional block since it already has the induction variable and
68 // loop-carried values as arguments.
69 auto *conditionalBlock = &loop.getRegion().front();
70 conditionalBlock->addArgument(rewriter.getIndexType(), loc);
71 auto *firstBlock =
72 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
73 auto *lastBlock = &loop.getRegion().back();
74
75 // Move the blocks from the DoLoopOp between initBlock and endBlock
76 rewriter.inlineRegionBefore(loop.getRegion(), endBlock);
77
78 // Get loop values from the DoLoopOp
79 auto low = loop.getLowerBound();
80 auto high = loop.getUpperBound();
81 assert(low && high && "must be a Value");
82 auto step = loop.getStep();
83
84 // Initalization block
85 rewriter.setInsertionPointToEnd(initBlock);
86 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
87 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
88 mlir::Value iters =
89 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
90
91 if (forceLoopToExecuteOnce) {
92 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
93 auto cond = rewriter.create<mlir::arith::CmpIOp>(
94 loc, arith::CmpIPredicate::sle, iters, zero);
95 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
96 iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters);
97 }
98
99 llvm::SmallVector<mlir::Value> loopOperands;
100 loopOperands.push_back(low);
101 auto operands = loop.getIterOperands();
102 loopOperands.append(operands.begin(), operands.end());
103 loopOperands.push_back(iters);
104
105 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
106
107 // Last loop block
108 auto *terminator = lastBlock->getTerminator();
109 rewriter.setInsertionPointToEnd(lastBlock);
110 auto iv = conditionalBlock->getArgument(0);
111 mlir::Value steppedIndex =
112 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr);
113 assert(steppedIndex && "must be a Value");
114 auto lastArg = conditionalBlock->getNumArguments() - 1;
115 auto itersLeft = conditionalBlock->getArgument(lastArg);
116 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
117 mlir::Value itersMinusOne =
118 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
119
120 llvm::SmallVector<mlir::Value> loopCarried;
121 loopCarried.push_back(steppedIndex);
122 auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin())
123 : terminator->operand_begin();
124 loopCarried.append(begin, terminator->operand_end());
125 loopCarried.push_back(itersMinusOne);
126 auto backEdge =
127 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
128 rewriter.eraseOp(terminator);
129
130 // Copy loop annotations from the do loop to the loop back edge.
131 if (auto ann = loop.getLoopAnnotation())
132 backEdge->setAttr("loop_annotation", *ann);
133
134 // Conditional block
135 rewriter.setInsertionPointToEnd(conditionalBlock);
136 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
137 auto comparison = rewriter.create<mlir::arith::CmpIOp>(
138 loc, arith::CmpIPredicate::sgt, itersLeft, zero);
139
140 rewriter.create<mlir::cf::CondBranchOp>(
141 loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
142 llvm::ArrayRef<mlir::Value>());
143
144 // The result of the loop operation is the values of the condition block
145 // arguments except the induction variable on the last iteration.
146 auto args = loop.getFinalValue()
147 ? conditionalBlock->getArguments()
148 : conditionalBlock->getArguments().drop_front();
149 rewriter.replaceOp(loop, args.drop_back());
150 return success();
151 }
152
153private:
154 bool forceLoopToExecuteOnce;
155 bool setNSW;
156};
157
158/// Convert `fir.if` to control-flow
159class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
160public:
161 using OpRewritePattern::OpRewritePattern;
162
163 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW)
164 : mlir::OpRewritePattern<fir::IfOp>(ctx) {}
165
166 llvm::LogicalResult
167 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
168 auto loc = ifOp.getLoc();
169
170 // Split the block containing the 'fir.if' into two parts. The part before
171 // will contain the condition, the part after will be the continuation
172 // point.
173 auto *condBlock = rewriter.getInsertionBlock();
174 auto opPosition = rewriter.getInsertionPoint();
175 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
176 mlir::Block *continueBlock;
177 if (ifOp.getNumResults() == 0) {
178 continueBlock = remainingOpsBlock;
179 } else {
180 continueBlock = rewriter.createBlock(
181 remainingOpsBlock, ifOp.getResultTypes(),
182 llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc));
183 rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
184 }
185
186 // Move blocks from the "then" region to the region containing 'fir.if',
187 // place it before the continuation block, and branch to it.
188 auto &ifOpRegion = ifOp.getThenRegion();
189 auto *ifOpBlock = &ifOpRegion.front();
190 auto *ifOpTerminator = ifOpRegion.back().getTerminator();
191 auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
192 rewriter.setInsertionPointToEnd(&ifOpRegion.back());
193 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
194 ifOpTerminatorOperands);
195 rewriter.eraseOp(ifOpTerminator);
196 rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
197
198 // Move blocks from the "else" region (if present) to the region containing
199 // 'fir.if', place it before the continuation block and branch to it. It
200 // will be placed after the "then" regions.
201 auto *otherwiseBlock = continueBlock;
202 auto &otherwiseRegion = ifOp.getElseRegion();
203 if (!otherwiseRegion.empty()) {
204 otherwiseBlock = &otherwiseRegion.front();
205 auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
206 auto otherwiseTermOperands = otherwiseTerm->getOperands();
207 rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
208 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
209 otherwiseTermOperands);
210 rewriter.eraseOp(otherwiseTerm);
211 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
212 }
213
214 rewriter.setInsertionPointToEnd(condBlock);
215 auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
216 loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
217 otherwiseBlock, llvm::ArrayRef<mlir::Value>());
218 llvm::ArrayRef<int32_t> weights = ifOp.getWeights();
219 if (!weights.empty())
220 branchOp.setWeights(weights);
221 rewriter.replaceOp(ifOp, continueBlock->getArguments());
222 return success();
223 }
224};
225
226/// Convert `fir.iter_while` to control-flow.
227class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
228public:
229 using OpRewritePattern::OpRewritePattern;
230
231 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce,
232 bool setNSW)
233 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), setNSW(setNSW) {}
234
235 llvm::LogicalResult
236 matchAndRewrite(fir::IterWhileOp whileOp,
237 mlir::PatternRewriter &rewriter) const override {
238 auto loc = whileOp.getLoc();
239 mlir::arith::IntegerOverflowFlags flags{};
240 if (setNSW)
241 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw);
242 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
243 rewriter.getContext(), flags);
244
245 // Start by splitting the block containing the 'fir.do_loop' into two parts.
246 // The part before will get the init code, the part after will be the end
247 // point.
248 auto *initBlock = rewriter.getInsertionBlock();
249 auto initPosition = rewriter.getInsertionPoint();
250 auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
251
252 // Use the first block of the loop body as the condition block since it is
253 // the block that has the induction variable and loop-carried values as
254 // arguments. Split out all operations from the first block into a new
255 // block. Move all body blocks from the loop body region to the region
256 // containing the loop.
257 auto *conditionBlock = &whileOp.getRegion().front();
258 auto *firstBodyBlock =
259 rewriter.splitBlock(conditionBlock, conditionBlock->begin());
260 auto *lastBodyBlock = &whileOp.getRegion().back();
261 rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock);
262 auto iv = conditionBlock->getArgument(0);
263 auto iterateVar = conditionBlock->getArgument(1);
264
265 // Append the induction variable stepping logic to the last body block and
266 // branch back to the condition block. Loop-carried values are taken from
267 // operands of the loop terminator.
268 auto *terminator = lastBodyBlock->getTerminator();
269 rewriter.setInsertionPointToEnd(lastBodyBlock);
270 auto step = whileOp.getStep();
271 mlir::Value stepped =
272 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr);
273 assert(stepped && "must be a Value");
274
275 llvm::SmallVector<mlir::Value> loopCarried;
276 loopCarried.push_back(stepped);
277 auto begin = whileOp.getFinalValue()
278 ? std::next(terminator->operand_begin())
279 : terminator->operand_begin();
280 loopCarried.append(begin, terminator->operand_end());
281 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
282 rewriter.eraseOp(terminator);
283
284 // Compute loop bounds before branching to the condition.
285 rewriter.setInsertionPointToEnd(initBlock);
286 auto lowerBound = whileOp.getLowerBound();
287 auto upperBound = whileOp.getUpperBound();
288 assert(lowerBound && upperBound && "must be a Value");
289
290 // The initial values of loop-carried values is obtained from the operands
291 // of the loop operation.
292 llvm::SmallVector<mlir::Value> destOperands;
293 destOperands.push_back(lowerBound);
294 auto iterOperands = whileOp.getIterOperands();
295 destOperands.append(iterOperands.begin(), iterOperands.end());
296 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
297
298 // With the body block done, we can fill in the condition block.
299 rewriter.setInsertionPointToEnd(conditionBlock);
300 // The comparison depends on the sign of the step value. We fully expect
301 // this expression to be folded by the optimizer or LLVM. This expression
302 // is written this way so that `step == 0` always returns `false`.
303 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
304 auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
305 loc, arith::CmpIPredicate::slt, zero, step);
306 auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
307 loc, arith::CmpIPredicate::sle, iv, upperBound);
308 auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
309 loc, arith::CmpIPredicate::slt, step, zero);
310 auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
311 loc, arith::CmpIPredicate::sle, upperBound, iv);
312 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
313 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
314 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
315 // Remember to AND in the early-exit bool.
316 auto comparison =
317 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
318 rewriter.create<mlir::cf::CondBranchOp>(
319 loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
320 endBlock, llvm::ArrayRef<mlir::Value>());
321 // The result of the loop operation is the values of the condition block
322 // arguments except the induction variable on the last iteration.
323 auto args = whileOp.getFinalValue()
324 ? conditionBlock->getArguments()
325 : conditionBlock->getArguments().drop_front();
326 rewriter.replaceOp(whileOp, args);
327 return success();
328 }
329
330private:
331 bool setNSW;
332};
333
334/// Convert FIR structured control flow ops to CFG ops.
335class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> {
336public:
337 using CFGConversionBase<CfgConversion>::CFGConversionBase;
338
339 void runOnOperation() override {
340 auto *context = &this->getContext();
341 mlir::RewritePatternSet patterns(context);
342 fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce,
343 this->setNSW);
344 mlir::ConversionTarget target(*context);
345 target.addLegalDialect<mlir::affine::AffineDialect,
346 mlir::cf::ControlFlowDialect, FIROpsDialect,
347 mlir::func::FuncDialect>();
348
349 // apply the patterns
350 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
351 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
352 if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target,
353 std::move(patterns)))) {
354 mlir::emitError(mlir::UnknownLoc::get(context),
355 "error in converting to CFG\n");
356 this->signalPassFailure();
357 }
358 }
359};
360
361} // namespace
362
363/// Expose conversion rewriters to other passes
364void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
365 bool forceLoopToExecuteOnce,
366 bool setNSW) {
367 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
368 patterns.getContext(), forceLoopToExecuteOnce, setNSW);
369}
370

source code of flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp