1//===-- ControlFlowConverter.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Optimizer/Dialect/FIRDialect.h"
10#include "flang/Optimizer/Dialect/FIROps.h"
11#include "flang/Optimizer/Dialect/FIROpsSupport.h"
12#include "flang/Optimizer/Dialect/Support/FIRContext.h"
13#include "flang/Optimizer/Dialect/Support/KindMapping.h"
14#include "flang/Optimizer/Support/InternalNames.h"
15#include "flang/Optimizer/Support/TypeCode.h"
16#include "flang/Optimizer/Transforms/Passes.h"
17#include "flang/Runtime/derived-api.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/SmallSet.h"
24#include "llvm/Support/CommandLine.h"
25
26namespace fir {
27#define GEN_PASS_DEF_CFGCONVERSION
28#include "flang/Optimizer/Transforms/Passes.h.inc"
29} // namespace fir
30
31using namespace fir;
32using namespace mlir;
33
34namespace {
35
36// Conversion of fir control ops to more primitive control-flow.
37//
38// FIR loops that cannot be converted to the affine dialect will remain as
39// `fir.do_loop` operations. These can be converted to control-flow operations.
40
41/// Convert `fir.do_loop` to CFG
42class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
43public:
44 using OpRewritePattern::OpRewritePattern;
45
46 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW)
47 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
48 forceLoopToExecuteOnce(forceLoopToExecuteOnce), setNSW(setNSW) {}
49
50 llvm::LogicalResult
51 matchAndRewrite(DoLoopOp loop,
52 mlir::PatternRewriter &rewriter) const override {
53 auto loc = loop.getLoc();
54 mlir::arith::IntegerOverflowFlags flags{};
55 if (setNSW)
56 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw);
57 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
58 rewriter.getContext(), flags);
59
60 // Create the start and end blocks that will wrap the DoLoopOp with an
61 // initalizer and an end point
62 auto *initBlock = rewriter.getInsertionBlock();
63 auto initPos = rewriter.getInsertionPoint();
64 auto *endBlock = rewriter.splitBlock(initBlock, initPos);
65
66 // Split the first DoLoopOp block in two parts. The part before will be the
67 // conditional block since it already has the induction variable and
68 // loop-carried values as arguments.
69 auto *conditionalBlock = &loop.getRegion().front();
70 conditionalBlock->addArgument(rewriter.getIndexType(), loc);
71 auto *firstBlock =
72 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
73 auto *lastBlock = &loop.getRegion().back();
74
75 // Move the blocks from the DoLoopOp between initBlock and endBlock
76 rewriter.inlineRegionBefore(loop.getRegion(), endBlock);
77
78 // Get loop values from the DoLoopOp
79 auto low = loop.getLowerBound();
80 auto high = loop.getUpperBound();
81 assert(low && high && "must be a Value");
82 auto step = loop.getStep();
83
84 // Initalization block
85 rewriter.setInsertionPointToEnd(initBlock);
86 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
87 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
88 mlir::Value iters =
89 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
90
91 if (forceLoopToExecuteOnce) {
92 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
93 auto cond = rewriter.create<mlir::arith::CmpIOp>(
94 loc, arith::CmpIPredicate::sle, iters, zero);
95 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
96 iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters);
97 }
98
99 llvm::SmallVector<mlir::Value> loopOperands;
100 loopOperands.push_back(low);
101 auto operands = loop.getIterOperands();
102 loopOperands.append(operands.begin(), operands.end());
103 loopOperands.push_back(iters);
104
105 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
106
107 // Last loop block
108 auto *terminator = lastBlock->getTerminator();
109 rewriter.setInsertionPointToEnd(lastBlock);
110 auto iv = conditionalBlock->getArgument(0);
111 mlir::Value steppedIndex =
112 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr);
113 assert(steppedIndex && "must be a Value");
114 auto lastArg = conditionalBlock->getNumArguments() - 1;
115 auto itersLeft = conditionalBlock->getArgument(lastArg);
116 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
117 mlir::Value itersMinusOne =
118 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
119
120 llvm::SmallVector<mlir::Value> loopCarried;
121 loopCarried.push_back(steppedIndex);
122 auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin())
123 : terminator->operand_begin();
124 loopCarried.append(begin, terminator->operand_end());
125 loopCarried.push_back(itersMinusOne);
126 auto backEdge =
127 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
128 rewriter.eraseOp(terminator);
129
130 // Copy loop annotations from the do loop to the loop back edge.
131 if (auto ann = loop.getLoopAnnotation())
132 backEdge->setAttr("loop_annotation", *ann);
133
134 // Conditional block
135 rewriter.setInsertionPointToEnd(conditionalBlock);
136 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
137 auto comparison = rewriter.create<mlir::arith::CmpIOp>(
138 loc, arith::CmpIPredicate::sgt, itersLeft, zero);
139
140 rewriter.create<mlir::cf::CondBranchOp>(
141 loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
142 llvm::ArrayRef<mlir::Value>());
143
144 // The result of the loop operation is the values of the condition block
145 // arguments except the induction variable on the last iteration.
146 auto args = loop.getFinalValue()
147 ? conditionalBlock->getArguments()
148 : conditionalBlock->getArguments().drop_front();
149 rewriter.replaceOp(loop, args.drop_back());
150 return success();
151 }
152
153private:
154 bool forceLoopToExecuteOnce;
155 bool setNSW;
156};
157
158/// Convert `fir.if` to control-flow
159class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
160public:
161 using OpRewritePattern::OpRewritePattern;
162
163 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW)
164 : mlir::OpRewritePattern<fir::IfOp>(ctx) {}
165
166 llvm::LogicalResult
167 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
168 auto loc = ifOp.getLoc();
169
170 // Split the block containing the 'fir.if' into two parts. The part before
171 // will contain the condition, the part after will be the continuation
172 // point.
173 auto *condBlock = rewriter.getInsertionBlock();
174 auto opPosition = rewriter.getInsertionPoint();
175 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
176 mlir::Block *continueBlock;
177 if (ifOp.getNumResults() == 0) {
178 continueBlock = remainingOpsBlock;
179 } else {
180 continueBlock = rewriter.createBlock(
181 remainingOpsBlock, ifOp.getResultTypes(),
182 llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc));
183 rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
184 }
185
186 // Move blocks from the "then" region to the region containing 'fir.if',
187 // place it before the continuation block, and branch to it.
188 auto &ifOpRegion = ifOp.getThenRegion();
189 auto *ifOpBlock = &ifOpRegion.front();
190 auto *ifOpTerminator = ifOpRegion.back().getTerminator();
191 auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
192 rewriter.setInsertionPointToEnd(&ifOpRegion.back());
193 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
194 ifOpTerminatorOperands);
195 rewriter.eraseOp(ifOpTerminator);
196 rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
197
198 // Move blocks from the "else" region (if present) to the region containing
199 // 'fir.if', place it before the continuation block and branch to it. It
200 // will be placed after the "then" regions.
201 auto *otherwiseBlock = continueBlock;
202 auto &otherwiseRegion = ifOp.getElseRegion();
203 if (!otherwiseRegion.empty()) {
204 otherwiseBlock = &otherwiseRegion.front();
205 auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
206 auto otherwiseTermOperands = otherwiseTerm->getOperands();
207 rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
208 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
209 otherwiseTermOperands);
210 rewriter.eraseOp(otherwiseTerm);
211 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
212 }
213
214 rewriter.setInsertionPointToEnd(condBlock);
215 rewriter.create<mlir::cf::CondBranchOp>(
216 loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
217 otherwiseBlock, llvm::ArrayRef<mlir::Value>());
218 rewriter.replaceOp(ifOp, continueBlock->getArguments());
219 return success();
220 }
221};
222
223/// Convert `fir.iter_while` to control-flow.
224class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
225public:
226 using OpRewritePattern::OpRewritePattern;
227
228 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce,
229 bool setNSW)
230 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), setNSW(setNSW) {}
231
232 llvm::LogicalResult
233 matchAndRewrite(fir::IterWhileOp whileOp,
234 mlir::PatternRewriter &rewriter) const override {
235 auto loc = whileOp.getLoc();
236 mlir::arith::IntegerOverflowFlags flags{};
237 if (setNSW)
238 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw);
239 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
240 rewriter.getContext(), flags);
241
242 // Start by splitting the block containing the 'fir.do_loop' into two parts.
243 // The part before will get the init code, the part after will be the end
244 // point.
245 auto *initBlock = rewriter.getInsertionBlock();
246 auto initPosition = rewriter.getInsertionPoint();
247 auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
248
249 // Use the first block of the loop body as the condition block since it is
250 // the block that has the induction variable and loop-carried values as
251 // arguments. Split out all operations from the first block into a new
252 // block. Move all body blocks from the loop body region to the region
253 // containing the loop.
254 auto *conditionBlock = &whileOp.getRegion().front();
255 auto *firstBodyBlock =
256 rewriter.splitBlock(conditionBlock, conditionBlock->begin());
257 auto *lastBodyBlock = &whileOp.getRegion().back();
258 rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock);
259 auto iv = conditionBlock->getArgument(0);
260 auto iterateVar = conditionBlock->getArgument(1);
261
262 // Append the induction variable stepping logic to the last body block and
263 // branch back to the condition block. Loop-carried values are taken from
264 // operands of the loop terminator.
265 auto *terminator = lastBodyBlock->getTerminator();
266 rewriter.setInsertionPointToEnd(lastBodyBlock);
267 auto step = whileOp.getStep();
268 mlir::Value stepped =
269 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr);
270 assert(stepped && "must be a Value");
271
272 llvm::SmallVector<mlir::Value> loopCarried;
273 loopCarried.push_back(stepped);
274 auto begin = whileOp.getFinalValue()
275 ? std::next(terminator->operand_begin())
276 : terminator->operand_begin();
277 loopCarried.append(begin, terminator->operand_end());
278 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
279 rewriter.eraseOp(terminator);
280
281 // Compute loop bounds before branching to the condition.
282 rewriter.setInsertionPointToEnd(initBlock);
283 auto lowerBound = whileOp.getLowerBound();
284 auto upperBound = whileOp.getUpperBound();
285 assert(lowerBound && upperBound && "must be a Value");
286
287 // The initial values of loop-carried values is obtained from the operands
288 // of the loop operation.
289 llvm::SmallVector<mlir::Value> destOperands;
290 destOperands.push_back(lowerBound);
291 auto iterOperands = whileOp.getIterOperands();
292 destOperands.append(iterOperands.begin(), iterOperands.end());
293 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
294
295 // With the body block done, we can fill in the condition block.
296 rewriter.setInsertionPointToEnd(conditionBlock);
297 // The comparison depends on the sign of the step value. We fully expect
298 // this expression to be folded by the optimizer or LLVM. This expression
299 // is written this way so that `step == 0` always returns `false`.
300 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
301 auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
302 loc, arith::CmpIPredicate::slt, zero, step);
303 auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
304 loc, arith::CmpIPredicate::sle, iv, upperBound);
305 auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
306 loc, arith::CmpIPredicate::slt, step, zero);
307 auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
308 loc, arith::CmpIPredicate::sle, upperBound, iv);
309 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
310 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
311 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
312 // Remember to AND in the early-exit bool.
313 auto comparison =
314 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
315 rewriter.create<mlir::cf::CondBranchOp>(
316 loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
317 endBlock, llvm::ArrayRef<mlir::Value>());
318 // The result of the loop operation is the values of the condition block
319 // arguments except the induction variable on the last iteration.
320 auto args = whileOp.getFinalValue()
321 ? conditionBlock->getArguments()
322 : conditionBlock->getArguments().drop_front();
323 rewriter.replaceOp(whileOp, args);
324 return success();
325 }
326
327private:
328 bool setNSW;
329};
330
331/// Convert FIR structured control flow ops to CFG ops.
332class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> {
333public:
334 using CFGConversionBase<CfgConversion>::CFGConversionBase;
335
336 void runOnOperation() override {
337 auto *context = &this->getContext();
338 mlir::RewritePatternSet patterns(context);
339 fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce,
340 this->setNSW);
341 mlir::ConversionTarget target(*context);
342 target.addLegalDialect<mlir::affine::AffineDialect,
343 mlir::cf::ControlFlowDialect, FIROpsDialect,
344 mlir::func::FuncDialect>();
345
346 // apply the patterns
347 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
348 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
349 if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target,
350 std::move(patterns)))) {
351 mlir::emitError(mlir::UnknownLoc::get(context),
352 "error in converting to CFG\n");
353 this->signalPassFailure();
354 }
355 }
356};
357
358} // namespace
359
360/// Expose conversion rewriters to other passes
361void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
362 bool forceLoopToExecuteOnce,
363 bool setNSW) {
364 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
365 patterns.getContext(), forceLoopToExecuteOnce, setNSW);
366}
367

source code of flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp