1 | //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Define conversions from the ControlFlow dialect to the SCF dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
17 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/Dialect/SCF/IR/SCF.h" |
21 | #include "mlir/Dialect/UB/IR/UBOps.h" |
22 | #include "mlir/Pass/Pass.h" |
23 | #include "mlir/Transforms/CFGToSCF.h" |
24 | |
25 | namespace mlir { |
26 | #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS |
27 | #include "mlir/Conversion/Passes.h.inc" |
28 | } // namespace mlir |
29 | |
30 | using namespace mlir; |
31 | |
32 | FailureOr<Operation *> |
33 | ControlFlowToSCFTransformation::createStructuredBranchRegionOp( |
34 | OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, |
35 | MutableArrayRef<Region> regions) { |
36 | if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) { |
37 | assert(regions.size() == 2); |
38 | auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(), |
39 | resultTypes, condBrOp.getCondition()); |
40 | ifOp.getThenRegion().takeBody(regions[0]); |
41 | ifOp.getElseRegion().takeBody(regions[1]); |
42 | return ifOp.getOperation(); |
43 | } |
44 | |
45 | if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) { |
46 | // `getCFGSwitchValue` returns an i32 that we need to convert to index |
47 | // fist. |
48 | auto cast = builder.create<arith::IndexCastUIOp>( |
49 | controlFlowCondOp->getLoc(), builder.getIndexType(), |
50 | switchOp.getFlag()); |
51 | SmallVector<int64_t> cases; |
52 | if (auto caseValues = switchOp.getCaseValues()) |
53 | llvm::append_range( |
54 | cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { |
55 | return apInt.getZExtValue(); |
56 | })); |
57 | |
58 | assert(regions.size() == cases.size() + 1); |
59 | |
60 | auto indexSwitchOp = builder.create<scf::IndexSwitchOp>( |
61 | controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); |
62 | |
63 | indexSwitchOp.getDefaultRegion().takeBody(regions[0]); |
64 | for (auto &&[targetRegion, sourceRegion] : |
65 | llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) |
66 | targetRegion.takeBody(sourceRegion); |
67 | |
68 | return indexSwitchOp.getOperation(); |
69 | } |
70 | |
71 | controlFlowCondOp->emitOpError( |
72 | message: "Cannot convert unknown control flow op to structured control flow" ); |
73 | return failure(); |
74 | } |
75 | |
76 | LogicalResult |
77 | ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( |
78 | Location loc, OpBuilder &builder, Operation *branchRegionOp, |
79 | Operation *replacedControlFlowOp, ValueRange results) { |
80 | builder.create<scf::YieldOp>(loc, results); |
81 | return success(); |
82 | } |
83 | |
84 | FailureOr<Operation *> |
85 | ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( |
86 | OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, |
87 | Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { |
88 | Location loc = replacedOp->getLoc(); |
89 | auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(), |
90 | loopVariablesInit); |
91 | |
92 | whileOp.getBefore().takeBody(loopBody); |
93 | |
94 | builder.setInsertionPointToEnd(&whileOp.getBefore().back()); |
95 | // `getCFGSwitchValue` returns a i32. We therefore need to truncate the |
96 | // condition to i1 first. It is guaranteed to be either 0 or 1 already. |
97 | builder.create<scf::ConditionOp>( |
98 | loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition), |
99 | loopVariablesNextIter); |
100 | |
101 | Block *afterBlock = builder.createBlock(&whileOp.getAfter()); |
102 | afterBlock->addArguments( |
103 | types: loopVariablesInit.getTypes(), |
104 | locs: SmallVector<Location>(loopVariablesInit.size(), loc)); |
105 | builder.create<scf::YieldOp>(loc, afterBlock->getArguments()); |
106 | |
107 | return whileOp.getOperation(); |
108 | } |
109 | |
110 | Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, |
111 | OpBuilder &builder, |
112 | unsigned int value) { |
113 | return builder.create<arith::ConstantOp>(loc, |
114 | builder.getI32IntegerAttr(value)); |
115 | } |
116 | |
117 | void ControlFlowToSCFTransformation::createCFGSwitchOp( |
118 | Location loc, OpBuilder &builder, Value flag, |
119 | ArrayRef<unsigned int> caseValues, BlockRange caseDestinations, |
120 | ArrayRef<ValueRange> caseArguments, Block *defaultDest, |
121 | ValueRange defaultArgs) { |
122 | builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs, |
123 | llvm::to_vector_of<int32_t>(caseValues), |
124 | caseDestinations, caseArguments); |
125 | } |
126 | |
127 | Value ControlFlowToSCFTransformation::getUndefValue(Location loc, |
128 | OpBuilder &builder, |
129 | Type type) { |
130 | return builder.create<ub::PoisonOp>(loc, type, nullptr); |
131 | } |
132 | |
133 | FailureOr<Operation *> |
134 | ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, |
135 | OpBuilder &builder, |
136 | Region ®ion) { |
137 | |
138 | // TODO: This should create a `ub.unreachable` op. Once such an operation |
139 | // exists to make the pass independent of the func dialect. For now just |
140 | // return poison values. |
141 | Operation *parentOp = region.getParentOp(); |
142 | auto funcOp = dyn_cast<func::FuncOp>(parentOp); |
143 | if (!funcOp) |
144 | return emitError(loc, message: "Cannot create unreachable terminator for '" ) |
145 | << parentOp->getName() << "'" ; |
146 | |
147 | return builder |
148 | .create<func::ReturnOp>( |
149 | loc, llvm::map_to_vector(funcOp.getResultTypes(), |
150 | [&](Type type) { |
151 | return getUndefValue(loc, builder, type); |
152 | })) |
153 | .getOperation(); |
154 | } |
155 | |
156 | namespace { |
157 | |
158 | struct LiftControlFlowToSCF |
159 | : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> { |
160 | |
161 | using Base::Base; |
162 | |
163 | void runOnOperation() override { |
164 | ControlFlowToSCFTransformation transformation; |
165 | |
166 | bool changed = false; |
167 | Operation *op = getOperation(); |
168 | WalkResult result = op->walk(callback: [&](func::FuncOp funcOp) { |
169 | if (funcOp.getBody().empty()) |
170 | return WalkResult::advance(); |
171 | |
172 | auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp) |
173 | : getAnalysis<DominanceInfo>(); |
174 | |
175 | auto visitor = [&](Operation *innerOp) -> WalkResult { |
176 | for (Region ® : innerOp->getRegions()) { |
177 | FailureOr<bool> changedFunc = |
178 | transformCFGToSCF(reg, transformation, domInfo); |
179 | if (failed(result: changedFunc)) |
180 | return WalkResult::interrupt(); |
181 | |
182 | changed |= *changedFunc; |
183 | } |
184 | return WalkResult::advance(); |
185 | }; |
186 | |
187 | if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted()) |
188 | return WalkResult::interrupt(); |
189 | |
190 | return WalkResult::advance(); |
191 | }); |
192 | if (result.wasInterrupted()) |
193 | return signalPassFailure(); |
194 | |
195 | if (!changed) |
196 | markAllAnalysesPreserved(); |
197 | } |
198 | }; |
199 | } // namespace |
200 | |