1//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Analysis/DataLayoutAnalysis.h"
10#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
11#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
12#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Rewrite/FrozenRewritePatternSet.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include <memory>
20
21#define DEBUG_TYPE "convert-to-llvm"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTTOLLVMPASS
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31/// Base class for creating the internal implementation of `convert-to-llvm`
32/// passes.
33class ConvertToLLVMPassInterface {
34public:
35 ConvertToLLVMPassInterface(MLIRContext *context,
36 ArrayRef<std::string> filterDialects);
37 virtual ~ConvertToLLVMPassInterface() = default;
38
39 /// Get the dependent dialects used by `convert-to-llvm`.
40 static void getDependentDialects(DialectRegistry &registry);
41
42 /// Initialize the internal state of the `convert-to-llvm` pass
43 /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
44 /// This method returns whether the initialization process failed.
45 virtual LogicalResult initialize() = 0;
46
47 /// Transform `op` to LLVM with the conversions available in the pass. The
48 /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
49 /// to further configure the conversion process. This method is invoked by
50 /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
51 /// transformation process failed.
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager) const = 0;
54
55protected:
56 /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
57 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58 /// then `visitor` is invoked only with the dialects in the `filterDialects`
59 /// list.
60 LogicalResult visitInterfaces(
61 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
62 MLIRContext *context;
63 /// List of dialects names to use as filters.
64 ArrayRef<std::string> filterDialects;
65};
66
67/// This DialectExtension can be attached to the context, which will invoke the
68/// `apply()` method for every loaded dialect. If a dialect implements the
69/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
70/// through the interface. This extension is loaded in the context before
71/// starting a pass pipeline that involves dialect conversion to LLVM.
72class LoadDependentDialectExtension : public DialectExtensionBase {
73public:
74 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
75
76 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
77
78 void apply(MLIRContext *context,
79 MutableArrayRef<Dialect *> dialects) const final {
80 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
81 for (Dialect *dialect : dialects) {
82 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
83 if (!iface)
84 continue;
85 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
86 << dialect->getNamespace() << "\n");
87 iface->loadDependentDialects(context);
88 }
89 }
90
91 /// Return a copy of this extension.
92 std::unique_ptr<DialectExtensionBase> clone() const final {
93 return std::make_unique<LoadDependentDialectExtension>(*this);
94 }
95};
96
97//===----------------------------------------------------------------------===//
98// StaticConvertToLLVM
99//===----------------------------------------------------------------------===//
100
101/// Static implementation of the `convert-to-llvm` pass. This version only looks
102/// at dialect interfaces to configure the conversion process.
103struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
104 /// Pattern set with conversions to LLVM.
105 std::shared_ptr<const FrozenRewritePatternSet> patterns;
106 /// The conversion target.
107 std::shared_ptr<const ConversionTarget> target;
108 /// The LLVM type converter.
109 std::shared_ptr<const LLVMTypeConverter> typeConverter;
110 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
111
112 /// Configure the conversion to LLVM at pass initialization.
113 LogicalResult initialize() final {
114 auto target = std::make_shared<ConversionTarget>(*context);
115 auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
116 RewritePatternSet tempPatterns(context);
117 target->addLegalDialect<LLVM::LLVMDialect>();
118 // Populate the patterns with the dialect interface.
119 if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
120 iface->populateConvertToLLVMConversionPatterns(
121 target&: *target, typeConverter&: *typeConverter, patterns&: tempPatterns);
122 })))
123 return failure();
124 this->patterns =
125 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
126 this->target = target;
127 this->typeConverter = typeConverter;
128 return success();
129 }
130
131 /// Apply the conversion driver.
132 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
133 if (failed(applyPartialConversion(op, *target, *patterns)))
134 return failure();
135 return success();
136 }
137};
138
139//===----------------------------------------------------------------------===//
140// DynamicConvertToLLVM
141//===----------------------------------------------------------------------===//
142
143/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
144/// the IR to configure the conversion to LLVM.
145struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
146 /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
147 /// to partially configure the conversion process.
148 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
149 interfaces;
150 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
151
152 /// Collect the dialect interfaces used to configure the conversion process.
153 LogicalResult initialize() final {
154 auto interfaces =
155 std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
156 // Collect the interfaces.
157 if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
158 interfaces->push_back(iface);
159 })))
160 return failure();
161 this->interfaces = interfaces;
162 return success();
163 }
164
165 /// Configure the conversion process and apply the conversion driver.
166 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
167 RewritePatternSet patterns(context);
168 ConversionTarget target(*context);
169 target.addLegalDialect<LLVM::LLVMDialect>();
170 // Get the data layout analysis.
171 const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
172 LLVMTypeConverter typeConverter(context, &dlAnalysis);
173
174 // Configure the conversion with dialect level interfaces.
175 for (ConvertToLLVMPatternInterface *iface : *interfaces)
176 iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
177 patterns);
178
179 // Configure the conversion attribute interfaces.
180 populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
181 patterns);
182
183 // Apply the conversion.
184 if (failed(applyPartialConversion(op, target, std::move(patterns))))
185 return failure();
186 return success();
187 }
188};
189
190//===----------------------------------------------------------------------===//
191// ConvertToLLVMPass
192//===----------------------------------------------------------------------===//
193
194/// This is a generic pass to convert to LLVM, it uses the
195/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
196/// the injection of conversion patterns.
197class ConvertToLLVMPass
198 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
199 std::shared_ptr<const ConvertToLLVMPassInterface> impl;
200
201public:
202 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
203 void getDependentDialects(DialectRegistry &registry) const final {
204 ConvertToLLVMPassInterface::getDependentDialects(registry);
205 }
206
207 LogicalResult initialize(MLIRContext *context) final {
208 std::shared_ptr<ConvertToLLVMPassInterface> impl;
209 // Choose the pass implementation.
210 if (useDynamic)
211 impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
212 else
213 impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
214 if (failed(impl->initialize()))
215 return failure();
216 this->impl = impl;
217 return success();
218 }
219
220 void runOnOperation() final {
221 if (failed(impl->transform(getOperation(), getAnalysisManager())))
222 return signalPassFailure();
223 }
224};
225
226} // namespace
227
228//===----------------------------------------------------------------------===//
229// ConvertToLLVMPassInterface
230//===----------------------------------------------------------------------===//
231
232ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
233 MLIRContext *context, ArrayRef<std::string> filterDialects)
234 : context(context), filterDialects(filterDialects) {}
235
236void ConvertToLLVMPassInterface::getDependentDialects(
237 DialectRegistry &registry) {
238 registry.insert<LLVM::LLVMDialect>();
239 registry.addExtensions<LoadDependentDialectExtension>();
240}
241
242LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
243 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
244 if (!filterDialects.empty()) {
245 // Test mode: Populate only patterns from the specified dialects. Produce
246 // an error if the dialect is not loaded or does not implement the
247 // interface.
248 for (StringRef dialectName : filterDialects) {
249 Dialect *dialect = context->getLoadedDialect(name: dialectName);
250 if (!dialect)
251 return emitError(UnknownLoc::get(context))
252 << "dialect not loaded: " << dialectName << "\n";
253 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
254 if (!iface)
255 return emitError(UnknownLoc::get(context))
256 << "dialect does not implement ConvertToLLVMPatternInterface: "
257 << dialectName << "\n";
258 visitor(iface);
259 }
260 } else {
261 // Normal mode: Populate all patterns from all dialects that implement the
262 // interface.
263 for (Dialect *dialect : context->getLoadedDialects()) {
264 // First time we encounter this dialect: if it implements the interface,
265 // let's populate patterns !
266 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(Val: dialect);
267 if (!iface)
268 continue;
269 visitor(iface);
270 }
271 }
272 return success();
273}
274
275//===----------------------------------------------------------------------===//
276// API
277//===----------------------------------------------------------------------===//
278
279void mlir::registerConvertToLLVMDependentDialectLoading(
280 DialectRegistry &registry) {
281 registry.addExtensions<LoadDependentDialectExtension>();
282}
283

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp