1//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataLayoutAnalysis.h"
10#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Rewrite/FrozenRewritePatternSet.h"
16#include "mlir/Transforms/DialectConversion.h"
17#include <memory>
18
19#define DEBUG_TYPE "convert-to-llvm"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTTOLLVMPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29/// Base class for creating the internal implementation of `convert-to-llvm`
30/// passes.
31class ConvertToLLVMPassInterface {
32public:
33 ConvertToLLVMPassInterface(MLIRContext *context,
34 ArrayRef<std::string> filterDialects);
35 virtual ~ConvertToLLVMPassInterface() = default;
36
37 /// Get the dependent dialects used by `convert-to-llvm`.
38 static void getDependentDialects(DialectRegistry &registry);
39
40 /// Initialize the internal state of the `convert-to-llvm` pass
41 /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
42 /// This method returns whether the initialization process failed.
43 virtual LogicalResult initialize() = 0;
44
45 /// Transform `op` to LLVM with the conversions available in the pass. The
46 /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
47 /// to further configure the conversion process. This method is invoked by
48 /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
49 /// transformation process failed.
50 virtual LogicalResult transform(Operation *op,
51 AnalysisManager manager) const = 0;
52
53protected:
54 /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
55 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
56 /// then `visitor` is invoked only with the dialects in the `filterDialects`
57 /// list.
58 LogicalResult visitInterfaces(
59 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
60 MLIRContext *context;
61 /// List of dialects names to use as filters.
62 ArrayRef<std::string> filterDialects;
63};
64
65/// This DialectExtension can be attached to the context, which will invoke the
66/// `apply()` method for every loaded dialect. If a dialect implements the
67/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
68/// through the interface. This extension is loaded in the context before
69/// starting a pass pipeline that involves dialect conversion to LLVM.
70class LoadDependentDialectExtension : public DialectExtensionBase {
71public:
72 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
73
74 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
75
76 void apply(MLIRContext *context,
77 MutableArrayRef<Dialect *> dialects) const final {
78 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
79 for (Dialect *dialect : dialects) {
80 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
81 if (!iface)
82 continue;
83 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
84 << dialect->getNamespace() << "\n");
85 iface->loadDependentDialects(context);
86 }
87 }
88
89 /// Return a copy of this extension.
90 std::unique_ptr<DialectExtensionBase> clone() const final {
91 return std::make_unique<LoadDependentDialectExtension>(args: *this);
92 }
93};
94
95//===----------------------------------------------------------------------===//
96// StaticConvertToLLVM
97//===----------------------------------------------------------------------===//
98
99/// Static implementation of the `convert-to-llvm` pass. This version only looks
100/// at dialect interfaces to configure the conversion process.
101struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
102 /// Pattern set with conversions to LLVM.
103 std::shared_ptr<const FrozenRewritePatternSet> patterns;
104 /// The conversion target.
105 std::shared_ptr<const ConversionTarget> target;
106 /// The LLVM type converter.
107 std::shared_ptr<const LLVMTypeConverter> typeConverter;
108 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
109
110 /// Configure the conversion to LLVM at pass initialization.
111 LogicalResult initialize() final {
112 auto target = std::make_shared<ConversionTarget>(args&: *context);
113 auto typeConverter = std::make_shared<LLVMTypeConverter>(args&: context);
114 RewritePatternSet tempPatterns(context);
115 target->addLegalDialect<LLVM::LLVMDialect>();
116 // Populate the patterns with the dialect interface.
117 if (failed(Result: visitInterfaces(visitor: [&](ConvertToLLVMPatternInterface *iface) {
118 iface->populateConvertToLLVMConversionPatterns(
119 target&: *target, typeConverter&: *typeConverter, patterns&: tempPatterns);
120 })))
121 return failure();
122 this->patterns =
123 std::make_unique<FrozenRewritePatternSet>(args: std::move(tempPatterns));
124 this->target = target;
125 this->typeConverter = typeConverter;
126 return success();
127 }
128
129 /// Apply the conversion driver.
130 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
131 if (failed(Result: applyPartialConversion(op, target: *target, patterns: *patterns)))
132 return failure();
133 return success();
134 }
135};
136
137//===----------------------------------------------------------------------===//
138// DynamicConvertToLLVM
139//===----------------------------------------------------------------------===//
140
141/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
142/// the IR to configure the conversion to LLVM.
143struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
144 /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
145 /// to partially configure the conversion process.
146 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
147 interfaces;
148 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
149
150 /// Collect the dialect interfaces used to configure the conversion process.
151 LogicalResult initialize() final {
152 auto interfaces =
153 std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
154 // Collect the interfaces.
155 if (failed(Result: visitInterfaces(visitor: [&](ConvertToLLVMPatternInterface *iface) {
156 interfaces->push_back(Elt: iface);
157 })))
158 return failure();
159 this->interfaces = interfaces;
160 return success();
161 }
162
163 /// Configure the conversion process and apply the conversion driver.
164 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
165 RewritePatternSet patterns(context);
166 ConversionTarget target(*context);
167 target.addLegalDialect<LLVM::LLVMDialect>();
168 // Get the data layout analysis.
169 const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
170 LLVMTypeConverter typeConverter(context, &dlAnalysis);
171
172 // Configure the conversion with dialect level interfaces.
173 for (ConvertToLLVMPatternInterface *iface : *interfaces)
174 iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
175 patterns);
176
177 // Configure the conversion attribute interfaces.
178 populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
179 patterns);
180
181 // Apply the conversion.
182 if (failed(Result: applyPartialConversion(op, target, patterns: std::move(patterns))))
183 return failure();
184 return success();
185 }
186};
187
188//===----------------------------------------------------------------------===//
189// ConvertToLLVMPass
190//===----------------------------------------------------------------------===//
191
192/// This is a generic pass to convert to LLVM, it uses the
193/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
194/// the injection of conversion patterns.
195class ConvertToLLVMPass
196 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
197 std::shared_ptr<const ConvertToLLVMPassInterface> impl;
198
199public:
200 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
201 void getDependentDialects(DialectRegistry &registry) const final {
202 ConvertToLLVMPassInterface::getDependentDialects(registry);
203 }
204
205 LogicalResult initialize(MLIRContext *context) final {
206 std::shared_ptr<ConvertToLLVMPassInterface> impl;
207 // Choose the pass implementation.
208 if (useDynamic)
209 impl = std::make_shared<DynamicConvertToLLVM>(args&: context, args&: filterDialects);
210 else
211 impl = std::make_shared<StaticConvertToLLVM>(args&: context, args&: filterDialects);
212 if (failed(Result: impl->initialize()))
213 return failure();
214 this->impl = impl;
215 return success();
216 }
217
218 void runOnOperation() final {
219 if (failed(Result: impl->transform(op: getOperation(), manager: getAnalysisManager())))
220 return signalPassFailure();
221 }
222};
223
224} // namespace
225
226//===----------------------------------------------------------------------===//
227// ConvertToLLVMPassInterface
228//===----------------------------------------------------------------------===//
229
230ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
231 MLIRContext *context, ArrayRef<std::string> filterDialects)
232 : context(context), filterDialects(filterDialects) {}
233
234void ConvertToLLVMPassInterface::getDependentDialects(
235 DialectRegistry &registry) {
236 registry.insert<LLVM::LLVMDialect>();
237 registry.addExtensions<LoadDependentDialectExtension>();
238}
239
240LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
241 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
242 if (!filterDialects.empty()) {
243 // Test mode: Populate only patterns from the specified dialects. Produce
244 // an error if the dialect is not loaded or does not implement the
245 // interface.
246 for (StringRef dialectName : filterDialects) {
247 Dialect *dialect = context->getLoadedDialect(name: dialectName);
248 if (!dialect)
249 return emitError(loc: UnknownLoc::get(context))
250 << "dialect not loaded: " << dialectName << "\n";
251 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
252 if (!iface)
253 return emitError(loc: UnknownLoc::get(context))
254 << "dialect does not implement ConvertToLLVMPatternInterface: "
255 << dialectName << "\n";
256 visitor(iface);
257 }
258 } else {
259 // Normal mode: Populate all patterns from all dialects that implement the
260 // interface.
261 for (Dialect *dialect : context->getLoadedDialects()) {
262 // First time we encounter this dialect: if it implements the interface,
263 // let's populate patterns !
264 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
265 if (!iface)
266 continue;
267 visitor(iface);
268 }
269 }
270 return success();
271}
272
273//===----------------------------------------------------------------------===//
274// API
275//===----------------------------------------------------------------------===//
276
277void mlir::registerConvertToLLVMDependentDialectLoading(
278 DialectRegistry &registry) {
279 registry.addExtensions<LoadDependentDialectExtension>();
280}
281

source code of mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp