1//===- DialectRegistry.h - Dialect Registration and Extension ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines functionality for registring and extending dialects.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_DIALECTREGISTRY_H
14#define MLIR_IR_DIALECTREGISTRY_H
15
16#include "mlir/IR/MLIRContext.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/StringRef.h"
20
21#include <map>
22#include <tuple>
23
24namespace mlir {
25class Dialect;
26
27using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
28using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
29using DynamicDialectPopulationFunction =
30 std::function<void(MLIRContext *, DynamicDialect *)>;
31
32//===----------------------------------------------------------------------===//
33// DialectExtension
34//===----------------------------------------------------------------------===//
35
36/// This class represents an opaque dialect extension. It contains a set of
37/// required dialects and an application function. The required dialects control
38/// when the extension is applied, i.e. the extension is applied when all
39/// required dialects are loaded. The application function can be used to attach
40/// additional functionality to attributes, dialects, operations, types, etc.,
41/// and may also load additional necessary dialects.
42class DialectExtensionBase {
43public:
44 virtual ~DialectExtensionBase();
45
46 /// Return the dialects that our required by this extension to be loaded
47 /// before applying. If empty then the extension is invoked for every loaded
48 /// dialect indepently.
49 ArrayRef<StringRef> getRequiredDialects() const { return dialectNames; }
50
51 /// Apply this extension to the given context and the required dialects.
52 virtual void apply(MLIRContext *context,
53 MutableArrayRef<Dialect *> dialects) const = 0;
54
55 /// Return a copy of this extension.
56 virtual std::unique_ptr<DialectExtensionBase> clone() const = 0;
57
58protected:
59 /// Initialize the extension with a set of required dialects.
60 /// If the list is empty, the extension is invoked for every loaded dialect
61 /// independently.
62 DialectExtensionBase(ArrayRef<StringRef> dialectNames)
63 : dialectNames(dialectNames.begin(), dialectNames.end()) {}
64
65private:
66 /// The names of the dialects affected by this extension.
67 SmallVector<StringRef> dialectNames;
68};
69
70/// This class represents a dialect extension anchored on the given set of
71/// dialects. When all of the specified dialects have been loaded, the
72/// application function of this extension will be executed.
73template <typename DerivedT, typename... DialectsT>
74class DialectExtension : public DialectExtensionBase {
75public:
76 /// Applies this extension to the given context and set of required dialects.
77 virtual void apply(MLIRContext *context, DialectsT *...dialects) const = 0;
78
79 /// Return a copy of this extension.
80 std::unique_ptr<DialectExtensionBase> clone() const final {
81 return std::make_unique<DerivedT>(static_cast<const DerivedT &>(*this));
82 }
83
84protected:
85 DialectExtension()
86 : DialectExtensionBase(
87 ArrayRef<StringRef>({DialectsT::getDialectNamespace()...})) {}
88
89 /// Override the base apply method to allow providing the exact dialect types.
90 void apply(MLIRContext *context,
91 MutableArrayRef<Dialect *> dialects) const final {
92 unsigned dialectIdx = 0;
93 auto derivedDialects = std::tuple<DialectsT *...>{
94 static_cast<DialectsT *>(dialects[dialectIdx++])...};
95 std::apply([&](DialectsT *...dialect) { apply(context, dialect...); },
96 derivedDialects);
97 }
98};
99
100namespace dialect_extension_detail {
101
102/// Checks if the given interface, which is attempting to be used, is a
103/// promised interface of this dialect that has yet to be implemented. If so,
104/// emits a fatal error.
105void handleUseOfUndefinedPromisedInterface(Dialect &dialect,
106 TypeID interfaceRequestorID,
107 TypeID interfaceID,
108 StringRef interfaceName);
109
110/// Checks if the given interface, which is attempting to be attached, is a
111/// promised interface of this dialect that has yet to be implemented. If so,
112/// the promised interface is marked as resolved.
113void handleAdditionOfUndefinedPromisedInterface(Dialect &dialect,
114 TypeID interfaceRequestorID,
115 TypeID interfaceID);
116
117/// Checks if a promise has been made for the interface/requestor pair.
118bool hasPromisedInterface(Dialect &dialect, TypeID interfaceRequestorID,
119 TypeID interfaceID);
120
121/// Checks if a promise has been made for the interface/requestor pair.
122template <typename ConcreteT, typename InterfaceT>
123bool hasPromisedInterface(Dialect &dialect) {
124 return hasPromisedInterface(dialect, TypeID::get<ConcreteT>(),
125 InterfaceT::getInterfaceID());
126}
127
128} // namespace dialect_extension_detail
129
130//===----------------------------------------------------------------------===//
131// DialectRegistry
132//===----------------------------------------------------------------------===//
133
134/// The DialectRegistry maps a dialect namespace to a constructor for the
135/// matching dialect. This allows for decoupling the list of dialects
136/// "available" from the dialects loaded in the Context. The parser in
137/// particular will lazily load dialects in the Context as operations are
138/// encountered.
139class DialectRegistry {
140 using MapTy =
141 std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
142
143public:
144 explicit DialectRegistry();
145
146 template <typename ConcreteDialect>
147 void insert() {
148 insert(TypeID::get<ConcreteDialect>(),
149 ConcreteDialect::getDialectNamespace(),
150 static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
151 // Just allocate the dialect, the context
152 // takes ownership of it.
153 return ctx->getOrLoadDialect<ConcreteDialect>();
154 })));
155 }
156
157 template <typename ConcreteDialect, typename OtherDialect,
158 typename... MoreDialects>
159 void insert() {
160 insert<ConcreteDialect>();
161 insert<OtherDialect, MoreDialects...>();
162 }
163
164 /// Add a new dialect constructor to the registry. The constructor must be
165 /// calling MLIRContext::getOrLoadDialect in order for the context to take
166 /// ownership of the dialect and for delayed interface registration to happen.
167 void insert(TypeID typeID, StringRef name,
168 const DialectAllocatorFunction &ctor);
169
170 /// Add a new dynamic dialect constructor in the registry. The constructor
171 /// provides as argument the created dynamic dialect, and is expected to
172 /// register the dialect types, attributes, and ops, using the
173 /// methods defined in ExtensibleDialect such as registerDynamicOperation.
174 void insertDynamic(StringRef name,
175 const DynamicDialectPopulationFunction &ctor);
176
177 /// Return an allocation function for constructing the dialect identified
178 /// by its namespace, or nullptr if the namespace is not in this registry.
179 DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
180
181 // Register all dialects available in the current registry with the registry
182 // in the provided context.
183 void appendTo(DialectRegistry &destination) const {
184 for (const auto &nameAndRegistrationIt : registry)
185 destination.insert(typeID: nameAndRegistrationIt.second.first,
186 name: nameAndRegistrationIt.first,
187 ctor: nameAndRegistrationIt.second.second);
188 // Merge the extensions.
189 for (const auto &extension : extensions)
190 destination.extensions.push_back(x: extension->clone());
191 }
192
193 /// Return the names of dialects known to this registry.
194 auto getDialectNames() const {
195 return llvm::map_range(
196 C: registry,
197 F: [](const MapTy::value_type &item) -> StringRef { return item.first; });
198 }
199
200 /// Apply any held extensions that require the given dialect. Users are not
201 /// expected to call this directly.
202 void applyExtensions(Dialect *dialect) const;
203
204 /// Apply any applicable extensions to the given context. Users are not
205 /// expected to call this directly.
206 void applyExtensions(MLIRContext *ctx) const;
207
208 /// Add the given extension to the registry.
209 void addExtension(std::unique_ptr<DialectExtensionBase> extension) {
210 extensions.push_back(x: std::move(extension));
211 }
212
213 /// Add the given extensions to the registry.
214 template <typename... ExtensionsT>
215 void addExtensions() {
216 (addExtension(std::make_unique<ExtensionsT>()), ...);
217 }
218
219 /// Add an extension function that requires the given dialects.
220 /// Note: This bare functor overload is provided in addition to the
221 /// std::function variant to enable dialect type deduction, e.g.:
222 /// registry.addExtension(+[](MLIRContext *ctx, MyDialect *dialect) { ... })
223 ///
224 /// is equivalent to:
225 /// registry.addExtension<MyDialect>(
226 /// [](MLIRContext *ctx, MyDialect *dialect){ ... }
227 /// )
228 template <typename... DialectsT>
229 void addExtension(void (*extensionFn)(MLIRContext *, DialectsT *...)) {
230 addExtension<DialectsT...>(
231 std::function<void(MLIRContext *, DialectsT * ...)>(extensionFn));
232 }
233 template <typename... DialectsT>
234 void
235 addExtension(std::function<void(MLIRContext *, DialectsT *...)> extensionFn) {
236 using ExtensionFnT = std::function<void(MLIRContext *, DialectsT * ...)>;
237
238 struct Extension : public DialectExtension<Extension, DialectsT...> {
239 Extension(const Extension &) = default;
240 Extension(ExtensionFnT extensionFn)
241 : extensionFn(std::move(extensionFn)) {}
242 ~Extension() override = default;
243
244 void apply(MLIRContext *context, DialectsT *...dialects) const final {
245 extensionFn(context, dialects...);
246 }
247 ExtensionFnT extensionFn;
248 };
249 addExtension(std::make_unique<Extension>(std::move(extensionFn)));
250 }
251
252 /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs'
253 /// contains all of the components of this registry.
254 bool isSubsetOf(const DialectRegistry &rhs) const;
255
256private:
257 MapTy registry;
258 std::vector<std::unique_ptr<DialectExtensionBase>> extensions;
259};
260
261} // namespace mlir
262
263#endif // MLIR_IR_DIALECTREGISTRY_H
264

source code of mlir/include/mlir/IR/DialectRegistry.h