1 | //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert MLIR standard and builtin dialects |
10 | // into the LLVM IR dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
15 | |
16 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
17 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
18 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
19 | #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" |
20 | #include "mlir/Conversion/LLVMCommon/VectorPattern.h" |
21 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
22 | #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
23 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
24 | #include "mlir/IR/BuiltinOps.h" |
25 | #include "mlir/IR/PatternMatch.h" |
26 | #include "mlir/Pass/Pass.h" |
27 | #include "mlir/Transforms/DialectConversion.h" |
28 | #include "llvm/ADT/StringRef.h" |
29 | #include <functional> |
30 | |
31 | namespace mlir { |
32 | #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS |
33 | #include "mlir/Conversion/Passes.h.inc" |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | |
38 | #define PASS_NAME "convert-cf-to-llvm" |
39 | |
40 | namespace { |
41 | /// Lower `cf.assert`. The default lowering calls the `abort` function if the |
42 | /// assertion is violated and has no effect otherwise. The failure message is |
43 | /// ignored by the default lowering but should be propagated by any custom |
44 | /// lowering. |
45 | struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { |
46 | explicit AssertOpLowering(LLVMTypeConverter &typeConverter, |
47 | bool abortOnFailedAssert = true) |
48 | : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), |
49 | abortOnFailedAssert(abortOnFailedAssert) {} |
50 | |
51 | LogicalResult |
52 | matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, |
53 | ConversionPatternRewriter &rewriter) const override { |
54 | auto loc = op.getLoc(); |
55 | auto module = op->getParentOfType<ModuleOp>(); |
56 | |
57 | // Split block at `assert` operation. |
58 | Block *opBlock = rewriter.getInsertionBlock(); |
59 | auto opPosition = rewriter.getInsertionPoint(); |
60 | Block *continuationBlock = rewriter.splitBlock(block: opBlock, before: opPosition); |
61 | |
62 | // Failed block: Generate IR to print the message and call `abort`. |
63 | Block *failureBlock = rewriter.createBlock(parent: opBlock->getParent()); |
64 | LLVM::createPrintStrCall(builder&: rewriter, loc: loc, moduleOp: module, symbolName: "assert_msg" , string: op.getMsg(), |
65 | typeConverter: *getTypeConverter(), /*addNewLine=*/addNewline: false, |
66 | /*runtimeFunctionName=*/"puts" ); |
67 | if (abortOnFailedAssert) { |
68 | // Insert the `abort` declaration if necessary. |
69 | auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort" ); |
70 | if (!abortFunc) { |
71 | OpBuilder::InsertionGuard guard(rewriter); |
72 | rewriter.setInsertionPointToStart(module.getBody()); |
73 | auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); |
74 | abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), |
75 | "abort" , abortFuncTy); |
76 | } |
77 | rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt); |
78 | rewriter.create<LLVM::UnreachableOp>(loc); |
79 | } else { |
80 | rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock); |
81 | } |
82 | |
83 | // Generate assertion test. |
84 | rewriter.setInsertionPointToEnd(opBlock); |
85 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
86 | op, adaptor.getArg(), continuationBlock, failureBlock); |
87 | |
88 | return success(); |
89 | } |
90 | |
91 | private: |
92 | /// If set to `false`, messages are printed but program execution continues. |
93 | /// This is useful for testing asserts. |
94 | bool abortOnFailedAssert = true; |
95 | }; |
96 | |
97 | /// The cf->LLVM lowerings for branching ops require that the blocks they jump |
98 | /// to first have updated types which should be handled by a pattern operating |
99 | /// on the parent op. |
100 | static LogicalResult verifyMatchingValues(ConversionPatternRewriter &rewriter, |
101 | ValueRange operands, |
102 | ValueRange blockArgs, Location loc, |
103 | llvm::StringRef messagePrefix) { |
104 | for (const auto &idxAndTypes : |
105 | llvm::enumerate(llvm::zip(blockArgs, operands))) { |
106 | int64_t i = idxAndTypes.index(); |
107 | Value argValue = |
108 | rewriter.getRemappedValue(std::get<0>(idxAndTypes.value())); |
109 | Type operandType = std::get<1>(idxAndTypes.value()).getType(); |
110 | // In the case of an invalid jump, the block argument will have been |
111 | // remapped to an UnrealizedConversionCast. In the case of a valid jump, |
112 | // there might still be a no-op conversion cast with both types being equal. |
113 | // Consider both of these details to see if the jump would be invalid. |
114 | if (auto op = dyn_cast_or_null<UnrealizedConversionCastOp>( |
115 | argValue.getDefiningOp())) { |
116 | if (op.getOperandTypes().front() != operandType) { |
117 | return rewriter.notifyMatchFailure(loc, [&](Diagnostic &diag) { |
118 | diag << messagePrefix; |
119 | diag << "mismatched types from operand # " << i << " " ; |
120 | diag << operandType; |
121 | diag << " not compatible with destination block argument type " ; |
122 | diag << op.getOperandTypes().front(); |
123 | diag << " which should be converted with the parent op." ; |
124 | }); |
125 | } |
126 | } |
127 | } |
128 | return success(); |
129 | } |
130 | |
131 | /// Ensure that all block types were updated and then create an LLVM::BrOp |
132 | struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { |
133 | using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; |
134 | |
135 | LogicalResult |
136 | matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, |
137 | ConversionPatternRewriter &rewriter) const override { |
138 | if (failed(verifyMatchingValues(rewriter, adaptor.getDestOperands(), |
139 | op.getSuccessor()->getArguments(), |
140 | op.getLoc(), |
141 | /*messagePrefix=*/"" ))) |
142 | return failure(); |
143 | |
144 | rewriter.replaceOpWithNewOp<LLVM::BrOp>( |
145 | op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); |
146 | return success(); |
147 | } |
148 | }; |
149 | |
150 | /// Ensure that all block types were updated and then create an LLVM::CondBrOp |
151 | struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { |
152 | using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; |
153 | |
154 | LogicalResult |
155 | matchAndRewrite(cf::CondBranchOp op, |
156 | typename cf::CondBranchOp::Adaptor adaptor, |
157 | ConversionPatternRewriter &rewriter) const override { |
158 | if (failed(verifyMatchingValues(rewriter, adaptor.getFalseDestOperands(), |
159 | op.getFalseDest()->getArguments(), |
160 | op.getLoc(), "in false case branch " ))) |
161 | return failure(); |
162 | if (failed(verifyMatchingValues(rewriter, adaptor.getTrueDestOperands(), |
163 | op.getTrueDest()->getArguments(), |
164 | op.getLoc(), "in true case branch " ))) |
165 | return failure(); |
166 | |
167 | rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
168 | op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); |
169 | return success(); |
170 | } |
171 | }; |
172 | |
173 | /// Ensure that all block types were updated and then create an LLVM::SwitchOp |
174 | struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { |
175 | using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; |
176 | |
177 | LogicalResult |
178 | matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, |
179 | ConversionPatternRewriter &rewriter) const override { |
180 | if (failed(verifyMatchingValues(rewriter, adaptor.getDefaultOperands(), |
181 | op.getDefaultDestination()->getArguments(), |
182 | op.getLoc(), "in switch default case " ))) |
183 | return failure(); |
184 | |
185 | for (const auto &i : llvm::enumerate( |
186 | llvm::zip(adaptor.getCaseOperands(), op.getCaseDestinations()))) { |
187 | if (failed(verifyMatchingValues( |
188 | rewriter, std::get<0>(i.value()), |
189 | std::get<1>(i.value())->getArguments(), op.getLoc(), |
190 | "in switch case " + std::to_string(i.index()) + " " ))) { |
191 | return failure(); |
192 | } |
193 | } |
194 | |
195 | rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( |
196 | op, adaptor.getOperands(), op->getSuccessors(), op->getAttrs()); |
197 | return success(); |
198 | } |
199 | }; |
200 | |
201 | } // namespace |
202 | |
203 | void mlir::cf::populateControlFlowToLLVMConversionPatterns( |
204 | LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
205 | // clang-format off |
206 | patterns.add< |
207 | AssertOpLowering, |
208 | BranchOpLowering, |
209 | CondBranchOpLowering, |
210 | SwitchOpLowering>(arg&: converter); |
211 | // clang-format on |
212 | } |
213 | |
214 | void mlir::cf::populateAssertToLLVMConversionPattern( |
215 | LLVMTypeConverter &converter, RewritePatternSet &patterns, |
216 | bool abortOnFailure) { |
217 | patterns.add<AssertOpLowering>(arg&: converter, args&: abortOnFailure); |
218 | } |
219 | |
220 | //===----------------------------------------------------------------------===// |
221 | // Pass Definition |
222 | //===----------------------------------------------------------------------===// |
223 | |
224 | namespace { |
225 | /// A pass converting MLIR operations into the LLVM IR dialect. |
226 | struct ConvertControlFlowToLLVM |
227 | : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { |
228 | |
229 | using Base::Base; |
230 | |
231 | /// Run the dialect converter on the module. |
232 | void runOnOperation() override { |
233 | LLVMConversionTarget target(getContext()); |
234 | RewritePatternSet patterns(&getContext()); |
235 | |
236 | LowerToLLVMOptions options(&getContext()); |
237 | if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
238 | options.overrideIndexBitwidth(indexBitwidth); |
239 | |
240 | LLVMTypeConverter converter(&getContext(), options); |
241 | mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); |
242 | |
243 | if (failed(applyPartialConversion(getOperation(), target, |
244 | std::move(patterns)))) |
245 | signalPassFailure(); |
246 | } |
247 | }; |
248 | } // namespace |
249 | |
250 | //===----------------------------------------------------------------------===// |
251 | // ConvertToLLVMPatternInterface implementation |
252 | //===----------------------------------------------------------------------===// |
253 | |
254 | namespace { |
255 | /// Implement the interface to convert MemRef to LLVM. |
256 | struct ControlFlowToLLVMDialectInterface |
257 | : public ConvertToLLVMPatternInterface { |
258 | using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
259 | void loadDependentDialects(MLIRContext *context) const final { |
260 | context->loadDialect<LLVM::LLVMDialect>(); |
261 | } |
262 | |
263 | /// Hook for derived dialect interface to provide conversion patterns |
264 | /// and mark dialect legal for the conversion target. |
265 | void populateConvertToLLVMConversionPatterns( |
266 | ConversionTarget &target, LLVMTypeConverter &typeConverter, |
267 | RewritePatternSet &patterns) const final { |
268 | mlir::cf::populateControlFlowToLLVMConversionPatterns(converter&: typeConverter, |
269 | patterns); |
270 | } |
271 | }; |
272 | } // namespace |
273 | |
274 | void mlir::cf::registerConvertControlFlowToLLVMInterface( |
275 | DialectRegistry ®istry) { |
276 | registry.addExtension(extensionFn: +[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { |
277 | dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); |
278 | }); |
279 | } |
280 | |