1//===- Dialect.cpp - Dialect implementation -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
10#include "mlir/IR/BuiltinDialect.h"
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/DialectInterface.h"
14#include "mlir/IR/DialectRegistry.h"
15#include "mlir/IR/ExtensibleDialect.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Support/TypeID.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/SmallVectorExtras.h"
21#include "llvm/ADT/Twine.h"
22#include "llvm/Support/Debug.h"
23#include "llvm/Support/Regex.h"
24#include <memory>
25
26#define DEBUG_TYPE "dialect"
27
28using namespace mlir;
29using namespace detail;
30
31//===----------------------------------------------------------------------===//
32// Dialect
33//===----------------------------------------------------------------------===//
34
35Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
36 : name(name), dialectID(id), context(context) {
37 assert(isValidNamespace(name) && "invalid dialect namespace");
38}
39
40Dialect::~Dialect() = default;
41
42/// Verify an attribute from this dialect on the argument at 'argIndex' for
43/// the region at 'regionIndex' on the given operation. Returns failure if
44/// the verification failed, success otherwise. This hook may optionally be
45/// invoked from any operation containing a region.
46LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
47 NamedAttribute) {
48 return success();
49}
50
51/// Verify an attribute from this dialect on the result at 'resultIndex' for
52/// the region at 'regionIndex' on the given operation. Returns failure if
53/// the verification failed, success otherwise. This hook may optionally be
54/// invoked from any operation containing a region.
55LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
56 unsigned, NamedAttribute) {
57 return success();
58}
59
60/// Parse an attribute registered to this dialect.
61Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
62 parser.emitError(loc: parser.getNameLoc())
63 << "dialect '" << getNamespace()
64 << "' provides no attribute parsing hook";
65 return Attribute();
66}
67
68/// Parse a type registered to this dialect.
69Type Dialect::parseType(DialectAsmParser &parser) const {
70 // If this dialect allows unknown types, then represent this with OpaqueType.
71 if (allowsUnknownTypes()) {
72 StringAttr ns = StringAttr::get(context: getContext(), bytes: getNamespace());
73 return OpaqueType::get(dialectNamespace: ns, typeData: parser.getFullSymbolSpec());
74 }
75
76 parser.emitError(loc: parser.getNameLoc())
77 << "dialect '" << getNamespace() << "' provides no type parsing hook";
78 return Type();
79}
80
81std::optional<Dialect::ParseOpHook>
82Dialect::getParseOperationHook(StringRef opName) const {
83 return std::nullopt;
84}
85
86llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
87Dialect::getOperationPrinter(Operation *op) const {
88 assert(op->getDialect() == this &&
89 "Dialect hook invoked on non-dialect owned operation");
90 return nullptr;
91}
92
93/// Utility function that returns if the given string is a valid dialect
94/// namespace
95bool Dialect::isValidNamespace(StringRef str) {
96 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
97 return dialectNameRegex.match(String: str);
98}
99
100/// Register a set of dialect interfaces with this dialect instance.
101void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
102 // Handle the case where the models resolve a promised interface.
103 handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(), interfaceID: interface->getID());
104
105 auto it = registeredInterfaces.try_emplace(Key: interface->getID(),
106 Args: std::move(interface));
107 (void)it;
108 LLVM_DEBUG({
109 if (!it.second) {
110 llvm::dbgs() << "[" DEBUG_TYPE
111 "] repeated interface registration for dialect "
112 << getNamespace();
113 }
114 });
115}
116
117//===----------------------------------------------------------------------===//
118// Dialect Interface
119//===----------------------------------------------------------------------===//
120
121DialectInterface::~DialectInterface() = default;
122
123MLIRContext *DialectInterface::getContext() const {
124 return dialect->getContext();
125}
126
127DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
128 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
129 for (auto *dialect : ctx->getLoadedDialects()) {
130#ifndef NDEBUG
131 dialect->handleUseOfUndefinedPromisedInterface(
132 dialect->getTypeID(), interfaceKind, interfaceName);
133#endif
134 if (auto *interface = dialect->getRegisteredInterface(interfaceID: interfaceKind)) {
135 interfaces.insert(V: interface);
136 orderedInterfaces.push_back(x: interface);
137 }
138 }
139}
140
141DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
142
143/// Get the interface for the dialect of given operation, or null if one
144/// is not registered.
145const DialectInterface *
146DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
147 return getInterfaceFor(dialect: op->getDialect());
148}
149
150//===----------------------------------------------------------------------===//
151// DialectExtension
152//===----------------------------------------------------------------------===//
153
154DialectExtensionBase::~DialectExtensionBase() = default;
155
156void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
157 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
158 StringRef interfaceName) {
159 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
160 interfaceID, interfaceName);
161}
162
163void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
164 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
165 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
166 interfaceID);
167}
168
169bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
170 TypeID interfaceRequestorID,
171 TypeID interfaceID) {
172 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
173}
174
175//===----------------------------------------------------------------------===//
176// DialectRegistry
177//===----------------------------------------------------------------------===//
178
179namespace {
180template <typename Fn>
181void applyExtensionsFn(
182 Fn &&applyExtension,
183 const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
184 &extensions) {
185 // Note: Additional extensions may be added while applying an extension.
186 // The iterators will be invalidated if extensions are added so we'll keep
187 // a copy of the extensions for ourselves.
188
189 const auto extractExtension =
190 [](const auto &entry) -> DialectExtensionBase * {
191 return entry.second.get();
192 };
193
194 auto startIt = extensions.begin(), endIt = extensions.end();
195 size_t count = 0;
196 while (startIt != endIt) {
197 count += endIt - startIt;
198
199 // Grab the subset of extensions we'll apply in this iteration.
200 const auto subset =
201 llvm::map_to_vector(llvm::make_range(x: startIt, y: endIt), extractExtension);
202
203 for (const auto *ext : subset)
204 applyExtension(*ext);
205
206 // Book-keep for the next iteration.
207 startIt = extensions.begin() + count;
208 endIt = extensions.end();
209 }
210}
211} // namespace
212
213DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
214
215DialectAllocatorFunctionRef
216DialectRegistry::getDialectAllocator(StringRef name) const {
217 auto it = registry.find(x: name);
218 if (it == registry.end())
219 return nullptr;
220 return it->second.second;
221}
222
223void DialectRegistry::insert(TypeID typeID, StringRef name,
224 const DialectAllocatorFunction &ctor) {
225 auto inserted = registry.insert(
226 x: std::make_pair(x: std::string(name), y: std::make_pair(x&: typeID, y: ctor)));
227 if (!inserted.second && inserted.first->second.first != typeID) {
228 llvm::report_fatal_error(
229 reason: "Trying to register different dialects for the same namespace: " +
230 name);
231 }
232}
233
234void DialectRegistry::insertDynamic(
235 StringRef name, const DynamicDialectPopulationFunction &ctor) {
236 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
237 // dialect yet, since the TypeID of a dynamic dialect is defined at its
238 // construction.
239 TypeID typeID = TypeID::get<void>();
240
241 // Create the dialect, and then call ctor, which allocates its components.
242 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
243 auto *dynDialect = ctx->getOrLoadDynamicDialect(
244 dialectNamespace: nameStr, ctor: [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
245 assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
246 return dynDialect;
247 };
248
249 insert(typeID, name, ctor: constructor);
250}
251
252void DialectRegistry::applyExtensions(Dialect *dialect) const {
253 MLIRContext *ctx = dialect->getContext();
254 StringRef dialectName = dialect->getNamespace();
255
256 // Functor used to try to apply the given extension.
257 auto applyExtension = [&](const DialectExtensionBase &extension) {
258 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
259 // An empty set is equivalent to always invoke.
260 if (dialectNames.empty()) {
261 extension.apply(context: ctx, dialects: dialect);
262 return;
263 }
264
265 // Handle the simple case of a single dialect name. In this case, the
266 // required dialect should be the current dialect.
267 if (dialectNames.size() == 1) {
268 if (dialectNames.front() == dialectName)
269 extension.apply(context: ctx, dialects: dialect);
270 return;
271 }
272
273 // Otherwise, check to see if this extension requires this dialect.
274 const StringRef *nameIt = llvm::find(Range&: dialectNames, Val: dialectName);
275 if (nameIt == dialectNames.end())
276 return;
277
278 // If it does, ensure that all of the other required dialects have been
279 // loaded.
280 SmallVector<Dialect *> requiredDialects;
281 requiredDialects.reserve(N: dialectNames.size());
282 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
283 ++it) {
284 // The current dialect is known to be loaded.
285 if (it == nameIt) {
286 requiredDialects.push_back(Elt: dialect);
287 continue;
288 }
289 // Otherwise, check if it is loaded.
290 Dialect *loadedDialect = ctx->getLoadedDialect(name: *it);
291 if (!loadedDialect)
292 return;
293 requiredDialects.push_back(Elt: loadedDialect);
294 }
295 extension.apply(context: ctx, dialects: requiredDialects);
296 };
297
298 applyExtensionsFn(applyExtension, extensions);
299}
300
301void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
302 // Functor used to try to apply the given extension.
303 auto applyExtension = [&](const DialectExtensionBase &extension) {
304 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
305 if (dialectNames.empty()) {
306 auto loadedDialects = ctx->getLoadedDialects();
307 extension.apply(context: ctx, dialects: loadedDialects);
308 return;
309 }
310
311 // Check to see if all of the dialects for this extension are loaded.
312 SmallVector<Dialect *> requiredDialects;
313 requiredDialects.reserve(N: dialectNames.size());
314 for (StringRef dialectName : dialectNames) {
315 Dialect *loadedDialect = ctx->getLoadedDialect(name: dialectName);
316 if (!loadedDialect)
317 return;
318 requiredDialects.push_back(Elt: loadedDialect);
319 }
320 extension.apply(context: ctx, dialects: requiredDialects);
321 };
322
323 applyExtensionsFn(applyExtension, extensions);
324}
325
326bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
327 // Check that all extension keys are present in 'rhs'.
328 const auto hasExtension = [&](const auto &key) {
329 return rhs.extensions.contains(Key: key);
330 };
331 if (!llvm::all_of(Range: make_first_range(c: extensions), P: hasExtension))
332 return false;
333
334 // Check that the current dialects fully overlap with the dialects in 'rhs'.
335 return llvm::all_of(
336 Range: registry, P: [&](const auto &it) { return rhs.registry.count(it.first); });
337}
338

source code of mlir/lib/IR/Dialect.cpp