1 | //===- ExecutionEngineModule.cpp - Python module for execution engine -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/ExecutionEngine.h" |
10 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
11 | #include "mlir/Bindings/Python/Nanobind.h" |
12 | |
13 | namespace nb = nanobind; |
14 | using namespace mlir; |
15 | using namespace mlir::python; |
16 | |
17 | namespace { |
18 | |
19 | /// Owning Wrapper around an ExecutionEngine. |
20 | class PyExecutionEngine { |
21 | public: |
22 | PyExecutionEngine(MlirExecutionEngine executionEngine) |
23 | : executionEngine(executionEngine) {} |
24 | PyExecutionEngine(PyExecutionEngine &&other) noexcept |
25 | : executionEngine(other.executionEngine) { |
26 | other.executionEngine.ptr = nullptr; |
27 | } |
28 | ~PyExecutionEngine() { |
29 | if (!mlirExecutionEngineIsNull(executionEngine)) |
30 | mlirExecutionEngineDestroy(executionEngine); |
31 | } |
32 | MlirExecutionEngine get() { return executionEngine; } |
33 | |
34 | void release() { |
35 | executionEngine.ptr = nullptr; |
36 | referencedObjects.clear(); |
37 | } |
38 | nb::object getCapsule() { |
39 | return nb::steal<nb::object>(mlirPythonExecutionEngineToCapsule(get())); |
40 | } |
41 | |
42 | // Add an object to the list of referenced objects whose lifetime must exceed |
43 | // those of the ExecutionEngine. |
44 | void addReferencedObject(const nb::object &obj) { |
45 | referencedObjects.push_back(obj); |
46 | } |
47 | |
48 | static nb::object createFromCapsule(nb::object capsule) { |
49 | MlirExecutionEngine rawPm = |
50 | mlirPythonCapsuleToExecutionEngine(capsule.ptr()); |
51 | if (mlirExecutionEngineIsNull(rawPm)) |
52 | throw nb::python_error(); |
53 | return nb::cast(PyExecutionEngine(rawPm), nb::rv_policy::move); |
54 | } |
55 | |
56 | private: |
57 | MlirExecutionEngine executionEngine; |
58 | // We support Python ctypes closures as callbacks. Keep a list of the objects |
59 | // so that they don't get garbage collected. (The ExecutionEngine itself |
60 | // just holds raw pointers with no lifetime semantics). |
61 | std::vector<nb::object> referencedObjects; |
62 | }; |
63 | |
64 | } // namespace |
65 | |
66 | /// Create the `mlir.execution_engine` module here. |
67 | NB_MODULE(_mlirExecutionEngine, m) { |
68 | m.doc() = "MLIR Execution Engine" ; |
69 | |
70 | //---------------------------------------------------------------------------- |
71 | // Mapping of the top-level PassManager |
72 | //---------------------------------------------------------------------------- |
73 | nb::class_<PyExecutionEngine>(m, "ExecutionEngine" ) |
74 | .def( |
75 | "__init__" , |
76 | [](PyExecutionEngine &self, MlirModule module, int optLevel, |
77 | const std::vector<std::string> &sharedLibPaths, |
78 | bool enableObjectDump) { |
79 | llvm::SmallVector<MlirStringRef, 4> libPaths; |
80 | for (const std::string &path : sharedLibPaths) |
81 | libPaths.push_back({path.c_str(), path.length()}); |
82 | MlirExecutionEngine executionEngine = |
83 | mlirExecutionEngineCreate(module, optLevel, libPaths.size(), |
84 | libPaths.data(), enableObjectDump); |
85 | if (mlirExecutionEngineIsNull(executionEngine)) |
86 | throw std::runtime_error( |
87 | "Failure while creating the ExecutionEngine." ); |
88 | new (&self) PyExecutionEngine(executionEngine); |
89 | }, |
90 | nb::arg("module" ), nb::arg("opt_level" ) = 2, |
91 | nb::arg("shared_libs" ) = nb::list(), |
92 | nb::arg("enable_object_dump" ) = true, |
93 | "Create a new ExecutionEngine instance for the given Module. The " |
94 | "module must contain only dialects that can be translated to LLVM. " |
95 | "Perform transformations and code generation at the optimization " |
96 | "level `opt_level` if specified, or otherwise at the default " |
97 | "level of two (-O2). Load a list of libraries specified in " |
98 | "`shared_libs`." ) |
99 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule) |
100 | .def("_testing_release" , &PyExecutionEngine::release, |
101 | "Releases (leaks) the backing ExecutionEngine (for testing purpose)" ) |
102 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule) |
103 | .def( |
104 | "raw_lookup" , |
105 | [](PyExecutionEngine &executionEngine, const std::string &func) { |
106 | auto *res = mlirExecutionEngineLookupPacked( |
107 | executionEngine.get(), |
108 | mlirStringRefCreate(func.c_str(), func.size())); |
109 | return reinterpret_cast<uintptr_t>(res); |
110 | }, |
111 | nb::arg("func_name" ), |
112 | "Lookup function `func` in the ExecutionEngine." ) |
113 | .def( |
114 | "raw_register_runtime" , |
115 | [](PyExecutionEngine &executionEngine, const std::string &name, |
116 | nb::object callbackObj) { |
117 | executionEngine.addReferencedObject(callbackObj); |
118 | uintptr_t rawSym = |
119 | nb::cast<uintptr_t>(nb::getattr(callbackObj, "value" )); |
120 | mlirExecutionEngineRegisterSymbol( |
121 | executionEngine.get(), |
122 | mlirStringRefCreate(name.c_str(), name.size()), |
123 | reinterpret_cast<void *>(rawSym)); |
124 | }, |
125 | nb::arg("name" ), nb::arg("callback" ), |
126 | "Register `callback` as the runtime symbol `name`." ) |
127 | .def( |
128 | "dump_to_object_file" , |
129 | [](PyExecutionEngine &executionEngine, const std::string &fileName) { |
130 | mlirExecutionEngineDumpToObjectFile( |
131 | executionEngine.get(), |
132 | mlirStringRefCreate(fileName.c_str(), fileName.size())); |
133 | }, |
134 | nb::arg("file_name" ), "Dump ExecutionEngine to an object file." ); |
135 | } |
136 | |