1//===- ConvertToEmitCPass.cpp - Conversion to EmitC pass --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h"
10
11#include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h"
12#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14#include "mlir/Dialect/EmitC/IR/EmitC.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17#include "llvm/Support/Debug.h"
18
19#include <memory>
20
21#define DEBUG_TYPE "convert-to-emitc"
22
23namespace mlir {
24#define GEN_PASS_DEF_CONVERTTOEMITC
25#include "mlir/Conversion/Passes.h.inc"
26} // namespace mlir
27
28using namespace mlir;
29
30namespace {
31/// Base class for creating the internal implementation of `convert-to-emitc`
32/// passes.
33class ConvertToEmitCPassInterface {
34public:
35 ConvertToEmitCPassInterface(MLIRContext *context,
36 ArrayRef<std::string> filterDialects);
37 virtual ~ConvertToEmitCPassInterface() = default;
38
39 /// Get the dependent dialects used by `convert-to-emitc`.
40 static void getDependentDialects(DialectRegistry &registry);
41
42 /// Initialize the internal state of the `convert-to-emitc` pass
43 /// implementation. This method is invoked by `ConvertToEmitC::initialize`.
44 /// This method returns whether the initialization process failed.
45 virtual LogicalResult initialize() = 0;
46
47 /// Transform `op` to the EmitC dialect with the conversions available in the
48 /// pass. The analysis manager can be used to query analyzes like
49 /// `DataLayoutAnalysis` to further configure the conversion process. This
50 /// method is invoked by `ConvertToEmitC::runOnOperation`. This method returns
51 /// whether the transformation process failed.
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager) const = 0;
54
55protected:
56 /// Visit the `ConvertToEmitCPatternInterface` dialect interfaces and call
57 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58 /// then `visitor` is invoked only with the dialects in the `filterDialects`
59 /// list.
60 LogicalResult visitInterfaces(
61 llvm::function_ref<void(ConvertToEmitCPatternInterface *)> visitor);
62 MLIRContext *context;
63 /// List of dialects names to use as filters.
64 ArrayRef<std::string> filterDialects;
65};
66
67/// This DialectExtension can be attached to the context, which will invoke the
68/// `apply()` method for every loaded dialect. If a dialect implements the
69/// `ConvertToEmitCPatternInterface` interface, we load dependent dialects
70/// through the interface. This extension is loaded in the context before
71/// starting a pass pipeline that involves dialect conversion to the EmitC
72/// dialect.
73class LoadDependentDialectExtension : public DialectExtensionBase {
74public:
75 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
76
77 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
78
79 void apply(MLIRContext *context,
80 MutableArrayRef<Dialect *> dialects) const final {
81 LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC extension load\n");
82 for (Dialect *dialect : dialects) {
83 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
84 if (!iface)
85 continue;
86 LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC found dialect interface for "
87 << dialect->getNamespace() << "\n");
88 }
89 }
90
91 /// Return a copy of this extension.
92 std::unique_ptr<DialectExtensionBase> clone() const final {
93 return std::make_unique<LoadDependentDialectExtension>(*this);
94 }
95};
96
97//===----------------------------------------------------------------------===//
98// StaticConvertToEmitC
99//===----------------------------------------------------------------------===//
100
101/// Static implementation of the `convert-to-emitc` pass. This version only
102/// looks at dialect interfaces to configure the conversion process.
103struct StaticConvertToEmitC : public ConvertToEmitCPassInterface {
104 /// Pattern set with conversions to the EmitC dialect.
105 std::shared_ptr<const FrozenRewritePatternSet> patterns;
106 /// The conversion target.
107 std::shared_ptr<const ConversionTarget> target;
108 /// The type converter.
109 std::shared_ptr<const TypeConverter> typeConverter;
110 using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface;
111
112 /// Configure the conversion to EmitC at pass initialization.
113 LogicalResult initialize() final {
114 auto target = std::make_shared<ConversionTarget>(*context);
115 auto typeConverter = std::make_shared<TypeConverter>();
116
117 // Add fallback identity converison.
118 typeConverter->addConversion([](Type type) -> std::optional<Type> {
119 if (emitc::isSupportedEmitCType(type))
120 return type;
121 return std::nullopt;
122 });
123
124 RewritePatternSet tempPatterns(context);
125 target->addLegalDialect<emitc::EmitCDialect>();
126 // Populate the patterns with the dialect interface.
127 if (failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) {
128 iface->populateConvertToEmitCConversionPatterns(
129 target&: *target, typeConverter&: *typeConverter, patterns&: tempPatterns);
130 })))
131 return failure();
132 this->patterns =
133 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
134 this->target = target;
135 this->typeConverter = typeConverter;
136 return success();
137 }
138
139 /// Apply the conversion driver.
140 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
141 if (failed(applyPartialConversion(op, *target, *patterns)))
142 return failure();
143 return success();
144 }
145};
146
147//===----------------------------------------------------------------------===//
148// ConvertToEmitC
149//===----------------------------------------------------------------------===//
150
151/// This is a generic pass to convert to the EmitC dialect. It uses the
152/// `ConvertToEmitCPatternInterface` dialect interface to delegate the injection
153/// of conversion patterns to dialects.
154class ConvertToEmitC : public impl::ConvertToEmitCBase<ConvertToEmitC> {
155 std::shared_ptr<const ConvertToEmitCPassInterface> impl;
156
157public:
158 using impl::ConvertToEmitCBase<ConvertToEmitC>::ConvertToEmitCBase;
159 void getDependentDialects(DialectRegistry &registry) const final {
160 ConvertToEmitCPassInterface::getDependentDialects(registry);
161 }
162
163 LogicalResult initialize(MLIRContext *context) final {
164 std::shared_ptr<ConvertToEmitCPassInterface> impl;
165 impl = std::make_shared<StaticConvertToEmitC>(context, filterDialects);
166 if (failed(impl->initialize()))
167 return failure();
168 this->impl = impl;
169 return success();
170 }
171
172 void runOnOperation() final {
173 if (failed(impl->transform(getOperation(), getAnalysisManager())))
174 return signalPassFailure();
175 }
176};
177
178} // namespace
179
180//===----------------------------------------------------------------------===//
181// ConvertToEmitCPassInterface
182//===----------------------------------------------------------------------===//
183
184ConvertToEmitCPassInterface::ConvertToEmitCPassInterface(
185 MLIRContext *context, ArrayRef<std::string> filterDialects)
186 : context(context), filterDialects(filterDialects) {}
187
188void ConvertToEmitCPassInterface::getDependentDialects(
189 DialectRegistry &registry) {
190 registry.insert<emitc::EmitCDialect>();
191 registry.addExtensions<LoadDependentDialectExtension>();
192}
193
194LogicalResult ConvertToEmitCPassInterface::visitInterfaces(
195 llvm::function_ref<void(ConvertToEmitCPatternInterface *)> visitor) {
196 if (!filterDialects.empty()) {
197 // Test mode: Populate only patterns from the specified dialects. Produce
198 // an error if the dialect is not loaded or does not implement the
199 // interface.
200 for (StringRef dialectName : filterDialects) {
201 Dialect *dialect = context->getLoadedDialect(name: dialectName);
202 if (!dialect)
203 return emitError(UnknownLoc::get(context))
204 << "dialect not loaded: " << dialectName << "\n";
205 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(Val: dialect);
206 if (!iface)
207 return emitError(UnknownLoc::get(context))
208 << "dialect does not implement ConvertToEmitCPatternInterface: "
209 << dialectName << "\n";
210 visitor(iface);
211 }
212 } else {
213 // Normal mode: Populate all patterns from all dialects that implement the
214 // interface.
215 for (Dialect *dialect : context->getLoadedDialects()) {
216 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(Val: dialect);
217 if (!iface)
218 continue;
219 visitor(iface);
220 }
221 }
222 return success();
223}
224

source code of mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp