1//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Define conversions from the ControlFlow dialect to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/Dialect/SCF/IR/SCF.h"
19#include "mlir/Dialect/UB/IR/UBOps.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/CFGToSCF.h"
22
23namespace mlir {
24#define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30FailureOr<Operation *>
31ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
32 OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
33 MutableArrayRef<Region> regions) {
34 if (auto condBrOp = dyn_cast<cf::CondBranchOp>(Val: controlFlowCondOp)) {
35 assert(regions.size() == 2);
36 auto ifOp = builder.create<scf::IfOp>(location: controlFlowCondOp->getLoc(),
37 args&: resultTypes, args: condBrOp.getCondition());
38 ifOp.getThenRegion().takeBody(other&: regions[0]);
39 ifOp.getElseRegion().takeBody(other&: regions[1]);
40 return ifOp.getOperation();
41 }
42
43 if (auto switchOp = dyn_cast<cf::SwitchOp>(Val: controlFlowCondOp)) {
44 // `getCFGSwitchValue` returns an i32 that we need to convert to index
45 // fist.
46 auto cast = builder.create<arith::IndexCastUIOp>(
47 location: controlFlowCondOp->getLoc(), args: builder.getIndexType(),
48 args: switchOp.getFlag());
49 SmallVector<int64_t> cases;
50 if (auto caseValues = switchOp.getCaseValues())
51 llvm::append_range(
52 C&: cases, R: llvm::map_range(C&: *caseValues, F: [](const llvm::APInt &apInt) {
53 return apInt.getZExtValue();
54 }));
55
56 assert(regions.size() == cases.size() + 1);
57
58 auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
59 location: controlFlowCondOp->getLoc(), args&: resultTypes, args&: cast, args&: cases, args: cases.size());
60
61 indexSwitchOp.getDefaultRegion().takeBody(other&: regions[0]);
62 for (auto &&[targetRegion, sourceRegion] :
63 llvm::zip(t: indexSwitchOp.getCaseRegions(), u: llvm::drop_begin(RangeOrContainer&: regions)))
64 targetRegion.takeBody(other&: sourceRegion);
65
66 return indexSwitchOp.getOperation();
67 }
68
69 controlFlowCondOp->emitOpError(
70 message: "Cannot convert unknown control flow op to structured control flow");
71 return failure();
72}
73
74LogicalResult
75ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
76 Location loc, OpBuilder &builder, Operation *branchRegionOp,
77 Operation *replacedControlFlowOp, ValueRange results) {
78 builder.create<scf::YieldOp>(location: loc, args&: results);
79 return success();
80}
81
82FailureOr<Operation *>
83ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
84 OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
85 Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
86 Location loc = replacedOp->getLoc();
87 auto whileOp = builder.create<scf::WhileOp>(location: loc, args: loopVariablesInit.getTypes(),
88 args&: loopVariablesInit);
89
90 whileOp.getBefore().takeBody(other&: loopBody);
91
92 builder.setInsertionPointToEnd(&whileOp.getBefore().back());
93 // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
94 // condition to i1 first. It is guaranteed to be either 0 or 1 already.
95 builder.create<scf::ConditionOp>(
96 location: loc, args: builder.create<arith::TruncIOp>(location: loc, args: builder.getI1Type(), args&: condition),
97 args&: loopVariablesNextIter);
98
99 Block *afterBlock = builder.createBlock(parent: &whileOp.getAfter());
100 afterBlock->addArguments(
101 types: loopVariablesInit.getTypes(),
102 locs: SmallVector<Location>(loopVariablesInit.size(), loc));
103 builder.create<scf::YieldOp>(location: loc, args: afterBlock->getArguments());
104
105 return whileOp.getOperation();
106}
107
108Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
109 OpBuilder &builder,
110 unsigned int value) {
111 return builder.create<arith::ConstantOp>(location: loc,
112 args: builder.getI32IntegerAttr(value));
113}
114
115void ControlFlowToSCFTransformation::createCFGSwitchOp(
116 Location loc, OpBuilder &builder, Value flag,
117 ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
118 ArrayRef<ValueRange> caseArguments, Block *defaultDest,
119 ValueRange defaultArgs) {
120 builder.create<cf::SwitchOp>(location: loc, args&: flag, args&: defaultDest, args&: defaultArgs,
121 args: llvm::to_vector_of<int32_t>(Range&: caseValues),
122 args&: caseDestinations, args&: caseArguments);
123}
124
125Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
126 OpBuilder &builder,
127 Type type) {
128 return builder.create<ub::PoisonOp>(location: loc, args&: type, args: nullptr);
129}
130
131FailureOr<Operation *>
132ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
133 OpBuilder &builder,
134 Region &region) {
135
136 // TODO: This should create a `ub.unreachable` op. Once such an operation
137 // exists to make the pass independent of the func dialect. For now just
138 // return poison values.
139 Operation *parentOp = region.getParentOp();
140 auto funcOp = dyn_cast<func::FuncOp>(Val: parentOp);
141 if (!funcOp)
142 return emitError(loc, message: "Cannot create unreachable terminator for '")
143 << parentOp->getName() << "'";
144
145 return builder
146 .create<func::ReturnOp>(
147 location: loc, args: llvm::map_to_vector(C: funcOp.getResultTypes(),
148 F: [&](Type type) {
149 return getUndefValue(loc, builder, type);
150 }))
151 .getOperation();
152}
153
154namespace {
155
156struct LiftControlFlowToSCF
157 : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
158
159 using Base::Base;
160
161 void runOnOperation() override {
162 ControlFlowToSCFTransformation transformation;
163
164 bool changed = false;
165 Operation *op = getOperation();
166 WalkResult result = op->walk(callback: [&](func::FuncOp funcOp) {
167 if (funcOp.getBody().empty())
168 return WalkResult::advance();
169
170 auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(child: funcOp)
171 : getAnalysis<DominanceInfo>();
172
173 auto visitor = [&](Operation *innerOp) -> WalkResult {
174 for (Region &reg : innerOp->getRegions()) {
175 FailureOr<bool> changedFunc =
176 transformCFGToSCF(region&: reg, interface&: transformation, dominanceInfo&: domInfo);
177 if (failed(Result: changedFunc))
178 return WalkResult::interrupt();
179
180 changed |= *changedFunc;
181 }
182 return WalkResult::advance();
183 };
184
185 if (funcOp->walk<WalkOrder::PostOrder>(callback&: visitor).wasInterrupted())
186 return WalkResult::interrupt();
187
188 return WalkResult::advance();
189 });
190 if (result.wasInterrupted())
191 return signalPassFailure();
192
193 if (!changed)
194 markAllAnalysesPreserved();
195 }
196};
197} // namespace
198

source code of mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp