1//===- Dialect.cpp - Dialect implementation -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/Dialect.h"
10#include "mlir/IR/BuiltinDialect.h"
11#include "mlir/IR/Diagnostics.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/DialectInterface.h"
14#include "mlir/IR/DialectRegistry.h"
15#include "mlir/IR/ExtensibleDialect.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/Support/TypeID.h"
19#include "llvm/ADT/MapVector.h"
20#include "llvm/ADT/SetOperations.h"
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/ADT/SmallVectorExtras.h"
23#include "llvm/ADT/Twine.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/ManagedStatic.h"
26#include "llvm/Support/Regex.h"
27#include <memory>
28
29#define DEBUG_TYPE "dialect"
30
31using namespace mlir;
32using namespace detail;
33
34//===----------------------------------------------------------------------===//
35// Dialect
36//===----------------------------------------------------------------------===//
37
38Dialect::Dialect(StringRef name, MLIRContext *context, TypeID id)
39 : name(name), dialectID(id), context(context) {
40 assert(isValidNamespace(name) && "invalid dialect namespace");
41}
42
43Dialect::~Dialect() = default;
44
45/// Verify an attribute from this dialect on the argument at 'argIndex' for
46/// the region at 'regionIndex' on the given operation. Returns failure if
47/// the verification failed, success otherwise. This hook may optionally be
48/// invoked from any operation containing a region.
49LogicalResult Dialect::verifyRegionArgAttribute(Operation *, unsigned, unsigned,
50 NamedAttribute) {
51 return success();
52}
53
54/// Verify an attribute from this dialect on the result at 'resultIndex' for
55/// the region at 'regionIndex' on the given operation. Returns failure if
56/// the verification failed, success otherwise. This hook may optionally be
57/// invoked from any operation containing a region.
58LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned,
59 unsigned, NamedAttribute) {
60 return success();
61}
62
63/// Parse an attribute registered to this dialect.
64Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
65 parser.emitError(loc: parser.getNameLoc())
66 << "dialect '" << getNamespace()
67 << "' provides no attribute parsing hook";
68 return Attribute();
69}
70
71/// Parse a type registered to this dialect.
72Type Dialect::parseType(DialectAsmParser &parser) const {
73 // If this dialect allows unknown types, then represent this with OpaqueType.
74 if (allowsUnknownTypes()) {
75 StringAttr ns = StringAttr::get(getContext(), getNamespace());
76 return OpaqueType::get(ns, parser.getFullSymbolSpec());
77 }
78
79 parser.emitError(loc: parser.getNameLoc())
80 << "dialect '" << getNamespace() << "' provides no type parsing hook";
81 return Type();
82}
83
84std::optional<Dialect::ParseOpHook>
85Dialect::getParseOperationHook(StringRef opName) const {
86 return std::nullopt;
87}
88
89llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
90Dialect::getOperationPrinter(Operation *op) const {
91 assert(op->getDialect() == this &&
92 "Dialect hook invoked on non-dialect owned operation");
93 return nullptr;
94}
95
96/// Utility function that returns if the given string is a valid dialect
97/// namespace
98bool Dialect::isValidNamespace(StringRef str) {
99 llvm::Regex dialectNameRegex("^[a-zA-Z_][a-zA-Z_0-9\\$]*$");
100 return dialectNameRegex.match(String: str);
101}
102
103/// Register a set of dialect interfaces with this dialect instance.
104void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) {
105 // Handle the case where the models resolve a promised interface.
106 handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID: getTypeID(), interfaceID: interface->getID());
107
108 auto it = registeredInterfaces.try_emplace(Key: interface->getID(),
109 Args: std::move(interface));
110 (void)it;
111 LLVM_DEBUG({
112 if (!it.second) {
113 llvm::dbgs() << "[" DEBUG_TYPE
114 "] repeated interface registration for dialect "
115 << getNamespace();
116 }
117 });
118}
119
120//===----------------------------------------------------------------------===//
121// Dialect Interface
122//===----------------------------------------------------------------------===//
123
124DialectInterface::~DialectInterface() = default;
125
126MLIRContext *DialectInterface::getContext() const {
127 return dialect->getContext();
128}
129
130DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
131 MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) {
132 for (auto *dialect : ctx->getLoadedDialects()) {
133#ifndef NDEBUG
134 dialect->handleUseOfUndefinedPromisedInterface(
135 interfaceRequestorID: dialect->getTypeID(), interfaceID: interfaceKind, interfaceName);
136#endif
137 if (auto *interface = dialect->getRegisteredInterface(interfaceID: interfaceKind)) {
138 interfaces.insert(V: interface);
139 orderedInterfaces.push_back(x: interface);
140 }
141 }
142}
143
144DialectInterfaceCollectionBase::~DialectInterfaceCollectionBase() = default;
145
146/// Get the interface for the dialect of given operation, or null if one
147/// is not registered.
148const DialectInterface *
149DialectInterfaceCollectionBase::getInterfaceFor(Operation *op) const {
150 return getInterfaceFor(dialect: op->getDialect());
151}
152
153//===----------------------------------------------------------------------===//
154// DialectExtension
155//===----------------------------------------------------------------------===//
156
157DialectExtensionBase::~DialectExtensionBase() = default;
158
159void dialect_extension_detail::handleUseOfUndefinedPromisedInterface(
160 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID,
161 StringRef interfaceName) {
162 dialect.handleUseOfUndefinedPromisedInterface(interfaceRequestorID,
163 interfaceID, interfaceName);
164}
165
166void dialect_extension_detail::handleAdditionOfUndefinedPromisedInterface(
167 Dialect &dialect, TypeID interfaceRequestorID, TypeID interfaceID) {
168 dialect.handleAdditionOfUndefinedPromisedInterface(interfaceRequestorID,
169 interfaceID);
170}
171
172bool dialect_extension_detail::hasPromisedInterface(Dialect &dialect,
173 TypeID interfaceRequestorID,
174 TypeID interfaceID) {
175 return dialect.hasPromisedInterface(interfaceRequestorID, interfaceID);
176}
177
178//===----------------------------------------------------------------------===//
179// DialectRegistry
180//===----------------------------------------------------------------------===//
181
182namespace {
183template <typename Fn>
184void applyExtensionsFn(
185 Fn &&applyExtension,
186 const llvm::MapVector<TypeID, std::unique_ptr<DialectExtensionBase>>
187 &extensions) {
188 // Note: Additional extensions may be added while applying an extension.
189 // The iterators will be invalidated if extensions are added so we'll keep
190 // a copy of the extensions for ourselves.
191
192 const auto extractExtension =
193 [](const auto &entry) -> DialectExtensionBase * {
194 return entry.second.get();
195 };
196
197 auto startIt = extensions.begin(), endIt = extensions.end();
198 size_t count = 0;
199 while (startIt != endIt) {
200 count += endIt - startIt;
201
202 // Grab the subset of extensions we'll apply in this iteration.
203 const auto subset =
204 llvm::map_to_vector(llvm::make_range(x: startIt, y: endIt), extractExtension);
205
206 for (const auto *ext : subset)
207 applyExtension(*ext);
208
209 // Book-keep for the next iteration.
210 startIt = extensions.begin() + count;
211 endIt = extensions.end();
212 }
213}
214} // namespace
215
216DialectRegistry::DialectRegistry() { insert<BuiltinDialect>(); }
217
218DialectAllocatorFunctionRef
219DialectRegistry::getDialectAllocator(StringRef name) const {
220 auto it = registry.find(x: name);
221 if (it == registry.end())
222 return nullptr;
223 return it->second.second;
224}
225
226void DialectRegistry::insert(TypeID typeID, StringRef name,
227 const DialectAllocatorFunction &ctor) {
228 auto inserted = registry.insert(
229 x: std::make_pair(x: std::string(name), y: std::make_pair(x&: typeID, y: ctor)));
230 if (!inserted.second && inserted.first->second.first != typeID) {
231 llvm::report_fatal_error(
232 reason: "Trying to register different dialects for the same namespace: " +
233 name);
234 }
235}
236
237void DialectRegistry::insertDynamic(
238 StringRef name, const DynamicDialectPopulationFunction &ctor) {
239 // This TypeID marks dynamic dialects. We cannot give a TypeID for the
240 // dialect yet, since the TypeID of a dynamic dialect is defined at its
241 // construction.
242 TypeID typeID = TypeID::get<void>();
243
244 // Create the dialect, and then call ctor, which allocates its components.
245 auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
246 auto *dynDialect = ctx->getOrLoadDynamicDialect(
247 dialectNamespace: nameStr, ctor: [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
248 assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
249 return dynDialect;
250 };
251
252 insert(typeID, name, ctor: constructor);
253}
254
255void DialectRegistry::applyExtensions(Dialect *dialect) const {
256 MLIRContext *ctx = dialect->getContext();
257 StringRef dialectName = dialect->getNamespace();
258
259 // Functor used to try to apply the given extension.
260 auto applyExtension = [&](const DialectExtensionBase &extension) {
261 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
262 // An empty set is equivalent to always invoke.
263 if (dialectNames.empty()) {
264 extension.apply(context: ctx, dialects: dialect);
265 return;
266 }
267
268 // Handle the simple case of a single dialect name. In this case, the
269 // required dialect should be the current dialect.
270 if (dialectNames.size() == 1) {
271 if (dialectNames.front() == dialectName)
272 extension.apply(context: ctx, dialects: dialect);
273 return;
274 }
275
276 // Otherwise, check to see if this extension requires this dialect.
277 const StringRef *nameIt = llvm::find(Range&: dialectNames, Val: dialectName);
278 if (nameIt == dialectNames.end())
279 return;
280
281 // If it does, ensure that all of the other required dialects have been
282 // loaded.
283 SmallVector<Dialect *> requiredDialects;
284 requiredDialects.reserve(N: dialectNames.size());
285 for (auto it = dialectNames.begin(), e = dialectNames.end(); it != e;
286 ++it) {
287 // The current dialect is known to be loaded.
288 if (it == nameIt) {
289 requiredDialects.push_back(Elt: dialect);
290 continue;
291 }
292 // Otherwise, check if it is loaded.
293 Dialect *loadedDialect = ctx->getLoadedDialect(name: *it);
294 if (!loadedDialect)
295 return;
296 requiredDialects.push_back(Elt: loadedDialect);
297 }
298 extension.apply(context: ctx, dialects: requiredDialects);
299 };
300
301 applyExtensionsFn(applyExtension, extensions);
302}
303
304void DialectRegistry::applyExtensions(MLIRContext *ctx) const {
305 // Functor used to try to apply the given extension.
306 auto applyExtension = [&](const DialectExtensionBase &extension) {
307 ArrayRef<StringRef> dialectNames = extension.getRequiredDialects();
308 if (dialectNames.empty()) {
309 auto loadedDialects = ctx->getLoadedDialects();
310 extension.apply(context: ctx, dialects: loadedDialects);
311 return;
312 }
313
314 // Check to see if all of the dialects for this extension are loaded.
315 SmallVector<Dialect *> requiredDialects;
316 requiredDialects.reserve(N: dialectNames.size());
317 for (StringRef dialectName : dialectNames) {
318 Dialect *loadedDialect = ctx->getLoadedDialect(name: dialectName);
319 if (!loadedDialect)
320 return;
321 requiredDialects.push_back(Elt: loadedDialect);
322 }
323 extension.apply(context: ctx, dialects: requiredDialects);
324 };
325
326 applyExtensionsFn(applyExtension, extensions);
327}
328
329bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const {
330 // Check that all extension keys are present in 'rhs'.
331 const auto hasExtension = [&](const auto &key) {
332 return rhs.extensions.contains(Key: key);
333 };
334 if (!llvm::all_of(Range: make_first_range(c: extensions), P: hasExtension))
335 return false;
336
337 // Check that the current dialects fully overlap with the dialects in 'rhs'.
338 return llvm::all_of(
339 Range: registry, P: [&](const auto &it) { return rhs.registry.count(it.first); });
340}
341

source code of mlir/lib/IR/Dialect.cpp