| 1 | //===----------------------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file implements pass that inlines CIR operations regions into the parent |
| 10 | // function region. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "PassDetail.h" |
| 15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 16 | #include "mlir/IR/Block.h" |
| 17 | #include "mlir/IR/Builders.h" |
| 18 | #include "mlir/IR/PatternMatch.h" |
| 19 | #include "mlir/Support/LogicalResult.h" |
| 20 | #include "mlir/Transforms/DialectConversion.h" |
| 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 22 | #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| 23 | #include "clang/CIR/Dialect/Passes.h" |
| 24 | #include "clang/CIR/MissingFeatures.h" |
| 25 | |
| 26 | using namespace mlir; |
| 27 | using namespace cir; |
| 28 | |
| 29 | namespace { |
| 30 | |
| 31 | /// Lowers operations with the terminator trait that have a single successor. |
| 32 | void lowerTerminator(mlir::Operation *op, mlir::Block *dest, |
| 33 | mlir::PatternRewriter &rewriter) { |
| 34 | assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator" ); |
| 35 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 36 | rewriter.setInsertionPoint(op); |
| 37 | rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest); |
| 38 | } |
| 39 | |
| 40 | /// Walks a region while skipping operations of type `Ops`. This ensures the |
| 41 | /// callback is not applied to said operations and its children. |
| 42 | template <typename... Ops> |
| 43 | void walkRegionSkipping( |
| 44 | mlir::Region ®ion, |
| 45 | mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) { |
| 46 | region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) { |
| 47 | if (isa<Ops...>(op)) |
| 48 | return mlir::WalkResult::skip(); |
| 49 | return callback(op); |
| 50 | }); |
| 51 | } |
| 52 | |
| 53 | struct CIRFlattenCFGPass : public CIRFlattenCFGBase<CIRFlattenCFGPass> { |
| 54 | |
| 55 | CIRFlattenCFGPass() = default; |
| 56 | void runOnOperation() override; |
| 57 | }; |
| 58 | |
| 59 | struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> { |
| 60 | using OpRewritePattern<IfOp>::OpRewritePattern; |
| 61 | |
| 62 | mlir::LogicalResult |
| 63 | matchAndRewrite(cir::IfOp ifOp, |
| 64 | mlir::PatternRewriter &rewriter) const override { |
| 65 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 66 | mlir::Location loc = ifOp.getLoc(); |
| 67 | bool emptyElse = ifOp.getElseRegion().empty(); |
| 68 | mlir::Block *currentBlock = rewriter.getInsertionBlock(); |
| 69 | mlir::Block *remainingOpsBlock = |
| 70 | rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); |
| 71 | mlir::Block *continueBlock; |
| 72 | if (ifOp->getResults().empty()) |
| 73 | continueBlock = remainingOpsBlock; |
| 74 | else |
| 75 | llvm_unreachable("NYI" ); |
| 76 | |
| 77 | // Inline the region |
| 78 | mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front(); |
| 79 | mlir::Block *thenAfterBody = &ifOp.getThenRegion().back(); |
| 80 | rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock); |
| 81 | |
| 82 | rewriter.setInsertionPointToEnd(thenAfterBody); |
| 83 | if (auto thenYieldOp = |
| 84 | dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) { |
| 85 | rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(), |
| 86 | continueBlock); |
| 87 | } |
| 88 | |
| 89 | rewriter.setInsertionPointToEnd(continueBlock); |
| 90 | |
| 91 | // Has else region: inline it. |
| 92 | mlir::Block *elseBeforeBody = nullptr; |
| 93 | mlir::Block *elseAfterBody = nullptr; |
| 94 | if (!emptyElse) { |
| 95 | elseBeforeBody = &ifOp.getElseRegion().front(); |
| 96 | elseAfterBody = &ifOp.getElseRegion().back(); |
| 97 | rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock); |
| 98 | } else { |
| 99 | elseBeforeBody = elseAfterBody = continueBlock; |
| 100 | } |
| 101 | |
| 102 | rewriter.setInsertionPointToEnd(currentBlock); |
| 103 | rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody, |
| 104 | elseBeforeBody); |
| 105 | |
| 106 | if (!emptyElse) { |
| 107 | rewriter.setInsertionPointToEnd(elseAfterBody); |
| 108 | if (auto elseYieldOP = |
| 109 | dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) { |
| 110 | rewriter.replaceOpWithNewOp<cir::BrOp>( |
| 111 | elseYieldOP, elseYieldOP.getArgs(), continueBlock); |
| 112 | } |
| 113 | } |
| 114 | |
| 115 | rewriter.replaceOp(ifOp, continueBlock->getArguments()); |
| 116 | return mlir::success(); |
| 117 | } |
| 118 | }; |
| 119 | |
| 120 | class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { |
| 121 | public: |
| 122 | using OpRewritePattern<cir::ScopeOp>::OpRewritePattern; |
| 123 | |
| 124 | mlir::LogicalResult |
| 125 | matchAndRewrite(cir::ScopeOp scopeOp, |
| 126 | mlir::PatternRewriter &rewriter) const override { |
| 127 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 128 | mlir::Location loc = scopeOp.getLoc(); |
| 129 | |
| 130 | // Empty scope: just remove it. |
| 131 | // TODO: Remove this logic once CIR uses MLIR infrastructure to remove |
| 132 | // trivially dead operations. MLIR canonicalizer is too aggressive and we |
| 133 | // need to either (a) make sure all our ops model all side-effects and/or |
| 134 | // (b) have more options in the canonicalizer in MLIR to temper |
| 135 | // aggressiveness level. |
| 136 | if (scopeOp.isEmpty()) { |
| 137 | rewriter.eraseOp(scopeOp); |
| 138 | return mlir::success(); |
| 139 | } |
| 140 | |
| 141 | // Split the current block before the ScopeOp to create the inlining |
| 142 | // point. |
| 143 | mlir::Block *currentBlock = rewriter.getInsertionBlock(); |
| 144 | mlir::Block *continueBlock = |
| 145 | rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); |
| 146 | if (scopeOp.getNumResults() > 0) |
| 147 | continueBlock->addArguments(scopeOp.getResultTypes(), loc); |
| 148 | |
| 149 | // Inline body region. |
| 150 | mlir::Block *beforeBody = &scopeOp.getScopeRegion().front(); |
| 151 | mlir::Block *afterBody = &scopeOp.getScopeRegion().back(); |
| 152 | rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock); |
| 153 | |
| 154 | // Save stack and then branch into the body of the region. |
| 155 | rewriter.setInsertionPointToEnd(currentBlock); |
| 156 | assert(!cir::MissingFeatures::stackSaveOp()); |
| 157 | rewriter.create<cir::BrOp>(loc, mlir::ValueRange(), beforeBody); |
| 158 | |
| 159 | // Replace the scopeop return with a branch that jumps out of the body. |
| 160 | // Stack restore before leaving the body region. |
| 161 | rewriter.setInsertionPointToEnd(afterBody); |
| 162 | if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) { |
| 163 | rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(), |
| 164 | continueBlock); |
| 165 | } |
| 166 | |
| 167 | // Replace the op with values return from the body region. |
| 168 | rewriter.replaceOp(scopeOp, continueBlock->getArguments()); |
| 169 | |
| 170 | return mlir::success(); |
| 171 | } |
| 172 | }; |
| 173 | |
| 174 | class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { |
| 175 | public: |
| 176 | using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; |
| 177 | |
| 178 | inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, |
| 179 | cir::YieldOp yieldOp, |
| 180 | mlir::Block *destination) const { |
| 181 | rewriter.setInsertionPoint(yieldOp); |
| 182 | rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), |
| 183 | destination); |
| 184 | } |
| 185 | |
| 186 | // Return the new defaultDestination block. |
| 187 | Block *condBrToRangeDestination(cir::SwitchOp op, |
| 188 | mlir::PatternRewriter &rewriter, |
| 189 | mlir::Block *rangeDestination, |
| 190 | mlir::Block *defaultDestination, |
| 191 | const APInt &lowerBound, |
| 192 | const APInt &upperBound) const { |
| 193 | assert(lowerBound.sle(upperBound) && "Invalid range" ); |
| 194 | mlir::Block *resBlock = rewriter.createBlock(defaultDestination); |
| 195 | cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); |
| 196 | cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); |
| 197 | |
| 198 | cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( |
| 199 | op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); |
| 200 | |
| 201 | cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( |
| 202 | op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); |
| 203 | cir::BinOp diffValue = |
| 204 | rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, |
| 205 | op.getCondition(), lowerBoundValue); |
| 206 | |
| 207 | // Use unsigned comparison to check if the condition is in the range. |
| 208 | cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( |
| 209 | op.getLoc(), uIntType, CastKind::integral, diffValue); |
| 210 | cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( |
| 211 | op.getLoc(), uIntType, CastKind::integral, rangeLength); |
| 212 | |
| 213 | cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( |
| 214 | op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, |
| 215 | uDiffValue, uRangeLength); |
| 216 | rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, |
| 217 | defaultDestination); |
| 218 | return resBlock; |
| 219 | } |
| 220 | |
| 221 | mlir::LogicalResult |
| 222 | matchAndRewrite(cir::SwitchOp op, |
| 223 | mlir::PatternRewriter &rewriter) const override { |
| 224 | llvm::SmallVector<CaseOp> cases; |
| 225 | op.collectCases(cases); |
| 226 | |
| 227 | // Empty switch statement: just erase it. |
| 228 | if (cases.empty()) { |
| 229 | rewriter.eraseOp(op); |
| 230 | return mlir::success(); |
| 231 | } |
| 232 | |
| 233 | // Create exit block from the next node of cir.switch op. |
| 234 | mlir::Block *exitBlock = rewriter.splitBlock( |
| 235 | rewriter.getBlock(), op->getNextNode()->getIterator()); |
| 236 | |
| 237 | // We lower cir.switch op in the following process: |
| 238 | // 1. Inline the region from the switch op after switch op. |
| 239 | // 2. Traverse each cir.case op: |
| 240 | // a. Record the entry block, block arguments and condition for every |
| 241 | // case. b. Inline the case region after the case op. |
| 242 | // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the |
| 243 | // recorded block and conditions. |
| 244 | |
| 245 | // inline everything from switch body between the switch op and the exit |
| 246 | // block. |
| 247 | { |
| 248 | cir::YieldOp switchYield = nullptr; |
| 249 | // Clear switch operation. |
| 250 | for (mlir::Block &block : |
| 251 | llvm::make_early_inc_range(op.getBody().getBlocks())) |
| 252 | if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) |
| 253 | switchYield = yieldOp; |
| 254 | |
| 255 | assert(!op.getBody().empty()); |
| 256 | mlir::Block *originalBlock = op->getBlock(); |
| 257 | mlir::Block *swopBlock = |
| 258 | rewriter.splitBlock(originalBlock, op->getIterator()); |
| 259 | rewriter.inlineRegionBefore(op.getBody(), exitBlock); |
| 260 | |
| 261 | if (switchYield) |
| 262 | rewriteYieldOp(rewriter, switchYield, exitBlock); |
| 263 | |
| 264 | rewriter.setInsertionPointToEnd(originalBlock); |
| 265 | rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); |
| 266 | } |
| 267 | |
| 268 | // Allocate required data structures (disconsider default case in |
| 269 | // vectors). |
| 270 | llvm::SmallVector<mlir::APInt, 8> caseValues; |
| 271 | llvm::SmallVector<mlir::Block *, 8> caseDestinations; |
| 272 | llvm::SmallVector<mlir::ValueRange, 8> caseOperands; |
| 273 | |
| 274 | llvm::SmallVector<std::pair<APInt, APInt>> rangeValues; |
| 275 | llvm::SmallVector<mlir::Block *> rangeDestinations; |
| 276 | llvm::SmallVector<mlir::ValueRange> rangeOperands; |
| 277 | |
| 278 | // Initialize default case as optional. |
| 279 | mlir::Block *defaultDestination = exitBlock; |
| 280 | mlir::ValueRange defaultOperands = exitBlock->getArguments(); |
| 281 | |
| 282 | // Digest the case statements values and bodies. |
| 283 | for (cir::CaseOp caseOp : cases) { |
| 284 | mlir::Region ®ion = caseOp.getCaseRegion(); |
| 285 | |
| 286 | // Found default case: save destination and operands. |
| 287 | switch (caseOp.getKind()) { |
| 288 | case cir::CaseOpKind::Default: |
| 289 | defaultDestination = ®ion.front(); |
| 290 | defaultOperands = defaultDestination->getArguments(); |
| 291 | break; |
| 292 | case cir::CaseOpKind::Range: |
| 293 | assert(caseOp.getValue().size() == 2 && |
| 294 | "Case range should have 2 case value" ); |
| 295 | rangeValues.push_back( |
| 296 | {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(), |
| 297 | cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()}); |
| 298 | rangeDestinations.push_back(®ion.front()); |
| 299 | rangeOperands.push_back(rangeDestinations.back()->getArguments()); |
| 300 | break; |
| 301 | case cir::CaseOpKind::Anyof: |
| 302 | case cir::CaseOpKind::Equal: |
| 303 | // AnyOf cases kind can have multiple values, hence the loop below. |
| 304 | for (const mlir::Attribute &value : caseOp.getValue()) { |
| 305 | caseValues.push_back(cast<cir::IntAttr>(value).getValue()); |
| 306 | caseDestinations.push_back(®ion.front()); |
| 307 | caseOperands.push_back(caseDestinations.back()->getArguments()); |
| 308 | } |
| 309 | break; |
| 310 | } |
| 311 | |
| 312 | // Handle break statements. |
| 313 | walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>( |
| 314 | region, [&](mlir::Operation *op) { |
| 315 | if (!isa<cir::BreakOp>(op)) |
| 316 | return mlir::WalkResult::advance(); |
| 317 | |
| 318 | lowerTerminator(op, exitBlock, rewriter); |
| 319 | return mlir::WalkResult::skip(); |
| 320 | }); |
| 321 | |
| 322 | // Track fallthrough in cases. |
| 323 | for (mlir::Block &blk : region.getBlocks()) { |
| 324 | if (blk.getNumSuccessors()) |
| 325 | continue; |
| 326 | |
| 327 | if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) { |
| 328 | mlir::Operation *nextOp = caseOp->getNextNode(); |
| 329 | assert(nextOp && "caseOp is not expected to be the last op" ); |
| 330 | mlir::Block *oldBlock = nextOp->getBlock(); |
| 331 | mlir::Block *newBlock = |
| 332 | rewriter.splitBlock(oldBlock, nextOp->getIterator()); |
| 333 | rewriter.setInsertionPointToEnd(oldBlock); |
| 334 | rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), |
| 335 | newBlock); |
| 336 | rewriteYieldOp(rewriter, yieldOp, newBlock); |
| 337 | } |
| 338 | } |
| 339 | |
| 340 | mlir::Block *oldBlock = caseOp->getBlock(); |
| 341 | mlir::Block *newBlock = |
| 342 | rewriter.splitBlock(oldBlock, caseOp->getIterator()); |
| 343 | |
| 344 | mlir::Block &entryBlock = caseOp.getCaseRegion().front(); |
| 345 | rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock); |
| 346 | |
| 347 | // Create a branch to the entry of the inlined region. |
| 348 | rewriter.setInsertionPointToEnd(oldBlock); |
| 349 | rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); |
| 350 | } |
| 351 | |
| 352 | // Remove all cases since we've inlined the regions. |
| 353 | for (cir::CaseOp caseOp : cases) { |
| 354 | mlir::Block *caseBlock = caseOp->getBlock(); |
| 355 | // Erase the block with no predecessors here to make the generated code |
| 356 | // simpler a little bit. |
| 357 | if (caseBlock->hasNoPredecessors()) |
| 358 | rewriter.eraseBlock(caseBlock); |
| 359 | else |
| 360 | rewriter.eraseOp(caseOp); |
| 361 | } |
| 362 | |
| 363 | for (auto [rangeVal, operand, destination] : |
| 364 | llvm::zip(rangeValues, rangeOperands, rangeDestinations)) { |
| 365 | APInt lowerBound = rangeVal.first; |
| 366 | APInt upperBound = rangeVal.second; |
| 367 | |
| 368 | // The case range is unreachable, skip it. |
| 369 | if (lowerBound.sgt(upperBound)) |
| 370 | continue; |
| 371 | |
| 372 | // If range is small, add multiple switch instruction cases. |
| 373 | // This magical number is from the original CGStmt code. |
| 374 | constexpr int kSmallRangeThreshold = 64; |
| 375 | if ((upperBound - lowerBound) |
| 376 | .ult(llvm::APInt(32, kSmallRangeThreshold))) { |
| 377 | for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) { |
| 378 | caseValues.push_back(iValue); |
| 379 | caseOperands.push_back(operand); |
| 380 | caseDestinations.push_back(destination); |
| 381 | } |
| 382 | continue; |
| 383 | } |
| 384 | |
| 385 | defaultDestination = |
| 386 | condBrToRangeDestination(op, rewriter, destination, |
| 387 | defaultDestination, lowerBound, upperBound); |
| 388 | defaultOperands = operand; |
| 389 | } |
| 390 | |
| 391 | // Set switch op to branch to the newly created blocks. |
| 392 | rewriter.setInsertionPoint(op); |
| 393 | rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>( |
| 394 | op, op.getCondition(), defaultDestination, defaultOperands, caseValues, |
| 395 | caseDestinations, caseOperands); |
| 396 | |
| 397 | return mlir::success(); |
| 398 | } |
| 399 | }; |
| 400 | |
| 401 | class CIRLoopOpInterfaceFlattening |
| 402 | : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> { |
| 403 | public: |
| 404 | using mlir::OpInterfaceRewritePattern< |
| 405 | cir::LoopOpInterface>::OpInterfaceRewritePattern; |
| 406 | |
| 407 | inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body, |
| 408 | mlir::Block *exit, |
| 409 | mlir::PatternRewriter &rewriter) const { |
| 410 | mlir::OpBuilder::InsertionGuard guard(rewriter); |
| 411 | rewriter.setInsertionPoint(op); |
| 412 | rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body, |
| 413 | exit); |
| 414 | } |
| 415 | |
| 416 | mlir::LogicalResult |
| 417 | matchAndRewrite(cir::LoopOpInterface op, |
| 418 | mlir::PatternRewriter &rewriter) const final { |
| 419 | // Setup CFG blocks. |
| 420 | mlir::Block *entry = rewriter.getInsertionBlock(); |
| 421 | mlir::Block *exit = |
| 422 | rewriter.splitBlock(entry, rewriter.getInsertionPoint()); |
| 423 | mlir::Block *cond = &op.getCond().front(); |
| 424 | mlir::Block *body = &op.getBody().front(); |
| 425 | mlir::Block *step = |
| 426 | (op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr); |
| 427 | |
| 428 | // Setup loop entry branch. |
| 429 | rewriter.setInsertionPointToEnd(entry); |
| 430 | rewriter.create<cir::BrOp>(op.getLoc(), &op.getEntry().front()); |
| 431 | |
| 432 | // Branch from condition region to body or exit. |
| 433 | auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator()); |
| 434 | lowerConditionOp(conditionOp, body, exit, rewriter); |
| 435 | |
| 436 | // TODO(cir): Remove the walks below. It visits operations unnecessarily. |
| 437 | // However, to solve this we would likely need a custom DialectConversion |
| 438 | // driver to customize the order that operations are visited. |
| 439 | |
| 440 | // Lower continue statements. |
| 441 | mlir::Block *dest = (step ? step : cond); |
| 442 | op.walkBodySkippingNestedLoops([&](mlir::Operation *op) { |
| 443 | if (!isa<cir::ContinueOp>(op)) |
| 444 | return mlir::WalkResult::advance(); |
| 445 | |
| 446 | lowerTerminator(op, dest, rewriter); |
| 447 | return mlir::WalkResult::skip(); |
| 448 | }); |
| 449 | |
| 450 | // Lower break statements. |
| 451 | assert(!cir::MissingFeatures::switchOp()); |
| 452 | walkRegionSkipping<cir::LoopOpInterface>( |
| 453 | op.getBody(), [&](mlir::Operation *op) { |
| 454 | if (!isa<cir::BreakOp>(op)) |
| 455 | return mlir::WalkResult::advance(); |
| 456 | |
| 457 | lowerTerminator(op, exit, rewriter); |
| 458 | return mlir::WalkResult::skip(); |
| 459 | }); |
| 460 | |
| 461 | // Lower optional body region yield. |
| 462 | for (mlir::Block &blk : op.getBody().getBlocks()) { |
| 463 | auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator()); |
| 464 | if (bodyYield) |
| 465 | lowerTerminator(bodyYield, (step ? step : cond), rewriter); |
| 466 | } |
| 467 | |
| 468 | // Lower mandatory step region yield. |
| 469 | if (step) |
| 470 | lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond, |
| 471 | rewriter); |
| 472 | |
| 473 | // Move region contents out of the loop op. |
| 474 | rewriter.inlineRegionBefore(op.getCond(), exit); |
| 475 | rewriter.inlineRegionBefore(op.getBody(), exit); |
| 476 | if (step) |
| 477 | rewriter.inlineRegionBefore(*op.maybeGetStep(), exit); |
| 478 | |
| 479 | rewriter.eraseOp(op); |
| 480 | return mlir::success(); |
| 481 | } |
| 482 | }; |
| 483 | |
| 484 | class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { |
| 485 | public: |
| 486 | using OpRewritePattern<cir::TernaryOp>::OpRewritePattern; |
| 487 | |
| 488 | mlir::LogicalResult |
| 489 | matchAndRewrite(cir::TernaryOp op, |
| 490 | mlir::PatternRewriter &rewriter) const override { |
| 491 | Location loc = op->getLoc(); |
| 492 | Block *condBlock = rewriter.getInsertionBlock(); |
| 493 | Block::iterator opPosition = rewriter.getInsertionPoint(); |
| 494 | Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); |
| 495 | llvm::SmallVector<mlir::Location, 2> locs; |
| 496 | // Ternary result is optional, make sure to populate the location only |
| 497 | // when relevant. |
| 498 | if (op->getResultTypes().size()) |
| 499 | locs.push_back(loc); |
| 500 | Block *continueBlock = |
| 501 | rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); |
| 502 | rewriter.create<cir::BrOp>(loc, remainingOpsBlock); |
| 503 | |
| 504 | Region &trueRegion = op.getTrueRegion(); |
| 505 | Block *trueBlock = &trueRegion.front(); |
| 506 | mlir::Operation *trueTerminator = trueRegion.back().getTerminator(); |
| 507 | rewriter.setInsertionPointToEnd(&trueRegion.back()); |
| 508 | auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); |
| 509 | |
| 510 | rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), |
| 511 | continueBlock); |
| 512 | rewriter.inlineRegionBefore(trueRegion, continueBlock); |
| 513 | |
| 514 | Block *falseBlock = continueBlock; |
| 515 | Region &falseRegion = op.getFalseRegion(); |
| 516 | |
| 517 | falseBlock = &falseRegion.front(); |
| 518 | mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); |
| 519 | rewriter.setInsertionPointToEnd(&falseRegion.back()); |
| 520 | auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); |
| 521 | rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), |
| 522 | continueBlock); |
| 523 | rewriter.inlineRegionBefore(falseRegion, continueBlock); |
| 524 | |
| 525 | rewriter.setInsertionPointToEnd(condBlock); |
| 526 | rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); |
| 527 | |
| 528 | rewriter.replaceOp(op, continueBlock->getArguments()); |
| 529 | |
| 530 | // Ok, we're done! |
| 531 | return mlir::success(); |
| 532 | } |
| 533 | }; |
| 534 | |
| 535 | void populateFlattenCFGPatterns(RewritePatternSet &patterns) { |
| 536 | patterns |
| 537 | .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening, |
| 538 | CIRSwitchOpFlattening, CIRTernaryOpFlattening>( |
| 539 | patterns.getContext()); |
| 540 | } |
| 541 | |
| 542 | void CIRFlattenCFGPass::runOnOperation() { |
| 543 | RewritePatternSet patterns(&getContext()); |
| 544 | populateFlattenCFGPatterns(patterns); |
| 545 | |
| 546 | // Collect operations to apply patterns. |
| 547 | llvm::SmallVector<Operation *, 16> ops; |
| 548 | getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) { |
| 549 | assert(!cir::MissingFeatures::ifOp()); |
| 550 | assert(!cir::MissingFeatures::switchOp()); |
| 551 | assert(!cir::MissingFeatures::tryOp()); |
| 552 | if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op)) |
| 553 | ops.push_back(op); |
| 554 | }); |
| 555 | |
| 556 | // Apply patterns. |
| 557 | if (applyOpPatternsGreedily(ops, std::move(patterns)).failed()) |
| 558 | signalPassFailure(); |
| 559 | } |
| 560 | |
| 561 | } // namespace |
| 562 | |
| 563 | namespace mlir { |
| 564 | |
| 565 | std::unique_ptr<Pass> createCIRFlattenCFGPass() { |
| 566 | return std::make_unique<CIRFlattenCFGPass>(); |
| 567 | } |
| 568 | |
| 569 | } // namespace mlir |
| 570 | |