1//===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Define conversions from the ControlFlow dialect to the SCF dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
14
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/SCF/IR/SCF.h"
20#include "mlir/Dialect/UB/IR/UBOps.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/CFGToSCF.h"
23
24namespace mlir {
25#define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
26#include "mlir/Conversion/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30
31FailureOr<Operation *>
32ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
33 OpBuilder &builder, Operation *controlFlowCondOp, TypeRange resultTypes,
34 MutableArrayRef<Region> regions) {
35 if (auto condBrOp = dyn_cast<cf::CondBranchOp>(controlFlowCondOp)) {
36 assert(regions.size() == 2);
37 auto ifOp = builder.create<scf::IfOp>(controlFlowCondOp->getLoc(),
38 resultTypes, condBrOp.getCondition());
39 ifOp.getThenRegion().takeBody(regions[0]);
40 ifOp.getElseRegion().takeBody(regions[1]);
41 return ifOp.getOperation();
42 }
43
44 if (auto switchOp = dyn_cast<cf::SwitchOp>(controlFlowCondOp)) {
45 // `getCFGSwitchValue` returns an i32 that we need to convert to index
46 // fist.
47 auto cast = builder.create<arith::IndexCastUIOp>(
48 controlFlowCondOp->getLoc(), builder.getIndexType(),
49 switchOp.getFlag());
50 SmallVector<int64_t> cases;
51 if (auto caseValues = switchOp.getCaseValues())
52 llvm::append_range(
53 cases, llvm::map_range(*caseValues, [](const llvm::APInt &apInt) {
54 return apInt.getZExtValue();
55 }));
56
57 assert(regions.size() == cases.size() + 1);
58
59 auto indexSwitchOp = builder.create<scf::IndexSwitchOp>(
60 controlFlowCondOp->getLoc(), resultTypes, cast, cases, cases.size());
61
62 indexSwitchOp.getDefaultRegion().takeBody(regions[0]);
63 for (auto &&[targetRegion, sourceRegion] :
64 llvm::zip(indexSwitchOp.getCaseRegions(), llvm::drop_begin(regions)))
65 targetRegion.takeBody(sourceRegion);
66
67 return indexSwitchOp.getOperation();
68 }
69
70 controlFlowCondOp->emitOpError(
71 message: "Cannot convert unknown control flow op to structured control flow");
72 return failure();
73}
74
75LogicalResult
76ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
77 Location loc, OpBuilder &builder, Operation *branchRegionOp,
78 Operation *replacedControlFlowOp, ValueRange results) {
79 builder.create<scf::YieldOp>(loc, results);
80 return success();
81}
82
83FailureOr<Operation *>
84ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
85 OpBuilder &builder, Operation *replacedOp, ValueRange loopVariablesInit,
86 Value condition, ValueRange loopVariablesNextIter, Region &&loopBody) {
87 Location loc = replacedOp->getLoc();
88 auto whileOp = builder.create<scf::WhileOp>(loc, loopVariablesInit.getTypes(),
89 loopVariablesInit);
90
91 whileOp.getBefore().takeBody(loopBody);
92
93 builder.setInsertionPointToEnd(&whileOp.getBefore().back());
94 // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
95 // condition to i1 first. It is guaranteed to be either 0 or 1 already.
96 builder.create<scf::ConditionOp>(
97 loc, builder.create<arith::TruncIOp>(loc, builder.getI1Type(), condition),
98 loopVariablesNextIter);
99
100 Block *afterBlock = builder.createBlock(&whileOp.getAfter());
101 afterBlock->addArguments(
102 types: loopVariablesInit.getTypes(),
103 locs: SmallVector<Location>(loopVariablesInit.size(), loc));
104 builder.create<scf::YieldOp>(loc, afterBlock->getArguments());
105
106 return whileOp.getOperation();
107}
108
109Value ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc,
110 OpBuilder &builder,
111 unsigned int value) {
112 return builder.create<arith::ConstantOp>(loc,
113 builder.getI32IntegerAttr(value));
114}
115
116void ControlFlowToSCFTransformation::createCFGSwitchOp(
117 Location loc, OpBuilder &builder, Value flag,
118 ArrayRef<unsigned int> caseValues, BlockRange caseDestinations,
119 ArrayRef<ValueRange> caseArguments, Block *defaultDest,
120 ValueRange defaultArgs) {
121 builder.create<cf::SwitchOp>(loc, flag, defaultDest, defaultArgs,
122 llvm::to_vector_of<int32_t>(caseValues),
123 caseDestinations, caseArguments);
124}
125
126Value ControlFlowToSCFTransformation::getUndefValue(Location loc,
127 OpBuilder &builder,
128 Type type) {
129 return builder.create<ub::PoisonOp>(loc, type, nullptr);
130}
131
132FailureOr<Operation *>
133ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc,
134 OpBuilder &builder,
135 Region &region) {
136
137 // TODO: This should create a `ub.unreachable` op. Once such an operation
138 // exists to make the pass independent of the func dialect. For now just
139 // return poison values.
140 Operation *parentOp = region.getParentOp();
141 auto funcOp = dyn_cast<func::FuncOp>(parentOp);
142 if (!funcOp)
143 return emitError(loc, message: "Cannot create unreachable terminator for '")
144 << parentOp->getName() << "'";
145
146 return builder
147 .create<func::ReturnOp>(
148 loc, llvm::map_to_vector(funcOp.getResultTypes(),
149 [&](Type type) {
150 return getUndefValue(loc, builder, type);
151 }))
152 .getOperation();
153}
154
155namespace {
156
157struct LiftControlFlowToSCF
158 : public impl::LiftControlFlowToSCFPassBase<LiftControlFlowToSCF> {
159
160 using Base::Base;
161
162 void runOnOperation() override {
163 ControlFlowToSCFTransformation transformation;
164
165 bool changed = false;
166 Operation *op = getOperation();
167 WalkResult result = op->walk(callback: [&](func::FuncOp funcOp) {
168 if (funcOp.getBody().empty())
169 return WalkResult::advance();
170
171 auto &domInfo = funcOp != op ? getChildAnalysis<DominanceInfo>(funcOp)
172 : getAnalysis<DominanceInfo>();
173
174 auto visitor = [&](Operation *innerOp) -> WalkResult {
175 for (Region &reg : innerOp->getRegions()) {
176 FailureOr<bool> changedFunc =
177 transformCFGToSCF(reg, transformation, domInfo);
178 if (failed(Result: changedFunc))
179 return WalkResult::interrupt();
180
181 changed |= *changedFunc;
182 }
183 return WalkResult::advance();
184 };
185
186 if (funcOp->walk<WalkOrder::PostOrder>(visitor).wasInterrupted())
187 return WalkResult::interrupt();
188
189 return WalkResult::advance();
190 });
191 if (result.wasInterrupted())
192 return signalPassFailure();
193
194 if (!changed)
195 markAllAnalysesPreserved();
196 }
197};
198} // namespace
199

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp