1//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR standard and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15
16#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
21#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34
35#define PASS_NAME "convert-cf-to-llvm"
36
37namespace {
38/// Lower `cf.assert`. The default lowering calls the `abort` function if the
39/// assertion is violated and has no effect otherwise. The failure message is
40/// ignored by the default lowering but should be propagated by any custom
41/// lowering.
42struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
43 explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
44 bool abortOnFailedAssert = true,
45 SymbolTableCollection *symbolTables = nullptr)
46 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
47 abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {}
48
49 LogicalResult
50 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 auto loc = op.getLoc();
53 auto module = op->getParentOfType<ModuleOp>();
54
55 // Split block at `assert` operation.
56 Block *opBlock = rewriter.getInsertionBlock();
57 auto opPosition = rewriter.getInsertionPoint();
58 Block *continuationBlock = rewriter.splitBlock(block: opBlock, before: opPosition);
59
60 // Failed block: Generate IR to print the message and call `abort`.
61 Block *failureBlock = rewriter.createBlock(parent: opBlock->getParent());
62 auto createResult = LLVM::createPrintStrCall(
63 builder&: rewriter, loc, moduleOp: module, symbolName: "assert_msg", string: op.getMsg(), typeConverter: *getTypeConverter(),
64 /*addNewLine=*/addNewline: false,
65 /*runtimeFunctionName=*/"puts", symbolTables);
66 if (createResult.failed())
67 return failure();
68
69 if (abortOnFailedAssert) {
70 // Insert the `abort` declaration if necessary.
71 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(name: "abort");
72 if (!abortFunc) {
73 OpBuilder::InsertionGuard guard(rewriter);
74 rewriter.setInsertionPointToStart(module.getBody());
75 auto abortFuncTy = LLVM::LLVMFunctionType::get(result: getVoidType(), arguments: {});
76 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(location: rewriter.getUnknownLoc(),
77 args: "abort", args&: abortFuncTy);
78 }
79 rewriter.create<LLVM::CallOp>(location: loc, args&: abortFunc, args: ValueRange());
80 rewriter.create<LLVM::UnreachableOp>(location: loc);
81 } else {
82 rewriter.create<LLVM::BrOp>(location: loc, args: ValueRange(), args&: continuationBlock);
83 }
84
85 // Generate assertion test.
86 rewriter.setInsertionPointToEnd(opBlock);
87 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
88 op, args: adaptor.getArg(), args&: continuationBlock, args&: failureBlock);
89
90 return success();
91 }
92
93private:
94 /// If set to `false`, messages are printed but program execution continues.
95 /// This is useful for testing asserts.
96 bool abortOnFailedAssert = true;
97
98 SymbolTableCollection *symbolTables = nullptr;
99};
100
101/// Helper function for converting branch ops. This function converts the
102/// signature of the given block. If the new block signature is different from
103/// `expectedTypes`, returns "failure".
104static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105 const TypeConverter *converter,
106 Operation *branchOp, Block *block,
107 TypeRange expectedTypes) {
108 assert(converter && "expected non-null type converter");
109 assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110
111 // There is nothing to do if the types already match.
112 if (block->getArgumentTypes() == expectedTypes)
113 return block;
114
115 // Compute the new block argument types and convert the block.
116 std::optional<TypeConverter::SignatureConversion> conversion =
117 converter->convertBlockSignature(block);
118 if (!conversion)
119 return rewriter.notifyMatchFailure(arg&: branchOp,
120 msg: "could not compute block signature");
121 if (expectedTypes != conversion->getConvertedTypes())
122 return rewriter.notifyMatchFailure(
123 arg&: branchOp,
124 msg: "mismatch between adaptor operand types and computed block signature");
125 return rewriter.applySignatureConversion(block, conversion&: *conversion, converter);
126}
127
128/// Convert the destination block signature (if necessary) and lower the branch
129/// op to llvm.br.
130struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
131 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
132
133 LogicalResult
134 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 FailureOr<Block *> convertedBlock =
137 getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: op.getSuccessor(),
138 expectedTypes: TypeRange(adaptor.getOperands()));
139 if (failed(Result: convertedBlock))
140 return failure();
141 DictionaryAttr attrs = op->getAttrDictionary();
142 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
143 op, args: adaptor.getOperands(), args&: *convertedBlock);
144 // TODO: We should not just forward all attributes like that. But there are
145 // existing Flang tests that depend on this behavior.
146 newOp->setAttrs(attrs);
147 return success();
148 }
149};
150
151/// Convert the destination block signatures (if necessary) and lower the
152/// branch op to llvm.cond_br.
153struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
154 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
155
156 LogicalResult
157 matchAndRewrite(cf::CondBranchOp op,
158 typename cf::CondBranchOp::Adaptor adaptor,
159 ConversionPatternRewriter &rewriter) const override {
160 FailureOr<Block *> convertedTrueBlock =
161 getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: op.getTrueDest(),
162 expectedTypes: TypeRange(adaptor.getTrueDestOperands()));
163 if (failed(Result: convertedTrueBlock))
164 return failure();
165 FailureOr<Block *> convertedFalseBlock =
166 getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: op.getFalseDest(),
167 expectedTypes: TypeRange(adaptor.getFalseDestOperands()));
168 if (failed(Result: convertedFalseBlock))
169 return failure();
170 DictionaryAttr attrs = op->getAttrDictionary();
171 auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
172 op, args: adaptor.getCondition(), args: adaptor.getTrueDestOperands(),
173 args: adaptor.getFalseDestOperands(), args: op.getBranchWeightsAttr(),
174 args&: *convertedTrueBlock, args&: *convertedFalseBlock);
175 // TODO: We should not just forward all attributes like that. But there are
176 // existing Flang tests that depend on this behavior.
177 newOp->setAttrs(attrs);
178 return success();
179 }
180};
181
182/// Convert the destination block signatures (if necessary) and lower the
183/// switch op to llvm.switch.
184struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
185 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
186
187 LogicalResult
188 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
189 ConversionPatternRewriter &rewriter) const override {
190 // Get or convert default block.
191 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
192 rewriter, converter: getTypeConverter(), branchOp: op, block: op.getDefaultDestination(),
193 expectedTypes: TypeRange(adaptor.getDefaultOperands()));
194 if (failed(Result: convertedDefaultBlock))
195 return failure();
196
197 // Get or convert all case blocks.
198 SmallVector<Block *> caseDestinations;
199 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
200 for (auto it : llvm::enumerate(First: op.getCaseDestinations())) {
201 Block *b = it.value();
202 FailureOr<Block *> convertedBlock =
203 getConvertedBlock(rewriter, converter: getTypeConverter(), branchOp: op, block: b,
204 expectedTypes: TypeRange(caseOperands[it.index()]));
205 if (failed(Result: convertedBlock))
206 return failure();
207 caseDestinations.push_back(Elt: *convertedBlock);
208 }
209
210 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
211 op, args: adaptor.getFlag(), args&: *convertedDefaultBlock,
212 args: adaptor.getDefaultOperands(), args: adaptor.getCaseValuesAttr(),
213 args&: caseDestinations, args&: caseOperands);
214 return success();
215 }
216};
217
218} // namespace
219
220void mlir::cf::populateControlFlowToLLVMConversionPatterns(
221 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
222 // clang-format off
223 patterns.add<
224 BranchOpLowering,
225 CondBranchOpLowering,
226 SwitchOpLowering>(arg: converter);
227 // clang-format on
228}
229
230void mlir::cf::populateAssertToLLVMConversionPattern(
231 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
232 bool abortOnFailure, SymbolTableCollection *symbolTables) {
233 patterns.add<AssertOpLowering>(arg: converter, args&: abortOnFailure, args&: symbolTables);
234}
235
236//===----------------------------------------------------------------------===//
237// Pass Definition
238//===----------------------------------------------------------------------===//
239
240namespace {
241/// A pass converting MLIR operations into the LLVM IR dialect.
242struct ConvertControlFlowToLLVM
243 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
244
245 using Base::Base;
246
247 /// Run the dialect converter on the module.
248 void runOnOperation() override {
249 MLIRContext *ctx = &getContext();
250 LLVMConversionTarget target(*ctx);
251 // This pass lowers only CF dialect ops, but it also modifies block
252 // signatures inside other ops. These ops should be treated as legal. They
253 // are lowered by other passes.
254 target.markUnknownOpDynamicallyLegal(fn: [&](Operation *op) {
255 return op->getDialect() !=
256 ctx->getLoadedDialect<cf::ControlFlowDialect>();
257 });
258
259 LowerToLLVMOptions options(ctx);
260 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
261 options.overrideIndexBitwidth(bitwidth: indexBitwidth);
262
263 LLVMTypeConverter converter(ctx, options);
264 RewritePatternSet patterns(ctx);
265 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
266 mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
267
268 if (failed(Result: applyPartialConversion(op: getOperation(), target,
269 patterns: std::move(patterns))))
270 signalPassFailure();
271 }
272};
273} // namespace
274
275//===----------------------------------------------------------------------===//
276// ConvertToLLVMPatternInterface implementation
277//===----------------------------------------------------------------------===//
278
279namespace {
280/// Implement the interface to convert MemRef to LLVM.
281struct ControlFlowToLLVMDialectInterface
282 : public ConvertToLLVMPatternInterface {
283 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
284 void loadDependentDialects(MLIRContext *context) const final {
285 context->loadDialect<LLVM::LLVMDialect>();
286 }
287
288 /// Hook for derived dialect interface to provide conversion patterns
289 /// and mark dialect legal for the conversion target.
290 void populateConvertToLLVMConversionPatterns(
291 ConversionTarget &target, LLVMTypeConverter &typeConverter,
292 RewritePatternSet &patterns) const final {
293 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter: typeConverter,
294 patterns);
295 mlir::cf::populateAssertToLLVMConversionPattern(converter: typeConverter, patterns);
296 }
297};
298} // namespace
299
300void mlir::cf::registerConvertControlFlowToLLVMInterface(
301 DialectRegistry &registry) {
302 registry.addExtension(extensionFn: +[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
303 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
304 });
305}
306

source code of mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp