1 | //===-- ControlFlowConverter.cpp ------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
10 | #include "flang/Optimizer/Dialect/FIROps.h" |
11 | #include "flang/Optimizer/Dialect/FIROpsSupport.h" |
12 | #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
13 | #include "flang/Optimizer/Dialect/Support/KindMapping.h" |
14 | #include "flang/Optimizer/Support/InternalNames.h" |
15 | #include "flang/Optimizer/Support/TypeCode.h" |
16 | #include "flang/Optimizer/Transforms/Passes.h" |
17 | #include "flang/Runtime/derived-api.h" |
18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
19 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
20 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/DialectConversion.h" |
23 | #include "llvm/ADT/SmallSet.h" |
24 | #include "llvm/Support/CommandLine.h" |
25 | |
26 | namespace fir { |
27 | #define GEN_PASS_DEF_CFGCONVERSION |
28 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
29 | } // namespace fir |
30 | |
31 | using namespace fir; |
32 | using namespace mlir; |
33 | |
34 | namespace { |
35 | |
36 | // Conversion of fir control ops to more primitive control-flow. |
37 | // |
38 | // FIR loops that cannot be converted to the affine dialect will remain as |
39 | // `fir.do_loop` operations. These can be converted to control-flow operations. |
40 | |
41 | /// Convert `fir.do_loop` to CFG |
42 | class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> { |
43 | public: |
44 | using OpRewritePattern::OpRewritePattern; |
45 | |
46 | CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) |
47 | : mlir::OpRewritePattern<fir::DoLoopOp>(ctx), |
48 | forceLoopToExecuteOnce(forceLoopToExecuteOnce) {} |
49 | |
50 | mlir::LogicalResult |
51 | matchAndRewrite(DoLoopOp loop, |
52 | mlir::PatternRewriter &rewriter) const override { |
53 | auto loc = loop.getLoc(); |
54 | |
55 | // Create the start and end blocks that will wrap the DoLoopOp with an |
56 | // initalizer and an end point |
57 | auto *initBlock = rewriter.getInsertionBlock(); |
58 | auto initPos = rewriter.getInsertionPoint(); |
59 | auto *endBlock = rewriter.splitBlock(initBlock, initPos); |
60 | |
61 | // Split the first DoLoopOp block in two parts. The part before will be the |
62 | // conditional block since it already has the induction variable and |
63 | // loop-carried values as arguments. |
64 | auto *conditionalBlock = &loop.getRegion().front(); |
65 | conditionalBlock->addArgument(rewriter.getIndexType(), loc); |
66 | auto *firstBlock = |
67 | rewriter.splitBlock(conditionalBlock, conditionalBlock->begin()); |
68 | auto *lastBlock = &loop.getRegion().back(); |
69 | |
70 | // Move the blocks from the DoLoopOp between initBlock and endBlock |
71 | rewriter.inlineRegionBefore(loop.getRegion(), endBlock); |
72 | |
73 | // Get loop values from the DoLoopOp |
74 | auto low = loop.getLowerBound(); |
75 | auto high = loop.getUpperBound(); |
76 | assert(low && high && "must be a Value" ); |
77 | auto step = loop.getStep(); |
78 | |
79 | // Initalization block |
80 | rewriter.setInsertionPointToEnd(initBlock); |
81 | auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low); |
82 | auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step); |
83 | mlir::Value iters = |
84 | rewriter.create<mlir::arith::DivSIOp>(loc, distance, step); |
85 | |
86 | if (forceLoopToExecuteOnce) { |
87 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
88 | auto cond = rewriter.create<mlir::arith::CmpIOp>( |
89 | loc, arith::CmpIPredicate::sle, iters, zero); |
90 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
91 | iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters); |
92 | } |
93 | |
94 | llvm::SmallVector<mlir::Value> loopOperands; |
95 | loopOperands.push_back(low); |
96 | auto operands = loop.getIterOperands(); |
97 | loopOperands.append(operands.begin(), operands.end()); |
98 | loopOperands.push_back(iters); |
99 | |
100 | rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands); |
101 | |
102 | // Last loop block |
103 | auto *terminator = lastBlock->getTerminator(); |
104 | rewriter.setInsertionPointToEnd(lastBlock); |
105 | auto iv = conditionalBlock->getArgument(0); |
106 | mlir::Value steppedIndex = |
107 | rewriter.create<mlir::arith::AddIOp>(loc, iv, step); |
108 | assert(steppedIndex && "must be a Value" ); |
109 | auto lastArg = conditionalBlock->getNumArguments() - 1; |
110 | auto itersLeft = conditionalBlock->getArgument(lastArg); |
111 | auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1); |
112 | mlir::Value itersMinusOne = |
113 | rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one); |
114 | |
115 | llvm::SmallVector<mlir::Value> loopCarried; |
116 | loopCarried.push_back(steppedIndex); |
117 | auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin()) |
118 | : terminator->operand_begin(); |
119 | loopCarried.append(begin, terminator->operand_end()); |
120 | loopCarried.push_back(itersMinusOne); |
121 | rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried); |
122 | rewriter.eraseOp(terminator); |
123 | |
124 | // Conditional block |
125 | rewriter.setInsertionPointToEnd(conditionalBlock); |
126 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
127 | auto comparison = rewriter.create<mlir::arith::CmpIOp>( |
128 | loc, arith::CmpIPredicate::sgt, itersLeft, zero); |
129 | |
130 | rewriter.create<mlir::cf::CondBranchOp>( |
131 | loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock, |
132 | llvm::ArrayRef<mlir::Value>()); |
133 | |
134 | // The result of the loop operation is the values of the condition block |
135 | // arguments except the induction variable on the last iteration. |
136 | auto args = loop.getFinalValue() |
137 | ? conditionalBlock->getArguments() |
138 | : conditionalBlock->getArguments().drop_front(); |
139 | rewriter.replaceOp(loop, args.drop_back()); |
140 | return success(); |
141 | } |
142 | |
143 | private: |
144 | bool forceLoopToExecuteOnce; |
145 | }; |
146 | |
147 | /// Convert `fir.if` to control-flow |
148 | class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> { |
149 | public: |
150 | using OpRewritePattern::OpRewritePattern; |
151 | |
152 | CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) |
153 | : mlir::OpRewritePattern<fir::IfOp>(ctx) {} |
154 | |
155 | mlir::LogicalResult |
156 | matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override { |
157 | auto loc = ifOp.getLoc(); |
158 | |
159 | // Split the block containing the 'fir.if' into two parts. The part before |
160 | // will contain the condition, the part after will be the continuation |
161 | // point. |
162 | auto *condBlock = rewriter.getInsertionBlock(); |
163 | auto opPosition = rewriter.getInsertionPoint(); |
164 | auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); |
165 | mlir::Block *continueBlock; |
166 | if (ifOp.getNumResults() == 0) { |
167 | continueBlock = remainingOpsBlock; |
168 | } else { |
169 | continueBlock = rewriter.createBlock( |
170 | remainingOpsBlock, ifOp.getResultTypes(), |
171 | llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc)); |
172 | rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock); |
173 | } |
174 | |
175 | // Move blocks from the "then" region to the region containing 'fir.if', |
176 | // place it before the continuation block, and branch to it. |
177 | auto &ifOpRegion = ifOp.getThenRegion(); |
178 | auto *ifOpBlock = &ifOpRegion.front(); |
179 | auto *ifOpTerminator = ifOpRegion.back().getTerminator(); |
180 | auto ifOpTerminatorOperands = ifOpTerminator->getOperands(); |
181 | rewriter.setInsertionPointToEnd(&ifOpRegion.back()); |
182 | rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, |
183 | ifOpTerminatorOperands); |
184 | rewriter.eraseOp(ifOpTerminator); |
185 | rewriter.inlineRegionBefore(ifOpRegion, continueBlock); |
186 | |
187 | // Move blocks from the "else" region (if present) to the region containing |
188 | // 'fir.if', place it before the continuation block and branch to it. It |
189 | // will be placed after the "then" regions. |
190 | auto *otherwiseBlock = continueBlock; |
191 | auto &otherwiseRegion = ifOp.getElseRegion(); |
192 | if (!otherwiseRegion.empty()) { |
193 | otherwiseBlock = &otherwiseRegion.front(); |
194 | auto *otherwiseTerm = otherwiseRegion.back().getTerminator(); |
195 | auto otherwiseTermOperands = otherwiseTerm->getOperands(); |
196 | rewriter.setInsertionPointToEnd(&otherwiseRegion.back()); |
197 | rewriter.create<mlir::cf::BranchOp>(loc, continueBlock, |
198 | otherwiseTermOperands); |
199 | rewriter.eraseOp(otherwiseTerm); |
200 | rewriter.inlineRegionBefore(otherwiseRegion, continueBlock); |
201 | } |
202 | |
203 | rewriter.setInsertionPointToEnd(condBlock); |
204 | rewriter.create<mlir::cf::CondBranchOp>( |
205 | loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(), |
206 | otherwiseBlock, llvm::ArrayRef<mlir::Value>()); |
207 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
208 | return success(); |
209 | } |
210 | }; |
211 | |
212 | /// Convert `fir.iter_while` to control-flow. |
213 | class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> { |
214 | public: |
215 | using OpRewritePattern::OpRewritePattern; |
216 | |
217 | CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce) |
218 | : mlir::OpRewritePattern<fir::IterWhileOp>(ctx) {} |
219 | |
220 | mlir::LogicalResult |
221 | matchAndRewrite(fir::IterWhileOp whileOp, |
222 | mlir::PatternRewriter &rewriter) const override { |
223 | auto loc = whileOp.getLoc(); |
224 | |
225 | // Start by splitting the block containing the 'fir.do_loop' into two parts. |
226 | // The part before will get the init code, the part after will be the end |
227 | // point. |
228 | auto *initBlock = rewriter.getInsertionBlock(); |
229 | auto initPosition = rewriter.getInsertionPoint(); |
230 | auto *endBlock = rewriter.splitBlock(initBlock, initPosition); |
231 | |
232 | // Use the first block of the loop body as the condition block since it is |
233 | // the block that has the induction variable and loop-carried values as |
234 | // arguments. Split out all operations from the first block into a new |
235 | // block. Move all body blocks from the loop body region to the region |
236 | // containing the loop. |
237 | auto *conditionBlock = &whileOp.getRegion().front(); |
238 | auto *firstBodyBlock = |
239 | rewriter.splitBlock(conditionBlock, conditionBlock->begin()); |
240 | auto *lastBodyBlock = &whileOp.getRegion().back(); |
241 | rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock); |
242 | auto iv = conditionBlock->getArgument(0); |
243 | auto iterateVar = conditionBlock->getArgument(1); |
244 | |
245 | // Append the induction variable stepping logic to the last body block and |
246 | // branch back to the condition block. Loop-carried values are taken from |
247 | // operands of the loop terminator. |
248 | auto *terminator = lastBodyBlock->getTerminator(); |
249 | rewriter.setInsertionPointToEnd(lastBodyBlock); |
250 | auto step = whileOp.getStep(); |
251 | mlir::Value stepped = rewriter.create<mlir::arith::AddIOp>(loc, iv, step); |
252 | assert(stepped && "must be a Value" ); |
253 | |
254 | llvm::SmallVector<mlir::Value> loopCarried; |
255 | loopCarried.push_back(stepped); |
256 | auto begin = whileOp.getFinalValue() |
257 | ? std::next(terminator->operand_begin()) |
258 | : terminator->operand_begin(); |
259 | loopCarried.append(begin, terminator->operand_end()); |
260 | rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried); |
261 | rewriter.eraseOp(terminator); |
262 | |
263 | // Compute loop bounds before branching to the condition. |
264 | rewriter.setInsertionPointToEnd(initBlock); |
265 | auto lowerBound = whileOp.getLowerBound(); |
266 | auto upperBound = whileOp.getUpperBound(); |
267 | assert(lowerBound && upperBound && "must be a Value" ); |
268 | |
269 | // The initial values of loop-carried values is obtained from the operands |
270 | // of the loop operation. |
271 | llvm::SmallVector<mlir::Value> destOperands; |
272 | destOperands.push_back(lowerBound); |
273 | auto iterOperands = whileOp.getIterOperands(); |
274 | destOperands.append(iterOperands.begin(), iterOperands.end()); |
275 | rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands); |
276 | |
277 | // With the body block done, we can fill in the condition block. |
278 | rewriter.setInsertionPointToEnd(conditionBlock); |
279 | // The comparison depends on the sign of the step value. We fully expect |
280 | // this expression to be folded by the optimizer or LLVM. This expression |
281 | // is written this way so that `step == 0` always returns `false`. |
282 | auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); |
283 | auto compl0 = rewriter.create<mlir::arith::CmpIOp>( |
284 | loc, arith::CmpIPredicate::slt, zero, step); |
285 | auto compl1 = rewriter.create<mlir::arith::CmpIOp>( |
286 | loc, arith::CmpIPredicate::sle, iv, upperBound); |
287 | auto compl2 = rewriter.create<mlir::arith::CmpIOp>( |
288 | loc, arith::CmpIPredicate::slt, step, zero); |
289 | auto compl3 = rewriter.create<mlir::arith::CmpIOp>( |
290 | loc, arith::CmpIPredicate::sle, upperBound, iv); |
291 | auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1); |
292 | auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3); |
293 | auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1); |
294 | // Remember to AND in the early-exit bool. |
295 | auto comparison = |
296 | rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2); |
297 | rewriter.create<mlir::cf::CondBranchOp>( |
298 | loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(), |
299 | endBlock, llvm::ArrayRef<mlir::Value>()); |
300 | // The result of the loop operation is the values of the condition block |
301 | // arguments except the induction variable on the last iteration. |
302 | auto args = whileOp.getFinalValue() |
303 | ? conditionBlock->getArguments() |
304 | : conditionBlock->getArguments().drop_front(); |
305 | rewriter.replaceOp(whileOp, args); |
306 | return success(); |
307 | } |
308 | }; |
309 | |
310 | /// Convert FIR structured control flow ops to CFG ops. |
311 | class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> { |
312 | public: |
313 | using CFGConversionBase<CfgConversion>::CFGConversionBase; |
314 | |
315 | void runOnOperation() override { |
316 | auto *context = &this->getContext(); |
317 | mlir::RewritePatternSet patterns(context); |
318 | fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce); |
319 | mlir::ConversionTarget target(*context); |
320 | target.addLegalDialect<mlir::affine::AffineDialect, |
321 | mlir::cf::ControlFlowDialect, FIROpsDialect, |
322 | mlir::func::FuncDialect>(); |
323 | |
324 | // apply the patterns |
325 | target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>(); |
326 | target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); |
327 | if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target, |
328 | std::move(patterns)))) { |
329 | mlir::emitError(mlir::UnknownLoc::get(context), |
330 | "error in converting to CFG\n" ); |
331 | this->signalPassFailure(); |
332 | } |
333 | } |
334 | }; |
335 | |
336 | } // namespace |
337 | |
338 | /// Expose conversion rewriters to other passes |
339 | void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns, |
340 | bool forceLoopToExecuteOnce) { |
341 | patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( |
342 | patterns.getContext(), forceLoopToExecuteOnce); |
343 | } |
344 | |