1//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
10#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Rewrite/FrozenRewritePatternSet.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include <memory>
19
20#define DEBUG_TYPE "convert-to-llvm"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30
31/// This DialectExtension can be attached to the context, which will invoke the
32/// `apply()` method for every loaded dialect. If a dialect implements the
33/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
34/// through the interface. This extension is loaded in the context before
35/// starting a pass pipeline that involves dialect conversion to LLVM.
36class LoadDependentDialectExtension : public DialectExtensionBase {
37public:
38 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
39
40 void apply(MLIRContext *context,
41 MutableArrayRef<Dialect *> dialects) const final {
42 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
43 for (Dialect *dialect : dialects) {
44 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
45 if (!iface)
46 continue;
47 LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
48 << dialect->getNamespace() << "\n");
49 iface->loadDependentDialects(context);
50 }
51 }
52
53 /// Return a copy of this extension.
54 std::unique_ptr<DialectExtensionBase> clone() const final {
55 return std::make_unique<LoadDependentDialectExtension>(*this);
56 }
57};
58
59/// This is a generic pass to convert to LLVM, it uses the
60/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
61/// the injection of conversion patterns.
62class ConvertToLLVMPass
63 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
64 std::shared_ptr<const FrozenRewritePatternSet> patterns;
65 std::shared_ptr<const ConversionTarget> target;
66 std::shared_ptr<const LLVMTypeConverter> typeConverter;
67
68public:
69 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
70 void getDependentDialects(DialectRegistry &registry) const final {
71 registry.insert<LLVM::LLVMDialect>();
72 registry.addExtensions<LoadDependentDialectExtension>();
73 }
74
75 LogicalResult initialize(MLIRContext *context) final {
76 RewritePatternSet tempPatterns(context);
77 auto target = std::make_shared<ConversionTarget>(*context);
78 target->addLegalDialect<LLVM::LLVMDialect>();
79 auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
80
81 if (!filterDialects.empty()) {
82 // Test mode: Populate only patterns from the specified dialects. Produce
83 // an error if the dialect is not loaded or does not implement the
84 // interface.
85 for (std::string &dialectName : filterDialects) {
86 Dialect *dialect = context->getLoadedDialect(dialectName);
87 if (!dialect)
88 return emitError(UnknownLoc::get(context))
89 << "dialect not loaded: " << dialectName << "\n";
90 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
91 if (!iface)
92 return emitError(UnknownLoc::get(context))
93 << "dialect does not implement ConvertToLLVMPatternInterface: "
94 << dialectName << "\n";
95 iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
96 tempPatterns);
97 }
98 } else {
99 // Normal mode: Populate all patterns from all dialects that implement the
100 // interface.
101 for (Dialect *dialect : context->getLoadedDialects()) {
102 // First time we encounter this dialect: if it implements the interface,
103 // let's populate patterns !
104 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
105 if (!iface)
106 continue;
107 iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
108 tempPatterns);
109 }
110 }
111
112 this->patterns =
113 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
114 this->target = target;
115 this->typeConverter = typeConverter;
116 return success();
117 }
118
119 void runOnOperation() final {
120 if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
121 signalPassFailure();
122 }
123};
124
125} // namespace
126
127void mlir::registerConvertToLLVMDependentDialectLoading(
128 DialectRegistry &registry) {
129 registry.addExtensions<LoadDependentDialectExtension>();
130}
131
132std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
133 return std::make_unique<ConvertToLLVMPass>();
134}
135

source code of mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp