1//===- MLIRContext.h - MLIR Global Context Class ----------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_MLIRCONTEXT_H
10#define MLIR_IR_MLIRCONTEXT_H
11
12#include "mlir/Support/LLVM.h"
13#include "mlir/Support/TypeID.h"
14#include "llvm/ADT/ArrayRef.h"
15#include <functional>
16#include <memory>
17#include <vector>
18
19namespace llvm {
20class ThreadPoolInterface;
21} // namespace llvm
22
23namespace mlir {
24namespace tracing {
25class Action;
26}
27class DiagnosticEngine;
28class Dialect;
29class DialectRegistry;
30class DynamicDialect;
31class InFlightDiagnostic;
32class Location;
33class MLIRContextImpl;
34class RegisteredOperationName;
35class StorageUniquer;
36class IRUnit;
37
38/// MLIRContext is the top-level object for a collection of MLIR operations. It
39/// holds immortal uniqued objects like types, and the tables used to unique
40/// them.
41///
42/// MLIRContext gets a redundant "MLIR" prefix because otherwise it ends up with
43/// a very generic name ("Context") and because it is uncommon for clients to
44/// interact with it.
45///
46/// The context wrap some multi-threading facilities, and in particular by
47/// default it will implicitly create a thread pool.
48/// This can be undesirable if multiple context exists at the same time or if a
49/// process will be long-lived and create and destroy contexts.
50/// To control better thread spawning, an externally owned ThreadPool can be
51/// injected in the context. For example:
52///
53/// llvm::DefaultThreadPool myThreadPool;
54/// while (auto *request = nextCompilationRequests()) {
55/// MLIRContext ctx(registry, MLIRContext::Threading::DISABLED);
56/// ctx.setThreadPool(myThreadPool);
57/// processRequest(request, cxt);
58/// }
59///
60class MLIRContext {
61public:
62 enum class Threading { DISABLED, ENABLED };
63 /// Create a new Context.
64 explicit MLIRContext(Threading multithreading = Threading::ENABLED);
65 explicit MLIRContext(const DialectRegistry &registry,
66 Threading multithreading = Threading::ENABLED);
67 ~MLIRContext();
68
69 /// Return information about all IR dialects loaded in the context.
70 std::vector<Dialect *> getLoadedDialects();
71
72 /// Return the dialect registry associated with this context.
73 const DialectRegistry &getDialectRegistry();
74
75 /// Append the contents of the given dialect registry to the registry
76 /// associated with this context.
77 void appendDialectRegistry(const DialectRegistry &registry);
78
79 /// Return information about all available dialects in the registry in this
80 /// context.
81 std::vector<StringRef> getAvailableDialects();
82
83 /// Get a registered IR dialect with the given namespace. If an exact match is
84 /// not found, then return nullptr.
85 Dialect *getLoadedDialect(StringRef name);
86
87 /// Get a registered IR dialect for the given derived dialect type. The
88 /// derived type must provide a static 'getDialectNamespace' method.
89 template <typename T>
90 T *getLoadedDialect() {
91 return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
92 }
93
94 /// Get (or create) a dialect for the given derived dialect type. The derived
95 /// type must provide a static 'getDialectNamespace' method.
96 template <typename T>
97 T *getOrLoadDialect() {
98 return static_cast<T *>(
99 getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
100 std::unique_ptr<T> dialect(new T(this));
101 return dialect;
102 }));
103 }
104
105 /// Load a dialect in the context.
106 template <typename Dialect>
107 void loadDialect() {
108 // Do not load the dialect if it is currently loading. This can happen if a
109 // dialect initializer triggers loading the same dialect recursively.
110 if (!isDialectLoading(dialectNamespace: Dialect::getDialectNamespace()))
111 getOrLoadDialect<Dialect>();
112 }
113
114 /// Load a list dialects in the context.
115 template <typename Dialect, typename OtherDialect, typename... MoreDialects>
116 void loadDialect() {
117 loadDialect<Dialect>();
118 loadDialect<OtherDialect, MoreDialects...>();
119 }
120
121 /// Get (or create) a dynamic dialect for the given name.
122 DynamicDialect *
123 getOrLoadDynamicDialect(StringRef dialectNamespace,
124 function_ref<void(DynamicDialect *)> ctor);
125
126 /// Load all dialects available in the registry in this context.
127 void loadAllAvailableDialects();
128
129 /// Get (or create) a dialect for the given derived dialect name.
130 /// The dialect will be loaded from the registry if no dialect is found.
131 /// If no dialect is loaded for this name and none is available in the
132 /// registry, returns nullptr.
133 Dialect *getOrLoadDialect(StringRef name);
134
135 /// Return true if we allow to create operation for unregistered dialects.
136 bool allowsUnregisteredDialects();
137
138 /// Enables creating operations in unregistered dialects.
139 /// This option is **heavily discouraged**: it is convenient during testing
140 /// but it is not a good practice to use it in production code. Some system
141 /// invariants can be broken (like loading a dialect after creating
142 /// operations) without being caught by assertions or other means.
143 void allowUnregisteredDialects(bool allow = true);
144
145 /// Return true if multi-threading is enabled by the context.
146 bool isMultithreadingEnabled();
147
148 /// Set the flag specifying if multi-threading is disabled by the context.
149 /// The command line debugging flag `--mlir-disable-threading` is overriding
150 /// this call and making it a no-op!
151 void disableMultithreading(bool disable = true);
152 void enableMultithreading(bool enable = true) {
153 disableMultithreading(disable: !enable);
154 }
155
156 /// Set a new thread pool to be used in this context. This method requires
157 /// that multithreading is disabled for this context prior to the call. This
158 /// allows to share a thread pool across multiple contexts, as well as
159 /// decoupling the lifetime of the threads from the contexts. The thread pool
160 /// must outlive the context. Multi-threading will be enabled as part of this
161 /// method.
162 /// The command line debugging flag `--mlir-disable-threading` will still
163 /// prevent threading from being enabled and threading won't be enabled after
164 /// this call in this case.
165 void setThreadPool(llvm::ThreadPoolInterface &pool);
166
167 /// Return the number of threads used by the thread pool in this context. The
168 /// number of computed hardware threads can change over the lifetime of a
169 /// process based on affinity changes, so users should use the number of
170 /// threads actually in the thread pool for dispatching work. Returns 1 if
171 /// multithreading is disabled.
172 unsigned getNumThreads();
173
174 /// Return the thread pool used by this context. This method requires that
175 /// multithreading be enabled within the context, and should generally not be
176 /// used directly. Users should instead prefer the threading utilities within
177 /// Threading.h.
178 llvm::ThreadPoolInterface &getThreadPool();
179
180 /// Return true if we should attach the operation to diagnostics emitted via
181 /// Operation::emit.
182 bool shouldPrintOpOnDiagnostic();
183
184 /// Set the flag specifying if we should attach the operation to diagnostics
185 /// emitted via Operation::emit.
186 void printOpOnDiagnostic(bool enable);
187
188 /// Return true if we should attach the current stacktrace to diagnostics when
189 /// emitted.
190 bool shouldPrintStackTraceOnDiagnostic();
191
192 /// Set the flag specifying if we should attach the current stacktrace when
193 /// emitting diagnostics.
194 void printStackTraceOnDiagnostic(bool enable);
195
196 /// Return a sorted array containing the information about all registered
197 /// operations.
198 ArrayRef<RegisteredOperationName> getRegisteredOperations();
199
200 /// Return true if this operation name is registered in this context.
201 bool isOperationRegistered(StringRef name);
202
203 // This is effectively private given that only MLIRContext.cpp can see the
204 // MLIRContextImpl type.
205 MLIRContextImpl &getImpl() { return *impl; }
206
207 /// Returns the diagnostic engine for this context.
208 DiagnosticEngine &getDiagEngine();
209
210 /// Returns the storage uniquer used for creating affine constructs.
211 StorageUniquer &getAffineUniquer();
212
213 /// Returns the storage uniquer used for constructing type storage instances.
214 /// This should not be used directly.
215 StorageUniquer &getTypeUniquer();
216
217 /// Returns the storage uniquer used for constructing attribute storage
218 /// instances. This should not be used directly.
219 StorageUniquer &getAttributeUniquer();
220
221 /// These APIs are tracking whether the context will be used in a
222 /// multithreading environment: this has no effect other than enabling
223 /// assertions on misuses of some APIs.
224 void enterMultiThreadedExecution();
225 void exitMultiThreadedExecution();
226
227 /// Get a dialect for the provided namespace and TypeID: abort the program if
228 /// a dialect exist for this namespace with different TypeID. If a dialect has
229 /// not been loaded for this namespace/TypeID yet, use the provided ctor to
230 /// create one on the fly and load it. Returns a pointer to the dialect owned
231 /// by the context.
232 /// The use of this method is in general discouraged in favor of
233 /// 'getOrLoadDialect<DialectClass>()'.
234 Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
235 function_ref<std::unique_ptr<Dialect>()> ctor);
236
237 /// Returns a hash of the registry of the context that may be used to give
238 /// a rough indicator of if the state of the context registry has changed. The
239 /// context registry correlates to loaded dialects and their entities
240 /// (attributes, operations, types, etc.).
241 llvm::hash_code getRegistryHash();
242
243 //===--------------------------------------------------------------------===//
244 // Action API
245 //===--------------------------------------------------------------------===//
246
247 /// Signatures for the action handler that can be registered with the context.
248 using HandlerTy =
249 std::function<void(function_ref<void()>, const tracing::Action &)>;
250
251 /// Register a handler for handling actions that are dispatched through this
252 /// context. A nullptr handler can be set to disable a previously set handler.
253 void registerActionHandler(HandlerTy handler);
254
255 /// Return true if a valid ActionHandler is set.
256 bool hasActionHandler();
257
258 /// Dispatch the provided action to the handler if any, or just execute it.
259 void executeAction(function_ref<void()> actionFn,
260 const tracing::Action &action) {
261 if (LLVM_UNLIKELY(hasActionHandler()))
262 executeActionInternal(actionFn, action);
263 else
264 actionFn();
265 }
266
267 /// Dispatch the provided action to the handler if any, or just execute it.
268 template <typename ActionTy, typename... Args>
269 void executeAction(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
270 Args &&...args) {
271 if (LLVM_UNLIKELY(hasActionHandler()))
272 executeActionInternal<ActionTy, Args...>(actionFn, irUnits,
273 std::forward<Args>(args)...);
274 else
275 actionFn();
276 }
277
278private:
279 /// Return true if the given dialect is currently loading.
280 bool isDialectLoading(StringRef dialectNamespace);
281
282 /// Internal helper for the dispatch method.
283 void executeActionInternal(function_ref<void()> actionFn,
284 const tracing::Action &action);
285
286 /// Internal helper for the dispatch method. We get here after checking that
287 /// there is a handler, for the purpose of keeping this code out-of-line. and
288 /// avoid calling the ctor for the Action unnecessarily.
289 template <typename ActionTy, typename... Args>
290 LLVM_ATTRIBUTE_NOINLINE void
291 executeActionInternal(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
292 Args &&...args) {
293 executeActionInternal(actionFn,
294 ActionTy(irUnits, std::forward<Args>(args)...));
295 }
296
297 const std::unique_ptr<MLIRContextImpl> impl;
298
299 MLIRContext(const MLIRContext &) = delete;
300 void operator=(const MLIRContext &) = delete;
301};
302
303//===----------------------------------------------------------------------===//
304// MLIRContext CommandLine Options
305//===----------------------------------------------------------------------===//
306
307/// Register a set of useful command-line options that can be used to configure
308/// various flags within the MLIRContext. These flags are used when constructing
309/// an MLIR context for initialization.
310void registerMLIRContextCLOptions();
311
312} // namespace mlir
313
314#endif // MLIR_IR_MLIRCONTEXT_H
315

source code of mlir/include/mlir/IR/MLIRContext.h