1 | //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
10 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" |
11 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
12 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
14 | #include "mlir/IR/PatternMatch.h" |
15 | #include "mlir/Pass/Pass.h" |
16 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
17 | #include "mlir/Transforms/DialectConversion.h" |
18 | #include <memory> |
19 | |
20 | #define DEBUG_TYPE "convert-to-llvm" |
21 | |
22 | namespace mlir { |
23 | #define GEN_PASS_DEF_CONVERTTOLLVMPASS |
24 | #include "mlir/Conversion/Passes.h.inc" |
25 | } // namespace mlir |
26 | |
27 | using namespace mlir; |
28 | |
29 | namespace { |
30 | |
31 | /// This DialectExtension can be attached to the context, which will invoke the |
32 | /// `apply()` method for every loaded dialect. If a dialect implements the |
33 | /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects |
34 | /// through the interface. This extension is loaded in the context before |
35 | /// starting a pass pipeline that involves dialect conversion to LLVM. |
36 | class LoadDependentDialectExtension : public DialectExtensionBase { |
37 | public: |
38 | LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {} |
39 | |
40 | void apply(MLIRContext *context, |
41 | MutableArrayRef<Dialect *> dialects) const final { |
42 | LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n" ); |
43 | for (Dialect *dialect : dialects) { |
44 | auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
45 | if (!iface) |
46 | continue; |
47 | LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for " |
48 | << dialect->getNamespace() << "\n" ); |
49 | iface->loadDependentDialects(context); |
50 | } |
51 | } |
52 | |
53 | /// Return a copy of this extension. |
54 | std::unique_ptr<DialectExtensionBase> clone() const final { |
55 | return std::make_unique<LoadDependentDialectExtension>(*this); |
56 | } |
57 | }; |
58 | |
59 | /// This is a generic pass to convert to LLVM, it uses the |
60 | /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects |
61 | /// the injection of conversion patterns. |
62 | class ConvertToLLVMPass |
63 | : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> { |
64 | std::shared_ptr<const FrozenRewritePatternSet> patterns; |
65 | std::shared_ptr<const ConversionTarget> target; |
66 | std::shared_ptr<const LLVMTypeConverter> typeConverter; |
67 | |
68 | public: |
69 | using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase; |
70 | void getDependentDialects(DialectRegistry ®istry) const final { |
71 | registry.insert<LLVM::LLVMDialect>(); |
72 | registry.addExtensions<LoadDependentDialectExtension>(); |
73 | } |
74 | |
75 | LogicalResult initialize(MLIRContext *context) final { |
76 | RewritePatternSet tempPatterns(context); |
77 | auto target = std::make_shared<ConversionTarget>(*context); |
78 | target->addLegalDialect<LLVM::LLVMDialect>(); |
79 | auto typeConverter = std::make_shared<LLVMTypeConverter>(context); |
80 | |
81 | if (!filterDialects.empty()) { |
82 | // Test mode: Populate only patterns from the specified dialects. Produce |
83 | // an error if the dialect is not loaded or does not implement the |
84 | // interface. |
85 | for (std::string &dialectName : filterDialects) { |
86 | Dialect *dialect = context->getLoadedDialect(dialectName); |
87 | if (!dialect) |
88 | return emitError(UnknownLoc::get(context)) |
89 | << "dialect not loaded: " << dialectName << "\n" ; |
90 | auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
91 | if (!iface) |
92 | return emitError(UnknownLoc::get(context)) |
93 | << "dialect does not implement ConvertToLLVMPatternInterface: " |
94 | << dialectName << "\n" ; |
95 | iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, |
96 | tempPatterns); |
97 | } |
98 | } else { |
99 | // Normal mode: Populate all patterns from all dialects that implement the |
100 | // interface. |
101 | for (Dialect *dialect : context->getLoadedDialects()) { |
102 | // First time we encounter this dialect: if it implements the interface, |
103 | // let's populate patterns ! |
104 | auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); |
105 | if (!iface) |
106 | continue; |
107 | iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, |
108 | tempPatterns); |
109 | } |
110 | } |
111 | |
112 | this->patterns = |
113 | std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns)); |
114 | this->target = target; |
115 | this->typeConverter = typeConverter; |
116 | return success(); |
117 | } |
118 | |
119 | void runOnOperation() final { |
120 | if (failed(applyPartialConversion(getOperation(), *target, *patterns))) |
121 | signalPassFailure(); |
122 | } |
123 | }; |
124 | |
125 | } // namespace |
126 | |
127 | void mlir::registerConvertToLLVMDependentDialectLoading( |
128 | DialectRegistry ®istry) { |
129 | registry.addExtensions<LoadDependentDialectExtension>(); |
130 | } |
131 | |
132 | std::unique_ptr<Pass> mlir::createConvertToLLVMPass() { |
133 | return std::make_unique<ConvertToLLVMPass>(); |
134 | } |
135 | |