1//===- Globals.h - MLIR Python extension globals --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
10#define MLIR_BINDINGS_PYTHON_GLOBALS_H
11
12#include "PybindUtils.h"
13
14#include "mlir-c/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/ADT/StringSet.h"
19
20#include <optional>
21#include <string>
22#include <vector>
23
24namespace mlir {
25namespace python {
26
27/// Globals that are always accessible once the extension has been initialized.
28class PyGlobals {
29public:
30 PyGlobals();
31 ~PyGlobals();
32
33 /// Most code should get the globals via this static accessor.
34 static PyGlobals &get() {
35 assert(instance && "PyGlobals is null");
36 return *instance;
37 }
38
39 /// Get and set the list of parent modules to search for dialect
40 /// implementation classes.
41 std::vector<std::string> &getDialectSearchPrefixes() {
42 return dialectSearchPrefixes;
43 }
44 void setDialectSearchPrefixes(std::vector<std::string> newValues) {
45 dialectSearchPrefixes.swap(x&: newValues);
46 }
47
48 /// Loads a python module corresponding to the given dialect namespace.
49 /// No-ops if the module has already been loaded or is not found. Raises
50 /// an error on any evaluation issues.
51 /// Note that this returns void because it is expected that the module
52 /// contains calls to decorators and helpers that register the salient
53 /// entities. Returns true if dialect is successfully loaded.
54 bool loadDialectModule(llvm::StringRef dialectNamespace);
55
56 /// Adds a user-friendly Attribute builder.
57 /// Raises an exception if the mapping already exists and replace == false.
58 /// This is intended to be called by implementation code.
59 void registerAttributeBuilder(const std::string &attributeKind,
60 pybind11::function pyFunc,
61 bool replace = false);
62
63 /// Adds a user-friendly type caster. Raises an exception if the mapping
64 /// already exists and replace == false. This is intended to be called by
65 /// implementation code.
66 void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
67 bool replace = false);
68
69 /// Adds a user-friendly value caster. Raises an exception if the mapping
70 /// already exists and replace == false. This is intended to be called by
71 /// implementation code.
72 void registerValueCaster(MlirTypeID mlirTypeID,
73 pybind11::function valueCaster,
74 bool replace = false);
75
76 /// Adds a concrete implementation dialect class.
77 /// Raises an exception if the mapping already exists.
78 /// This is intended to be called by implementation code.
79 void registerDialectImpl(const std::string &dialectNamespace,
80 pybind11::object pyClass);
81
82 /// Adds a concrete implementation operation class.
83 /// Raises an exception if the mapping already exists and replace == false.
84 /// This is intended to be called by implementation code.
85 void registerOperationImpl(const std::string &operationName,
86 pybind11::object pyClass, bool replace = false);
87
88 /// Returns the custom Attribute builder for Attribute kind.
89 std::optional<pybind11::function>
90 lookupAttributeBuilder(const std::string &attributeKind);
91
92 /// Returns the custom type caster for MlirTypeID mlirTypeID.
93 std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
94 MlirDialect dialect);
95
96 /// Returns the custom value caster for MlirTypeID mlirTypeID.
97 std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
98 MlirDialect dialect);
99
100 /// Looks up a registered dialect class by namespace. Note that this may
101 /// trigger loading of the defining module and can arbitrarily re-enter.
102 std::optional<pybind11::object>
103 lookupDialectClass(const std::string &dialectNamespace);
104
105 /// Looks up a registered operation class (deriving from OpView) by operation
106 /// name. Note that this may trigger a load of the dialect, which can
107 /// arbitrarily re-enter.
108 std::optional<pybind11::object>
109 lookupOperationClass(llvm::StringRef operationName);
110
111private:
112 static PyGlobals *instance;
113 /// Module name prefixes to search under for dialect implementation modules.
114 std::vector<std::string> dialectSearchPrefixes;
115 /// Map of dialect namespace to external dialect class object.
116 llvm::StringMap<pybind11::object> dialectClassMap;
117 /// Map of full operation name to external operation class object.
118 llvm::StringMap<pybind11::object> operationClassMap;
119 /// Map of attribute ODS name to custom builder.
120 llvm::StringMap<pybind11::object> attributeBuilderMap;
121 /// Map of MlirTypeID to custom type caster.
122 llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
123 /// Map of MlirTypeID to custom value caster.
124 llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
125 /// Set of dialect namespaces that we have attempted to import implementation
126 /// modules for.
127 llvm::StringSet<> loadedDialectModules;
128};
129
130} // namespace python
131} // namespace mlir
132
133#endif // MLIR_BINDINGS_PYTHON_GLOBALS_H
134

source code of mlir/lib/Bindings/Python/Globals.h