1//===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert MLIR standard and builtin dialects
10// into the LLVM IR dialect.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15
16#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
18#include "mlir/Conversion/LLVMCommon/Pattern.h"
19#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
20#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
21#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Transforms/DialectConversion.h"
28#include "llvm/ADT/StringRef.h"
29#include <functional>
30
31namespace mlir {
32#define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS
33#include "mlir/Conversion/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38#define PASS_NAME "convert-cf-to-llvm"
39
40namespace {
41/// Lower `cf.assert`. The default lowering calls the `abort` function if the
42/// assertion is violated and has no effect otherwise. The failure message is
43/// ignored by the default lowering but should be propagated by any custom
44/// lowering.
45struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
46 explicit AssertOpLowering(const LLVMTypeConverter &typeConverter,
47 bool abortOnFailedAssert = true)
48 : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
49 abortOnFailedAssert(abortOnFailedAssert) {}
50
51 LogicalResult
52 matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 auto loc = op.getLoc();
55 auto module = op->getParentOfType<ModuleOp>();
56
57 // Split block at `assert` operation.
58 Block *opBlock = rewriter.getInsertionBlock();
59 auto opPosition = rewriter.getInsertionPoint();
60 Block *continuationBlock = rewriter.splitBlock(block: opBlock, before: opPosition);
61
62 // Failed block: Generate IR to print the message and call `abort`.
63 Block *failureBlock = rewriter.createBlock(parent: opBlock->getParent());
64 auto createResult = LLVM::createPrintStrCall(
65 builder&: rewriter, loc: loc, moduleOp: module, symbolName: "assert_msg", string: op.getMsg(), typeConverter: *getTypeConverter(),
66 /*addNewLine=*/addNewline: false,
67 /*runtimeFunctionName=*/"puts");
68 if (createResult.failed())
69 return failure();
70
71 if (abortOnFailedAssert) {
72 // Insert the `abort` declaration if necessary.
73 auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
74 if (!abortFunc) {
75 OpBuilder::InsertionGuard guard(rewriter);
76 rewriter.setInsertionPointToStart(module.getBody());
77 auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
78 abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
79 "abort", abortFuncTy);
80 }
81 rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
82 rewriter.create<LLVM::UnreachableOp>(loc);
83 } else {
84 rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
85 }
86
87 // Generate assertion test.
88 rewriter.setInsertionPointToEnd(opBlock);
89 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
90 op, adaptor.getArg(), continuationBlock, failureBlock);
91
92 return success();
93 }
94
95private:
96 /// If set to `false`, messages are printed but program execution continues.
97 /// This is useful for testing asserts.
98 bool abortOnFailedAssert = true;
99};
100
101/// Helper function for converting branch ops. This function converts the
102/// signature of the given block. If the new block signature is different from
103/// `expectedTypes`, returns "failure".
104static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
105 const TypeConverter *converter,
106 Operation *branchOp, Block *block,
107 TypeRange expectedTypes) {
108 assert(converter && "expected non-null type converter");
109 assert(!block->isEntryBlock() && "entry blocks have no predecessors");
110
111 // There is nothing to do if the types already match.
112 if (block->getArgumentTypes() == expectedTypes)
113 return block;
114
115 // Compute the new block argument types and convert the block.
116 std::optional<TypeConverter::SignatureConversion> conversion =
117 converter->convertBlockSignature(block);
118 if (!conversion)
119 return rewriter.notifyMatchFailure(branchOp,
120 "could not compute block signature");
121 if (expectedTypes != conversion->getConvertedTypes())
122 return rewriter.notifyMatchFailure(
123 branchOp,
124 "mismatch between adaptor operand types and computed block signature");
125 return rewriter.applySignatureConversion(block, conversion&: *conversion, converter);
126}
127
128/// Convert the destination block signature (if necessary) and lower the branch
129/// op to llvm.br.
130struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
131 using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
132
133 LogicalResult
134 matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 FailureOr<Block *> convertedBlock =
137 getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
138 TypeRange(adaptor.getOperands()));
139 if (failed(convertedBlock))
140 return failure();
141 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
142 op, adaptor.getOperands(), *convertedBlock);
143 // TODO: We should not just forward all attributes like that. But there are
144 // existing Flang tests that depend on this behavior.
145 newOp->setAttrs(op->getAttrDictionary());
146 return success();
147 }
148};
149
150/// Convert the destination block signatures (if necessary) and lower the
151/// branch op to llvm.cond_br.
152struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
153 using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
154
155 LogicalResult
156 matchAndRewrite(cf::CondBranchOp op,
157 typename cf::CondBranchOp::Adaptor adaptor,
158 ConversionPatternRewriter &rewriter) const override {
159 FailureOr<Block *> convertedTrueBlock =
160 getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
161 TypeRange(adaptor.getTrueDestOperands()));
162 if (failed(convertedTrueBlock))
163 return failure();
164 FailureOr<Block *> convertedFalseBlock =
165 getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
166 TypeRange(adaptor.getFalseDestOperands()));
167 if (failed(convertedFalseBlock))
168 return failure();
169 Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
170 op, adaptor.getCondition(), *convertedTrueBlock,
171 adaptor.getTrueDestOperands(), *convertedFalseBlock,
172 adaptor.getFalseDestOperands());
173 // TODO: We should not just forward all attributes like that. But there are
174 // existing Flang tests that depend on this behavior.
175 newOp->setAttrs(op->getAttrDictionary());
176 return success();
177 }
178};
179
180/// Convert the destination block signatures (if necessary) and lower the
181/// switch op to llvm.switch.
182struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
183 using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
184
185 LogicalResult
186 matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
187 ConversionPatternRewriter &rewriter) const override {
188 // Get or convert default block.
189 FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
190 rewriter, getTypeConverter(), op, op.getDefaultDestination(),
191 TypeRange(adaptor.getDefaultOperands()));
192 if (failed(convertedDefaultBlock))
193 return failure();
194
195 // Get or convert all case blocks.
196 SmallVector<Block *> caseDestinations;
197 SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands();
198 for (auto it : llvm::enumerate(op.getCaseDestinations())) {
199 Block *b = it.value();
200 FailureOr<Block *> convertedBlock =
201 getConvertedBlock(rewriter, getTypeConverter(), op, b,
202 TypeRange(caseOperands[it.index()]));
203 if (failed(convertedBlock))
204 return failure();
205 caseDestinations.push_back(*convertedBlock);
206 }
207
208 rewriter.replaceOpWithNewOp<LLVM::SwitchOp>(
209 op, adaptor.getFlag(), *convertedDefaultBlock,
210 adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(),
211 caseDestinations, caseOperands);
212 return success();
213 }
214};
215
216} // namespace
217
218void mlir::cf::populateControlFlowToLLVMConversionPatterns(
219 const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
220 // clang-format off
221 patterns.add<
222 BranchOpLowering,
223 CondBranchOpLowering,
224 SwitchOpLowering>(arg: converter);
225 // clang-format on
226}
227
228void mlir::cf::populateAssertToLLVMConversionPattern(
229 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
230 bool abortOnFailure) {
231 patterns.add<AssertOpLowering>(arg: converter, args&: abortOnFailure);
232}
233
234//===----------------------------------------------------------------------===//
235// Pass Definition
236//===----------------------------------------------------------------------===//
237
238namespace {
239/// A pass converting MLIR operations into the LLVM IR dialect.
240struct ConvertControlFlowToLLVM
241 : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> {
242
243 using Base::Base;
244
245 /// Run the dialect converter on the module.
246 void runOnOperation() override {
247 MLIRContext *ctx = &getContext();
248 LLVMConversionTarget target(*ctx);
249 // This pass lowers only CF dialect ops, but it also modifies block
250 // signatures inside other ops. These ops should be treated as legal. They
251 // are lowered by other passes.
252 target.markUnknownOpDynamicallyLegal([&](Operation *op) {
253 return op->getDialect() !=
254 ctx->getLoadedDialect<cf::ControlFlowDialect>();
255 });
256
257 LowerToLLVMOptions options(ctx);
258 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
259 options.overrideIndexBitwidth(indexBitwidth);
260
261 LLVMTypeConverter converter(ctx, options);
262 RewritePatternSet patterns(ctx);
263 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
264 mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns);
265
266 if (failed(applyPartialConversion(getOperation(), target,
267 std::move(patterns))))
268 signalPassFailure();
269 }
270};
271} // namespace
272
273//===----------------------------------------------------------------------===//
274// ConvertToLLVMPatternInterface implementation
275//===----------------------------------------------------------------------===//
276
277namespace {
278/// Implement the interface to convert MemRef to LLVM.
279struct ControlFlowToLLVMDialectInterface
280 : public ConvertToLLVMPatternInterface {
281 using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
282 void loadDependentDialects(MLIRContext *context) const final {
283 context->loadDialect<LLVM::LLVMDialect>();
284 }
285
286 /// Hook for derived dialect interface to provide conversion patterns
287 /// and mark dialect legal for the conversion target.
288 void populateConvertToLLVMConversionPatterns(
289 ConversionTarget &target, LLVMTypeConverter &typeConverter,
290 RewritePatternSet &patterns) const final {
291 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter: typeConverter,
292 patterns);
293 mlir::cf::populateAssertToLLVMConversionPattern(converter: typeConverter, patterns);
294 }
295};
296} // namespace
297
298void mlir::cf::registerConvertControlFlowToLLVMInterface(
299 DialectRegistry &registry) {
300 registry.addExtension(extensionFn: +[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) {
301 dialect->addInterfaces<ControlFlowToLLVMDialectInterface>();
302 });
303}
304

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp