1 | //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert MLIR standard and builtin dialects |
10 | // into the LLVM IR dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
15 | |
16 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
17 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
18 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
19 | #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" |
20 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
21 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
22 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
23 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
24 | #include "mlir/IR/BuiltinOps.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/Pass/Pass.h" |
27 | #include "mlir/Transforms/DialectConversion.h" |
28 | #include "llvm/ADT/StringRef.h" |
29 | #include <functional> |
30 | |
31 | namespace mlir { |
32 | #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS |
33 | #include "mlir/Conversion/Passes.h.inc" |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | |
38 | #define PASS_NAME "convert-cf-to-llvm" |
39 | |
40 | namespace { |
41 | /// Lower `cf.assert`. The default lowering calls the `abort` function if the |
42 | /// assertion is violated and has no effect otherwise. The failure message is |
43 | /// ignored by the default lowering but should be propagated by any custom |
44 | /// lowering. |
45 | struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { |
46 | explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, |
47 | bool abortOnFailedAssert = true) |
48 | : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), |
49 | abortOnFailedAssert(abortOnFailedAssert) {} |
50 | |
51 | LogicalResult |
52 | matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, |
53 | ConversionPatternRewriter &rewriter) const override { |
54 | auto loc = op.getLoc(); |
55 | auto module = op->getParentOfType<ModuleOp>(); |
56 | |
57 | // Split block at `assert` operation. |
58 | Block *opBlock = rewriter.getInsertionBlock(); |
59 | auto opPosition = rewriter.getInsertionPoint(); |
60 | Block *continuationBlock = rewriter.splitBlock(block: opBlock, before: opPosition); |
61 | |
62 | // Failed block: Generate IR to print the message and call `abort`. |
63 | Block *failureBlock = rewriter.createBlock(parent: opBlock->getParent()); |
64 | auto createResult = LLVM::createPrintStrCall( |
65 | builder&: rewriter, loc: loc, moduleOp: module, symbolName: "assert_msg", string: op.getMsg(), typeConverter: *getTypeConverter(), |
66 | /*addNewLine=*/addNewline: false, |
67 | /*runtimeFunctionName=*/"puts"); |
68 | if (createResult.failed()) |
69 | return failure(); |
70 | |
71 | if (abortOnFailedAssert) { |
72 | // Insert the `abort` declaration if necessary. |
73 | auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); |
74 | if (!abortFunc) { |
75 | OpBuilder::InsertionGuard guard(rewriter); |
76 | rewriter.setInsertionPointToStart(module.getBody()); |
77 | auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); |
78 | abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), |
79 | "abort", abortFuncTy); |
80 | } |
81 | rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); |
82 | rewriter.create<LLVM::UnreachableOp>(loc); |
83 | } else { |
84 | rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); |
85 | } |
86 | |
87 | // Generate assertion test. |
88 | rewriter.setInsertionPointToEnd(opBlock); |
89 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
90 | op, adaptor.getArg(), continuationBlock, failureBlock); |
91 | |
92 | return success(); |
93 | } |
94 | |
95 | private: |
96 | /// If set to `false`, messages are printed but program execution continues. |
97 | /// This is useful for testing asserts. |
98 | bool abortOnFailedAssert = true; |
99 | }; |
100 | |
101 | /// Helper function for converting branch ops. This function converts the |
102 | /// signature of the given block. If the new block signature is different from |
103 | /// `expectedTypes`, returns "failure". |
104 | static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, |
105 | const TypeConverter *converter, |
106 | Operation *branchOp, Block *block, |
107 | TypeRange expectedTypes) { |
108 | assert(converter && "expected non-null type converter"); |
109 | assert(!block->isEntryBlock() && "entry blocks have no predecessors"); |
110 | |
111 | // There is nothing to do if the types already match. |
112 | if (block->getArgumentTypes() == expectedTypes) |
113 | return block; |
114 | |
115 | // Compute the new block argument types and convert the block. |
116 | std::optional<TypeConverter::SignatureConversion> conversion = |
117 | converter->convertBlockSignature(block); |
118 | if (!conversion) |
119 | return rewriter.notifyMatchFailure(branchOp, |
120 | "could not compute block signature"); |
121 | if (expectedTypes != conversion->getConvertedTypes()) |
122 | return rewriter.notifyMatchFailure( |
123 | branchOp, |
124 | "mismatch between adaptor operand types and computed block signature"); |
125 | return rewriter.applySignatureConversion(block, conversion&: *conversion, converter); |
126 | } |
127 | |
128 | /// Convert the destination block signature (if necessary) and lower the branch |
129 | /// op to llvm.br. |
130 | struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { |
131 | using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; |
132 | |
133 | LogicalResult |
134 | matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, |
135 | ConversionPatternRewriter &rewriter) const override { |
136 | FailureOr<Block *> convertedBlock = |
137 | getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), |
138 | TypeRange(adaptor.getOperands())); |
139 | if (failed(convertedBlock)) |
140 | return failure(); |
141 | Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( |
142 | op, adaptor.getOperands(), *convertedBlock); |
143 | // TODO: We should not just forward all attributes like that. But there are |
144 | // existing Flang tests that depend on this behavior. |
145 | newOp->setAttrs(op->getAttrDictionary()); |
146 | return success(); |
147 | } |
148 | }; |
149 | |
150 | /// Convert the destination block signatures (if necessary) and lower the |
151 | /// branch op to llvm.cond_br. |
152 | struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { |
153 | using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; |
154 | |
155 | LogicalResult |
156 | matchAndRewrite(cf::CondBranchOp op, |
157 | typename cf::CondBranchOp::Adaptor adaptor, |
158 | ConversionPatternRewriter &rewriter) const override { |
159 | FailureOr<Block *> convertedTrueBlock = |
160 | getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), |
161 | TypeRange(adaptor.getTrueDestOperands())); |
162 | if (failed(convertedTrueBlock)) |
163 | return failure(); |
164 | FailureOr<Block *> convertedFalseBlock = |
165 | getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), |
166 | TypeRange(adaptor.getFalseDestOperands())); |
167 | if (failed(convertedFalseBlock)) |
168 | return failure(); |
169 | Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
170 | op, adaptor.getCondition(), *convertedTrueBlock, |
171 | adaptor.getTrueDestOperands(), *convertedFalseBlock, |
172 | adaptor.getFalseDestOperands()); |
173 | // TODO: We should not just forward all attributes like that. But there are |
174 | // existing Flang tests that depend on this behavior. |
175 | newOp->setAttrs(op->getAttrDictionary()); |
176 | return success(); |
177 | } |
178 | }; |
179 | |
180 | /// Convert the destination block signatures (if necessary) and lower the |
181 | /// switch op to llvm.switch. |
182 | struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { |
183 | using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; |
184 | |
185 | LogicalResult |
186 | matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, |
187 | ConversionPatternRewriter &rewriter) const override { |
188 | // Get or convert default block. |
189 | FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( |
190 | rewriter, getTypeConverter(), op, op.getDefaultDestination(), |
191 | TypeRange(adaptor.getDefaultOperands())); |
192 | if (failed(convertedDefaultBlock)) |
193 | return failure(); |
194 | |
195 | // Get or convert all case blocks. |
196 | SmallVector<Block *> caseDestinations; |
197 | SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); |
198 | for (auto it : llvm::enumerate(op.getCaseDestinations())) { |
199 | Block *b = it.value(); |
200 | FailureOr<Block *> convertedBlock = |
201 | getConvertedBlock(rewriter, getTypeConverter(), op, b, |
202 | TypeRange(caseOperands[it.index()])); |
203 | if (failed(convertedBlock)) |
204 | return failure(); |
205 | caseDestinations.push_back(*convertedBlock); |
206 | } |
207 | |
208 | rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( |
209 | op, adaptor.getFlag(), *convertedDefaultBlock, |
210 | adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), |
211 | caseDestinations, caseOperands); |
212 | return success(); |
213 | } |
214 | }; |
215 | |
216 | } // namespace |
217 | |
218 | void mlir::cf::populateControlFlowToLLVMConversionPatterns( |
219 | const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
220 | // clang-format off |
221 | patterns.add< |
222 | BranchOpLowering, |
223 | CondBranchOpLowering, |
224 | SwitchOpLowering>(arg: converter); |
225 | // clang-format on |
226 | } |
227 | |
228 | void mlir::cf::populateAssertToLLVMConversionPattern( |
229 | const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
230 | bool abortOnFailure) { |
231 | patterns.add<AssertOpLowering>(arg: converter, args&: abortOnFailure); |
232 | } |
233 | |
234 | //===----------------------------------------------------------------------===// |
235 | // Pass Definition |
236 | //===----------------------------------------------------------------------===// |
237 | |
238 | namespace { |
239 | /// A pass converting MLIR operations into the LLVM IR dialect. |
240 | struct ConvertControlFlowToLLVM |
241 | : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { |
242 | |
243 | using Base::Base; |
244 | |
245 | /// Run the dialect converter on the module. |
246 | void runOnOperation() override { |
247 | MLIRContext *ctx = &getContext(); |
248 | LLVMConversionTarget target(*ctx); |
249 | // This pass lowers only CF dialect ops, but it also modifies block |
250 | // signatures inside other ops. These ops should be treated as legal. They |
251 | // are lowered by other passes. |
252 | target.markUnknownOpDynamicallyLegal([&](Operation *op) { |
253 | return op->getDialect() != |
254 | ctx->getLoadedDialect<cf::ControlFlowDialect>(); |
255 | }); |
256 | |
257 | LowerToLLVMOptions options(ctx); |
258 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
259 | options.overrideIndexBitwidth(indexBitwidth); |
260 | |
261 | LLVMTypeConverter converter(ctx, options); |
262 | RewritePatternSet patterns(ctx); |
263 | mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); |
264 | mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); |
265 | |
266 | if (failed(applyPartialConversion(getOperation(), target, |
267 | std::move(patterns)))) |
268 | signalPassFailure(); |
269 | } |
270 | }; |
271 | } // namespace |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // ConvertToLLVMPatternInterface implementation |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | namespace { |
278 | /// Implement the interface to convert MemRef to LLVM. |
279 | struct ControlFlowToLLVMDialectInterface |
280 | : public ConvertToLLVMPatternInterface { |
281 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
282 | void loadDependentDialects(MLIRContext *context) const final { |
283 | context->loadDialect<LLVM::LLVMDialect>(); |
284 | } |
285 | |
286 | /// Hook for derived dialect interface to provide conversion patterns |
287 | /// and mark dialect legal for the conversion target. |
288 | void populateConvertToLLVMConversionPatterns( |
289 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
290 | RewritePatternSet &patterns) const final { |
291 | mlir::cf::populateControlFlowToLLVMConversionPatterns(converter: typeConverter, |
292 | patterns); |
293 | mlir::cf::populateAssertToLLVMConversionPattern(converter: typeConverter, patterns); |
294 | } |
295 | }; |
296 | } // namespace |
297 | |
298 | void mlir::cf::registerConvertControlFlowToLLVMInterface( |
299 | DialectRegistry ®istry) { |
300 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { |
301 | dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); |
302 | }); |
303 | } |
304 |
Definitions
- AssertOpLowering
- AssertOpLowering
- matchAndRewrite
- getConvertedBlock
- BranchOpLowering
- matchAndRewrite
- CondBranchOpLowering
- matchAndRewrite
- SwitchOpLowering
- matchAndRewrite
- populateControlFlowToLLVMConversionPatterns
- populateAssertToLLVMConversionPattern
- ConvertControlFlowToLLVM
- runOnOperation
- ControlFlowToLLVMDialectInterface
- loadDependentDialects
- populateConvertToLLVMConversionPatterns
Improve your Profiling and Debugging skills
Find out more