1 | //===-- ControlFlowConverter.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
10 | #include "flang/Optimizer/Dialect/FIROps.h" |
11 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
12 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
13 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
14 | #include "flang/Optimizer/Support/InternalNames.h" |
15 | #include "flang/Optimizer/Support/TypeCode.h" |
16 | #include "flang/Optimizer/Transforms/Passes.h" |
17 | #include "flang/Runtime/derived-api.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
19 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | #include "llvm/ADT/SmallSet.h" |
24 | #include "llvm/Support/CommandLine.h" |
25 | |
26 | namespace fir { |
27 | #define GEN_PASS_DEF_CFGCONVERSION |
28 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
29 | } // namespace fir |
30 | |
31 | using namespace fir; |
32 | using namespace mlir; |
33 | |
34 | namespace { |
35 | |
36 | // Conversion of fir control ops to more primitive control-flow. |
37 | // |
38 | // FIR loops that cannot be converted to the affine dialect will remain as |
39 | // `fir.do_loop` operations. These can be converted to control-flow operations. |
40 | |
41 | /// Convert `fir.do_loop` to CFG |
42 | class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { |
43 | public: |
44 | using OpRewritePattern::OpRewritePattern; |
45 | |
46 | CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW) |
47 | : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), |
48 | forceLoopToExecuteOnce(forceLoopToExecuteOnce), setNSW(setNSW) {} |
49 | |
50 | llvm::LogicalResult |
51 | matchAndRewrite(DoLoopOp loop, |
52 | mlir::PatternRewriter &rewriter) const override { |
53 | auto loc = loop.getLoc(); |
54 | mlir::arith::IntegerOverflowFlags flags{}; |
55 | if (setNSW) |
56 | flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw); |
57 | auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get( |
58 | rewriter.getContext(), flags); |
59 | |
60 | // Create the start and end blocks that will wrap the DoLoopOp with an |
61 | // initalizer and an end point |
62 | auto *initBlock = rewriter.getInsertionBlock(); |
63 | auto initPos = rewriter.getInsertionPoint(); |
64 | auto *endBlock = rewriter.splitBlock(initBlock, initPos); |
65 | |
66 | // Split the first DoLoopOp block in two parts. The part before will be the |
67 | // conditional block since it already has the induction variable and |
68 | // loop-carried values as arguments. |
69 | auto *conditionalBlock = &loop.getRegion().front(); |
70 | conditionalBlock->addArgument(rewriter.getIndexType(), loc); |
71 | auto *firstBlock = |
72 | rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); |
73 | auto *lastBlock = &loop.getRegion().back(); |
74 | |
75 | // Move the blocks from the DoLoopOp between initBlock and endBlock |
76 | rewriter.inlineRegionBefore(loop.getRegion(), endBlock); |
77 | |
78 | // Get loop values from the DoLoopOp |
79 | auto low = loop.getLowerBound(); |
80 | auto high = loop.getUpperBound(); |
81 | assert(low && high && "must be a Value" ); |
82 | auto step = loop.getStep(); |
83 | |
84 | // Initalization block |
85 | rewriter.setInsertionPointToEnd(initBlock); |
86 | auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); |
87 | auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); |
88 | mlir::Value iters = |
89 | rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); |
90 | |
91 | if (forceLoopToExecuteOnce) { |
92 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
93 | auto cond = rewriter.create<mlir::arith::CmpIOp>( |
94 | loc, arith::CmpIPredicate::sle, iters, zero); |
95 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
96 | iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters); |
97 | } |
98 | |
99 | llvm::SmallVector<mlir::Value> loopOperands; |
100 | loopOperands.push_back(low); |
101 | auto operands = loop.getIterOperands(); |
102 | loopOperands.append(operands.begin(), operands.end()); |
103 | loopOperands.push_back(iters); |
104 | |
105 | rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands); |
106 | |
107 | // Last loop block |
108 | auto *terminator = lastBlock->getTerminator(); |
109 | rewriter.setInsertionPointToEnd(lastBlock); |
110 | auto iv = conditionalBlock->getArgument(0); |
111 | mlir::Value steppedIndex = |
112 | rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr); |
113 | assert(steppedIndex && "must be a Value" ); |
114 | auto lastArg = conditionalBlock->getNumArguments() - 1; |
115 | auto itersLeft = conditionalBlock->getArgument(lastArg); |
116 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
117 | mlir::Value itersMinusOne = |
118 | rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); |
119 | |
120 | llvm::SmallVector<mlir::Value> loopCarried; |
121 | loopCarried.push_back(steppedIndex); |
122 | auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin()) |
123 | : terminator->operand_begin(); |
124 | loopCarried.append(begin, terminator->operand_end()); |
125 | loopCarried.push_back(itersMinusOne); |
126 | auto backEdge = |
127 | rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried); |
128 | rewriter.eraseOp(terminator); |
129 | |
130 | // Copy loop annotations from the do loop to the loop back edge. |
131 | if (auto ann = loop.getLoopAnnotation()) |
132 | backEdge->setAttr("loop_annotation" , *ann); |
133 | |
134 | // Conditional block |
135 | rewriter.setInsertionPointToEnd(conditionalBlock); |
136 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
137 | auto comparison = rewriter.create<mlir::arith::CmpIOp>( |
138 | loc, arith::CmpIPredicate::sgt, itersLeft, zero); |
139 | |
140 | rewriter.create<mlir::cf::CondBranchOp>( |
141 | loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock, |
142 | llvm::ArrayRef<mlir::Value>()); |
143 | |
144 | // The result of the loop operation is the values of the condition block |
145 | // arguments except the induction variable on the last iteration. |
146 | auto args = loop.getFinalValue() |
147 | ? conditionalBlock->getArguments() |
148 | : conditionalBlock->getArguments().drop_front(); |
149 | rewriter.replaceOp(loop, args.drop_back()); |
150 | return success(); |
151 | } |
152 | |
153 | private: |
154 | bool forceLoopToExecuteOnce; |
155 | bool setNSW; |
156 | }; |
157 | |
158 | /// Convert `fir.if` to control-flow |
159 | class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { |
160 | public: |
161 | using OpRewritePattern::OpRewritePattern; |
162 | |
163 | CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW) |
164 | : mlir::OpRewritePattern<fir::IfOp>(ctx) {} |
165 | |
166 | llvm::LogicalResult |
167 | matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { |
168 | auto loc = ifOp.getLoc(); |
169 | |
170 | // Split the block containing the 'fir.if' into two parts. The part before |
171 | // will contain the condition, the part after will be the continuation |
172 | // point. |
173 | auto *condBlock = rewriter.getInsertionBlock(); |
174 | auto opPosition = rewriter.getInsertionPoint(); |
175 | auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); |
176 | mlir::Block *continueBlock; |
177 | if (ifOp.getNumResults() == 0) { |
178 | continueBlock = remainingOpsBlock; |
179 | } else { |
180 | continueBlock = rewriter.createBlock( |
181 | remainingOpsBlock, ifOp.getResultTypes(), |
182 | llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc)); |
183 | rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock); |
184 | } |
185 | |
186 | // Move blocks from the "then" region to the region containing 'fir.if', |
187 | // place it before the continuation block, and branch to it. |
188 | auto &ifOpRegion = ifOp.getThenRegion(); |
189 | auto *ifOpBlock = &ifOpRegion.front(); |
190 | auto *ifOpTerminator = ifOpRegion.back().getTerminator(); |
191 | auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); |
192 | rewriter.setInsertionPointToEnd(&ifOpRegion.back()); |
193 | rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, |
194 | ifOpTerminatorOperands); |
195 | rewriter.eraseOp(ifOpTerminator); |
196 | rewriter.inlineRegionBefore(ifOpRegion, continueBlock); |
197 | |
198 | // Move blocks from the "else" region (if present) to the region containing |
199 | // 'fir.if', place it before the continuation block and branch to it. It |
200 | // will be placed after the "then" regions. |
201 | auto *otherwiseBlock = continueBlock; |
202 | auto &otherwiseRegion = ifOp.getElseRegion(); |
203 | if (!otherwiseRegion.empty()) { |
204 | otherwiseBlock = &otherwiseRegion.front(); |
205 | auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); |
206 | auto otherwiseTermOperands = otherwiseTerm->getOperands(); |
207 | rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); |
208 | rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, |
209 | otherwiseTermOperands); |
210 | rewriter.eraseOp(otherwiseTerm); |
211 | rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); |
212 | } |
213 | |
214 | rewriter.setInsertionPointToEnd(condBlock); |
215 | rewriter.create<mlir::cf::CondBranchOp>( |
216 | loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), |
217 | otherwiseBlock, llvm::ArrayRef<mlir::Value>()); |
218 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
219 | return success(); |
220 | } |
221 | }; |
222 | |
223 | /// Convert `fir.iter_while` to control-flow. |
224 | class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { |
225 | public: |
226 | using OpRewritePattern::OpRewritePattern; |
227 | |
228 | CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, |
229 | bool setNSW) |
230 | : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), setNSW(setNSW) {} |
231 | |
232 | llvm::LogicalResult |
233 | matchAndRewrite(fir::IterWhileOp whileOp, |
234 | mlir::PatternRewriter &rewriter) const override { |
235 | auto loc = whileOp.getLoc(); |
236 | mlir::arith::IntegerOverflowFlags flags{}; |
237 | if (setNSW) |
238 | flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw); |
239 | auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get( |
240 | rewriter.getContext(), flags); |
241 | |
242 | // Start by splitting the block containing the 'fir.do_loop' into two parts. |
243 | // The part before will get the init code, the part after will be the end |
244 | // point. |
245 | auto *initBlock = rewriter.getInsertionBlock(); |
246 | auto initPosition = rewriter.getInsertionPoint(); |
247 | auto *endBlock = rewriter.splitBlock(initBlock, initPosition); |
248 | |
249 | // Use the first block of the loop body as the condition block since it is |
250 | // the block that has the induction variable and loop-carried values as |
251 | // arguments. Split out all operations from the first block into a new |
252 | // block. Move all body blocks from the loop body region to the region |
253 | // containing the loop. |
254 | auto *conditionBlock = &whileOp.getRegion().front(); |
255 | auto *firstBodyBlock = |
256 | rewriter.splitBlock(conditionBlock, conditionBlock->begin()); |
257 | auto *lastBodyBlock = &whileOp.getRegion().back(); |
258 | rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock); |
259 | auto iv = conditionBlock->getArgument(0); |
260 | auto iterateVar = conditionBlock->getArgument(1); |
261 | |
262 | // Append the induction variable stepping logic to the last body block and |
263 | // branch back to the condition block. Loop-carried values are taken from |
264 | // operands of the loop terminator. |
265 | auto *terminator = lastBodyBlock->getTerminator(); |
266 | rewriter.setInsertionPointToEnd(lastBodyBlock); |
267 | auto step = whileOp.getStep(); |
268 | mlir::Value stepped = |
269 | rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr); |
270 | assert(stepped && "must be a Value" ); |
271 | |
272 | llvm::SmallVector<mlir::Value> loopCarried; |
273 | loopCarried.push_back(stepped); |
274 | auto begin = whileOp.getFinalValue() |
275 | ? std::next(terminator->operand_begin()) |
276 | : terminator->operand_begin(); |
277 | loopCarried.append(begin, terminator->operand_end()); |
278 | rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried); |
279 | rewriter.eraseOp(terminator); |
280 | |
281 | // Compute loop bounds before branching to the condition. |
282 | rewriter.setInsertionPointToEnd(initBlock); |
283 | auto lowerBound = whileOp.getLowerBound(); |
284 | auto upperBound = whileOp.getUpperBound(); |
285 | assert(lowerBound && upperBound && "must be a Value" ); |
286 | |
287 | // The initial values of loop-carried values is obtained from the operands |
288 | // of the loop operation. |
289 | llvm::SmallVector<mlir::Value> destOperands; |
290 | destOperands.push_back(lowerBound); |
291 | auto iterOperands = whileOp.getIterOperands(); |
292 | destOperands.append(iterOperands.begin(), iterOperands.end()); |
293 | rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands); |
294 | |
295 | // With the body block done, we can fill in the condition block. |
296 | rewriter.setInsertionPointToEnd(conditionBlock); |
297 | // The comparison depends on the sign of the step value. We fully expect |
298 | // this expression to be folded by the optimizer or LLVM. This expression |
299 | // is written this way so that `step == 0` always returns `false`. |
300 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
301 | auto compl0 = rewriter.create<mlir::arith::CmpIOp>( |
302 | loc, arith::CmpIPredicate::slt, zero, step); |
303 | auto compl1 = rewriter.create<mlir::arith::CmpIOp>( |
304 | loc, arith::CmpIPredicate::sle, iv, upperBound); |
305 | auto compl2 = rewriter.create<mlir::arith::CmpIOp>( |
306 | loc, arith::CmpIPredicate::slt, step, zero); |
307 | auto compl3 = rewriter.create<mlir::arith::CmpIOp>( |
308 | loc, arith::CmpIPredicate::sle, upperBound, iv); |
309 | auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); |
310 | auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); |
311 | auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); |
312 | // Remember to AND in the early-exit bool. |
313 | auto comparison = |
314 | rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); |
315 | rewriter.create<mlir::cf::CondBranchOp>( |
316 | loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(), |
317 | endBlock, llvm::ArrayRef<mlir::Value>()); |
318 | // The result of the loop operation is the values of the condition block |
319 | // arguments except the induction variable on the last iteration. |
320 | auto args = whileOp.getFinalValue() |
321 | ? conditionBlock->getArguments() |
322 | : conditionBlock->getArguments().drop_front(); |
323 | rewriter.replaceOp(whileOp, args); |
324 | return success(); |
325 | } |
326 | |
327 | private: |
328 | bool setNSW; |
329 | }; |
330 | |
331 | /// Convert FIR structured control flow ops to CFG ops. |
332 | class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> { |
333 | public: |
334 | using CFGConversionBase<CfgConversion>::CFGConversionBase; |
335 | |
336 | void runOnOperation() override { |
337 | auto *context = &this->getContext(); |
338 | mlir::RewritePatternSet patterns(context); |
339 | fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce, |
340 | this->setNSW); |
341 | mlir::ConversionTarget target(*context); |
342 | target.addLegalDialect<mlir::affine::AffineDialect, |
343 | mlir::cf::ControlFlowDialect, FIROpsDialect, |
344 | mlir::func::FuncDialect>(); |
345 | |
346 | // apply the patterns |
347 | target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); |
348 | target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
349 | if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target, |
350 | std::move(patterns)))) { |
351 | mlir::emitError(mlir::UnknownLoc::get(context), |
352 | "error in converting to CFG\n" ); |
353 | this->signalPassFailure(); |
354 | } |
355 | } |
356 | }; |
357 | |
358 | } // namespace |
359 | |
360 | /// Expose conversion rewriters to other passes |
361 | void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns, |
362 | bool forceLoopToExecuteOnce, |
363 | bool setNSW) { |
364 | patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( |
365 | patterns.getContext(), forceLoopToExecuteOnce, setNSW); |
366 | } |
367 | |