1//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Define conversions from the ControlFlow dialect to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/SCF/IR/SCF.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/Pass/Pass.h"
23#include "mlir/Transforms/CFGToSCF.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
27#include "mlir/Conversion/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31
32FailureOr<Operation *>
33ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
34 OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
35 MutableArrayRef<Region> regions) {
36 if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
37 assert(regions.size() == 2);
38 auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
39 resultTypes, condBrOp.getCondition());
40 ifOp.getThenRegion().takeBody(regions[0]);
41 ifOp.getElseRegion().takeBody(regions[1]);
42 return ifOp.getOperation();
43 }
44
45 if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
46 // `getCFGSwitchValue` returns an i32 that we need to convert to index
47 // fist.
48 auto cast = builder.create<arith::IndexCastUIOp>(
49 controlFlowCondOp->getLoc(), builder.getIndexType(),
50 switchOp.getFlag());
51 SmallVector<int64_t> cases;
52 if (auto caseValues = switchOp.getCaseValues())
53 llvm::append_range(
54 cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
55 return apInt.getZExtValue();
56 }));
57
58 assert(regions.size() == cases.size() + 1);
59
60 auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
61 controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
62
63 indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
64 for (auto &&[targetRegion, sourceRegion] :
65 llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
66 targetRegion.takeBody(sourceRegion);
67
68 return indexSwitchOp.getOperation();
69 }
70
71 controlFlowCondOp->emitOpError(
72 message: "Cannot convert unknown control flow op to structured control flow");
73 return failure();
74}
75
76LogicalResult
77ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
78 Location loc, OpBuilder &builder, Operation *branchRegionOp,
79 Operation *replacedControlFlowOp, ValueRange results) {
80 builder.create<scf::YieldOp>(loc, results);
81 return success();
82}
83
84FailureOr<Operation *>
85ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
86 OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
87 Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
88 Location loc = replacedOp->getLoc();
89 auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
90 loopVariablesInit);
91
92 whileOp.getBefore().takeBody(loopBody);
93
94 builder.setInsertionPointToEnd(&whileOp.getBefore().back());
95 // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
96 // condition to i1 first. It is guaranteed to be either 0 or 1 already.
97 builder.create<scf::ConditionOp>(
98 loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
99 loopVariablesNextIter);
100
101 Block *afterBlock = builder.createBlock(&whileOp.getAfter());
102 afterBlock->addArguments(
103 types: loopVariablesInit.getTypes(),
104 locs: SmallVector<Location>(loopVariablesInit.size(), loc));
105 builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
106
107 return whileOp.getOperation();
108}
109
110Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
111 OpBuilder &builder,
112 unsigned int value) {
113 return builder.create<arith::ConstantOp>(loc,
114 builder.getI32IntegerAttr(value));
115}
116
117void ControlFlowToSCFTransformation::createCFGSwitchOp(
118 Location loc, OpBuilder &builder, Value flag,
119 ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
120 ArrayRef<ValueRange> caseArguments, Block *defaultDest,
121 ValueRange defaultArgs) {
122 builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
123 llvm::to_vector_of<int32_t>(caseValues),
124 caseDestinations, caseArguments);
125}
126
127Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
128 OpBuilder &builder,
129 Type type) {
130 return builder.create<ub::PoisonOp>(loc, type, nullptr);
131}
132
133FailureOr<Operation *>
134ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
135 OpBuilder &builder,
136 Region &region) {
137
138 // TODO: This should create a `ub.unreachable` op. Once such an operation
139 // exists to make the pass independent of the func dialect. For now just
140 // return poison values.
141 Operation *parentOp = region.getParentOp();
142 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
143 if (!funcOp)
144 return emitError(loc, message: "Cannot create unreachable terminator for '")
145 << parentOp->getName() << "'";
146
147 return builder
148 .create<func::ReturnOp>(
149 loc, llvm::map_to_vector(funcOp.getResultTypes(),
150 [&](Type type) {
151 return getUndefValue(loc, builder, type);
152 }))
153 .getOperation();
154}
155
156namespace {
157
158struct LiftControlFlowToSCF
159 : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
160
161 using Base::Base;
162
163 void runOnOperation() override {
164 ControlFlowToSCFTransformation transformation;
165
166 bool changed = false;
167 Operation *op = getOperation();
168 WalkResult result = op->walk(callback: [&](func::FuncOp funcOp) {
169 if (funcOp.getBody().empty())
170 return WalkResult::advance();
171
172 auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
173 : getAnalysis<DominanceInfo>();
174
175 auto visitor = [&](Operation *innerOp) -> WalkResult {
176 for (Region &reg : innerOp->getRegions()) {
177 FailureOr<bool> changedFunc =
178 transformCFGToSCF(reg, transformation, domInfo);
179 if (failed(result: changedFunc))
180 return WalkResult::interrupt();
181
182 changed |= *changedFunc;
183 }
184 return WalkResult::advance();
185 };
186
187 if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
188 return WalkResult::interrupt();
189
190 return WalkResult::advance();
191 });
192 if (result.wasInterrupted())
193 return signalPassFailure();
194
195 if (!changed)
196 markAllAnalysesPreserved();
197 }
198};
199} // namespace
200

source code of mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp