1 | //===- ExecutionEngineModule.cpp - Python module for execution engine -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/ExecutionEngine.h" |
10 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
11 | |
12 | namespace py = pybind11; |
13 | using namespace mlir; |
14 | using namespace mlir::python; |
15 | |
16 | namespace { |
17 | |
18 | /// Owning Wrapper around an ExecutionEngine. |
19 | class PyExecutionEngine { |
20 | public: |
21 | PyExecutionEngine(MlirExecutionEngine executionEngine) |
22 | : executionEngine(executionEngine) {} |
23 | PyExecutionEngine(PyExecutionEngine &&other) noexcept |
24 | : executionEngine(other.executionEngine) { |
25 | other.executionEngine.ptr = nullptr; |
26 | } |
27 | ~PyExecutionEngine() { |
28 | if (!mlirExecutionEngineIsNull(executionEngine)) |
29 | mlirExecutionEngineDestroy(executionEngine); |
30 | } |
31 | MlirExecutionEngine get() { return executionEngine; } |
32 | |
33 | void release() { |
34 | executionEngine.ptr = nullptr; |
35 | referencedObjects.clear(); |
36 | } |
37 | pybind11::object getCapsule() { |
38 | return py::reinterpret_steal<py::object>( |
39 | mlirPythonExecutionEngineToCapsule(get())); |
40 | } |
41 | |
42 | // Add an object to the list of referenced objects whose lifetime must exceed |
43 | // those of the ExecutionEngine. |
44 | void addReferencedObject(const pybind11::object &obj) { |
45 | referencedObjects.push_back(obj); |
46 | } |
47 | |
48 | static pybind11::object createFromCapsule(pybind11::object capsule) { |
49 | MlirExecutionEngine rawPm = |
50 | mlirPythonCapsuleToExecutionEngine(capsule.ptr()); |
51 | if (mlirExecutionEngineIsNull(rawPm)) |
52 | throw py::error_already_set(); |
53 | return py::cast(PyExecutionEngine(rawPm), py::return_value_policy::move); |
54 | } |
55 | |
56 | private: |
57 | MlirExecutionEngine executionEngine; |
58 | // We support Python ctypes closures as callbacks. Keep a list of the objects |
59 | // so that they don't get garbage collected. (The ExecutionEngine itself |
60 | // just holds raw pointers with no lifetime semantics). |
61 | std::vector<py::object> referencedObjects; |
62 | }; |
63 | |
64 | } // namespace |
65 | |
66 | /// Create the `mlir.execution_engine` module here. |
67 | PYBIND11_MODULE(_mlirExecutionEngine, m) { |
68 | m.doc() = "MLIR Execution Engine" ; |
69 | |
70 | //---------------------------------------------------------------------------- |
71 | // Mapping of the top-level PassManager |
72 | //---------------------------------------------------------------------------- |
73 | py::class_<PyExecutionEngine>(m, "ExecutionEngine" , py::module_local()) |
74 | .def(py::init<>([](MlirModule module, int optLevel, |
75 | const std::vector<std::string> &sharedLibPaths, |
76 | bool enableObjectDump) { |
77 | llvm::SmallVector<MlirStringRef, 4> libPaths; |
78 | for (const std::string &path : sharedLibPaths) |
79 | libPaths.push_back({path.c_str(), path.length()}); |
80 | MlirExecutionEngine executionEngine = |
81 | mlirExecutionEngineCreate(module, optLevel, libPaths.size(), |
82 | libPaths.data(), enableObjectDump); |
83 | if (mlirExecutionEngineIsNull(executionEngine)) |
84 | throw std::runtime_error( |
85 | "Failure while creating the ExecutionEngine." ); |
86 | return new PyExecutionEngine(executionEngine); |
87 | }), |
88 | py::arg("module" ), py::arg("opt_level" ) = 2, |
89 | py::arg("shared_libs" ) = py::list(), |
90 | py::arg("enable_object_dump" ) = true, |
91 | "Create a new ExecutionEngine instance for the given Module. The " |
92 | "module must contain only dialects that can be translated to LLVM. " |
93 | "Perform transformations and code generation at the optimization " |
94 | "level `opt_level` if specified, or otherwise at the default " |
95 | "level of two (-O2). Load a list of libraries specified in " |
96 | "`shared_libs`." ) |
97 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
98 | &PyExecutionEngine::getCapsule) |
99 | .def("_testing_release" , &PyExecutionEngine::release, |
100 | "Releases (leaks) the backing ExecutionEngine (for testing purpose)" ) |
101 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule) |
102 | .def( |
103 | "raw_lookup" , |
104 | [](PyExecutionEngine &executionEngine, const std::string &func) { |
105 | auto *res = mlirExecutionEngineLookupPacked( |
106 | executionEngine.get(), |
107 | mlirStringRefCreate(func.c_str(), func.size())); |
108 | return reinterpret_cast<uintptr_t>(res); |
109 | }, |
110 | py::arg("func_name" ), |
111 | "Lookup function `func` in the ExecutionEngine." ) |
112 | .def( |
113 | "raw_register_runtime" , |
114 | [](PyExecutionEngine &executionEngine, const std::string &name, |
115 | py::object callbackObj) { |
116 | executionEngine.addReferencedObject(callbackObj); |
117 | uintptr_t rawSym = |
118 | py::cast<uintptr_t>(py::getattr(callbackObj, "value" )); |
119 | mlirExecutionEngineRegisterSymbol( |
120 | executionEngine.get(), |
121 | mlirStringRefCreate(name.c_str(), name.size()), |
122 | reinterpret_cast<void *>(rawSym)); |
123 | }, |
124 | py::arg("name" ), py::arg("callback" ), |
125 | "Register `callback` as the runtime symbol `name`." ) |
126 | .def( |
127 | "dump_to_object_file" , |
128 | [](PyExecutionEngine &executionEngine, const std::string &fileName) { |
129 | mlirExecutionEngineDumpToObjectFile( |
130 | executionEngine.get(), |
131 | mlirStringRefCreate(fileName.c_str(), fileName.size())); |
132 | }, |
133 | py::arg("file_name" ), "Dump ExecutionEngine to an object file." ); |
134 | } |
135 | |