1 | //===- Pass.cpp - Pass Management -----------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Pass.h" |
10 | |
11 | #include "IRModule.h" |
12 | #include "mlir-c/Bindings/Python/Interop.h" |
13 | #include "mlir-c/Pass.h" |
14 | |
15 | namespace py = pybind11; |
16 | using namespace py::literals; |
17 | using namespace mlir; |
18 | using namespace mlir::python; |
19 | |
20 | namespace { |
21 | |
22 | /// Owning Wrapper around a PassManager. |
23 | class PyPassManager { |
24 | public: |
25 | PyPassManager(MlirPassManager passManager) : passManager(passManager) {} |
26 | PyPassManager(PyPassManager &&other) noexcept |
27 | : passManager(other.passManager) { |
28 | other.passManager.ptr = nullptr; |
29 | } |
30 | ~PyPassManager() { |
31 | if (!mlirPassManagerIsNull(passManager)) |
32 | mlirPassManagerDestroy(passManager); |
33 | } |
34 | MlirPassManager get() { return passManager; } |
35 | |
36 | void release() { passManager.ptr = nullptr; } |
37 | pybind11::object getCapsule() { |
38 | return py::reinterpret_steal<py::object>( |
39 | mlirPythonPassManagerToCapsule(get())); |
40 | } |
41 | |
42 | static pybind11::object createFromCapsule(pybind11::object capsule) { |
43 | MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); |
44 | if (mlirPassManagerIsNull(rawPm)) |
45 | throw py::error_already_set(); |
46 | return py::cast(PyPassManager(rawPm), py::return_value_policy::move); |
47 | } |
48 | |
49 | private: |
50 | MlirPassManager passManager; |
51 | }; |
52 | |
53 | } // namespace |
54 | |
55 | /// Create the `mlir.passmanager` here. |
56 | void mlir::python::populatePassManagerSubmodule(py::module &m) { |
57 | //---------------------------------------------------------------------------- |
58 | // Mapping of the top-level PassManager |
59 | //---------------------------------------------------------------------------- |
60 | py::class_<PyPassManager>(m, "PassManager" , py::module_local()) |
61 | .def(py::init<>([](const std::string &anchorOp, |
62 | DefaultingPyMlirContext context) { |
63 | MlirPassManager passManager = mlirPassManagerCreateOnOperation( |
64 | context->get(), |
65 | mlirStringRefCreate(anchorOp.data(), anchorOp.size())); |
66 | return new PyPassManager(passManager); |
67 | }), |
68 | "anchor_op"_a = py::str("any" ), "context"_a = py::none(), |
69 | "Create a new PassManager for the current (or provided) Context." ) |
70 | .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
71 | &PyPassManager::getCapsule) |
72 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) |
73 | .def("_testing_release" , &PyPassManager::release, |
74 | "Releases (leaks) the backing pass manager (testing)" ) |
75 | .def( |
76 | "enable_ir_printing" , |
77 | [](PyPassManager &passManager) { |
78 | mlirPassManagerEnableIRPrinting(passManager.get()); |
79 | }, |
80 | "Enable mlir-print-ir-after-all." ) |
81 | .def( |
82 | "enable_verifier" , |
83 | [](PyPassManager &passManager, bool enable) { |
84 | mlirPassManagerEnableVerifier(passManager.get(), enable); |
85 | }, |
86 | "enable"_a , "Enable / disable verify-each." ) |
87 | .def_static( |
88 | "parse" , |
89 | [](const std::string &pipeline, DefaultingPyMlirContext context) { |
90 | MlirPassManager passManager = mlirPassManagerCreate(context->get()); |
91 | PyPrintAccumulator errorMsg; |
92 | MlirLogicalResult status = mlirParsePassPipeline( |
93 | mlirPassManagerGetAsOpPassManager(passManager), |
94 | mlirStringRefCreate(pipeline.data(), pipeline.size()), |
95 | errorMsg.getCallback(), errorMsg.getUserData()); |
96 | if (mlirLogicalResultIsFailure(status)) |
97 | throw py::value_error(std::string(errorMsg.join())); |
98 | return new PyPassManager(passManager); |
99 | }, |
100 | "pipeline"_a , "context"_a = py::none(), |
101 | "Parse a textual pass-pipeline and return a top-level PassManager " |
102 | "that can be applied on a Module. Throw a ValueError if the pipeline " |
103 | "can't be parsed" ) |
104 | .def( |
105 | "add" , |
106 | [](PyPassManager &passManager, const std::string &pipeline) { |
107 | PyPrintAccumulator errorMsg; |
108 | MlirLogicalResult status = mlirOpPassManagerAddPipeline( |
109 | mlirPassManagerGetAsOpPassManager(passManager.get()), |
110 | mlirStringRefCreate(pipeline.data(), pipeline.size()), |
111 | errorMsg.getCallback(), errorMsg.getUserData()); |
112 | if (mlirLogicalResultIsFailure(status)) |
113 | throw py::value_error(std::string(errorMsg.join())); |
114 | }, |
115 | "pipeline"_a , |
116 | "Add textual pipeline elements to the pass manager. Throws a " |
117 | "ValueError if the pipeline can't be parsed." ) |
118 | .def( |
119 | "run" , |
120 | [](PyPassManager &passManager, PyOperationBase &op, |
121 | bool invalidateOps) { |
122 | if (invalidateOps) { |
123 | op.getOperation().getContext()->clearOperationsInside(op); |
124 | } |
125 | // Actually run the pass manager. |
126 | PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); |
127 | MlirLogicalResult status = mlirPassManagerRunOnOp( |
128 | passManager.get(), op.getOperation().get()); |
129 | if (mlirLogicalResultIsFailure(status)) |
130 | throw MLIRError("Failure while executing pass pipeline" , |
131 | errors.take()); |
132 | }, |
133 | "operation"_a , "invalidate_ops"_a = true, |
134 | "Run the pass manager on the provided operation, raising an " |
135 | "MLIRError on failure." ) |
136 | .def( |
137 | "__str__" , |
138 | [](PyPassManager &self) { |
139 | MlirPassManager passManager = self.get(); |
140 | PyPrintAccumulator printAccum; |
141 | mlirPrintPassPipeline( |
142 | mlirPassManagerGetAsOpPassManager(passManager), |
143 | printAccum.getCallback(), printAccum.getUserData()); |
144 | return printAccum.join(); |
145 | }, |
146 | "Print the textual representation for this PassManager, suitable to " |
147 | "be passed to `parse` for round-tripping." ); |
148 | } |
149 | |