1//===- IRModule.cpp - IR pybind module ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IRModule.h"
10
11#include <optional>
12#include <vector>
13
14#include "Globals.h"
15#include "NanobindUtils.h"
16#include "mlir-c/Support.h"
17#include "mlir/Bindings/Python/Nanobind.h"
18#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
19
20namespace nb = nanobind;
21using namespace mlir;
22using namespace mlir::python;
23
24// -----------------------------------------------------------------------------
25// PyGlobals
26// -----------------------------------------------------------------------------
27
28PyGlobals *PyGlobals::instance = nullptr;
29
30PyGlobals::PyGlobals() {
31 assert(!instance && "PyGlobals already constructed");
32 instance = this;
33 // The default search path include {mlir.}dialects, where {mlir.} is the
34 // package prefix configured at compile time.
35 dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
36}
37
38PyGlobals::~PyGlobals() { instance = nullptr; }
39
40bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
41 {
42 nb::ft_lock_guard lock(mutex);
43 if (loadedDialectModules.contains(key: dialectNamespace))
44 return true;
45 }
46 // Since re-entrancy is possible, make a copy of the search prefixes.
47 std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
48 nb::object loaded = nb::none();
49 for (std::string moduleName : localSearchPrefixes) {
50 moduleName.push_back(c: '.');
51 moduleName.append(s: dialectNamespace.data(), n: dialectNamespace.size());
52
53 try {
54 loaded = nb::module_::import_(moduleName.c_str());
55 } catch (nb::python_error &e) {
56 if (e.matches(PyExc_ModuleNotFoundError)) {
57 continue;
58 }
59 throw;
60 }
61 break;
62 }
63
64 if (loaded.is_none())
65 return false;
66 // Note: Iterator cannot be shared from prior to loading, since re-entrancy
67 // may have occurred, which may do anything.
68 nb::ft_lock_guard lock(mutex);
69 loadedDialectModules.insert(key: dialectNamespace);
70 return true;
71}
72
73void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
74 nb::callable pyFunc, bool replace) {
75 nb::ft_lock_guard lock(mutex);
76 nb::object &found = attributeBuilderMap[attributeKind];
77 if (found && !replace) {
78 throw std::runtime_error((llvm::Twine("Attribute builder for '") +
79 attributeKind +
80 "' is already registered with func: " +
81 nb::cast<std::string>(nb::str(found)))
82 .str());
83 }
84 found = std::move(pyFunc);
85}
86
87void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
88 nb::callable typeCaster, bool replace) {
89 nb::ft_lock_guard lock(mutex);
90 nb::object &found = typeCasterMap[mlirTypeID];
91 if (found && !replace)
92 throw std::runtime_error("Type caster is already registered with caster: " +
93 nb::cast<std::string>(nb::str(found)));
94 found = std::move(typeCaster);
95}
96
97void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
98 nb::callable valueCaster, bool replace) {
99 nb::ft_lock_guard lock(mutex);
100 nb::object &found = valueCasterMap[mlirTypeID];
101 if (found && !replace)
102 throw std::runtime_error("Value caster is already registered: " +
103 nb::cast<std::string>(nb::repr(found)));
104 found = std::move(valueCaster);
105}
106
107void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
108 nb::object pyClass) {
109 nb::ft_lock_guard lock(mutex);
110 nb::object &found = dialectClassMap[dialectNamespace];
111 if (found) {
112 throw std::runtime_error((llvm::Twine("Dialect namespace '") +
113 dialectNamespace + "' is already registered.")
114 .str());
115 }
116 found = std::move(pyClass);
117}
118
119void PyGlobals::registerOperationImpl(const std::string &operationName,
120 nb::object pyClass, bool replace) {
121 nb::ft_lock_guard lock(mutex);
122 nb::object &found = operationClassMap[operationName];
123 if (found && !replace) {
124 throw std::runtime_error((llvm::Twine("Operation '") + operationName +
125 "' is already registered.")
126 .str());
127 }
128 found = std::move(pyClass);
129}
130
131std::optional<nb::callable>
132PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
133 nb::ft_lock_guard lock(mutex);
134 const auto foundIt = attributeBuilderMap.find(attributeKind);
135 if (foundIt != attributeBuilderMap.end()) {
136 assert(foundIt->second && "attribute builder is defined");
137 return foundIt->second;
138 }
139 return std::nullopt;
140}
141
142std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
143 MlirDialect dialect) {
144 // Try to load dialect module.
145 (void)loadDialectModule(dialectNamespace: unwrap(mlirDialectGetNamespace(dialect)));
146 nb::ft_lock_guard lock(mutex);
147 const auto foundIt = typeCasterMap.find(mlirTypeID);
148 if (foundIt != typeCasterMap.end()) {
149 assert(foundIt->second && "type caster is defined");
150 return foundIt->second;
151 }
152 return std::nullopt;
153}
154
155std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
156 MlirDialect dialect) {
157 // Try to load dialect module.
158 (void)loadDialectModule(dialectNamespace: unwrap(mlirDialectGetNamespace(dialect)));
159 nb::ft_lock_guard lock(mutex);
160 const auto foundIt = valueCasterMap.find(mlirTypeID);
161 if (foundIt != valueCasterMap.end()) {
162 assert(foundIt->second && "value caster is defined");
163 return foundIt->second;
164 }
165 return std::nullopt;
166}
167
168std::optional<nb::object>
169PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
170 // Make sure dialect module is loaded.
171 if (!loadDialectModule(dialectNamespace))
172 return std::nullopt;
173 nb::ft_lock_guard lock(mutex);
174 const auto foundIt = dialectClassMap.find(dialectNamespace);
175 if (foundIt != dialectClassMap.end()) {
176 assert(foundIt->second && "dialect class is defined");
177 return foundIt->second;
178 }
179 // Not found and loading did not yield a registration.
180 return std::nullopt;
181}
182
183std::optional<nb::object>
184PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
185 // Make sure dialect module is loaded.
186 auto split = operationName.split(Separator: '.');
187 llvm::StringRef dialectNamespace = split.first;
188 if (!loadDialectModule(dialectNamespace))
189 return std::nullopt;
190
191 nb::ft_lock_guard lock(mutex);
192 auto foundIt = operationClassMap.find(operationName);
193 if (foundIt != operationClassMap.end()) {
194 assert(foundIt->second && "OpView is defined");
195 return foundIt->second;
196 }
197 // Not found and loading did not yield a registration.
198 return std::nullopt;
199}
200

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Bindings/Python/IRModule.cpp