1 | //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Define conversions from the ControlFlow dialect to the SCF dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h" |
14 | |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" |
17 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
19 | #include "mlir/Dialect/SCF/IR/SCF.h" |
20 | #include "mlir/Dialect/UB/IR/UBOps.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/CFGToSCF.h" |
23 | |
24 | namespace mlir { |
25 | #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS |
26 | #include "mlir/Conversion/Passes.h.inc" |
27 | } // namespace mlir |
28 | |
29 | using namespace mlir; |
30 | |
31 | FailureOr<Operation *> |
32 | ControlFlowToSCFTransformation::createStructuredBranchRegionOp( |
33 | OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes, |
34 | MutableArrayRef<Region> regions) { |
35 | if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) { |
36 | assert(regions.size() == 2); |
37 | auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(), |
38 | resultTypes, condBrOp.getCondition()); |
39 | ifOp.getThenRegion().takeBody(regions[0]); |
40 | ifOp.getElseRegion().takeBody(regions[1]); |
41 | return ifOp.getOperation(); |
42 | } |
43 | |
44 | if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) { |
45 | // `getCFGSwitchValue` returns an i32 that we need to convert to index |
46 | // fist. |
47 | auto cast = builder.create<arith::IndexCastUIOp>( |
48 | controlFlowCondOp->getLoc(), builder.getIndexType(), |
49 | switchOp.getFlag()); |
50 | SmallVector<int64_t> cases; |
51 | if (auto caseValues = switchOp.getCaseValues()) |
52 | llvm::append_range( |
53 | cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) { |
54 | return apInt.getZExtValue(); |
55 | })); |
56 | |
57 | assert(regions.size() == cases.size() + 1); |
58 | |
59 | auto indexSwitchOp = builder.create<scf::IndexSwitchOp>( |
60 | controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size()); |
61 | |
62 | indexSwitchOp.getDefaultRegion().takeBody(regions[0]); |
63 | for (auto &&[targetRegion, sourceRegion] : |
64 | llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions))) |
65 | targetRegion.takeBody(sourceRegion); |
66 | |
67 | return indexSwitchOp.getOperation(); |
68 | } |
69 | |
70 | controlFlowCondOp->emitOpError( |
71 | message: "Cannot convert unknown control flow op to structured control flow" ); |
72 | return failure(); |
73 | } |
74 | |
75 | LogicalResult |
76 | ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp( |
77 | Location loc, OpBuilder &builder, Operation *branchRegionOp, |
78 | Operation *replacedControlFlowOp, ValueRange results) { |
79 | builder.create<scf::YieldOp>(loc, results); |
80 | return success(); |
81 | } |
82 | |
83 | FailureOr<Operation *> |
84 | ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp( |
85 | OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit, |
86 | Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) { |
87 | Location loc = replacedOp->getLoc(); |
88 | auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(), |
89 | loopVariablesInit); |
90 | |
91 | whileOp.getBefore().takeBody(loopBody); |
92 | |
93 | builder.setInsertionPointToEnd(&whileOp.getBefore().back()); |
94 | // `getCFGSwitchValue` returns a i32. We therefore need to truncate the |
95 | // condition to i1 first. It is guaranteed to be either 0 or 1 already. |
96 | builder.create<scf::ConditionOp>( |
97 | loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition), |
98 | loopVariablesNextIter); |
99 | |
100 | Block *afterBlock = builder.createBlock(&whileOp.getAfter()); |
101 | afterBlock->addArguments( |
102 | types: loopVariablesInit.getTypes(), |
103 | locs: SmallVector<Location>(loopVariablesInit.size(), loc)); |
104 | builder.create<scf::YieldOp>(loc, afterBlock->getArguments()); |
105 | |
106 | return whileOp.getOperation(); |
107 | } |
108 | |
109 | Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc, |
110 | OpBuilder &builder, |
111 | unsigned int value) { |
112 | return builder.create<arith::ConstantOp>(loc, |
113 | builder.getI32IntegerAttr(value)); |
114 | } |
115 | |
116 | void ControlFlowToSCFTransformation::createCFGSwitchOp( |
117 | Location loc, OpBuilder &builder, Value flag, |
118 | ArrayRef<unsigned int> caseValues, BlockRange caseDestinations, |
119 | ArrayRef<ValueRange> caseArguments, Block *defaultDest, |
120 | ValueRange defaultArgs) { |
121 | builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs, |
122 | llvm::to_vector_of<int32_t>(caseValues), |
123 | caseDestinations, caseArguments); |
124 | } |
125 | |
126 | Value ControlFlowToSCFTransformation::getUndefValue(Location loc, |
127 | OpBuilder &builder, |
128 | Type type) { |
129 | return builder.create<ub::PoisonOp>(loc, type, nullptr); |
130 | } |
131 | |
132 | FailureOr<Operation *> |
133 | ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, |
134 | OpBuilder &builder, |
135 | Region ®ion) { |
136 | |
137 | // TODO: This should create a `ub.unreachable` op. Once such an operation |
138 | // exists to make the pass independent of the func dialect. For now just |
139 | // return poison values. |
140 | Operation *parentOp = region.getParentOp(); |
141 | auto funcOp = dyn_cast<func::FuncOp>(parentOp); |
142 | if (!funcOp) |
143 | return emitError(loc, message: "Cannot create unreachable terminator for '" ) |
144 | << parentOp->getName() << "'" ; |
145 | |
146 | return builder |
147 | .create<func::ReturnOp>( |
148 | loc, llvm::map_to_vector(funcOp.getResultTypes(), |
149 | [&](Type type) { |
150 | return getUndefValue(loc, builder, type); |
151 | })) |
152 | .getOperation(); |
153 | } |
154 | |
155 | namespace { |
156 | |
157 | struct LiftControlFlowToSCF |
158 | : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> { |
159 | |
160 | using Base::Base; |
161 | |
162 | void runOnOperation() override { |
163 | ControlFlowToSCFTransformation transformation; |
164 | |
165 | bool changed = false; |
166 | Operation *op = getOperation(); |
167 | WalkResult result = op->walk(callback: [&](func::FuncOp funcOp) { |
168 | if (funcOp.getBody().empty()) |
169 | return WalkResult::advance(); |
170 | |
171 | auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp) |
172 | : getAnalysis<DominanceInfo>(); |
173 | |
174 | auto visitor = [&](Operation *innerOp) -> WalkResult { |
175 | for (Region ® : innerOp->getRegions()) { |
176 | FailureOr<bool> changedFunc = |
177 | transformCFGToSCF(reg, transformation, domInfo); |
178 | if (failed(Result: changedFunc)) |
179 | return WalkResult::interrupt(); |
180 | |
181 | changed |= *changedFunc; |
182 | } |
183 | return WalkResult::advance(); |
184 | }; |
185 | |
186 | if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted()) |
187 | return WalkResult::interrupt(); |
188 | |
189 | return WalkResult::advance(); |
190 | }); |
191 | if (result.wasInterrupted()) |
192 | return signalPassFailure(); |
193 | |
194 | if (!changed) |
195 | markAllAnalysesPreserved(); |
196 | } |
197 | }; |
198 | } // namespace |
199 | |