1 | //===- IRModule.cpp - IR pybind module ------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "IRModule.h" |
10 | #include "Globals.h" |
11 | #include "PybindUtils.h" |
12 | |
13 | #include "mlir-c/Bindings/Python/Interop.h" |
14 | #include "mlir-c/Support.h" |
15 | |
16 | #include <optional> |
17 | #include <vector> |
18 | |
19 | namespace py = pybind11; |
20 | using namespace mlir; |
21 | using namespace mlir::python; |
22 | |
23 | // ----------------------------------------------------------------------------- |
24 | // PyGlobals |
25 | // ----------------------------------------------------------------------------- |
26 | |
27 | PyGlobals *PyGlobals::instance = nullptr; |
28 | |
29 | PyGlobals::PyGlobals() { |
30 | assert(!instance && "PyGlobals already constructed" ); |
31 | instance = this; |
32 | // The default search path include {mlir.}dialects, where {mlir.} is the |
33 | // package prefix configured at compile time. |
34 | dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects" )); |
35 | } |
36 | |
37 | PyGlobals::~PyGlobals() { instance = nullptr; } |
38 | |
39 | bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { |
40 | if (loadedDialectModules.contains(key: dialectNamespace)) |
41 | return true; |
42 | // Since re-entrancy is possible, make a copy of the search prefixes. |
43 | std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes; |
44 | py::object loaded = py::none(); |
45 | for (std::string moduleName : localSearchPrefixes) { |
46 | moduleName.push_back(c: '.'); |
47 | moduleName.append(s: dialectNamespace.data(), n: dialectNamespace.size()); |
48 | |
49 | try { |
50 | loaded = py::module::import(moduleName.c_str()); |
51 | } catch (py::error_already_set &e) { |
52 | if (e.matches(PyExc_ModuleNotFoundError)) { |
53 | continue; |
54 | } |
55 | throw; |
56 | } |
57 | break; |
58 | } |
59 | |
60 | if (loaded.is_none()) |
61 | return false; |
62 | // Note: Iterator cannot be shared from prior to loading, since re-entrancy |
63 | // may have occurred, which may do anything. |
64 | loadedDialectModules.insert(key: dialectNamespace); |
65 | return true; |
66 | } |
67 | |
68 | void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, |
69 | py::function pyFunc, bool replace) { |
70 | py::object &found = attributeBuilderMap[attributeKind]; |
71 | if (found && !replace) { |
72 | throw std::runtime_error((llvm::Twine("Attribute builder for '" ) + |
73 | attributeKind + |
74 | "' is already registered with func: " + |
75 | py::str(found).operator std::string()) |
76 | .str()); |
77 | } |
78 | found = std::move(pyFunc); |
79 | } |
80 | |
81 | void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, |
82 | pybind11::function typeCaster, |
83 | bool replace) { |
84 | pybind11::object &found = typeCasterMap[mlirTypeID]; |
85 | if (found && !replace) |
86 | throw std::runtime_error("Type caster is already registered with caster: " + |
87 | py::str(found).operator std::string()); |
88 | found = std::move(typeCaster); |
89 | } |
90 | |
91 | void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, |
92 | pybind11::function valueCaster, |
93 | bool replace) { |
94 | pybind11::object &found = valueCasterMap[mlirTypeID]; |
95 | if (found && !replace) |
96 | throw std::runtime_error("Value caster is already registered: " + |
97 | py::repr(found).cast<std::string>()); |
98 | found = std::move(valueCaster); |
99 | } |
100 | |
101 | void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, |
102 | py::object pyClass) { |
103 | py::object &found = dialectClassMap[dialectNamespace]; |
104 | if (found) { |
105 | throw std::runtime_error((llvm::Twine("Dialect namespace '" ) + |
106 | dialectNamespace + "' is already registered." ) |
107 | .str()); |
108 | } |
109 | found = std::move(pyClass); |
110 | } |
111 | |
112 | void PyGlobals::registerOperationImpl(const std::string &operationName, |
113 | py::object pyClass, bool replace) { |
114 | py::object &found = operationClassMap[operationName]; |
115 | if (found && !replace) { |
116 | throw std::runtime_error((llvm::Twine("Operation '" ) + operationName + |
117 | "' is already registered." ) |
118 | .str()); |
119 | } |
120 | found = std::move(pyClass); |
121 | } |
122 | |
123 | std::optional<py::function> |
124 | PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { |
125 | const auto foundIt = attributeBuilderMap.find(attributeKind); |
126 | if (foundIt != attributeBuilderMap.end()) { |
127 | assert(foundIt->second && "attribute builder is defined" ); |
128 | return foundIt->second; |
129 | } |
130 | return std::nullopt; |
131 | } |
132 | |
133 | std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, |
134 | MlirDialect dialect) { |
135 | // Try to load dialect module. |
136 | (void)loadDialectModule(dialectNamespace: unwrap(mlirDialectGetNamespace(dialect))); |
137 | const auto foundIt = typeCasterMap.find(mlirTypeID); |
138 | if (foundIt != typeCasterMap.end()) { |
139 | assert(foundIt->second && "type caster is defined" ); |
140 | return foundIt->second; |
141 | } |
142 | return std::nullopt; |
143 | } |
144 | |
145 | std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, |
146 | MlirDialect dialect) { |
147 | // Try to load dialect module. |
148 | (void)loadDialectModule(dialectNamespace: unwrap(mlirDialectGetNamespace(dialect))); |
149 | const auto foundIt = valueCasterMap.find(mlirTypeID); |
150 | if (foundIt != valueCasterMap.end()) { |
151 | assert(foundIt->second && "value caster is defined" ); |
152 | return foundIt->second; |
153 | } |
154 | return std::nullopt; |
155 | } |
156 | |
157 | std::optional<py::object> |
158 | PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { |
159 | // Make sure dialect module is loaded. |
160 | if (!loadDialectModule(dialectNamespace)) |
161 | return std::nullopt; |
162 | const auto foundIt = dialectClassMap.find(dialectNamespace); |
163 | if (foundIt != dialectClassMap.end()) { |
164 | assert(foundIt->second && "dialect class is defined" ); |
165 | return foundIt->second; |
166 | } |
167 | // Not found and loading did not yield a registration. |
168 | return std::nullopt; |
169 | } |
170 | |
171 | std::optional<pybind11::object> |
172 | PyGlobals::lookupOperationClass(llvm::StringRef operationName) { |
173 | // Make sure dialect module is loaded. |
174 | auto split = operationName.split(Separator: '.'); |
175 | llvm::StringRef dialectNamespace = split.first; |
176 | if (!loadDialectModule(dialectNamespace)) |
177 | return std::nullopt; |
178 | |
179 | auto foundIt = operationClassMap.find(operationName); |
180 | if (foundIt != operationClassMap.end()) { |
181 | assert(foundIt->second && "OpView is defined" ); |
182 | return foundIt->second; |
183 | } |
184 | // Not found and loading did not yield a registration. |
185 | return std::nullopt; |
186 | } |
187 | |