1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements pass that inlines CIR operations regions into the parent |
10 | // function region. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "PassDetail.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/IR/Block.h" |
17 | #include "mlir/IR/Builders.h" |
18 | #include "mlir/IR/PatternMatch.h" |
19 | #include "mlir/Support/LogicalResult.h" |
20 | #include "mlir/Transforms/DialectConversion.h" |
21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
22 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
23 | #include "clang/CIR/Dialect/Passes.h" |
24 | #include "clang/CIR/MissingFeatures.h" |
25 | |
26 | using namespace mlir; |
27 | using namespace cir; |
28 | |
29 | namespace { |
30 | |
31 | /// Lowers operations with the terminator trait that have a single successor. |
32 | void lowerTerminator(mlir::Operation *op, mlir::Block *dest, |
33 | mlir::PatternRewriter &rewriter) { |
34 | assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator" ); |
35 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
36 | rewriter.setInsertionPoint(op); |
37 | rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest); |
38 | } |
39 | |
40 | /// Walks a region while skipping operations of type `Ops`. This ensures the |
41 | /// callback is not applied to said operations and its children. |
42 | template <typename... Ops> |
43 | void walkRegionSkipping( |
44 | mlir::Region ®ion, |
45 | mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) { |
46 | region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) { |
47 | if (isa<Ops...>(op)) |
48 | return mlir::WalkResult::skip(); |
49 | return callback(op); |
50 | }); |
51 | } |
52 | |
53 | struct CIRFlattenCFGPass : public CIRFlattenCFGBase<CIRFlattenCFGPass> { |
54 | |
55 | CIRFlattenCFGPass() = default; |
56 | void runOnOperation() override; |
57 | }; |
58 | |
59 | struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> { |
60 | using OpRewritePattern<IfOp>::OpRewritePattern; |
61 | |
62 | mlir::LogicalResult |
63 | matchAndRewrite(cir::IfOp ifOp, |
64 | mlir::PatternRewriter &rewriter) const override { |
65 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
66 | mlir::Location loc = ifOp.getLoc(); |
67 | bool emptyElse = ifOp.getElseRegion().empty(); |
68 | mlir::Block *currentBlock = rewriter.getInsertionBlock(); |
69 | mlir::Block *remainingOpsBlock = |
70 | rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); |
71 | mlir::Block *continueBlock; |
72 | if (ifOp->getResults().empty()) |
73 | continueBlock = remainingOpsBlock; |
74 | else |
75 | llvm_unreachable("NYI" ); |
76 | |
77 | // Inline the region |
78 | mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front(); |
79 | mlir::Block *thenAfterBody = &ifOp.getThenRegion().back(); |
80 | rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock); |
81 | |
82 | rewriter.setInsertionPointToEnd(thenAfterBody); |
83 | if (auto thenYieldOp = |
84 | dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) { |
85 | rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(), |
86 | continueBlock); |
87 | } |
88 | |
89 | rewriter.setInsertionPointToEnd(continueBlock); |
90 | |
91 | // Has else region: inline it. |
92 | mlir::Block *elseBeforeBody = nullptr; |
93 | mlir::Block *elseAfterBody = nullptr; |
94 | if (!emptyElse) { |
95 | elseBeforeBody = &ifOp.getElseRegion().front(); |
96 | elseAfterBody = &ifOp.getElseRegion().back(); |
97 | rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock); |
98 | } else { |
99 | elseBeforeBody = elseAfterBody = continueBlock; |
100 | } |
101 | |
102 | rewriter.setInsertionPointToEnd(currentBlock); |
103 | rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody, |
104 | elseBeforeBody); |
105 | |
106 | if (!emptyElse) { |
107 | rewriter.setInsertionPointToEnd(elseAfterBody); |
108 | if (auto elseYieldOP = |
109 | dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) { |
110 | rewriter.replaceOpWithNewOp<cir::BrOp>( |
111 | elseYieldOP, elseYieldOP.getArgs(), continueBlock); |
112 | } |
113 | } |
114 | |
115 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
116 | return mlir::success(); |
117 | } |
118 | }; |
119 | |
120 | class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { |
121 | public: |
122 | using OpRewritePattern<cir::ScopeOp>::OpRewritePattern; |
123 | |
124 | mlir::LogicalResult |
125 | matchAndRewrite(cir::ScopeOp scopeOp, |
126 | mlir::PatternRewriter &rewriter) const override { |
127 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
128 | mlir::Location loc = scopeOp.getLoc(); |
129 | |
130 | // Empty scope: just remove it. |
131 | // TODO: Remove this logic once CIR uses MLIR infrastructure to remove |
132 | // trivially dead operations. MLIR canonicalizer is too aggressive and we |
133 | // need to either (a) make sure all our ops model all side-effects and/or |
134 | // (b) have more options in the canonicalizer in MLIR to temper |
135 | // aggressiveness level. |
136 | if (scopeOp.isEmpty()) { |
137 | rewriter.eraseOp(scopeOp); |
138 | return mlir::success(); |
139 | } |
140 | |
141 | // Split the current block before the ScopeOp to create the inlining |
142 | // point. |
143 | mlir::Block *currentBlock = rewriter.getInsertionBlock(); |
144 | mlir::Block *continueBlock = |
145 | rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); |
146 | if (scopeOp.getNumResults() > 0) |
147 | continueBlock->addArguments(scopeOp.getResultTypes(), loc); |
148 | |
149 | // Inline body region. |
150 | mlir::Block *beforeBody = &scopeOp.getScopeRegion().front(); |
151 | mlir::Block *afterBody = &scopeOp.getScopeRegion().back(); |
152 | rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock); |
153 | |
154 | // Save stack and then branch into the body of the region. |
155 | rewriter.setInsertionPointToEnd(currentBlock); |
156 | assert(!cir::MissingFeatures::stackSaveOp()); |
157 | rewriter.create<cir::BrOp>(loc, mlir::ValueRange(), beforeBody); |
158 | |
159 | // Replace the scopeop return with a branch that jumps out of the body. |
160 | // Stack restore before leaving the body region. |
161 | rewriter.setInsertionPointToEnd(afterBody); |
162 | if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) { |
163 | rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(), |
164 | continueBlock); |
165 | } |
166 | |
167 | // Replace the op with values return from the body region. |
168 | rewriter.replaceOp(scopeOp, continueBlock->getArguments()); |
169 | |
170 | return mlir::success(); |
171 | } |
172 | }; |
173 | |
174 | class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { |
175 | public: |
176 | using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; |
177 | |
178 | inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, |
179 | cir::YieldOp yieldOp, |
180 | mlir::Block *destination) const { |
181 | rewriter.setInsertionPoint(yieldOp); |
182 | rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), |
183 | destination); |
184 | } |
185 | |
186 | // Return the new defaultDestination block. |
187 | Block *condBrToRangeDestination(cir::SwitchOp op, |
188 | mlir::PatternRewriter &rewriter, |
189 | mlir::Block *rangeDestination, |
190 | mlir::Block *defaultDestination, |
191 | const APInt &lowerBound, |
192 | const APInt &upperBound) const { |
193 | assert(lowerBound.sle(upperBound) && "Invalid range" ); |
194 | mlir::Block *resBlock = rewriter.createBlock(defaultDestination); |
195 | cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); |
196 | cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); |
197 | |
198 | cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( |
199 | op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); |
200 | |
201 | cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( |
202 | op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); |
203 | cir::BinOp diffValue = |
204 | rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, |
205 | op.getCondition(), lowerBoundValue); |
206 | |
207 | // Use unsigned comparison to check if the condition is in the range. |
208 | cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( |
209 | op.getLoc(), uIntType, CastKind::integral, diffValue); |
210 | cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( |
211 | op.getLoc(), uIntType, CastKind::integral, rangeLength); |
212 | |
213 | cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( |
214 | op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, |
215 | uDiffValue, uRangeLength); |
216 | rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, |
217 | defaultDestination); |
218 | return resBlock; |
219 | } |
220 | |
221 | mlir::LogicalResult |
222 | matchAndRewrite(cir::SwitchOp op, |
223 | mlir::PatternRewriter &rewriter) const override { |
224 | llvm::SmallVector<CaseOp> cases; |
225 | op.collectCases(cases); |
226 | |
227 | // Empty switch statement: just erase it. |
228 | if (cases.empty()) { |
229 | rewriter.eraseOp(op); |
230 | return mlir::success(); |
231 | } |
232 | |
233 | // Create exit block from the next node of cir.switch op. |
234 | mlir::Block *exitBlock = rewriter.splitBlock( |
235 | rewriter.getBlock(), op->getNextNode()->getIterator()); |
236 | |
237 | // We lower cir.switch op in the following process: |
238 | // 1. Inline the region from the switch op after switch op. |
239 | // 2. Traverse each cir.case op: |
240 | // a. Record the entry block, block arguments and condition for every |
241 | // case. b. Inline the case region after the case op. |
242 | // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the |
243 | // recorded block and conditions. |
244 | |
245 | // inline everything from switch body between the switch op and the exit |
246 | // block. |
247 | { |
248 | cir::YieldOp switchYield = nullptr; |
249 | // Clear switch operation. |
250 | for (mlir::Block &block : |
251 | llvm::make_early_inc_range(op.getBody().getBlocks())) |
252 | if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) |
253 | switchYield = yieldOp; |
254 | |
255 | assert(!op.getBody().empty()); |
256 | mlir::Block *originalBlock = op->getBlock(); |
257 | mlir::Block *swopBlock = |
258 | rewriter.splitBlock(originalBlock, op->getIterator()); |
259 | rewriter.inlineRegionBefore(op.getBody(), exitBlock); |
260 | |
261 | if (switchYield) |
262 | rewriteYieldOp(rewriter, switchYield, exitBlock); |
263 | |
264 | rewriter.setInsertionPointToEnd(originalBlock); |
265 | rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); |
266 | } |
267 | |
268 | // Allocate required data structures (disconsider default case in |
269 | // vectors). |
270 | llvm::SmallVector<mlir::APInt, 8> caseValues; |
271 | llvm::SmallVector<mlir::Block *, 8> caseDestinations; |
272 | llvm::SmallVector<mlir::ValueRange, 8> caseOperands; |
273 | |
274 | llvm::SmallVector<std::pair<APInt, APInt>> rangeValues; |
275 | llvm::SmallVector<mlir::Block *> rangeDestinations; |
276 | llvm::SmallVector<mlir::ValueRange> rangeOperands; |
277 | |
278 | // Initialize default case as optional. |
279 | mlir::Block *defaultDestination = exitBlock; |
280 | mlir::ValueRange defaultOperands = exitBlock->getArguments(); |
281 | |
282 | // Digest the case statements values and bodies. |
283 | for (cir::CaseOp caseOp : cases) { |
284 | mlir::Region ®ion = caseOp.getCaseRegion(); |
285 | |
286 | // Found default case: save destination and operands. |
287 | switch (caseOp.getKind()) { |
288 | case cir::CaseOpKind::Default: |
289 | defaultDestination = ®ion.front(); |
290 | defaultOperands = defaultDestination->getArguments(); |
291 | break; |
292 | case cir::CaseOpKind::Range: |
293 | assert(caseOp.getValue().size() == 2 && |
294 | "Case range should have 2 case value" ); |
295 | rangeValues.push_back( |
296 | {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(), |
297 | cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()}); |
298 | rangeDestinations.push_back(®ion.front()); |
299 | rangeOperands.push_back(rangeDestinations.back()->getArguments()); |
300 | break; |
301 | case cir::CaseOpKind::Anyof: |
302 | case cir::CaseOpKind::Equal: |
303 | // AnyOf cases kind can have multiple values, hence the loop below. |
304 | for (const mlir::Attribute &value : caseOp.getValue()) { |
305 | caseValues.push_back(cast<cir::IntAttr>(value).getValue()); |
306 | caseDestinations.push_back(®ion.front()); |
307 | caseOperands.push_back(caseDestinations.back()->getArguments()); |
308 | } |
309 | break; |
310 | } |
311 | |
312 | // Handle break statements. |
313 | walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>( |
314 | region, [&](mlir::Operation *op) { |
315 | if (!isa<cir::BreakOp>(op)) |
316 | return mlir::WalkResult::advance(); |
317 | |
318 | lowerTerminator(op, exitBlock, rewriter); |
319 | return mlir::WalkResult::skip(); |
320 | }); |
321 | |
322 | // Track fallthrough in cases. |
323 | for (mlir::Block &blk : region.getBlocks()) { |
324 | if (blk.getNumSuccessors()) |
325 | continue; |
326 | |
327 | if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) { |
328 | mlir::Operation *nextOp = caseOp->getNextNode(); |
329 | assert(nextOp && "caseOp is not expected to be the last op" ); |
330 | mlir::Block *oldBlock = nextOp->getBlock(); |
331 | mlir::Block *newBlock = |
332 | rewriter.splitBlock(oldBlock, nextOp->getIterator()); |
333 | rewriter.setInsertionPointToEnd(oldBlock); |
334 | rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), |
335 | newBlock); |
336 | rewriteYieldOp(rewriter, yieldOp, newBlock); |
337 | } |
338 | } |
339 | |
340 | mlir::Block *oldBlock = caseOp->getBlock(); |
341 | mlir::Block *newBlock = |
342 | rewriter.splitBlock(oldBlock, caseOp->getIterator()); |
343 | |
344 | mlir::Block &entryBlock = caseOp.getCaseRegion().front(); |
345 | rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock); |
346 | |
347 | // Create a branch to the entry of the inlined region. |
348 | rewriter.setInsertionPointToEnd(oldBlock); |
349 | rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); |
350 | } |
351 | |
352 | // Remove all cases since we've inlined the regions. |
353 | for (cir::CaseOp caseOp : cases) { |
354 | mlir::Block *caseBlock = caseOp->getBlock(); |
355 | // Erase the block with no predecessors here to make the generated code |
356 | // simpler a little bit. |
357 | if (caseBlock->hasNoPredecessors()) |
358 | rewriter.eraseBlock(caseBlock); |
359 | else |
360 | rewriter.eraseOp(caseOp); |
361 | } |
362 | |
363 | for (auto [rangeVal, operand, destination] : |
364 | llvm::zip(rangeValues, rangeOperands, rangeDestinations)) { |
365 | APInt lowerBound = rangeVal.first; |
366 | APInt upperBound = rangeVal.second; |
367 | |
368 | // The case range is unreachable, skip it. |
369 | if (lowerBound.sgt(upperBound)) |
370 | continue; |
371 | |
372 | // If range is small, add multiple switch instruction cases. |
373 | // This magical number is from the original CGStmt code. |
374 | constexpr int kSmallRangeThreshold = 64; |
375 | if ((upperBound - lowerBound) |
376 | .ult(llvm::APInt(32, kSmallRangeThreshold))) { |
377 | for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) { |
378 | caseValues.push_back(iValue); |
379 | caseOperands.push_back(operand); |
380 | caseDestinations.push_back(destination); |
381 | } |
382 | continue; |
383 | } |
384 | |
385 | defaultDestination = |
386 | condBrToRangeDestination(op, rewriter, destination, |
387 | defaultDestination, lowerBound, upperBound); |
388 | defaultOperands = operand; |
389 | } |
390 | |
391 | // Set switch op to branch to the newly created blocks. |
392 | rewriter.setInsertionPoint(op); |
393 | rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>( |
394 | op, op.getCondition(), defaultDestination, defaultOperands, caseValues, |
395 | caseDestinations, caseOperands); |
396 | |
397 | return mlir::success(); |
398 | } |
399 | }; |
400 | |
401 | class CIRLoopOpInterfaceFlattening |
402 | : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> { |
403 | public: |
404 | using mlir::OpInterfaceRewritePattern< |
405 | cir::LoopOpInterface>::OpInterfaceRewritePattern; |
406 | |
407 | inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body, |
408 | mlir::Block *exit, |
409 | mlir::PatternRewriter &rewriter) const { |
410 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
411 | rewriter.setInsertionPoint(op); |
412 | rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body, |
413 | exit); |
414 | } |
415 | |
416 | mlir::LogicalResult |
417 | matchAndRewrite(cir::LoopOpInterface op, |
418 | mlir::PatternRewriter &rewriter) const final { |
419 | // Setup CFG blocks. |
420 | mlir::Block *entry = rewriter.getInsertionBlock(); |
421 | mlir::Block *exit = |
422 | rewriter.splitBlock(entry, rewriter.getInsertionPoint()); |
423 | mlir::Block *cond = &op.getCond().front(); |
424 | mlir::Block *body = &op.getBody().front(); |
425 | mlir::Block *step = |
426 | (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr); |
427 | |
428 | // Setup loop entry branch. |
429 | rewriter.setInsertionPointToEnd(entry); |
430 | rewriter.create<cir::BrOp>(op.getLoc(), &op.getEntry().front()); |
431 | |
432 | // Branch from condition region to body or exit. |
433 | auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator()); |
434 | lowerConditionOp(conditionOp, body, exit, rewriter); |
435 | |
436 | // TODO(cir): Remove the walks below. It visits operations unnecessarily. |
437 | // However, to solve this we would likely need a custom DialectConversion |
438 | // driver to customize the order that operations are visited. |
439 | |
440 | // Lower continue statements. |
441 | mlir::Block *dest = (step ? step : cond); |
442 | op.walkBodySkippingNestedLoops([&](mlir::Operation *op) { |
443 | if (!isa<cir::ContinueOp>(op)) |
444 | return mlir::WalkResult::advance(); |
445 | |
446 | lowerTerminator(op, dest, rewriter); |
447 | return mlir::WalkResult::skip(); |
448 | }); |
449 | |
450 | // Lower break statements. |
451 | assert(!cir::MissingFeatures::switchOp()); |
452 | walkRegionSkipping<cir::LoopOpInterface>( |
453 | op.getBody(), [&](mlir::Operation *op) { |
454 | if (!isa<cir::BreakOp>(op)) |
455 | return mlir::WalkResult::advance(); |
456 | |
457 | lowerTerminator(op, exit, rewriter); |
458 | return mlir::WalkResult::skip(); |
459 | }); |
460 | |
461 | // Lower optional body region yield. |
462 | for (mlir::Block &blk : op.getBody().getBlocks()) { |
463 | auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator()); |
464 | if (bodyYield) |
465 | lowerTerminator(bodyYield, (step ? step : cond), rewriter); |
466 | } |
467 | |
468 | // Lower mandatory step region yield. |
469 | if (step) |
470 | lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond, |
471 | rewriter); |
472 | |
473 | // Move region contents out of the loop op. |
474 | rewriter.inlineRegionBefore(op.getCond(), exit); |
475 | rewriter.inlineRegionBefore(op.getBody(), exit); |
476 | if (step) |
477 | rewriter.inlineRegionBefore(*op.maybeGetStep(), exit); |
478 | |
479 | rewriter.eraseOp(op); |
480 | return mlir::success(); |
481 | } |
482 | }; |
483 | |
484 | class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { |
485 | public: |
486 | using OpRewritePattern<cir::TernaryOp>::OpRewritePattern; |
487 | |
488 | mlir::LogicalResult |
489 | matchAndRewrite(cir::TernaryOp op, |
490 | mlir::PatternRewriter &rewriter) const override { |
491 | Location loc = op->getLoc(); |
492 | Block *condBlock = rewriter.getInsertionBlock(); |
493 | Block::iterator opPosition = rewriter.getInsertionPoint(); |
494 | Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); |
495 | llvm::SmallVector<mlir::Location, 2> locs; |
496 | // Ternary result is optional, make sure to populate the location only |
497 | // when relevant. |
498 | if (op->getResultTypes().size()) |
499 | locs.push_back(loc); |
500 | Block *continueBlock = |
501 | rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); |
502 | rewriter.create<cir::BrOp>(loc, remainingOpsBlock); |
503 | |
504 | Region &trueRegion = op.getTrueRegion(); |
505 | Block *trueBlock = &trueRegion.front(); |
506 | mlir::Operation *trueTerminator = trueRegion.back().getTerminator(); |
507 | rewriter.setInsertionPointToEnd(&trueRegion.back()); |
508 | auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); |
509 | |
510 | rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), |
511 | continueBlock); |
512 | rewriter.inlineRegionBefore(trueRegion, continueBlock); |
513 | |
514 | Block *falseBlock = continueBlock; |
515 | Region &falseRegion = op.getFalseRegion(); |
516 | |
517 | falseBlock = &falseRegion.front(); |
518 | mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); |
519 | rewriter.setInsertionPointToEnd(&falseRegion.back()); |
520 | auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); |
521 | rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), |
522 | continueBlock); |
523 | rewriter.inlineRegionBefore(falseRegion, continueBlock); |
524 | |
525 | rewriter.setInsertionPointToEnd(condBlock); |
526 | rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); |
527 | |
528 | rewriter.replaceOp(op, continueBlock->getArguments()); |
529 | |
530 | // Ok, we're done! |
531 | return mlir::success(); |
532 | } |
533 | }; |
534 | |
535 | void populateFlattenCFGPatterns(RewritePatternSet &patterns) { |
536 | patterns |
537 | .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening, |
538 | CIRSwitchOpFlattening, CIRTernaryOpFlattening>( |
539 | patterns.getContext()); |
540 | } |
541 | |
542 | void CIRFlattenCFGPass::runOnOperation() { |
543 | RewritePatternSet patterns(&getContext()); |
544 | populateFlattenCFGPatterns(patterns); |
545 | |
546 | // Collect operations to apply patterns. |
547 | llvm::SmallVector<Operation *, 16> ops; |
548 | getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) { |
549 | assert(!cir::MissingFeatures::ifOp()); |
550 | assert(!cir::MissingFeatures::switchOp()); |
551 | assert(!cir::MissingFeatures::tryOp()); |
552 | if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op)) |
553 | ops.push_back(op); |
554 | }); |
555 | |
556 | // Apply patterns. |
557 | if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) |
558 | signalPassFailure(); |
559 | } |
560 | |
561 | } // namespace |
562 | |
563 | namespace mlir { |
564 | |
565 | std::unique_ptr<Pass> createCIRFlattenCFGPass() { |
566 | return std::make_unique<CIRFlattenCFGPass>(); |
567 | } |
568 | |
569 | } // namespace mlir |
570 | |