1//===- Dialect.cpp - Dialect implementation -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
10#include "mlir/IR/BuiltinDialect.h"
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/DialectInterface.h"
14#include "mlir/IR/ExtensibleDialect.h"
15#include "mlir/IR/MLIRContext.h"
16#include "mlir/IR/Operation.h"
17#include "llvm/ADT/MapVector.h"
18#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/ManagedStatic.h"
21#include "llvm/Support/Regex.h"
22
23#define DEBUG_TYPE "dialect"
24
25using namespace mlir;
26using namespace detail;
27
28//===----------------------------------------------------------------------===//
29// Dialect
30//===----------------------------------------------------------------------===//
31
32Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
33 : name(name), dialectID(id), context(context) {
34 assert(isValidNamespace(name) && "invalid dialect namespace");
35}
36
37Dialect::~Dialect() = default;
38
39/// Verify an attribute from this dialect on the argument at 'argIndex' for
40/// the region at 'regionIndex' on the given operation. Returns failure if
41/// the verification failed, success otherwise. This hook may optionally be
42/// invoked from any operation containing a region.
43LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
44 NamedAttribute) {
45 return success();
46}
47
48/// Verify an attribute from this dialect on the result at 'resultIndex' for
49/// the region at 'regionIndex' on the given operation. Returns failure if
50/// the verification failed, success otherwise. This hook may optionally be
51/// invoked from any operation containing a region.
52LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
53 unsigned, NamedAttribute) {
54 return success();
55}
56
57/// Parse an attribute registered to this dialect.
58Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
59 parser.emitError(loc: parser.getNameLoc())
60 << "dialect '" << getNamespace()
61 << "' provides no attribute parsing hook";
62 return Attribute();
63}
64
65/// Parse a type registered to this dialect.
66Type Dialect::parseType(DialectAsmParser &parser) const {
67 // If this dialect allows unknown types, then represent this with OpaqueType.
68 if (allowsUnknownTypes()) {
69 StringAttr ns = StringAttr::get(getContext(), getNamespace());
70 return OpaqueType::get(ns, parser.getFullSymbolSpec());
71 }
72
73 parser.emitError(loc: parser.getNameLoc())
74 << "dialect '" << getNamespace() << "' provides no type parsing hook";
75 return Type();
76}
77
78std::optional<Dialect::ParseOpHook>
79Dialect::getParseOperationHook(StringRef opName) const {
80 return std::nullopt;
81}
82
83llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
84Dialect::getOperationPrinter(Operation *op) const {
85 assert(op->getDialect() == this &&
86 "Dialect hook invoked on non-dialect owned operation");
87 return nullptr;
88}
89
90/// Utility function that returns if the given string is a valid dialect
91/// namespace
92bool Dialect::isValidNamespace(StringRef str) {
93 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
94 return dialectNameRegex.match(String: str);
95}
96
97/// Register a set of dialect interfaces with this dialect instance.
98void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
99 // Handle the case where the models resolve a promised interface.
100 handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(), interfaceID: interface->getID());
101
102 auto it = registeredInterfaces.try_emplace(Key: interface->getID(),
103 Args: std::move(interface));
104 (void)it;
105 LLVM_DEBUG({
106 if (!it.second) {
107 llvm::dbgs() << "[" DEBUG_TYPE
108 "] repeated interface registration for dialect "
109 << getNamespace();
110 }
111 });
112}
113
114//===----------------------------------------------------------------------===//
115// Dialect Interface
116//===----------------------------------------------------------------------===//
117
118DialectInterface::~DialectInterface() = default;
119
120MLIRContext *DialectInterface::getContext() const {
121 return dialect->getContext();
122}
123
124DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
125 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
126 for (auto *dialect : ctx->getLoadedDialects()) {
127#ifndef NDEBUG
128 dialect->handleUseOfUndefinedPromisedInterface(
129 interfaceRequestorID: dialect->getTypeID(), interfaceID: interfaceKind, interfaceName);
130#endif
131 if (auto *interface = dialect->getRegisteredInterface(interfaceID: interfaceKind)) {
132 interfaces.insert(V: interface);
133 orderedInterfaces.push_back(x: interface);
134 }
135 }
136}
137
138DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
139
140/// Get the interface for the dialect of given operation, or null if one
141/// is not registered.
142const DialectInterface *
143DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
144 return getInterfaceFor(dialect: op->getDialect());
145}
146
147//===----------------------------------------------------------------------===//
148// DialectExtension
149//===----------------------------------------------------------------------===//
150
151DialectExtensionBase::~DialectExtensionBase() = default;
152
153void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
154 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
155 StringRef interfaceName) {
156 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
157 interfaceID, interfaceName);
158}
159
160void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
161 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
162 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
163 interfaceID);
164}
165
166bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
167 TypeID interfaceRequestorID,
168 TypeID interfaceID) {
169 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
170}
171
172//===----------------------------------------------------------------------===//
173// DialectRegistry
174//===----------------------------------------------------------------------===//
175
176DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
177
178DialectAllocatorFunctionRef
179DialectRegistry::getDialectAllocator(StringRef name) const {
180 auto it = registry.find(x: name.str());
181 if (it == registry.end())
182 return nullptr;
183 return it->second.second;
184}
185
186void DialectRegistry::insert(TypeID typeID, StringRef name,
187 const DialectAllocatorFunction &ctor) {
188 auto inserted = registry.insert(
189 x: std::make_pair(x: std::string(name), y: std::make_pair(x&: typeID, y: ctor)));
190 if (!inserted.second && inserted.first->second.first != typeID) {
191 llvm::report_fatal_error(
192 reason: "Trying to register different dialects for the same namespace: " +
193 name);
194 }
195}
196
197void DialectRegistry::insertDynamic(
198 StringRef name, const DynamicDialectPopulationFunction &ctor) {
199 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
200 // dialect yet, since the TypeID of a dynamic dialect is defined at its
201 // construction.
202 TypeID typeID = TypeID::get<void>();
203
204 // Create the dialect, and then call ctor, which allocates its components.
205 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
206 auto *dynDialect = ctx->getOrLoadDynamicDialect(
207 dialectNamespace: nameStr, ctor: [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
208 assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
209 return dynDialect;
210 };
211
212 insert(typeID, name, ctor: constructor);
213}
214
215void DialectRegistry::applyExtensions(Dialect *dialect) const {
216 MLIRContext *ctx = dialect->getContext();
217 StringRef dialectName = dialect->getNamespace();
218
219 // Functor used to try to apply the given extension.
220 auto applyExtension = [&](const DialectExtensionBase &extension) {
221 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
222 // An empty set is equivalent to always invoke.
223 if (dialectNames.empty()) {
224 extension.apply(context: ctx, dialects: dialect);
225 return;
226 }
227
228 // Handle the simple case of a single dialect name. In this case, the
229 // required dialect should be the current dialect.
230 if (dialectNames.size() == 1) {
231 if (dialectNames.front() == dialectName)
232 extension.apply(context: ctx, dialects: dialect);
233 return;
234 }
235
236 // Otherwise, check to see if this extension requires this dialect.
237 const StringRef *nameIt = llvm::find(Range&: dialectNames, Val: dialectName);
238 if (nameIt == dialectNames.end())
239 return;
240
241 // If it does, ensure that all of the other required dialects have been
242 // loaded.
243 SmallVector<Dialect *> requiredDialects;
244 requiredDialects.reserve(N: dialectNames.size());
245 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
246 ++it) {
247 // The current dialect is known to be loaded.
248 if (it == nameIt) {
249 requiredDialects.push_back(Elt: dialect);
250 continue;
251 }
252 // Otherwise, check if it is loaded.
253 Dialect *loadedDialect = ctx->getLoadedDialect(name: *it);
254 if (!loadedDialect)
255 return;
256 requiredDialects.push_back(Elt: loadedDialect);
257 }
258 extension.apply(context: ctx, dialects: requiredDialects);
259 };
260
261 // Note: Additional extensions may be added while applying an extension.
262 for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
263 applyExtension(*extensions[i]);
264}
265
266void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
267 // Functor used to try to apply the given extension.
268 auto applyExtension = [&](const DialectExtensionBase &extension) {
269 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
270 if (dialectNames.empty()) {
271 auto loadedDialects = ctx->getLoadedDialects();
272 extension.apply(context: ctx, dialects: loadedDialects);
273 return;
274 }
275
276 // Check to see if all of the dialects for this extension are loaded.
277 SmallVector<Dialect *> requiredDialects;
278 requiredDialects.reserve(N: dialectNames.size());
279 for (StringRef dialectName : dialectNames) {
280 Dialect *loadedDialect = ctx->getLoadedDialect(name: dialectName);
281 if (!loadedDialect)
282 return;
283 requiredDialects.push_back(Elt: loadedDialect);
284 }
285 extension.apply(context: ctx, dialects: requiredDialects);
286 };
287
288 // Note: Additional extensions may be added while applying an extension.
289 for (int i = 0; i < static_cast<int>(extensions.size()); ++i)
290 applyExtension(*extensions[i]);
291}
292
293bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
294 // Treat any extensions conservatively.
295 if (!extensions.empty())
296 return false;
297 // Check that the current dialects fully overlap with the dialects in 'rhs'.
298 return llvm::all_of(
299 Range: registry, P: [&](const auto &it) { return rhs.registry.count(it.first); });
300}
301

source code of mlir/lib/IR/Dialect.cpp