| 1 | //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Define conversions from the ControlFlow dialect to the SCF dialect. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h" |
| 14 | |
| 15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 16 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
| 17 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| 18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
| 19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 20 | #include "mlir/Dialect/UB/IR/UBOps.h" |
| 21 | #include "mlir/Pass/Pass.h" |
| 22 | #include "mlir/Transforms/CFGToSCF.h" |
| 23 | |
| 24 | namespace mlir { |
| 25 | #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS |
| 26 | #include "mlir/Conversion/Passes.h.inc" |
| 27 | } // namespace mlir |
| 28 | |
| 29 | using namespace mlir; |
| 30 | |
| 31 | FailureOr<Operation *> |
| 32 | ControlFlowToSCFTransformation::createStructuredBranchRegionOp( |
| 33 | OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, |
| 34 | MutableArrayRef<Region> regions) { |
| 35 | if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) { |
| 36 | assert(regions.size() == 2); |
| 37 | auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(), |
| 38 | resultTypes, condBrOp.getCondition()); |
| 39 | ifOp.getThenRegion().takeBody(regions[0]); |
| 40 | ifOp.getElseRegion().takeBody(regions[1]); |
| 41 | return ifOp.getOperation(); |
| 42 | } |
| 43 | |
| 44 | if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) { |
| 45 | // `getCFGSwitchValue` returns an i32 that we need to convert to index |
| 46 | // fist. |
| 47 | auto cast = builder.create<arith::IndexCastUIOp>( |
| 48 | controlFlowCondOp->getLoc(), builder.getIndexType(), |
| 49 | switchOp.getFlag()); |
| 50 | SmallVector<int64_t> cases; |
| 51 | if (auto caseValues = switchOp.getCaseValues()) |
| 52 | llvm::append_range( |
| 53 | cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { |
| 54 | return apInt.getZExtValue(); |
| 55 | })); |
| 56 | |
| 57 | assert(regions.size() == cases.size() + 1); |
| 58 | |
| 59 | auto indexSwitchOp = builder.create<scf::IndexSwitchOp>( |
| 60 | controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); |
| 61 | |
| 62 | indexSwitchOp.getDefaultRegion().takeBody(regions[0]); |
| 63 | for (auto &&[targetRegion, sourceRegion] : |
| 64 | llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) |
| 65 | targetRegion.takeBody(sourceRegion); |
| 66 | |
| 67 | return indexSwitchOp.getOperation(); |
| 68 | } |
| 69 | |
| 70 | controlFlowCondOp->emitOpError( |
| 71 | message: "Cannot convert unknown control flow op to structured control flow" ); |
| 72 | return failure(); |
| 73 | } |
| 74 | |
| 75 | LogicalResult |
| 76 | ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( |
| 77 | Location loc, OpBuilder &builder, Operation *branchRegionOp, |
| 78 | Operation *replacedControlFlowOp, ValueRange results) { |
| 79 | builder.create<scf::YieldOp>(loc, results); |
| 80 | return success(); |
| 81 | } |
| 82 | |
| 83 | FailureOr<Operation *> |
| 84 | ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( |
| 85 | OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, |
| 86 | Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { |
| 87 | Location loc = replacedOp->getLoc(); |
| 88 | auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(), |
| 89 | loopVariablesInit); |
| 90 | |
| 91 | whileOp.getBefore().takeBody(loopBody); |
| 92 | |
| 93 | builder.setInsertionPointToEnd(&whileOp.getBefore().back()); |
| 94 | // `getCFGSwitchValue` returns a i32. We therefore need to truncate the |
| 95 | // condition to i1 first. It is guaranteed to be either 0 or 1 already. |
| 96 | builder.create<scf::ConditionOp>( |
| 97 | loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition), |
| 98 | loopVariablesNextIter); |
| 99 | |
| 100 | Block *afterBlock = builder.createBlock(&whileOp.getAfter()); |
| 101 | afterBlock->addArguments( |
| 102 | types: loopVariablesInit.getTypes(), |
| 103 | locs: SmallVector<Location>(loopVariablesInit.size(), loc)); |
| 104 | builder.create<scf::YieldOp>(loc, afterBlock->getArguments()); |
| 105 | |
| 106 | return whileOp.getOperation(); |
| 107 | } |
| 108 | |
| 109 | Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, |
| 110 | OpBuilder &builder, |
| 111 | unsigned int value) { |
| 112 | return builder.create<arith::ConstantOp>(loc, |
| 113 | builder.getI32IntegerAttr(value)); |
| 114 | } |
| 115 | |
| 116 | void ControlFlowToSCFTransformation::createCFGSwitchOp( |
| 117 | Location loc, OpBuilder &builder, Value flag, |
| 118 | ArrayRef<unsigned int> caseValues, BlockRange caseDestinations, |
| 119 | ArrayRef<ValueRange> caseArguments, Block *defaultDest, |
| 120 | ValueRange defaultArgs) { |
| 121 | builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs, |
| 122 | llvm::to_vector_of<int32_t>(caseValues), |
| 123 | caseDestinations, caseArguments); |
| 124 | } |
| 125 | |
| 126 | Value ControlFlowToSCFTransformation::getUndefValue(Location loc, |
| 127 | OpBuilder &builder, |
| 128 | Type type) { |
| 129 | return builder.create<ub::PoisonOp>(loc, type, nullptr); |
| 130 | } |
| 131 | |
| 132 | FailureOr<Operation *> |
| 133 | ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, |
| 134 | OpBuilder &builder, |
| 135 | Region ®ion) { |
| 136 | |
| 137 | // TODO: This should create a `ub.unreachable` op. Once such an operation |
| 138 | // exists to make the pass independent of the func dialect. For now just |
| 139 | // return poison values. |
| 140 | Operation *parentOp = region.getParentOp(); |
| 141 | auto funcOp = dyn_cast<func::FuncOp>(parentOp); |
| 142 | if (!funcOp) |
| 143 | return emitError(loc, message: "Cannot create unreachable terminator for '" ) |
| 144 | << parentOp->getName() << "'" ; |
| 145 | |
| 146 | return builder |
| 147 | .create<func::ReturnOp>( |
| 148 | loc, llvm::map_to_vector(funcOp.getResultTypes(), |
| 149 | [&](Type type) { |
| 150 | return getUndefValue(loc, builder, type); |
| 151 | })) |
| 152 | .getOperation(); |
| 153 | } |
| 154 | |
| 155 | namespace { |
| 156 | |
| 157 | struct LiftControlFlowToSCF |
| 158 | : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> { |
| 159 | |
| 160 | using Base::Base; |
| 161 | |
| 162 | void runOnOperation() override { |
| 163 | ControlFlowToSCFTransformation transformation; |
| 164 | |
| 165 | bool changed = false; |
| 166 | Operation *op = getOperation(); |
| 167 | WalkResult result = op->walk(callback: [&](func::FuncOp funcOp) { |
| 168 | if (funcOp.getBody().empty()) |
| 169 | return WalkResult::advance(); |
| 170 | |
| 171 | auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp) |
| 172 | : getAnalysis<DominanceInfo>(); |
| 173 | |
| 174 | auto visitor = [&](Operation *innerOp) -> WalkResult { |
| 175 | for (Region ® : innerOp->getRegions()) { |
| 176 | FailureOr<bool> changedFunc = |
| 177 | transformCFGToSCF(reg, transformation, domInfo); |
| 178 | if (failed(Result: changedFunc)) |
| 179 | return WalkResult::interrupt(); |
| 180 | |
| 181 | changed |= *changedFunc; |
| 182 | } |
| 183 | return WalkResult::advance(); |
| 184 | }; |
| 185 | |
| 186 | if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted()) |
| 187 | return WalkResult::interrupt(); |
| 188 | |
| 189 | return WalkResult::advance(); |
| 190 | }); |
| 191 | if (result.wasInterrupted()) |
| 192 | return signalPassFailure(); |
| 193 | |
| 194 | if (!changed) |
| 195 | markAllAnalysesPreserved(); |
| 196 | } |
| 197 | }; |
| 198 | } // namespace |
| 199 | |