1 | //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert scf.for, scf.if and loop.terminator |
10 | // ops into standard CFG ops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
15 | |
16 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
21 | #include "mlir/IR/Builders.h" |
22 | #include "mlir/IR/BuiltinOps.h" |
23 | #include "mlir/IR/IRMapping.h" |
24 | #include "mlir/IR/MLIRContext.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/Transforms/DialectConversion.h" |
27 | #include "mlir/Transforms/Passes.h" |
28 | |
29 | namespace mlir { |
30 | #define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS |
31 | #include "mlir/Conversion/Passes.h.inc" |
32 | } // namespace mlir |
33 | |
34 | using namespace mlir; |
35 | using namespace mlir::scf; |
36 | |
37 | namespace { |
38 | |
39 | struct SCFToControlFlowPass |
40 | : public impl::SCFToControlFlowPassBase<SCFToControlFlowPass> { |
41 | void runOnOperation() override; |
42 | }; |
43 | |
44 | // Create a CFG subgraph for the loop around its body blocks (if the body |
45 | // contained other loops, they have been already lowered to a flow of blocks). |
46 | // Maintain the invariants that a CFG subgraph created for any loop has a single |
47 | // entry and a single exit, and that the entry/exit blocks are respectively |
48 | // first/last blocks in the parent region. The original loop operation is |
49 | // replaced by the initialization operations that set up the initial value of |
50 | // the loop induction variable (%iv) and computes the loop bounds that are loop- |
51 | // invariant for affine loops. The operations following the original scf.for |
52 | // are split out into a separate continuation (exit) block. A condition block is |
53 | // created before the continuation block. It checks the exit condition of the |
54 | // loop and branches either to the continuation block, or to the first block of |
55 | // the body. The condition block takes as arguments the values of the induction |
56 | // variable followed by loop-carried values. Since it dominates both the body |
57 | // blocks and the continuation block, loop-carried values are visible in all of |
58 | // those blocks. Induction variable modification is appended to the last block |
59 | // of the body (which is the exit block from the body subgraph thanks to the |
60 | // invariant we maintain) along with a branch that loops back to the condition |
61 | // block. Loop-carried values are the loop terminator operands, which are |
62 | // forwarded to the branch. |
63 | // |
64 | // +---------------------------------+ |
65 | // | <code before the ForOp> | |
66 | // | <definitions of %init...> | |
67 | // | <compute initial %iv value> | |
68 | // | cf.br cond(%iv, %init...) | |
69 | // +---------------------------------+ |
70 | // | |
71 | // -------| | |
72 | // | v v |
73 | // | +--------------------------------+ |
74 | // | | cond(%iv, %init...): | |
75 | // | | <compare %iv to upper bound> | |
76 | // | | cf.cond_br %r, body, end | |
77 | // | +--------------------------------+ |
78 | // | | | |
79 | // | | -------------| |
80 | // | v | |
81 | // | +--------------------------------+ | |
82 | // | | body-first: | | |
83 | // | | <%init visible by dominance> | | |
84 | // | | <body contents> | | |
85 | // | +--------------------------------+ | |
86 | // | | | |
87 | // | ... | |
88 | // | | | |
89 | // | +--------------------------------+ | |
90 | // | | body-last: | | |
91 | // | | <body contents> | | |
92 | // | | <operands of yield = %yields>| | |
93 | // | | %new_iv =<add step to %iv> | | |
94 | // | | cf.br cond(%new_iv, %yields) | | |
95 | // | +--------------------------------+ | |
96 | // | | | |
97 | // |----------- |-------------------- |
98 | // v |
99 | // +--------------------------------+ |
100 | // | end: | |
101 | // | <code after the ForOp> | |
102 | // | <%init visible by dominance> | |
103 | // +--------------------------------+ |
104 | // |
105 | struct ForLowering : public OpRewritePattern<ForOp> { |
106 | using OpRewritePattern<ForOp>::OpRewritePattern; |
107 | |
108 | LogicalResult matchAndRewrite(ForOp forOp, |
109 | PatternRewriter &rewriter) const override; |
110 | }; |
111 | |
112 | // Create a CFG subgraph for the scf.if operation (including its "then" and |
113 | // optional "else" operation blocks). We maintain the invariants that the |
114 | // subgraph has a single entry and a single exit point, and that the entry/exit |
115 | // blocks are respectively the first/last block of the enclosing region. The |
116 | // operations following the scf.if are split into a continuation (subgraph |
117 | // exit) block. The condition is lowered to a chain of blocks that implement the |
118 | // short-circuit scheme. The "scf.if" operation is replaced with a conditional |
119 | // branch to either the first block of the "then" region, or to the first block |
120 | // of the "else" region. In these blocks, "scf.yield" is unconditional branches |
121 | // to the post-dominating block. When the "scf.if" does not return values, the |
122 | // post-dominating block is the same as the continuation block. When it returns |
123 | // values, the post-dominating block is a new block with arguments that |
124 | // correspond to the values returned by the "scf.if" that unconditionally |
125 | // branches to the continuation block. This allows block arguments to dominate |
126 | // any uses of the hitherto "scf.if" results that they replaced. (Inserting a |
127 | // new block allows us to avoid modifying the argument list of an existing |
128 | // block, which is illegal in a conversion pattern). When the "else" region is |
129 | // empty, which is only allowed for "scf.if"s that don't return values, the |
130 | // condition branches directly to the continuation block. |
131 | // |
132 | // CFG for a scf.if with else and without results. |
133 | // |
134 | // +--------------------------------+ |
135 | // | <code before the IfOp> | |
136 | // | cf.cond_br %cond, %then, %else | |
137 | // +--------------------------------+ |
138 | // | | |
139 | // | --------------| |
140 | // v | |
141 | // +--------------------------------+ | |
142 | // | then: | | |
143 | // | <then contents> | | |
144 | // | cf.br continue | | |
145 | // +--------------------------------+ | |
146 | // | | |
147 | // |---------- |------------- |
148 | // | V |
149 | // | +--------------------------------+ |
150 | // | | else: | |
151 | // | | <else contents> | |
152 | // | | cf.br continue | |
153 | // | +--------------------------------+ |
154 | // | | |
155 | // ------| | |
156 | // v v |
157 | // +--------------------------------+ |
158 | // | continue: | |
159 | // | <code after the IfOp> | |
160 | // +--------------------------------+ |
161 | // |
162 | // CFG for a scf.if with results. |
163 | // |
164 | // +--------------------------------+ |
165 | // | <code before the IfOp> | |
166 | // | cf.cond_br %cond, %then, %else | |
167 | // +--------------------------------+ |
168 | // | | |
169 | // | --------------| |
170 | // v | |
171 | // +--------------------------------+ | |
172 | // | then: | | |
173 | // | <then contents> | | |
174 | // | cf.br dom(%args...) | | |
175 | // +--------------------------------+ | |
176 | // | | |
177 | // |---------- |------------- |
178 | // | V |
179 | // | +--------------------------------+ |
180 | // | | else: | |
181 | // | | <else contents> | |
182 | // | | cf.br dom(%args...) | |
183 | // | +--------------------------------+ |
184 | // | | |
185 | // ------| | |
186 | // v v |
187 | // +--------------------------------+ |
188 | // | dom(%args...): | |
189 | // | cf.br continue | |
190 | // +--------------------------------+ |
191 | // | |
192 | // v |
193 | // +--------------------------------+ |
194 | // | continue: | |
195 | // | <code after the IfOp> | |
196 | // +--------------------------------+ |
197 | // |
198 | struct IfLowering : public OpRewritePattern<IfOp> { |
199 | using OpRewritePattern<IfOp>::OpRewritePattern; |
200 | |
201 | LogicalResult matchAndRewrite(IfOp ifOp, |
202 | PatternRewriter &rewriter) const override; |
203 | }; |
204 | |
205 | struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> { |
206 | using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; |
207 | |
208 | LogicalResult matchAndRewrite(ExecuteRegionOp op, |
209 | PatternRewriter &rewriter) const override; |
210 | }; |
211 | |
212 | struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> { |
213 | using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern; |
214 | |
215 | LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp, |
216 | PatternRewriter &rewriter) const override; |
217 | }; |
218 | |
219 | /// Create a CFG subgraph for this loop construct. The regions of the loop need |
220 | /// not be a single block anymore (for example, if other SCF constructs that |
221 | /// they contain have been already converted to CFG), but need to be single-exit |
222 | /// from the last block of each region. The operations following the original |
223 | /// WhileOp are split into a new continuation block. Both regions of the WhileOp |
224 | /// are inlined, and their terminators are rewritten to organize the control |
225 | /// flow implementing the loop as follows. |
226 | /// |
227 | /// +---------------------------------+ |
228 | /// | <code before the WhileOp> | |
229 | /// | cf.br ^before(%operands...) | |
230 | /// +---------------------------------+ |
231 | /// | |
232 | /// -------| | |
233 | /// | v v |
234 | /// | +--------------------------------+ |
235 | /// | | ^before(%bargs...): | |
236 | /// | | %vals... = <some payload> | |
237 | /// | +--------------------------------+ |
238 | /// | | |
239 | /// | ... |
240 | /// | | |
241 | /// | +--------------------------------+ |
242 | /// | | ^before-last: |
243 | /// | | %cond = <compute condition> | |
244 | /// | | cf.cond_br %cond, | |
245 | /// | | ^after(%vals...), ^cont | |
246 | /// | +--------------------------------+ |
247 | /// | | | |
248 | /// | | -------------| |
249 | /// | v | |
250 | /// | +--------------------------------+ | |
251 | /// | | ^after(%aargs...): | | |
252 | /// | | <body contents> | | |
253 | /// | +--------------------------------+ | |
254 | /// | | | |
255 | /// | ... | |
256 | /// | | | |
257 | /// | +--------------------------------+ | |
258 | /// | | ^after-last: | | |
259 | /// | | %yields... = <some payload> | | |
260 | /// | | cf.br ^before(%yields...) | | |
261 | /// | +--------------------------------+ | |
262 | /// | | | |
263 | /// |----------- |-------------------- |
264 | /// v |
265 | /// +--------------------------------+ |
266 | /// | ^cont: | |
267 | /// | <code after the WhileOp> | |
268 | /// | <%vals from 'before' region | |
269 | /// | visible by dominance> | |
270 | /// +--------------------------------+ |
271 | /// |
272 | /// Values are communicated between ex-regions (the groups of blocks that used |
273 | /// to form a region before inlining) through block arguments of their |
274 | /// entry blocks, which are visible in all other dominated blocks. Similarly, |
275 | /// the results of the WhileOp are defined in the 'before' region, which is |
276 | /// required to have a single existing block, and are therefore accessible in |
277 | /// the continuation block due to dominance. |
278 | struct WhileLowering : public OpRewritePattern<WhileOp> { |
279 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
280 | |
281 | LogicalResult matchAndRewrite(WhileOp whileOp, |
282 | PatternRewriter &rewriter) const override; |
283 | }; |
284 | |
285 | /// Optimized version of the above for the case of the "after" region merely |
286 | /// forwarding its arguments back to the "before" region (i.e., a "do-while" |
287 | /// loop). This avoid inlining the "after" region completely and branches back |
288 | /// to the "before" entry instead. |
289 | struct DoWhileLowering : public OpRewritePattern<WhileOp> { |
290 | using OpRewritePattern<WhileOp>::OpRewritePattern; |
291 | |
292 | LogicalResult matchAndRewrite(WhileOp whileOp, |
293 | PatternRewriter &rewriter) const override; |
294 | }; |
295 | |
296 | /// Lower an `scf.index_switch` operation to a `cf.switch` operation. |
297 | struct IndexSwitchLowering : public OpRewritePattern<IndexSwitchOp> { |
298 | using OpRewritePattern::OpRewritePattern; |
299 | |
300 | LogicalResult matchAndRewrite(IndexSwitchOp op, |
301 | PatternRewriter &rewriter) const override; |
302 | }; |
303 | |
304 | /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it |
305 | /// has no shared outputs. Ops with shared outputs should be bufferized first. |
306 | /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other |
307 | /// dialects/passes. |
308 | struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> { |
309 | using OpRewritePattern<mlir::scf::ForallOp>::OpRewritePattern; |
310 | |
311 | LogicalResult matchAndRewrite(mlir::scf::ForallOp forallOp, |
312 | PatternRewriter &rewriter) const override; |
313 | }; |
314 | |
315 | } // namespace |
316 | |
317 | LogicalResult ForLowering::matchAndRewrite(ForOp forOp, |
318 | PatternRewriter &rewriter) const { |
319 | Location loc = forOp.getLoc(); |
320 | |
321 | // Start by splitting the block containing the 'scf.for' into two parts. |
322 | // The part before will get the init code, the part after will be the end |
323 | // point. |
324 | auto *initBlock = rewriter.getInsertionBlock(); |
325 | auto initPosition = rewriter.getInsertionPoint(); |
326 | auto *endBlock = rewriter.splitBlock(block: initBlock, before: initPosition); |
327 | |
328 | // Use the first block of the loop body as the condition block since it is the |
329 | // block that has the induction variable and loop-carried values as arguments. |
330 | // Split out all operations from the first block into a new block. Move all |
331 | // body blocks from the loop body region to the region containing the loop. |
332 | auto *conditionBlock = &forOp.getRegion().front(); |
333 | auto *firstBodyBlock = |
334 | rewriter.splitBlock(block: conditionBlock, before: conditionBlock->begin()); |
335 | auto *lastBodyBlock = &forOp.getRegion().back(); |
336 | rewriter.inlineRegionBefore(forOp.getRegion(), endBlock); |
337 | auto iv = conditionBlock->getArgument(0); |
338 | |
339 | // Append the induction variable stepping logic to the last body block and |
340 | // branch back to the condition block. Loop-carried values are taken from |
341 | // operands of the loop terminator. |
342 | Operation *terminator = lastBodyBlock->getTerminator(); |
343 | rewriter.setInsertionPointToEnd(lastBodyBlock); |
344 | auto step = forOp.getStep(); |
345 | auto stepped = rewriter.create<arith::AddIOp>(loc, iv, step).getResult(); |
346 | if (!stepped) |
347 | return failure(); |
348 | |
349 | SmallVector<Value, 8> loopCarried; |
350 | loopCarried.push_back(Elt: stepped); |
351 | loopCarried.append(in_start: terminator->operand_begin(), in_end: terminator->operand_end()); |
352 | rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried); |
353 | rewriter.eraseOp(op: terminator); |
354 | |
355 | // Compute loop bounds before branching to the condition. |
356 | rewriter.setInsertionPointToEnd(initBlock); |
357 | Value lowerBound = forOp.getLowerBound(); |
358 | Value upperBound = forOp.getUpperBound(); |
359 | if (!lowerBound || !upperBound) |
360 | return failure(); |
361 | |
362 | // The initial values of loop-carried values is obtained from the operands |
363 | // of the loop operation. |
364 | SmallVector<Value, 8> destOperands; |
365 | destOperands.push_back(Elt: lowerBound); |
366 | llvm::append_range(destOperands, forOp.getInitArgs()); |
367 | rewriter.create<cf::BranchOp>(loc, conditionBlock, destOperands); |
368 | |
369 | // With the body block done, we can fill in the condition block. |
370 | rewriter.setInsertionPointToEnd(conditionBlock); |
371 | auto comparison = rewriter.create<arith::CmpIOp>( |
372 | loc, arith::CmpIPredicate::slt, iv, upperBound); |
373 | |
374 | auto condBranchOp = rewriter.create<cf::CondBranchOp>( |
375 | loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock, |
376 | ArrayRef<Value>()); |
377 | |
378 | // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the |
379 | // llvm.loop_annotation attribute. |
380 | SmallVector<NamedAttribute> llvmAttrs; |
381 | llvm::copy_if(forOp->getAttrs(), std::back_inserter(x&: llvmAttrs), |
382 | [](auto attr) { |
383 | return isa<LLVM::LLVMDialect>(attr.getValue().getDialect()); |
384 | }); |
385 | condBranchOp->setDiscardableAttrs(llvmAttrs); |
386 | // The result of the loop operation is the values of the condition block |
387 | // arguments except the induction variable on the last iteration. |
388 | rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); |
389 | return success(); |
390 | } |
391 | |
392 | LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, |
393 | PatternRewriter &rewriter) const { |
394 | auto loc = ifOp.getLoc(); |
395 | |
396 | // Start by splitting the block containing the 'scf.if' into two parts. |
397 | // The part before will contain the condition, the part after will be the |
398 | // continuation point. |
399 | auto *condBlock = rewriter.getInsertionBlock(); |
400 | auto opPosition = rewriter.getInsertionPoint(); |
401 | auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition); |
402 | Block *continueBlock; |
403 | if (ifOp.getNumResults() == 0) { |
404 | continueBlock = remainingOpsBlock; |
405 | } else { |
406 | continueBlock = |
407 | rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes(), |
408 | SmallVector<Location>(ifOp.getNumResults(), loc)); |
409 | rewriter.create<cf::BranchOp>(loc, remainingOpsBlock); |
410 | } |
411 | |
412 | // Move blocks from the "then" region to the region containing 'scf.if', |
413 | // place it before the continuation block, and branch to it. |
414 | auto &thenRegion = ifOp.getThenRegion(); |
415 | auto *thenBlock = &thenRegion.front(); |
416 | Operation *thenTerminator = thenRegion.back().getTerminator(); |
417 | ValueRange thenTerminatorOperands = thenTerminator->getOperands(); |
418 | rewriter.setInsertionPointToEnd(&thenRegion.back()); |
419 | rewriter.create<cf::BranchOp>(loc, continueBlock, thenTerminatorOperands); |
420 | rewriter.eraseOp(op: thenTerminator); |
421 | rewriter.inlineRegionBefore(thenRegion, continueBlock); |
422 | |
423 | // Move blocks from the "else" region (if present) to the region containing |
424 | // 'scf.if', place it before the continuation block and branch to it. It |
425 | // will be placed after the "then" regions. |
426 | auto *elseBlock = continueBlock; |
427 | auto &elseRegion = ifOp.getElseRegion(); |
428 | if (!elseRegion.empty()) { |
429 | elseBlock = &elseRegion.front(); |
430 | Operation *elseTerminator = elseRegion.back().getTerminator(); |
431 | ValueRange elseTerminatorOperands = elseTerminator->getOperands(); |
432 | rewriter.setInsertionPointToEnd(&elseRegion.back()); |
433 | rewriter.create<cf::BranchOp>(loc, continueBlock, elseTerminatorOperands); |
434 | rewriter.eraseOp(op: elseTerminator); |
435 | rewriter.inlineRegionBefore(elseRegion, continueBlock); |
436 | } |
437 | |
438 | rewriter.setInsertionPointToEnd(condBlock); |
439 | rewriter.create<cf::CondBranchOp>(loc, ifOp.getCondition(), thenBlock, |
440 | /*trueArgs=*/ArrayRef<Value>(), elseBlock, |
441 | /*falseArgs=*/ArrayRef<Value>()); |
442 | |
443 | // Ok, we're done! |
444 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
445 | return success(); |
446 | } |
447 | |
448 | LogicalResult |
449 | ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op, |
450 | PatternRewriter &rewriter) const { |
451 | auto loc = op.getLoc(); |
452 | |
453 | auto *condBlock = rewriter.getInsertionBlock(); |
454 | auto opPosition = rewriter.getInsertionPoint(); |
455 | auto *remainingOpsBlock = rewriter.splitBlock(block: condBlock, before: opPosition); |
456 | |
457 | auto ®ion = op.getRegion(); |
458 | rewriter.setInsertionPointToEnd(condBlock); |
459 | rewriter.create<cf::BranchOp>(loc, ®ion.front()); |
460 | |
461 | for (Block &block : region) { |
462 | if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) { |
463 | ValueRange terminatorOperands = terminator->getOperands(); |
464 | rewriter.setInsertionPointToEnd(&block); |
465 | rewriter.create<cf::BranchOp>(loc, remainingOpsBlock, terminatorOperands); |
466 | rewriter.eraseOp(terminator); |
467 | } |
468 | } |
469 | |
470 | rewriter.inlineRegionBefore(region, remainingOpsBlock); |
471 | |
472 | SmallVector<Value> vals; |
473 | SmallVector<Location> argLocs(op.getNumResults(), op->getLoc()); |
474 | for (auto arg : |
475 | remainingOpsBlock->addArguments(op->getResultTypes(), argLocs)) |
476 | vals.push_back(arg); |
477 | rewriter.replaceOp(op, vals); |
478 | return success(); |
479 | } |
480 | |
481 | LogicalResult |
482 | ParallelLowering::matchAndRewrite(ParallelOp parallelOp, |
483 | PatternRewriter &rewriter) const { |
484 | Location loc = parallelOp.getLoc(); |
485 | auto reductionOp = dyn_cast<ReduceOp>(parallelOp.getBody()->getTerminator()); |
486 | if (!reductionOp) { |
487 | return failure(); |
488 | } |
489 | |
490 | // For a parallel loop, we essentially need to create an n-dimensional loop |
491 | // nest. We do this by translating to scf.for ops and have those lowered in |
492 | // a further rewrite. If a parallel loop contains reductions (and thus returns |
493 | // values), forward the initial values for the reductions down the loop |
494 | // hierarchy and bubble up the results by modifying the "yield" terminator. |
495 | SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.getInitVals()); |
496 | SmallVector<Value, 4> ivs; |
497 | ivs.reserve(N: parallelOp.getNumLoops()); |
498 | bool first = true; |
499 | SmallVector<Value, 4> loopResults(iterArgs); |
500 | for (auto [iv, lower, upper, step] : |
501 | llvm::zip(parallelOp.getInductionVars(), parallelOp.getLowerBound(), |
502 | parallelOp.getUpperBound(), parallelOp.getStep())) { |
503 | ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs); |
504 | ivs.push_back(forOp.getInductionVar()); |
505 | auto iterRange = forOp.getRegionIterArgs(); |
506 | iterArgs.assign(iterRange.begin(), iterRange.end()); |
507 | |
508 | if (first) { |
509 | // Store the results of the outermost loop that will be used to replace |
510 | // the results of the parallel loop when it is fully rewritten. |
511 | loopResults.assign(forOp.result_begin(), forOp.result_end()); |
512 | first = false; |
513 | } else if (!forOp.getResults().empty()) { |
514 | // A loop is constructed with an empty "yield" terminator if there are |
515 | // no results. |
516 | rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); |
517 | rewriter.create<scf::YieldOp>(loc, forOp.getResults()); |
518 | } |
519 | |
520 | rewriter.setInsertionPointToStart(forOp.getBody()); |
521 | } |
522 | |
523 | // First, merge reduction blocks into the main region. |
524 | SmallVector<Value> yieldOperands; |
525 | yieldOperands.reserve(N: parallelOp.getNumResults()); |
526 | for (int64_t i = 0, e = parallelOp.getNumResults(); i < e; ++i) { |
527 | Block &reductionBody = reductionOp.getReductions()[i].front(); |
528 | Value arg = iterArgs[yieldOperands.size()]; |
529 | yieldOperands.push_back( |
530 | cast<ReduceReturnOp>(reductionBody.getTerminator()).getResult()); |
531 | rewriter.eraseOp(op: reductionBody.getTerminator()); |
532 | rewriter.inlineBlockBefore(&reductionBody, reductionOp, |
533 | {arg, reductionOp.getOperands()[i]}); |
534 | } |
535 | rewriter.eraseOp(op: reductionOp); |
536 | |
537 | // Then merge the loop body without the terminator. |
538 | Block *newBody = rewriter.getInsertionBlock(); |
539 | if (newBody->empty()) |
540 | rewriter.mergeBlocks(source: parallelOp.getBody(), dest: newBody, argValues: ivs); |
541 | else |
542 | rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(), |
543 | ivs); |
544 | |
545 | // Finally, create the terminator if required (for loops with no results, it |
546 | // has been already created in loop construction). |
547 | if (!yieldOperands.empty()) { |
548 | rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); |
549 | rewriter.create<scf::YieldOp>(loc, yieldOperands); |
550 | } |
551 | |
552 | rewriter.replaceOp(parallelOp, loopResults); |
553 | |
554 | return success(); |
555 | } |
556 | |
557 | LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, |
558 | PatternRewriter &rewriter) const { |
559 | OpBuilder::InsertionGuard guard(rewriter); |
560 | Location loc = whileOp.getLoc(); |
561 | |
562 | // Split the current block before the WhileOp to create the inlining point. |
563 | Block *currentBlock = rewriter.getInsertionBlock(); |
564 | Block *continuation = |
565 | rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint()); |
566 | |
567 | // Inline both regions. |
568 | Block *after = whileOp.getAfterBody(); |
569 | Block *before = whileOp.getBeforeBody(); |
570 | rewriter.inlineRegionBefore(whileOp.getAfter(), continuation); |
571 | rewriter.inlineRegionBefore(whileOp.getBefore(), after); |
572 | |
573 | // Branch to the "before" region. |
574 | rewriter.setInsertionPointToEnd(currentBlock); |
575 | rewriter.create<cf::BranchOp>(loc, before, whileOp.getInits()); |
576 | |
577 | // Replace terminators with branches. Assuming bodies are SESE, which holds |
578 | // given only the patterns from this file, we only need to look at the last |
579 | // block. This should be reconsidered if we allow break/continue in SCF. |
580 | rewriter.setInsertionPointToEnd(before); |
581 | auto condOp = cast<ConditionOp>(before->getTerminator()); |
582 | rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), |
583 | after, condOp.getArgs(), |
584 | continuation, ValueRange()); |
585 | |
586 | rewriter.setInsertionPointToEnd(after); |
587 | auto yieldOp = cast<scf::YieldOp>(after->getTerminator()); |
588 | rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before, |
589 | yieldOp.getResults()); |
590 | |
591 | // Replace the op with values "yielded" from the "before" region, which are |
592 | // visible by dominance. |
593 | rewriter.replaceOp(whileOp, condOp.getArgs()); |
594 | |
595 | return success(); |
596 | } |
597 | |
598 | LogicalResult |
599 | DoWhileLowering::matchAndRewrite(WhileOp whileOp, |
600 | PatternRewriter &rewriter) const { |
601 | Block &afterBlock = *whileOp.getAfterBody(); |
602 | if (!llvm::hasSingleElement(C&: afterBlock)) |
603 | return rewriter.notifyMatchFailure(whileOp, |
604 | "do-while simplification applicable " |
605 | "only if 'after' region has no payload"); |
606 | |
607 | auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front()); |
608 | if (!yield || yield.getResults() != afterBlock.getArguments()) |
609 | return rewriter.notifyMatchFailure(whileOp, |
610 | "do-while simplification applicable " |
611 | "only to forwarding 'after' regions"); |
612 | |
613 | // Split the current block before the WhileOp to create the inlining point. |
614 | OpBuilder::InsertionGuard guard(rewriter); |
615 | Block *currentBlock = rewriter.getInsertionBlock(); |
616 | Block *continuation = |
617 | rewriter.splitBlock(block: currentBlock, before: rewriter.getInsertionPoint()); |
618 | |
619 | // Only the "before" region should be inlined. |
620 | Block *before = whileOp.getBeforeBody(); |
621 | rewriter.inlineRegionBefore(whileOp.getBefore(), continuation); |
622 | |
623 | // Branch to the "before" region. |
624 | rewriter.setInsertionPointToEnd(currentBlock); |
625 | rewriter.create<cf::BranchOp>(whileOp.getLoc(), before, whileOp.getInits()); |
626 | |
627 | // Loop around the "before" region based on condition. |
628 | rewriter.setInsertionPointToEnd(before); |
629 | auto condOp = cast<ConditionOp>(before->getTerminator()); |
630 | rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), |
631 | before, condOp.getArgs(), |
632 | continuation, ValueRange()); |
633 | |
634 | // Replace the op with values "yielded" from the "before" region, which are |
635 | // visible by dominance. |
636 | rewriter.replaceOp(whileOp, condOp.getArgs()); |
637 | |
638 | return success(); |
639 | } |
640 | |
641 | LogicalResult |
642 | IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op, |
643 | PatternRewriter &rewriter) const { |
644 | // Split the block at the op. |
645 | Block *condBlock = rewriter.getInsertionBlock(); |
646 | Block *continueBlock = rewriter.splitBlock(block: condBlock, before: Block::iterator(op)); |
647 | |
648 | // Create the arguments on the continue block with which to replace the |
649 | // results of the op. |
650 | SmallVector<Value> results; |
651 | results.reserve(N: op.getNumResults()); |
652 | for (Type resultType : op.getResultTypes()) |
653 | results.push_back(continueBlock->addArgument(resultType, op.getLoc())); |
654 | |
655 | // Handle the regions. |
656 | auto convertRegion = [&](Region ®ion) -> FailureOr<Block *> { |
657 | Block *block = ®ion.front(); |
658 | |
659 | // Convert the yield terminator to a branch to the continue block. |
660 | auto yield = cast<scf::YieldOp>(block->getTerminator()); |
661 | rewriter.setInsertionPoint(yield); |
662 | rewriter.replaceOpWithNewOp<cf::BranchOp>(yield, continueBlock, |
663 | yield.getOperands()); |
664 | |
665 | // Inline the region. |
666 | rewriter.inlineRegionBefore(region, before: continueBlock); |
667 | return block; |
668 | }; |
669 | |
670 | // Convert the case regions. |
671 | SmallVector<Block *> caseSuccessors; |
672 | SmallVector<int32_t> caseValues; |
673 | caseSuccessors.reserve(N: op.getCases().size()); |
674 | caseValues.reserve(N: op.getCases().size()); |
675 | for (auto [region, value] : llvm::zip(op.getCaseRegions(), op.getCases())) { |
676 | FailureOr<Block *> block = convertRegion(region); |
677 | if (failed(block)) |
678 | return failure(); |
679 | caseSuccessors.push_back(*block); |
680 | caseValues.push_back(value); |
681 | } |
682 | |
683 | // Convert the default region. |
684 | FailureOr<Block *> defaultBlock = convertRegion(op.getDefaultRegion()); |
685 | if (failed(Result: defaultBlock)) |
686 | return failure(); |
687 | |
688 | // Create the switch. |
689 | rewriter.setInsertionPointToEnd(condBlock); |
690 | SmallVector<ValueRange> caseOperands(caseSuccessors.size(), {}); |
691 | |
692 | // Cast switch index to integer case value. |
693 | Value caseValue = rewriter.create<arith::IndexCastOp>( |
694 | op.getLoc(), rewriter.getI32Type(), op.getArg()); |
695 | |
696 | rewriter.create<cf::SwitchOp>( |
697 | op.getLoc(), caseValue, *defaultBlock, ValueRange(), |
698 | rewriter.getDenseI32ArrayAttr(caseValues), caseSuccessors, caseOperands); |
699 | rewriter.replaceOp(op, continueBlock->getArguments()); |
700 | return success(); |
701 | } |
702 | |
703 | LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp, |
704 | PatternRewriter &rewriter) const { |
705 | return scf::forallToParallelLoop(rewriter, forallOp: forallOp); |
706 | } |
707 | |
708 | void mlir::populateSCFToControlFlowConversionPatterns( |
709 | RewritePatternSet &patterns) { |
710 | patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering, |
711 | WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>( |
712 | arg: patterns.getContext()); |
713 | patterns.add<DoWhileLowering>(arg: patterns.getContext(), /*benefit=*/args: 2); |
714 | } |
715 | |
716 | void SCFToControlFlowPass::runOnOperation() { |
717 | RewritePatternSet patterns(&getContext()); |
718 | populateSCFToControlFlowConversionPatterns(patterns); |
719 | |
720 | // Configure conversion to lower out SCF operations. |
721 | ConversionTarget target(getContext()); |
722 | target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp, |
723 | scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>(); |
724 | target.markUnknownOpDynamicallyLegal(fn: [](Operation *) { return true; }); |
725 | if (failed( |
726 | applyPartialConversion(getOperation(), target, std::move(patterns)))) |
727 | signalPassFailure(); |
728 | } |
729 |
Definitions
- SCFToControlFlowPass
- ForLowering
- IfLowering
- ExecuteRegionLowering
- ParallelLowering
- WhileLowering
- DoWhileLowering
- IndexSwitchLowering
- ForallLowering
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- matchAndRewrite
- populateSCFToControlFlowConversionPatterns
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more