1 | //===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines functionality for registring and extending dialects. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_DIALECTREGISTRY_H |
14 | #define MLIR_IR_DIALECTREGISTRY_H |
15 | |
16 | #include "mlir/IR/MLIRContext.h" |
17 | #include "llvm/ADT/ArrayRef.h" |
18 | #include "llvm/ADT/SmallVector.h" |
19 | #include "llvm/ADT/StringRef.h" |
20 | |
21 | #include <map> |
22 | #include <tuple> |
23 | |
24 | namespace mlir { |
25 | class Dialect; |
26 | |
27 | using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>; |
28 | using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>; |
29 | using DynamicDialectPopulationFunction = |
30 | std::function<void(MLIRContext *, DynamicDialect *)>; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // DialectExtension |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | /// This class represents an opaque dialect extension. It contains a set of |
37 | /// required dialects and an application function. The required dialects control |
38 | /// when the extension is applied, i.e. the extension is applied when all |
39 | /// required dialects are loaded. The application function can be used to attach |
40 | /// additional functionality to attributes, dialects, operations, types, etc., |
41 | /// and may also load additional necessary dialects. |
42 | class DialectExtensionBase { |
43 | public: |
44 | virtual ~DialectExtensionBase(); |
45 | |
46 | /// Return the dialects that our required by this extension to be loaded |
47 | /// before applying. If empty then the extension is invoked for every loaded |
48 | /// dialect indepently. |
49 | ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; } |
50 | |
51 | /// Apply this extension to the given context and the required dialects. |
52 | virtual void apply(MLIRContext *context, |
53 | MutableArrayRef<Dialect *> dialects) const = 0; |
54 | |
55 | /// Return a copy of this extension. |
56 | virtual std::unique_ptr<DialectExtensionBase> clone() const = 0; |
57 | |
58 | protected: |
59 | /// Initialize the extension with a set of required dialects. |
60 | /// If the list is empty, the extension is invoked for every loaded dialect |
61 | /// independently. |
62 | DialectExtensionBase(ArrayRef<StringRef> dialectNames) |
63 | : dialectNames(dialectNames.begin(), dialectNames.end()) {} |
64 | |
65 | private: |
66 | /// The names of the dialects affected by this extension. |
67 | SmallVector<StringRef> dialectNames; |
68 | }; |
69 | |
70 | /// This class represents a dialect extension anchored on the given set of |
71 | /// dialects. When all of the specified dialects have been loaded, the |
72 | /// application function of this extension will be executed. |
73 | template <typename DerivedT, typename... DialectsT> |
74 | class DialectExtension : public DialectExtensionBase { |
75 | public: |
76 | /// Applies this extension to the given context and set of required dialects. |
77 | virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0; |
78 | |
79 | /// Return a copy of this extension. |
80 | std::unique_ptr<DialectExtensionBase> clone() const final { |
81 | return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this)); |
82 | } |
83 | |
84 | protected: |
85 | DialectExtension() |
86 | : DialectExtensionBase( |
87 | ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {} |
88 | |
89 | /// Override the base apply method to allow providing the exact dialect types. |
90 | void apply(MLIRContext *context, |
91 | MutableArrayRef<Dialect *> dialects) const final { |
92 | unsigned dialectIdx = 0; |
93 | auto derivedDialects = std::tuple<DialectsT *...>{ |
94 | static_cast<DialectsT *>(dialects[dialectIdx++])...}; |
95 | std::apply([&](DialectsT *...dialect) { apply(context, dialect...); }, |
96 | derivedDialects); |
97 | } |
98 | }; |
99 | |
100 | namespace dialect_extension_detail { |
101 | |
102 | /// Checks if the given interface, which is attempting to be used, is a |
103 | /// promised interface of this dialect that has yet to be implemented. If so, |
104 | /// emits a fatal error. |
105 | void handleUseOfUndefinedPromisedInterface(Dialect &dialect, |
106 | TypeID interfaceRequestorID, |
107 | TypeID interfaceID, |
108 | StringRef interfaceName); |
109 | |
110 | /// Checks if the given interface, which is attempting to be attached, is a |
111 | /// promised interface of this dialect that has yet to be implemented. If so, |
112 | /// the promised interface is marked as resolved. |
113 | void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect, |
114 | TypeID interfaceRequestorID, |
115 | TypeID interfaceID); |
116 | |
117 | /// Checks if a promise has been made for the interface/requestor pair. |
118 | bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID, |
119 | TypeID interfaceID); |
120 | |
121 | /// Checks if a promise has been made for the interface/requestor pair. |
122 | template <typename ConcreteT, typename InterfaceT> |
123 | bool hasPromisedInterface(Dialect &dialect) { |
124 | return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(), |
125 | InterfaceT::getInterfaceID()); |
126 | } |
127 | |
128 | } // namespace dialect_extension_detail |
129 | |
130 | //===----------------------------------------------------------------------===// |
131 | // DialectRegistry |
132 | //===----------------------------------------------------------------------===// |
133 | |
134 | /// The DialectRegistry maps a dialect namespace to a constructor for the |
135 | /// matching dialect. This allows for decoupling the list of dialects |
136 | /// "available" from the dialects loaded in the Context. The parser in |
137 | /// particular will lazily load dialects in the Context as operations are |
138 | /// encountered. |
139 | class DialectRegistry { |
140 | using MapTy = |
141 | std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>; |
142 | |
143 | public: |
144 | explicit DialectRegistry(); |
145 | |
146 | template <typename ConcreteDialect> |
147 | void insert() { |
148 | insert(TypeID::get<ConcreteDialect>(), |
149 | ConcreteDialect::getDialectNamespace(), |
150 | static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) { |
151 | // Just allocate the dialect, the context |
152 | // takes ownership of it. |
153 | return ctx->getOrLoadDialect<ConcreteDialect>(); |
154 | }))); |
155 | } |
156 | |
157 | template <typename ConcreteDialect, typename OtherDialect, |
158 | typename... MoreDialects> |
159 | void insert() { |
160 | insert<ConcreteDialect>(); |
161 | insert<OtherDialect, MoreDialects...>(); |
162 | } |
163 | |
164 | /// Add a new dialect constructor to the registry. The constructor must be |
165 | /// calling MLIRContext::getOrLoadDialect in order for the context to take |
166 | /// ownership of the dialect and for delayed interface registration to happen. |
167 | void insert(TypeID typeID, StringRef name, |
168 | const DialectAllocatorFunction &ctor); |
169 | |
170 | /// Add a new dynamic dialect constructor in the registry. The constructor |
171 | /// provides as argument the created dynamic dialect, and is expected to |
172 | /// register the dialect types, attributes, and ops, using the |
173 | /// methods defined in ExtensibleDialect such as registerDynamicOperation. |
174 | void insertDynamic(StringRef name, |
175 | const DynamicDialectPopulationFunction &ctor); |
176 | |
177 | /// Return an allocation function for constructing the dialect identified |
178 | /// by its namespace, or nullptr if the namespace is not in this registry. |
179 | DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const; |
180 | |
181 | // Register all dialects available in the current registry with the registry |
182 | // in the provided context. |
183 | void appendTo(DialectRegistry &destination) const { |
184 | for (const auto &nameAndRegistrationIt : registry) |
185 | destination.insert(typeID: nameAndRegistrationIt.second.first, |
186 | name: nameAndRegistrationIt.first, |
187 | ctor: nameAndRegistrationIt.second.second); |
188 | // Merge the extensions. |
189 | for (const auto &extension : extensions) |
190 | destination.extensions.push_back(x: extension->clone()); |
191 | } |
192 | |
193 | /// Return the names of dialects known to this registry. |
194 | auto getDialectNames() const { |
195 | return llvm::map_range( |
196 | C: registry, |
197 | F: [](const MapTy::value_type &item) -> StringRef { return item.first; }); |
198 | } |
199 | |
200 | /// Apply any held extensions that require the given dialect. Users are not |
201 | /// expected to call this directly. |
202 | void applyExtensions(Dialect *dialect) const; |
203 | |
204 | /// Apply any applicable extensions to the given context. Users are not |
205 | /// expected to call this directly. |
206 | void applyExtensions(MLIRContext *ctx) const; |
207 | |
208 | /// Add the given extension to the registry. |
209 | void addExtension(std::unique_ptr<DialectExtensionBase> extension) { |
210 | extensions.push_back(x: std::move(extension)); |
211 | } |
212 | |
213 | /// Add the given extensions to the registry. |
214 | template <typename... ExtensionsT> |
215 | void addExtensions() { |
216 | (addExtension(std::make_unique<ExtensionsT>()), ...); |
217 | } |
218 | |
219 | /// Add an extension function that requires the given dialects. |
220 | /// Note: This bare functor overload is provided in addition to the |
221 | /// std::function variant to enable dialect type deduction, e.g.: |
222 | /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... }) |
223 | /// |
224 | /// is equivalent to: |
225 | /// registry.addExtension<MyDialect>( |
226 | /// [](MLIRContext *ctx, MyDialect *dialect){ ... } |
227 | /// ) |
228 | template <typename... DialectsT> |
229 | void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) { |
230 | addExtension<DialectsT...>( |
231 | std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn)); |
232 | } |
233 | template <typename... DialectsT> |
234 | void |
235 | addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) { |
236 | using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>; |
237 | |
238 | struct Extension : public DialectExtension<Extension, DialectsT...> { |
239 | Extension(const Extension &) = default; |
240 | Extension(ExtensionFnT extensionFn) |
241 | : extensionFn(std::move(extensionFn)) {} |
242 | ~Extension() override = default; |
243 | |
244 | void apply(MLIRContext *context, DialectsT *...dialects) const final { |
245 | extensionFn(context, dialects...); |
246 | } |
247 | ExtensionFnT extensionFn; |
248 | }; |
249 | addExtension(std::make_unique<Extension>(std::move(extensionFn))); |
250 | } |
251 | |
252 | /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' |
253 | /// contains all of the components of this registry. |
254 | bool isSubsetOf(const DialectRegistry &rhs) const; |
255 | |
256 | private: |
257 | MapTy registry; |
258 | std::vector<std::unique_ptr<DialectExtensionBase>> extensions; |
259 | }; |
260 | |
261 | } // namespace mlir |
262 | |
263 | #endif // MLIR_IR_DIALECTREGISTRY_H |
264 | |