1//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR standard and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15
16#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "llvm/ADT/StringRef.h"
29#include <functional>
30
31namespace mlir {
32#define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38#define PASS_NAME "convert-cf-to-llvm"
39
40namespace {
41/// Lower `cf.assert`. The default lowering calls the `abort` function if the
42/// assertion is violated and has no effect otherwise. The failure message is
43/// ignored by the default lowering but should be propagated by any custom
44/// lowering.
45struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46 explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
47 bool abortOnFailedAssert = true)
48 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49 abortOnFailedAssert(abortOnFailedAssert) {}
50
51 LogicalResult
52 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 auto loc = op.getLoc();
55 auto module = op->getParentOfType<ModuleOp>();
56
57 // Split block at `assert` operation.
58 Block *opBlock = rewriter.getInsertionBlock();
59 auto opPosition = rewriter.getInsertionPoint();
60 Block *continuationBlock = rewriter.splitBlock(block: opBlock, before: opPosition);
61
62 // Failed block: Generate IR to print the message and call `abort`.
63 Block *failureBlock = rewriter.createBlock(parent: opBlock->getParent());
64 LLVM::createPrintStrCall(builder&: rewriter, loc: loc, moduleOp: module, symbolName: "assert_msg", string: op.getMsg(),
65 typeConverter: *getTypeConverter(), /*addNewLine=*/addNewline: false,
66 /*runtimeFunctionName=*/"puts");
67 if (abortOnFailedAssert) {
68 // Insert the `abort` declaration if necessary.
69 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
70 if (!abortFunc) {
71 OpBuilder::InsertionGuard guard(rewriter);
72 rewriter.setInsertionPointToStart(module.getBody());
73 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
74 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
75 "abort", abortFuncTy);
76 }
77 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
78 rewriter.create<LLVM::UnreachableOp>(loc);
79 } else {
80 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
81 }
82
83 // Generate assertion test.
84 rewriter.setInsertionPointToEnd(opBlock);
85 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
86 op, adaptor.getArg(), continuationBlock, failureBlock);
87
88 return success();
89 }
90
91private:
92 /// If set to `false`, messages are printed but program execution continues.
93 /// This is useful for testing asserts.
94 bool abortOnFailedAssert = true;
95};
96
97/// The cf->LLVM lowerings for branching ops require that the blocks they jump
98/// to first have updated types which should be handled by a pattern operating
99/// on the parent op.
100static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter,
101 ValueRange operands,
102 ValueRange blockArgs, Location loc,
103 llvm::StringRef messagePrefix) {
104 for (const auto &idxAndTypes :
105 llvm::enumerate(llvm::zip(blockArgs, operands))) {
106 int64_t i = idxAndTypes.index();
107 Value argValue =
108 rewriter.getRemappedValue(std::get<0>(idxAndTypes.value()));
109 Type operandType = std::get<1>(idxAndTypes.value()).getType();
110 // In the case of an invalid jump, the block argument will have been
111 // remapped to an UnrealizedConversionCast. In the case of a valid jump,
112 // there might still be a no-op conversion cast with both types being equal.
113 // Consider both of these details to see if the jump would be invalid.
114 if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>(
115 argValue.getDefiningOp())) {
116 if (op.getOperandTypes().front() != operandType) {
117 return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) {
118 diag << messagePrefix;
119 diag << "mismatched types from operand # " << i << " ";
120 diag << operandType;
121 diag << " not compatible with destination block argument type ";
122 diag << op.getOperandTypes().front();
123 diag << " which should be converted with the parent op.";
124 });
125 }
126 }
127 }
128 return success();
129}
130
131/// Ensure that all block types were updated and then create an LLVM::BrOp
132struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
133 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
134
135 LogicalResult
136 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
137 ConversionPatternRewriter &rewriter) const override {
138 if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(),
139 op.getSuccessor()->getArguments(),
140 op.getLoc(),
141 /*messagePrefix=*/"")))
142 return failure();
143
144 rewriter.replaceOpWithNewOp<LLVM::BrOp>(
145 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
146 return success();
147 }
148};
149
150/// Ensure that all block types were updated and then create an LLVM::CondBrOp
151struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
152 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
153
154 LogicalResult
155 matchAndRewrite(cf::CondBranchOp op,
156 typename cf::CondBranchOp::Adaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(),
159 op.getFalseDest()->getArguments(),
160 op.getLoc(), "in false case branch ")))
161 return failure();
162 if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(),
163 op.getTrueDest()->getArguments(),
164 op.getLoc(), "in true case branch ")))
165 return failure();
166
167 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
168 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
169 return success();
170 }
171};
172
173/// Ensure that all block types were updated and then create an LLVM::SwitchOp
174struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
175 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
176
177 LogicalResult
178 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(),
181 op.getDefaultDestination()->getArguments(),
182 op.getLoc(), "in switch default case ")))
183 return failure();
184
185 for (const auto &i : llvm::enumerate(
186 llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) {
187 if (failed(verifyMatchingValues(
188 rewriter, std::get<0>(i.value()),
189 std::get<1>(i.value())->getArguments(), op.getLoc(),
190 "in switch case " + std::to_string(i.index()) + " "))) {
191 return failure();
192 }
193 }
194
195 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
196 op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs());
197 return success();
198 }
199};
200
201} // namespace
202
203void mlir::cf::populateControlFlowToLLVMConversionPatterns(
204 LLVMTypeConverter &converter, RewritePatternSet &patterns) {
205 // clang-format off
206 patterns.add<
207 AssertOpLowering,
208 BranchOpLowering,
209 CondBranchOpLowering,
210 SwitchOpLowering>(arg&: converter);
211 // clang-format on
212}
213
214void mlir::cf::populateAssertToLLVMConversionPattern(
215 LLVMTypeConverter &converter, RewritePatternSet &patterns,
216 bool abortOnFailure) {
217 patterns.add<AssertOpLowering>(arg&: converter, args&: abortOnFailure);
218}
219
220//===----------------------------------------------------------------------===//
221// Pass Definition
222//===----------------------------------------------------------------------===//
223
224namespace {
225/// A pass converting MLIR operations into the LLVM IR dialect.
226struct ConvertControlFlowToLLVM
227 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
228
229 using Base::Base;
230
231 /// Run the dialect converter on the module.
232 void runOnOperation() override {
233 LLVMConversionTarget target(getContext());
234 RewritePatternSet patterns(&getContext());
235
236 LowerToLLVMOptions options(&getContext());
237 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
238 options.overrideIndexBitwidth(indexBitwidth);
239
240 LLVMTypeConverter converter(&getContext(), options);
241 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
242
243 if (failed(applyPartialConversion(getOperation(), target,
244 std::move(patterns))))
245 signalPassFailure();
246 }
247};
248} // namespace
249
250//===----------------------------------------------------------------------===//
251// ConvertToLLVMPatternInterface implementation
252//===----------------------------------------------------------------------===//
253
254namespace {
255/// Implement the interface to convert MemRef to LLVM.
256struct ControlFlowToLLVMDialectInterface
257 : public ConvertToLLVMPatternInterface {
258 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
259 void loadDependentDialects(MLIRContext *context) const final {
260 context->loadDialect<LLVM::LLVMDialect>();
261 }
262
263 /// Hook for derived dialect interface to provide conversion patterns
264 /// and mark dialect legal for the conversion target.
265 void populateConvertToLLVMConversionPatterns(
266 ConversionTarget &target, LLVMTypeConverter &typeConverter,
267 RewritePatternSet &patterns) const final {
268 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter&: typeConverter,
269 patterns);
270 }
271};
272} // namespace
273
274void mlir::cf::registerConvertControlFlowToLLVMInterface(
275 DialectRegistry &registry) {
276 registry.addExtension(extensionFn: +[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
277 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
278 });
279}
280

source code of mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp