1//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/MLIRContext.h"
10#include "AffineExprDetail.h"
11#include "AffineMapDetail.h"
12#include "AttributeDetail.h"
13#include "IntegerSetDetail.h"
14#include "TypeDetail.h"
15#include "mlir/IR/Action.h"
16#include "mlir/IR/AffineExpr.h"
17#include "mlir/IR/AffineMap.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/BuiltinDialect.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/ExtensibleDialect.h"
24#include "mlir/IR/IntegerSet.h"
25#include "mlir/IR/Location.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/OperationSupport.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/Twine.h"
30#include "llvm/Support/Allocator.h"
31#include "llvm/Support/CommandLine.h"
32#include "llvm/Support/Compiler.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/ManagedStatic.h"
35#include "llvm/Support/Mutex.h"
36#include "llvm/Support/RWMutex.h"
37#include "llvm/Support/ThreadPool.h"
38#include "llvm/Support/raw_ostream.h"
39#include <memory>
40#include <optional>
41
42#define DEBUG_TYPE "mlircontext"
43
44using namespace mlir;
45using namespace mlir::detail;
46
47//===----------------------------------------------------------------------===//
48// MLIRContext CommandLine Options
49//===----------------------------------------------------------------------===//
50
51namespace {
52/// This struct contains command line options that can be used to initialize
53/// various bits of an MLIRContext. This uses a struct wrapper to avoid the need
54/// for global command line options.
55struct MLIRContextOptions {
56 llvm::cl::opt<bool> disableThreading{
57 "mlir-disable-threading",
58 llvm::cl::desc("Disable multi-threading within MLIR, overrides any "
59 "further call to MLIRContext::enableMultiThreading()")};
60
61 llvm::cl::opt<bool> printOpOnDiagnostic{
62 "mlir-print-op-on-diagnostic",
63 llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
64 "the operation as an attached note"),
65 llvm::cl::init(Val: true)};
66
67 llvm::cl::opt<bool> printStackTraceOnDiagnostic{
68 "mlir-print-stacktrace-on-diagnostic",
69 llvm::cl::desc("When a diagnostic is emitted, also print the stack trace "
70 "as an attached note")};
71};
72} // namespace
73
74static llvm::ManagedStatic<MLIRContextOptions> clOptions;
75
76static bool isThreadingGloballyDisabled() {
77#if LLVM_ENABLE_THREADS != 0
78 return clOptions.isConstructed() && clOptions->disableThreading;
79#else
80 return true;
81#endif
82}
83
84/// Register a set of useful command-line options that can be used to configure
85/// various flags within the MLIRContext. These flags are used when constructing
86/// an MLIR context for initialization.
87void mlir::registerMLIRContextCLOptions() {
88 // Make sure that the options struct has been initialized.
89 *clOptions;
90}
91
92//===----------------------------------------------------------------------===//
93// Locking Utilities
94//===----------------------------------------------------------------------===//
95
96namespace {
97/// Utility writer lock that takes a runtime flag that specifies if we really
98/// need to lock.
99struct ScopedWriterLock {
100 ScopedWriterLock(llvm::sys::SmartRWMutex<true> &mutexParam, bool shouldLock)
101 : mutex(shouldLock ? &mutexParam : nullptr) {
102 if (mutex)
103 mutex->lock();
104 }
105 ~ScopedWriterLock() {
106 if (mutex)
107 mutex->unlock();
108 }
109 llvm::sys::SmartRWMutex<true> *mutex;
110};
111} // namespace
112
113//===----------------------------------------------------------------------===//
114// MLIRContextImpl
115//===----------------------------------------------------------------------===//
116
117namespace mlir {
118/// This is the implementation of the MLIRContext class, using the pImpl idiom.
119/// This class is completely private to this file, so everything is public.
120class MLIRContextImpl {
121public:
122 //===--------------------------------------------------------------------===//
123 // Debugging
124 //===--------------------------------------------------------------------===//
125
126 /// An action handler for handling actions that are dispatched through this
127 /// context.
128 std::function<void(function_ref<void()>, const tracing::Action &)>
129 actionHandler;
130
131 //===--------------------------------------------------------------------===//
132 // Diagnostics
133 //===--------------------------------------------------------------------===//
134 DiagnosticEngine diagEngine;
135
136 //===--------------------------------------------------------------------===//
137 // Options
138 //===--------------------------------------------------------------------===//
139
140 /// In most cases, creating operation in unregistered dialect is not desired
141 /// and indicate a misconfiguration of the compiler. This option enables to
142 /// detect such use cases
143 bool allowUnregisteredDialects = false;
144
145 /// Enable support for multi-threading within MLIR.
146 bool threadingIsEnabled = true;
147
148 /// Track if we are currently executing in a threaded execution environment
149 /// (like the pass-manager): this is only a debugging feature to help reducing
150 /// the chances of data races one some context APIs.
151#ifndef NDEBUG
152 std::atomic<int> multiThreadedExecutionContext{0};
153#endif
154
155 /// If the operation should be attached to diagnostics printed via the
156 /// Operation::emit methods.
157 bool printOpOnDiagnostic = true;
158
159 /// If the current stack trace should be attached when emitting diagnostics.
160 bool printStackTraceOnDiagnostic = false;
161
162 //===--------------------------------------------------------------------===//
163 // Other
164 //===--------------------------------------------------------------------===//
165
166 /// This points to the ThreadPool used when processing MLIR tasks in parallel.
167 /// It can't be nullptr when multi-threading is enabled. Otherwise if
168 /// multi-threading is disabled, and the threadpool wasn't externally provided
169 /// using `setThreadPool`, this will be nullptr.
170 llvm::ThreadPoolInterface *threadPool = nullptr;
171
172 /// In case where the thread pool is owned by the context, this ensures
173 /// destruction with the context.
174 std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
175
176 /// An allocator used for AbstractAttribute and AbstractType objects.
177 llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
178
179 /// This is a mapping from operation name to the operation info describing it.
180 llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
181
182 /// A vector of operation info specifically for registered operations.
183 llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
184 llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
185
186 /// This is a sorted container of registered operations for a deterministic
187 /// and efficient `getRegisteredOperations` implementation.
188 SmallVector<RegisteredOperationName, 0> sortedRegisteredOperations;
189
190 /// This is a list of dialects that are created referring to this context.
191 /// The MLIRContext owns the objects. These need to be declared after the
192 /// registered operations to ensure correct destruction order.
193 DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
194 DialectRegistry dialectsRegistry;
195
196 /// A mutex used when accessing operation information.
197 llvm::sys::SmartRWMutex<true> operationInfoMutex;
198
199 //===--------------------------------------------------------------------===//
200 // Affine uniquing
201 //===--------------------------------------------------------------------===//
202
203 // Affine expression, map and integer set uniquing.
204 StorageUniquer affineUniquer;
205
206 //===--------------------------------------------------------------------===//
207 // Type uniquing
208 //===--------------------------------------------------------------------===//
209
210 DenseMap<TypeID, AbstractType *> registeredTypes;
211 StorageUniquer typeUniquer;
212
213 /// This is a mapping from type name to the abstract type describing it.
214 /// It is used by `AbstractType::lookup` to get an `AbstractType` from a name.
215 /// As this map needs to be populated before `StringAttr` is loaded, we
216 /// cannot use `StringAttr` as the key. The context does not take ownership
217 /// of the key, so the `StringRef` must outlive the context.
218 llvm::DenseMap<StringRef, AbstractType *> nameToType;
219
220 /// Cached Type Instances.
221 BFloat16Type bf16Ty;
222 Float16Type f16Ty;
223 FloatTF32Type tf32Ty;
224 Float32Type f32Ty;
225 Float64Type f64Ty;
226 Float80Type f80Ty;
227 Float128Type f128Ty;
228 IndexType indexTy;
229 IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty;
230 NoneType noneType;
231
232 //===--------------------------------------------------------------------===//
233 // Attribute uniquing
234 //===--------------------------------------------------------------------===//
235
236 DenseMap<TypeID, AbstractAttribute *> registeredAttributes;
237 StorageUniquer attributeUniquer;
238
239 /// This is a mapping from attribute name to the abstract attribute describing
240 /// it. It is used by `AbstractType::lookup` to get an `AbstractType` from a
241 /// name.
242 /// As this map needs to be populated before `StringAttr` is loaded, we
243 /// cannot use `StringAttr` as the key. The context does not take ownership
244 /// of the key, so the `StringRef` must outlive the context.
245 llvm::DenseMap<StringRef, AbstractAttribute *> nameToAttribute;
246
247 /// Cached Attribute Instances.
248 BoolAttr falseAttr, trueAttr;
249 UnitAttr unitAttr;
250 UnknownLoc unknownLocAttr;
251 DictionaryAttr emptyDictionaryAttr;
252 StringAttr emptyStringAttr;
253
254 /// Map of string attributes that may reference a dialect, that are awaiting
255 /// that dialect to be loaded.
256 llvm::sys::SmartMutex<true> dialectRefStrAttrMutex;
257 DenseMap<StringRef, SmallVector<StringAttrStorage *>>
258 dialectReferencingStrAttrs;
259
260 /// A distinct attribute allocator that allocates every time since the
261 /// address of the distinct attribute storage serves as unique identifier. The
262 /// allocator is thread safe and frees the allocated storage after its
263 /// destruction.
264 DistinctAttributeAllocator distinctAttributeAllocator;
265
266public:
267 MLIRContextImpl(bool threadingIsEnabled)
268 : threadingIsEnabled(threadingIsEnabled) {
269 if (threadingIsEnabled) {
270 ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
271 threadPool = ownedThreadPool.get();
272 }
273 }
274 ~MLIRContextImpl() {
275 for (auto typeMapping : registeredTypes)
276 typeMapping.second->~AbstractType();
277 for (auto attrMapping : registeredAttributes)
278 attrMapping.second->~AbstractAttribute();
279 }
280};
281} // namespace mlir
282
283MLIRContext::MLIRContext(Threading setting)
284 : MLIRContext(DialectRegistry(), setting) {}
285
286MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)
287 : impl(new MLIRContextImpl(setting == Threading::ENABLED &&
288 !isThreadingGloballyDisabled())) {
289 // Initialize values based on the command line flags if they were provided.
290 if (clOptions.isConstructed()) {
291 printOpOnDiagnostic(enable: clOptions->printOpOnDiagnostic);
292 printStackTraceOnDiagnostic(enable: clOptions->printStackTraceOnDiagnostic);
293 }
294
295 // Pre-populate the registry.
296 registry.appendTo(destination&: impl->dialectsRegistry);
297
298 // Ensure the builtin dialect is always pre-loaded.
299 getOrLoadDialect<BuiltinDialect>();
300
301 // Initialize several common attributes and types to avoid the need to lock
302 // the context when accessing them.
303
304 //// Types.
305 /// Floating-point Types.
306 impl->bf16Ty = TypeUniquer::get<BFloat16Type>(ctx: this);
307 impl->f16Ty = TypeUniquer::get<Float16Type>(ctx: this);
308 impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(ctx: this);
309 impl->f32Ty = TypeUniquer::get<Float32Type>(ctx: this);
310 impl->f64Ty = TypeUniquer::get<Float64Type>(ctx: this);
311 impl->f80Ty = TypeUniquer::get<Float80Type>(ctx: this);
312 impl->f128Ty = TypeUniquer::get<Float128Type>(ctx: this);
313 /// Index Type.
314 impl->indexTy = TypeUniquer::get<IndexType>(ctx: this);
315 /// Integer Types.
316 impl->int1Ty = TypeUniquer::get<IntegerType>(ctx: this, args: 1, args: IntegerType::Signless);
317 impl->int8Ty = TypeUniquer::get<IntegerType>(ctx: this, args: 8, args: IntegerType::Signless);
318 impl->int16Ty =
319 TypeUniquer::get<IntegerType>(ctx: this, args: 16, args: IntegerType::Signless);
320 impl->int32Ty =
321 TypeUniquer::get<IntegerType>(ctx: this, args: 32, args: IntegerType::Signless);
322 impl->int64Ty =
323 TypeUniquer::get<IntegerType>(ctx: this, args: 64, args: IntegerType::Signless);
324 impl->int128Ty =
325 TypeUniquer::get<IntegerType>(ctx: this, args: 128, args: IntegerType::Signless);
326 /// None Type.
327 impl->noneType = TypeUniquer::get<NoneType>(ctx: this);
328
329 //// Attributes.
330 //// Note: These must be registered after the types as they may generate one
331 //// of the above types internally.
332 /// Unknown Location Attribute.
333 impl->unknownLocAttr = AttributeUniquer::get<UnknownLoc>(ctx: this);
334 /// Bool Attributes.
335 impl->falseAttr = IntegerAttr::getBoolAttrUnchecked(type: impl->int1Ty, value: false);
336 impl->trueAttr = IntegerAttr::getBoolAttrUnchecked(type: impl->int1Ty, value: true);
337 /// Unit Attribute.
338 impl->unitAttr = AttributeUniquer::get<UnitAttr>(ctx: this);
339 /// The empty dictionary attribute.
340 impl->emptyDictionaryAttr = DictionaryAttr::getEmptyUnchecked(context: this);
341 /// The empty string attribute.
342 impl->emptyStringAttr = StringAttr::getEmptyStringAttrUnchecked(context: this);
343
344 // Register the affine storage objects with the uniquer.
345 impl->affineUniquer
346 .registerParametricStorageType<AffineBinaryOpExprStorage>();
347 impl->affineUniquer
348 .registerParametricStorageType<AffineConstantExprStorage>();
349 impl->affineUniquer.registerParametricStorageType<AffineDimExprStorage>();
350 impl->affineUniquer.registerParametricStorageType<AffineMapStorage>();
351 impl->affineUniquer.registerParametricStorageType<IntegerSetStorage>();
352}
353
354MLIRContext::~MLIRContext() = default;
355
356/// Copy the specified array of elements into memory managed by the provided
357/// bump pointer allocator. This assumes the elements are all PODs.
358template <typename T>
359static ArrayRef<T> copyArrayRefInto(llvm::BumpPtrAllocator &allocator,
360 ArrayRef<T> elements) {
361 auto result = allocator.Allocate<T>(elements.size());
362 llvm::uninitialized_copy(elements, result);
363 return ArrayRef<T>(result, elements.size());
364}
365
366//===----------------------------------------------------------------------===//
367// Action Handling
368//===----------------------------------------------------------------------===//
369
370void MLIRContext::registerActionHandler(HandlerTy handler) {
371 getImpl().actionHandler = std::move(handler);
372}
373
374/// Dispatch the provided action to the handler if any, or just execute it.
375void MLIRContext::executeActionInternal(function_ref<void()> actionFn,
376 const tracing::Action &action) {
377 assert(getImpl().actionHandler);
378 getImpl().actionHandler(actionFn, action);
379}
380
381bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; }
382
383//===----------------------------------------------------------------------===//
384// Diagnostic Handlers
385//===----------------------------------------------------------------------===//
386
387/// Returns the diagnostic engine for this context.
388DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
389
390//===----------------------------------------------------------------------===//
391// Dialect and Operation Registration
392//===----------------------------------------------------------------------===//
393
394void MLIRContext::appendDialectRegistry(const DialectRegistry &registry) {
395 if (registry.isSubsetOf(rhs: impl->dialectsRegistry))
396 return;
397
398 assert(impl->multiThreadedExecutionContext == 0 &&
399 "appending to the MLIRContext dialect registry while in a "
400 "multi-threaded execution context");
401 registry.appendTo(destination&: impl->dialectsRegistry);
402
403 // For the already loaded dialects, apply any possible extensions immediately.
404 registry.applyExtensions(ctx: this);
405}
406
407const DialectRegistry &MLIRContext::getDialectRegistry() {
408 return impl->dialectsRegistry;
409}
410
411/// Return information about all registered IR dialects.
412std::vector<Dialect *> MLIRContext::getLoadedDialects() {
413 std::vector<Dialect *> result;
414 result.reserve(n: impl->loadedDialects.size());
415 for (auto &dialect : impl->loadedDialects)
416 result.push_back(x: dialect.second.get());
417 llvm::array_pod_sort(Start: result.begin(), End: result.end(),
418 Compare: [](Dialect *const *lhs, Dialect *const *rhs) -> int {
419 return (*lhs)->getNamespace() < (*rhs)->getNamespace();
420 });
421 return result;
422}
423std::vector<StringRef> MLIRContext::getAvailableDialects() {
424 std::vector<StringRef> result;
425 for (auto dialect : impl->dialectsRegistry.getDialectNames())
426 result.push_back(x: dialect);
427 return result;
428}
429
430/// Get a registered IR dialect with the given namespace. If none is found,
431/// then return nullptr.
432Dialect *MLIRContext::getLoadedDialect(StringRef name) {
433 // Dialects are sorted by name, so we can use binary search for lookup.
434 auto it = impl->loadedDialects.find(Val: name);
435 return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
436}
437
438Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
439 Dialect *dialect = getLoadedDialect(name);
440 if (dialect)
441 return dialect;
442 DialectAllocatorFunctionRef allocator =
443 impl->dialectsRegistry.getDialectAllocator(name);
444 return allocator ? allocator(this) : nullptr;
445}
446
447/// Get a dialect for the provided namespace and TypeID: abort the program if a
448/// dialect exist for this namespace with different TypeID. Returns a pointer to
449/// the dialect owned by the context.
450Dialect *
451MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
452 function_ref<std::unique_ptr<Dialect>()> ctor) {
453 auto &impl = getImpl();
454 // Get the correct insertion position sorted by namespace.
455 auto dialectIt = impl.loadedDialects.try_emplace(Key: dialectNamespace, Args: nullptr);
456
457 if (dialectIt.second) {
458 LLVM_DEBUG(llvm::dbgs()
459 << "Load new dialect in Context " << dialectNamespace << "\n");
460#ifndef NDEBUG
461 if (impl.multiThreadedExecutionContext != 0)
462 llvm::report_fatal_error(
463 "Loading a dialect (" + dialectNamespace +
464 ") while in a multi-threaded execution context (maybe "
465 "the PassManager): this can indicate a "
466 "missing `dependentDialects` in a pass for example.");
467#endif // NDEBUG
468 // loadedDialects entry is initialized to nullptr, indicating that the
469 // dialect is currently being loaded. Re-lookup the address in
470 // loadedDialects because the table might have been rehashed by recursive
471 // dialect loading in ctor().
472 std::unique_ptr<Dialect> &dialectOwned =
473 impl.loadedDialects[dialectNamespace] = ctor();
474 Dialect *dialect = dialectOwned.get();
475 assert(dialect && "dialect ctor failed");
476
477 // Refresh all the identifiers dialect field, this catches cases where a
478 // dialect may be loaded after identifier prefixed with this dialect name
479 // were already created.
480 auto stringAttrsIt = impl.dialectReferencingStrAttrs.find(Val: dialectNamespace);
481 if (stringAttrsIt != impl.dialectReferencingStrAttrs.end()) {
482 for (StringAttrStorage *storage : stringAttrsIt->second)
483 storage->referencedDialect = dialect;
484 impl.dialectReferencingStrAttrs.erase(I: stringAttrsIt);
485 }
486
487 // Apply any extensions to this newly loaded dialect.
488 impl.dialectsRegistry.applyExtensions(dialect);
489 return dialect;
490 }
491
492#ifndef NDEBUG
493 if (dialectIt.first->second == nullptr)
494 llvm::report_fatal_error(
495 "Loading (and getting) a dialect (" + dialectNamespace +
496 ") while the same dialect is still loading: use loadDialect instead "
497 "of getOrLoadDialect.");
498#endif // NDEBUG
499
500 // Abort if dialect with namespace has already been registered.
501 std::unique_ptr<Dialect> &dialect = dialectIt.first->second;
502 if (dialect->getTypeID() != dialectID)
503 llvm::report_fatal_error(reason: "a dialect with namespace '" + dialectNamespace +
504 "' has already been registered");
505
506 return dialect.get();
507}
508
509bool MLIRContext::isDialectLoading(StringRef dialectNamespace) {
510 auto it = getImpl().loadedDialects.find(Val: dialectNamespace);
511 // nullptr indicates that the dialect is currently being loaded.
512 return it != getImpl().loadedDialects.end() && it->second == nullptr;
513}
514
515DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
516 StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) {
517 auto &impl = getImpl();
518 // Get the correct insertion position sorted by namespace.
519 auto dialectIt = impl.loadedDialects.find(Val: dialectNamespace);
520
521 if (dialectIt != impl.loadedDialects.end()) {
522 if (auto *dynDialect = dyn_cast<DynamicDialect>(Val: dialectIt->second.get()))
523 return dynDialect;
524 llvm::report_fatal_error(reason: "a dialect with namespace '" + dialectNamespace +
525 "' has already been registered");
526 }
527
528 LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
529 << dialectNamespace << "\n");
530#ifndef NDEBUG
531 if (impl.multiThreadedExecutionContext != 0)
532 llvm::report_fatal_error(
533 "Loading a dynamic dialect (" + dialectNamespace +
534 ") while in a multi-threaded execution context (maybe "
535 "the PassManager): this can indicate a "
536 "missing `dependentDialects` in a pass for example.");
537#endif
538
539 auto name = StringAttr::get(context: this, bytes: dialectNamespace);
540 auto *dialect = new DynamicDialect(name, this);
541 (void)getOrLoadDialect(dialectNamespace: name, dialectID: dialect->getTypeID(), ctor: [dialect, ctor]() {
542 ctor(dialect);
543 return std::unique_ptr<DynamicDialect>(dialect);
544 });
545 // This is the same result as `getOrLoadDialect` (if it didn't failed),
546 // since it has the same TypeID, and TypeIDs are unique.
547 return dialect;
548}
549
550void MLIRContext::loadAllAvailableDialects() {
551 for (StringRef name : getAvailableDialects())
552 getOrLoadDialect(name);
553}
554
555llvm::hash_code MLIRContext::getRegistryHash() {
556 llvm::hash_code hash(0);
557 // Factor in number of loaded dialects, attributes, operations, types.
558 hash = llvm::hash_combine(args: hash, args: impl->loadedDialects.size());
559 hash = llvm::hash_combine(args: hash, args: impl->registeredAttributes.size());
560 hash = llvm::hash_combine(args: hash, args: impl->registeredOperations.size());
561 hash = llvm::hash_combine(args: hash, args: impl->registeredTypes.size());
562 return hash;
563}
564
565bool MLIRContext::allowsUnregisteredDialects() {
566 return impl->allowUnregisteredDialects;
567}
568
569void MLIRContext::allowUnregisteredDialects(bool allowing) {
570 assert(impl->multiThreadedExecutionContext == 0 &&
571 "changing MLIRContext `allow-unregistered-dialects` configuration "
572 "while in a multi-threaded execution context");
573 impl->allowUnregisteredDialects = allowing;
574}
575
576/// Return true if multi-threading is enabled by the context.
577bool MLIRContext::isMultithreadingEnabled() {
578 return impl->threadingIsEnabled && llvm::llvm_is_multithreaded();
579}
580
581/// Set the flag specifying if multi-threading is disabled by the context.
582void MLIRContext::disableMultithreading(bool disable) {
583 // This API can be overridden by the global debugging flag
584 // --mlir-disable-threading
585 if (isThreadingGloballyDisabled())
586 return;
587 assert(impl->multiThreadedExecutionContext == 0 &&
588 "changing MLIRContext `disable-threading` configuration while "
589 "in a multi-threaded execution context");
590
591 impl->threadingIsEnabled = !disable;
592
593 // Update the threading mode for each of the uniquers.
594 impl->affineUniquer.disableMultithreading(disable);
595 impl->attributeUniquer.disableMultithreading(disable);
596 impl->typeUniquer.disableMultithreading(disable);
597
598 // Destroy thread pool (stop all threads) if it is no longer needed, or create
599 // a new one if multithreading was re-enabled.
600 if (disable) {
601 // If the thread pool is owned, explicitly set it to nullptr to avoid
602 // keeping a dangling pointer around. If the thread pool is externally
603 // owned, we don't do anything.
604 if (impl->ownedThreadPool) {
605 assert(impl->threadPool);
606 impl->threadPool = nullptr;
607 impl->ownedThreadPool.reset();
608 }
609 } else if (!impl->threadPool) {
610 // The thread pool isn't externally provided.
611 assert(!impl->ownedThreadPool);
612 impl->ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
613 impl->threadPool = impl->ownedThreadPool.get();
614 }
615}
616
617void MLIRContext::setThreadPool(llvm::ThreadPoolInterface &pool) {
618 assert(!isMultithreadingEnabled() &&
619 "expected multi-threading to be disabled when setting a ThreadPool");
620 impl->threadPool = &pool;
621 impl->ownedThreadPool.reset();
622 enableMultithreading();
623}
624
625unsigned MLIRContext::getNumThreads() {
626 if (isMultithreadingEnabled()) {
627 assert(impl->threadPool &&
628 "multi-threading is enabled but threadpool not set");
629 return impl->threadPool->getMaxConcurrency();
630 }
631 // No multithreading or active thread pool. Return 1 thread.
632 return 1;
633}
634
635llvm::ThreadPoolInterface &MLIRContext::getThreadPool() {
636 assert(isMultithreadingEnabled() &&
637 "expected multi-threading to be enabled within the context");
638 assert(impl->threadPool &&
639 "multi-threading is enabled but threadpool not set");
640 return *impl->threadPool;
641}
642
643void MLIRContext::enterMultiThreadedExecution() {
644#ifndef NDEBUG
645 ++impl->multiThreadedExecutionContext;
646#endif
647}
648void MLIRContext::exitMultiThreadedExecution() {
649#ifndef NDEBUG
650 --impl->multiThreadedExecutionContext;
651#endif
652}
653
654/// Return true if we should attach the operation to diagnostics emitted via
655/// Operation::emit.
656bool MLIRContext::shouldPrintOpOnDiagnostic() {
657 return impl->printOpOnDiagnostic;
658}
659
660/// Set the flag specifying if we should attach the operation to diagnostics
661/// emitted via Operation::emit.
662void MLIRContext::printOpOnDiagnostic(bool enable) {
663 assert(impl->multiThreadedExecutionContext == 0 &&
664 "changing MLIRContext `print-op-on-diagnostic` configuration while in "
665 "a multi-threaded execution context");
666 impl->printOpOnDiagnostic = enable;
667}
668
669/// Return true if we should attach the current stacktrace to diagnostics when
670/// emitted.
671bool MLIRContext::shouldPrintStackTraceOnDiagnostic() {
672 return impl->printStackTraceOnDiagnostic;
673}
674
675/// Set the flag specifying if we should attach the current stacktrace when
676/// emitting diagnostics.
677void MLIRContext::printStackTraceOnDiagnostic(bool enable) {
678 assert(impl->multiThreadedExecutionContext == 0 &&
679 "changing MLIRContext `print-stacktrace-on-diagnostic` configuration "
680 "while in a multi-threaded execution context");
681 impl->printStackTraceOnDiagnostic = enable;
682}
683
684/// Return information about all registered operations.
685ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
686 return impl->sortedRegisteredOperations;
687}
688
689/// Return information for registered operations by dialect.
690ArrayRef<RegisteredOperationName>
691MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
692 auto lowerBound = llvm::lower_bound(
693 Range&: impl->sortedRegisteredOperations, Value&: dialectName, C: [](auto &lhs, auto &rhs) {
694 return lhs.getDialect().getNamespace().compare(rhs);
695 });
696
697 if (lowerBound == impl->sortedRegisteredOperations.end() ||
698 lowerBound->getDialect().getNamespace() != dialectName)
699 return ArrayRef<RegisteredOperationName>();
700
701 auto upperBound =
702 std::upper_bound(first: lowerBound, last: impl->sortedRegisteredOperations.end(),
703 val: dialectName, comp: [](auto &lhs, auto &rhs) {
704 return lhs.compare(rhs.getDialect().getNamespace());
705 });
706
707 size_t count = std::distance(first: lowerBound, last: upperBound);
708 return ArrayRef(&*lowerBound, count);
709}
710
711bool MLIRContext::isOperationRegistered(StringRef name) {
712 return RegisteredOperationName::lookup(name, ctx: this).has_value();
713}
714
715void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
716 auto &impl = context->getImpl();
717 assert(impl.multiThreadedExecutionContext == 0 &&
718 "Registering a new type kind while in a multi-threaded execution "
719 "context");
720 auto *newInfo =
721 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
722 AbstractType(std::move(typeInfo));
723 if (!impl.registeredTypes.insert(KV: {typeID, newInfo}).second)
724 llvm::report_fatal_error(reason: "Dialect Type already registered.");
725 if (!impl.nameToType.insert(KV: {newInfo->getName(), newInfo}).second)
726 llvm::report_fatal_error(reason: "Dialect Type with name " + newInfo->getName() +
727 " is already registered.");
728}
729
730void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
731 auto &impl = context->getImpl();
732 assert(impl.multiThreadedExecutionContext == 0 &&
733 "Registering a new attribute kind while in a multi-threaded execution "
734 "context");
735 auto *newInfo =
736 new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
737 AbstractAttribute(std::move(attrInfo));
738 if (!impl.registeredAttributes.insert(KV: {typeID, newInfo}).second)
739 llvm::report_fatal_error(reason: "Dialect Attribute already registered.");
740 if (!impl.nameToAttribute.insert(KV: {newInfo->getName(), newInfo}).second)
741 llvm::report_fatal_error(reason: "Dialect Attribute with name " +
742 newInfo->getName() + " is already registered.");
743}
744
745//===----------------------------------------------------------------------===//
746// AbstractAttribute
747//===----------------------------------------------------------------------===//
748
749/// Get the dialect that registered the attribute with the provided typeid.
750const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
751 MLIRContext *context) {
752 const AbstractAttribute *abstract = lookupMutable(typeID, context);
753 if (!abstract)
754 llvm::report_fatal_error(reason: "Trying to create an Attribute that was not "
755 "registered in this MLIRContext.");
756 return *abstract;
757}
758
759AbstractAttribute *AbstractAttribute::lookupMutable(TypeID typeID,
760 MLIRContext *context) {
761 auto &impl = context->getImpl();
762 return impl.registeredAttributes.lookup(Val: typeID);
763}
764
765std::optional<std::reference_wrapper<const AbstractAttribute>>
766AbstractAttribute::lookup(StringRef name, MLIRContext *context) {
767 MLIRContextImpl &impl = context->getImpl();
768 const AbstractAttribute *type = impl.nameToAttribute.lookup(Val: name);
769
770 if (!type)
771 return std::nullopt;
772 return {*type};
773}
774
775//===----------------------------------------------------------------------===//
776// OperationName
777//===----------------------------------------------------------------------===//
778
779OperationName::Impl::Impl(StringRef name, Dialect *dialect, TypeID typeID,
780 detail::InterfaceMap interfaceMap)
781 : Impl(StringAttr::get(context: dialect->getContext(), bytes: name), dialect, typeID,
782 std::move(interfaceMap)) {}
783
784OperationName::OperationName(StringRef name, MLIRContext *context) {
785 MLIRContextImpl &ctxImpl = context->getImpl();
786
787 // Check for an existing name in read-only mode.
788 bool isMultithreadingEnabled = context->isMultithreadingEnabled();
789 if (isMultithreadingEnabled) {
790 // Check the registered info map first. In the overwhelmingly common case,
791 // the entry will be in here and it also removes the need to acquire any
792 // locks.
793 auto registeredIt = ctxImpl.registeredOperationsByName.find(Key: name);
794 if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
795 impl = registeredIt->second.impl;
796 return;
797 }
798
799 llvm::sys::SmartScopedReader<true> contextLock(ctxImpl.operationInfoMutex);
800 auto it = ctxImpl.operations.find(Key: name);
801 if (it != ctxImpl.operations.end()) {
802 impl = it->second.get();
803 return;
804 }
805 }
806
807 // Acquire a writer-lock so that we can safely create the new instance.
808 ScopedWriterLock lock(ctxImpl.operationInfoMutex, isMultithreadingEnabled);
809
810 auto it = ctxImpl.operations.try_emplace(Key: name);
811 if (it.second) {
812 auto nameAttr = StringAttr::get(context, bytes: name);
813 it.first->second = std::make_unique<UnregisteredOpModel>(
814 args&: nameAttr, args: nameAttr.getReferencedDialect(), args: TypeID::get<void>(),
815 args: detail::InterfaceMap());
816 }
817 impl = it.first->second.get();
818}
819
820StringRef OperationName::getDialectNamespace() const {
821 if (Dialect *dialect = getDialect())
822 return dialect->getNamespace();
823 return getStringRef().split(Separator: '.').first;
824}
825
826LogicalResult
827OperationName::UnregisteredOpModel::foldHook(Operation *, ArrayRef<Attribute>,
828 SmallVectorImpl<OpFoldResult> &) {
829 return failure();
830}
831void OperationName::UnregisteredOpModel::getCanonicalizationPatterns(
832 RewritePatternSet &, MLIRContext *) {}
833bool OperationName::UnregisteredOpModel::hasTrait(TypeID) { return false; }
834
835OperationName::ParseAssemblyFn
836OperationName::UnregisteredOpModel::getParseAssemblyFn() {
837 llvm::report_fatal_error(reason: "getParseAssemblyFn hook called on unregistered op");
838}
839void OperationName::UnregisteredOpModel::populateDefaultAttrs(
840 const OperationName &, NamedAttrList &) {}
841void OperationName::UnregisteredOpModel::printAssembly(
842 Operation *op, OpAsmPrinter &p, StringRef defaultDialect) {
843 p.printGenericOp(op);
844}
845LogicalResult
846OperationName::UnregisteredOpModel::verifyInvariants(Operation *) {
847 return success();
848}
849LogicalResult
850OperationName::UnregisteredOpModel::verifyRegionInvariants(Operation *) {
851 return success();
852}
853
854std::optional<Attribute>
855OperationName::UnregisteredOpModel::getInherentAttr(Operation *op,
856 StringRef name) {
857 auto dict = dyn_cast_or_null<DictionaryAttr>(Val: getPropertiesAsAttr(op));
858 if (!dict)
859 return std::nullopt;
860 if (Attribute attr = dict.get(name))
861 return attr;
862 return std::nullopt;
863}
864void OperationName::UnregisteredOpModel::setInherentAttr(Operation *op,
865 StringAttr name,
866 Attribute value) {
867 auto dict = dyn_cast_or_null<DictionaryAttr>(Val: getPropertiesAsAttr(op));
868 assert(dict);
869 NamedAttrList attrs(dict);
870 attrs.set(name, value);
871 *op->getPropertiesStorage().as<Attribute *>() =
872 attrs.getDictionary(context: op->getContext());
873}
874void OperationName::UnregisteredOpModel::populateInherentAttrs(
875 Operation *op, NamedAttrList &attrs) {}
876LogicalResult OperationName::UnregisteredOpModel::verifyInherentAttrs(
877 OperationName opName, NamedAttrList &attributes,
878 function_ref<InFlightDiagnostic()> emitError) {
879 return success();
880}
881int OperationName::UnregisteredOpModel::getOpPropertyByteSize() {
882 return sizeof(Attribute);
883}
884void OperationName::UnregisteredOpModel::initProperties(
885 OperationName opName, OpaqueProperties storage, OpaqueProperties init) {
886 new (storage.as<Attribute *>()) Attribute();
887 if (init)
888 *storage.as<Attribute *>() = *init.as<Attribute *>();
889}
890void OperationName::UnregisteredOpModel::deleteProperties(
891 OpaqueProperties prop) {
892 prop.as<Attribute *>()->~Attribute();
893}
894void OperationName::UnregisteredOpModel::populateDefaultProperties(
895 OperationName opName, OpaqueProperties properties) {}
896LogicalResult OperationName::UnregisteredOpModel::setPropertiesFromAttr(
897 OperationName opName, OpaqueProperties properties, Attribute attr,
898 function_ref<InFlightDiagnostic()> emitError) {
899 *properties.as<Attribute *>() = attr;
900 return success();
901}
902Attribute
903OperationName::UnregisteredOpModel::getPropertiesAsAttr(Operation *op) {
904 return *op->getPropertiesStorage().as<Attribute *>();
905}
906void OperationName::UnregisteredOpModel::copyProperties(OpaqueProperties lhs,
907 OpaqueProperties rhs) {
908 *lhs.as<Attribute *>() = *rhs.as<Attribute *>();
909}
910bool OperationName::UnregisteredOpModel::compareProperties(
911 OpaqueProperties lhs, OpaqueProperties rhs) {
912 return *lhs.as<Attribute *>() == *rhs.as<Attribute *>();
913}
914llvm::hash_code
915OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
916 return llvm::hash_combine(args: *prop.as<Attribute *>());
917}
918
919//===----------------------------------------------------------------------===//
920// RegisteredOperationName
921//===----------------------------------------------------------------------===//
922
923std::optional<RegisteredOperationName>
924RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
925 auto &impl = ctx->getImpl();
926 auto it = impl.registeredOperations.find(Val: typeID);
927 if (it != impl.registeredOperations.end())
928 return it->second;
929 return std::nullopt;
930}
931
932std::optional<RegisteredOperationName>
933RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
934 auto &impl = ctx->getImpl();
935 auto it = impl.registeredOperationsByName.find(Key: name);
936 if (it != impl.registeredOperationsByName.end())
937 return it->getValue();
938 return std::nullopt;
939}
940
941void RegisteredOperationName::insert(
942 std::unique_ptr<RegisteredOperationName::Impl> ownedImpl,
943 ArrayRef<StringRef> attrNames) {
944 RegisteredOperationName::Impl *impl = ownedImpl.get();
945 MLIRContext *ctx = impl->getDialect()->getContext();
946 auto &ctxImpl = ctx->getImpl();
947 assert(ctxImpl.multiThreadedExecutionContext == 0 &&
948 "registering a new operation kind while in a multi-threaded execution "
949 "context");
950
951 // Register the attribute names of this operation.
952 MutableArrayRef<StringAttr> cachedAttrNames;
953 if (!attrNames.empty()) {
954 cachedAttrNames = MutableArrayRef<StringAttr>(
955 ctxImpl.abstractDialectSymbolAllocator.Allocate<StringAttr>(
956 Num: attrNames.size()),
957 attrNames.size());
958 for (unsigned i : llvm::seq<unsigned>(Begin: 0, End: attrNames.size()))
959 new (&cachedAttrNames[i]) StringAttr(StringAttr::get(context: ctx, bytes: attrNames[i]));
960 impl->attributeNames = cachedAttrNames;
961 }
962 StringRef name = impl->getName().strref();
963 // Insert the operation info if it doesn't exist yet.
964 ctxImpl.operations[name] = std::move(ownedImpl);
965
966 // Update the registered info for this operation.
967 auto emplaced = ctxImpl.registeredOperations.try_emplace(
968 Key: impl->getTypeID(), Args: RegisteredOperationName(impl));
969 assert(emplaced.second && "operation name registration must be successful");
970 auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
971 Key: name, Args: RegisteredOperationName(impl));
972 (void)emplacedByName;
973 assert(emplacedByName.second &&
974 "operation name registration must be successful");
975
976 // Add emplaced operation name to the sorted operations container.
977 RegisteredOperationName &value = emplaced.first->second;
978 ctxImpl.sortedRegisteredOperations.insert(
979 I: llvm::upper_bound(Range&: ctxImpl.sortedRegisteredOperations, Value&: value,
980 C: [](auto &lhs, auto &rhs) {
981 return lhs.getIdentifier().compare(
982 rhs.getIdentifier());
983 }),
984 Elt: value);
985}
986
987//===----------------------------------------------------------------------===//
988// AbstractType
989//===----------------------------------------------------------------------===//
990
991const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
992 const AbstractType *type = lookupMutable(typeID, context);
993 if (!type)
994 llvm::report_fatal_error(
995 reason: "Trying to create a Type that was not registered in this MLIRContext.");
996 return *type;
997}
998
999AbstractType *AbstractType::lookupMutable(TypeID typeID, MLIRContext *context) {
1000 auto &impl = context->getImpl();
1001 return impl.registeredTypes.lookup(Val: typeID);
1002}
1003
1004std::optional<std::reference_wrapper<const AbstractType>>
1005AbstractType::lookup(StringRef name, MLIRContext *context) {
1006 MLIRContextImpl &impl = context->getImpl();
1007 const AbstractType *type = impl.nameToType.lookup(Val: name);
1008
1009 if (!type)
1010 return std::nullopt;
1011 return {*type};
1012}
1013
1014//===----------------------------------------------------------------------===//
1015// Type uniquing
1016//===----------------------------------------------------------------------===//
1017
1018/// Returns the storage uniquer used for constructing type storage instances.
1019/// This should not be used directly.
1020StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
1021
1022BFloat16Type BFloat16Type::get(MLIRContext *context) {
1023 return context->getImpl().bf16Ty;
1024}
1025Float16Type Float16Type::get(MLIRContext *context) {
1026 return context->getImpl().f16Ty;
1027}
1028FloatTF32Type FloatTF32Type::get(MLIRContext *context) {
1029 return context->getImpl().tf32Ty;
1030}
1031Float32Type Float32Type::get(MLIRContext *context) {
1032 return context->getImpl().f32Ty;
1033}
1034Float64Type Float64Type::get(MLIRContext *context) {
1035 return context->getImpl().f64Ty;
1036}
1037Float80Type Float80Type::get(MLIRContext *context) {
1038 return context->getImpl().f80Ty;
1039}
1040Float128Type Float128Type::get(MLIRContext *context) {
1041 return context->getImpl().f128Ty;
1042}
1043
1044/// Get an instance of the IndexType.
1045IndexType IndexType::get(MLIRContext *context) {
1046 return context->getImpl().indexTy;
1047}
1048
1049/// Return an existing integer type instance if one is cached within the
1050/// context.
1051static IntegerType
1052getCachedIntegerType(unsigned width,
1053 IntegerType::SignednessSemantics signedness,
1054 MLIRContext *context) {
1055 if (signedness != IntegerType::Signless)
1056 return IntegerType();
1057
1058 switch (width) {
1059 case 1:
1060 return context->getImpl().int1Ty;
1061 case 8:
1062 return context->getImpl().int8Ty;
1063 case 16:
1064 return context->getImpl().int16Ty;
1065 case 32:
1066 return context->getImpl().int32Ty;
1067 case 64:
1068 return context->getImpl().int64Ty;
1069 case 128:
1070 return context->getImpl().int128Ty;
1071 default:
1072 return IntegerType();
1073 }
1074}
1075
1076IntegerType IntegerType::get(MLIRContext *context, unsigned width,
1077 IntegerType::SignednessSemantics signedness) {
1078 if (auto cached = getCachedIntegerType(width, signedness, context))
1079 return cached;
1080 return Base::get(ctx: context, args&: width, args&: signedness);
1081}
1082
1083IntegerType
1084IntegerType::getChecked(function_ref<InFlightDiagnostic()> emitError,
1085 MLIRContext *context, unsigned width,
1086 SignednessSemantics signedness) {
1087 if (auto cached = getCachedIntegerType(width, signedness, context))
1088 return cached;
1089 return Base::getChecked(emitErrorFn: emitError, ctx: context, args: width, args: signedness);
1090}
1091
1092/// Get an instance of the NoneType.
1093NoneType NoneType::get(MLIRContext *context) {
1094 if (NoneType cachedInst = context->getImpl().noneType)
1095 return cachedInst;
1096 // Note: May happen when initializing the singleton attributes of the builtin
1097 // dialect.
1098 return Base::get(ctx: context);
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// Attribute uniquing
1103//===----------------------------------------------------------------------===//
1104
1105/// Returns the storage uniquer used for constructing attribute storage
1106/// instances. This should not be used directly.
1107StorageUniquer &MLIRContext::getAttributeUniquer() {
1108 return getImpl().attributeUniquer;
1109}
1110
1111/// Initialize the given attribute storage instance.
1112void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
1113 MLIRContext *ctx,
1114 TypeID attrID) {
1115 storage->initializeAbstractAttribute(abstractAttr: AbstractAttribute::lookup(typeID: attrID, context: ctx));
1116}
1117
1118BoolAttr BoolAttr::get(MLIRContext *context, bool value) {
1119 return value ? context->getImpl().trueAttr : context->getImpl().falseAttr;
1120}
1121
1122UnitAttr UnitAttr::get(MLIRContext *context) {
1123 return context->getImpl().unitAttr;
1124}
1125
1126UnknownLoc UnknownLoc::get(MLIRContext *context) {
1127 return context->getImpl().unknownLocAttr;
1128}
1129
1130DistinctAttrStorage *
1131detail::DistinctAttributeUniquer::allocateStorage(MLIRContext *context,
1132 Attribute referencedAttr) {
1133 return context->getImpl().distinctAttributeAllocator.allocate(referencedAttr);
1134}
1135
1136/// Return empty dictionary.
1137DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) {
1138 return context->getImpl().emptyDictionaryAttr;
1139}
1140
1141void StringAttrStorage::initialize(MLIRContext *context) {
1142 // Check for a dialect namespace prefix, if there isn't one we don't need to
1143 // do any additional initialization.
1144 auto dialectNamePair = value.split(Separator: '.');
1145 if (dialectNamePair.first.empty() || dialectNamePair.second.empty())
1146 return;
1147
1148 // If one exists, we check to see if this dialect is loaded. If it is, we set
1149 // the dialect now, if it isn't we record this storage for initialization
1150 // later if the dialect ever gets loaded.
1151 if ((referencedDialect = context->getLoadedDialect(name: dialectNamePair.first)))
1152 return;
1153
1154 MLIRContextImpl &impl = context->getImpl();
1155 llvm::sys::SmartScopedLock<true> lock(impl.dialectRefStrAttrMutex);
1156 impl.dialectReferencingStrAttrs[dialectNamePair.first].push_back(Elt: this);
1157}
1158
1159/// Return an empty string.
1160StringAttr StringAttr::get(MLIRContext *context) {
1161 return context->getImpl().emptyStringAttr;
1162}
1163
1164//===----------------------------------------------------------------------===//
1165// AffineMap uniquing
1166//===----------------------------------------------------------------------===//
1167
1168StorageUniquer &MLIRContext::getAffineUniquer() {
1169 return getImpl().affineUniquer;
1170}
1171
1172AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount,
1173 ArrayRef<AffineExpr> results,
1174 MLIRContext *context) {
1175 auto &impl = context->getImpl();
1176 auto *storage = impl.affineUniquer.get<AffineMapStorage>(
1177 initFn: [&](AffineMapStorage *storage) { storage->context = context; }, args&: dimCount,
1178 args&: symbolCount, args&: results);
1179 return AffineMap(storage);
1180}
1181
1182/// Check whether the arguments passed to the AffineMap::get() are consistent.
1183/// This method checks whether the highest index of dimensional identifier
1184/// present in result expressions is less than `dimCount` and the highest index
1185/// of symbolic identifier present in result expressions is less than
1186/// `symbolCount`.
1187LLVM_ATTRIBUTE_UNUSED static bool
1188willBeValidAffineMap(unsigned dimCount, unsigned symbolCount,
1189 ArrayRef<AffineExpr> results) {
1190 int64_t maxDimPosition = -1;
1191 int64_t maxSymbolPosition = -1;
1192 getMaxDimAndSymbol(exprsList: ArrayRef<ArrayRef<AffineExpr>>(results), maxDim&: maxDimPosition,
1193 maxSym&: maxSymbolPosition);
1194 if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) {
1195 LLVM_DEBUG(
1196 llvm::dbgs()
1197 << "maximum dimensional identifier position in result expression must "
1198 "be less than `dimCount` and maximum symbolic identifier position "
1199 "in result expression must be less than `symbolCount`\n");
1200 return false;
1201 }
1202 return true;
1203}
1204
1205AffineMap AffineMap::get(MLIRContext *context) {
1206 return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
1207}
1208
1209AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1210 MLIRContext *context) {
1211 return getImpl(dimCount, symbolCount, /*results=*/{}, context);
1212}
1213
1214AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1215 AffineExpr result) {
1216 assert(willBeValidAffineMap(dimCount, symbolCount, {result}));
1217 return getImpl(dimCount, symbolCount, results: {result}, context: result.getContext());
1218}
1219
1220AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
1221 ArrayRef<AffineExpr> results, MLIRContext *context) {
1222 assert(willBeValidAffineMap(dimCount, symbolCount, results));
1223 return getImpl(dimCount, symbolCount, results, context);
1224}
1225
1226//===----------------------------------------------------------------------===//
1227// Integer Sets: these are allocated into the bump pointer, and are immutable.
1228// Unlike AffineMap's, these are uniqued only if they are small.
1229//===----------------------------------------------------------------------===//
1230
1231IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
1232 ArrayRef<AffineExpr> constraints,
1233 ArrayRef<bool> eqFlags) {
1234 // The number of constraints can't be zero.
1235 assert(!constraints.empty());
1236 assert(constraints.size() == eqFlags.size());
1237
1238 auto &impl = constraints[0].getContext()->getImpl();
1239 auto *storage = impl.affineUniquer.get<IntegerSetStorage>(
1240 initFn: [](IntegerSetStorage *) {}, args&: dimCount, args&: symbolCount, args&: constraints, args&: eqFlags);
1241 return IntegerSet(storage);
1242}
1243
1244//===----------------------------------------------------------------------===//
1245// StorageUniquerSupport
1246//===----------------------------------------------------------------------===//
1247
1248/// Utility method to generate a callback that can be used to generate a
1249/// diagnostic when checking the construction invariants of a storage object.
1250/// This is defined out-of-line to avoid the need to include Location.h.
1251llvm::unique_function<InFlightDiagnostic()>
1252mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) {
1253 return [ctx] { return emitError(loc: UnknownLoc::get(context: ctx)); };
1254}
1255llvm::unique_function<InFlightDiagnostic()>
1256mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) {
1257 return [=] { return emitError(loc); };
1258}
1259

source code of mlir/lib/IR/MLIRContext.cpp