1 | //===- Globals.h - MLIR Python extension globals --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H |
10 | #define MLIR_BINDINGS_PYTHON_GLOBALS_H |
11 | |
12 | #include <optional> |
13 | #include <string> |
14 | #include <vector> |
15 | |
16 | #include "NanobindUtils.h" |
17 | #include "mlir-c/IR.h" |
18 | #include "mlir/CAPI/Support.h" |
19 | #include "llvm/ADT/DenseMap.h" |
20 | #include "llvm/ADT/StringRef.h" |
21 | #include "llvm/ADT/StringSet.h" |
22 | |
23 | namespace mlir { |
24 | namespace python { |
25 | |
26 | /// Globals that are always accessible once the extension has been initialized. |
27 | /// Methods of this class are thread-safe. |
28 | class PyGlobals { |
29 | public: |
30 | PyGlobals(); |
31 | ~PyGlobals(); |
32 | |
33 | /// Most code should get the globals via this static accessor. |
34 | static PyGlobals &get() { |
35 | assert(instance && "PyGlobals is null" ); |
36 | return *instance; |
37 | } |
38 | |
39 | /// Get and set the list of parent modules to search for dialect |
40 | /// implementation classes. |
41 | std::vector<std::string> getDialectSearchPrefixes() { |
42 | nanobind::ft_lock_guard lock(mutex); |
43 | return dialectSearchPrefixes; |
44 | } |
45 | void setDialectSearchPrefixes(std::vector<std::string> newValues) { |
46 | nanobind::ft_lock_guard lock(mutex); |
47 | dialectSearchPrefixes.swap(x&: newValues); |
48 | } |
49 | void addDialectSearchPrefix(std::string value) { |
50 | nanobind::ft_lock_guard lock(mutex); |
51 | dialectSearchPrefixes.push_back(x: std::move(value)); |
52 | } |
53 | |
54 | /// Loads a python module corresponding to the given dialect namespace. |
55 | /// No-ops if the module has already been loaded or is not found. Raises |
56 | /// an error on any evaluation issues. |
57 | /// Note that this returns void because it is expected that the module |
58 | /// contains calls to decorators and helpers that register the salient |
59 | /// entities. Returns true if dialect is successfully loaded. |
60 | bool loadDialectModule(llvm::StringRef dialectNamespace); |
61 | |
62 | /// Adds a user-friendly Attribute builder. |
63 | /// Raises an exception if the mapping already exists and replace == false. |
64 | /// This is intended to be called by implementation code. |
65 | void registerAttributeBuilder(const std::string &attributeKind, |
66 | nanobind::callable pyFunc, |
67 | bool replace = false); |
68 | |
69 | /// Adds a user-friendly type caster. Raises an exception if the mapping |
70 | /// already exists and replace == false. This is intended to be called by |
71 | /// implementation code. |
72 | void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, |
73 | bool replace = false); |
74 | |
75 | /// Adds a user-friendly value caster. Raises an exception if the mapping |
76 | /// already exists and replace == false. This is intended to be called by |
77 | /// implementation code. |
78 | void registerValueCaster(MlirTypeID mlirTypeID, |
79 | nanobind::callable valueCaster, |
80 | bool replace = false); |
81 | |
82 | /// Adds a concrete implementation dialect class. |
83 | /// Raises an exception if the mapping already exists. |
84 | /// This is intended to be called by implementation code. |
85 | void registerDialectImpl(const std::string &dialectNamespace, |
86 | nanobind::object pyClass); |
87 | |
88 | /// Adds a concrete implementation operation class. |
89 | /// Raises an exception if the mapping already exists and replace == false. |
90 | /// This is intended to be called by implementation code. |
91 | void registerOperationImpl(const std::string &operationName, |
92 | nanobind::object pyClass, bool replace = false); |
93 | |
94 | /// Returns the custom Attribute builder for Attribute kind. |
95 | std::optional<nanobind::callable> |
96 | lookupAttributeBuilder(const std::string &attributeKind); |
97 | |
98 | /// Returns the custom type caster for MlirTypeID mlirTypeID. |
99 | std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID, |
100 | MlirDialect dialect); |
101 | |
102 | /// Returns the custom value caster for MlirTypeID mlirTypeID. |
103 | std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID, |
104 | MlirDialect dialect); |
105 | |
106 | /// Looks up a registered dialect class by namespace. Note that this may |
107 | /// trigger loading of the defining module and can arbitrarily re-enter. |
108 | std::optional<nanobind::object> |
109 | lookupDialectClass(const std::string &dialectNamespace); |
110 | |
111 | /// Looks up a registered operation class (deriving from OpView) by operation |
112 | /// name. Note that this may trigger a load of the dialect, which can |
113 | /// arbitrarily re-enter. |
114 | std::optional<nanobind::object> |
115 | lookupOperationClass(llvm::StringRef operationName); |
116 | |
117 | private: |
118 | static PyGlobals *instance; |
119 | |
120 | nanobind::ft_mutex mutex; |
121 | |
122 | /// Module name prefixes to search under for dialect implementation modules. |
123 | std::vector<std::string> dialectSearchPrefixes; |
124 | /// Map of dialect namespace to external dialect class object. |
125 | llvm::StringMap<nanobind::object> dialectClassMap; |
126 | /// Map of full operation name to external operation class object. |
127 | llvm::StringMap<nanobind::object> operationClassMap; |
128 | /// Map of attribute ODS name to custom builder. |
129 | llvm::StringMap<nanobind::callable> attributeBuilderMap; |
130 | /// Map of MlirTypeID to custom type caster. |
131 | llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap; |
132 | /// Map of MlirTypeID to custom value caster. |
133 | llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap; |
134 | /// Set of dialect namespaces that we have attempted to import implementation |
135 | /// modules for. |
136 | llvm::StringSet<> loadedDialectModules; |
137 | }; |
138 | |
139 | } // namespace python |
140 | } // namespace mlir |
141 | |
142 | #endif // MLIR_BINDINGS_PYTHON_GLOBALS_H |
143 | |