1 | //===- Pass.cpp - Pass Management -----------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Pass.h" |
10 | |
11 | #include "IRModule.h" |
12 | #include "mlir-c/Pass.h" |
13 | #include "mlir/Bindings/Python/Nanobind.h" |
14 | #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. |
15 | |
16 | namespace nb = nanobind; |
17 | using namespace nb::literals; |
18 | using namespace mlir; |
19 | using namespace mlir::python; |
20 | |
21 | namespace { |
22 | |
23 | /// Owning Wrapper around a PassManager. |
24 | class PyPassManager { |
25 | public: |
26 | PyPassManager(MlirPassManager passManager) : passManager(passManager) {} |
27 | PyPassManager(PyPassManager &&other) noexcept |
28 | : passManager(other.passManager) { |
29 | other.passManager.ptr = nullptr; |
30 | } |
31 | ~PyPassManager() { |
32 | if (!mlirPassManagerIsNull(passManager)) |
33 | mlirPassManagerDestroy(passManager); |
34 | } |
35 | MlirPassManager get() { return passManager; } |
36 | |
37 | void release() { passManager.ptr = nullptr; } |
38 | nb::object getCapsule() { |
39 | return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get())); |
40 | } |
41 | |
42 | static nb::object createFromCapsule(nb::object capsule) { |
43 | MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); |
44 | if (mlirPassManagerIsNull(rawPm)) |
45 | throw nb::python_error(); |
46 | return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); |
47 | } |
48 | |
49 | private: |
50 | MlirPassManager passManager; |
51 | }; |
52 | |
53 | } // namespace |
54 | |
55 | /// Create the `mlir.passmanager` here. |
56 | void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { |
57 | //---------------------------------------------------------------------------- |
58 | // Mapping of the top-level PassManager |
59 | //---------------------------------------------------------------------------- |
60 | nb::class_<PyPassManager>(m, "PassManager" ) |
61 | .def( |
62 | "__init__" , |
63 | [](PyPassManager &self, const std::string &anchorOp, |
64 | DefaultingPyMlirContext context) { |
65 | MlirPassManager passManager = mlirPassManagerCreateOnOperation( |
66 | context->get(), |
67 | mlirStringRefCreate(anchorOp.data(), anchorOp.size())); |
68 | new (&self) PyPassManager(passManager); |
69 | }, |
70 | "anchor_op"_a = nb::str("any" ), "context"_a .none() = nb::none(), |
71 | "Create a new PassManager for the current (or provided) Context." ) |
72 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) |
73 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) |
74 | .def("_testing_release" , &PyPassManager::release, |
75 | "Releases (leaks) the backing pass manager (testing)" ) |
76 | .def( |
77 | "enable_ir_printing" , |
78 | [](PyPassManager &passManager, bool printBeforeAll, |
79 | bool printAfterAll, bool printModuleScope, bool printAfterChange, |
80 | bool printAfterFailure, std::optional<int64_t> largeElementsLimit, |
81 | bool enableDebugInfo, bool printGenericOpForm, |
82 | std::optional<std::string> optionalTreePrintingPath) { |
83 | MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
84 | if (largeElementsLimit) |
85 | mlirOpPrintingFlagsElideLargeElementsAttrs(flags, |
86 | *largeElementsLimit); |
87 | if (enableDebugInfo) |
88 | mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, |
89 | /*prettyForm=*/false); |
90 | if (printGenericOpForm) |
91 | mlirOpPrintingFlagsPrintGenericOpForm(flags); |
92 | std::string treePrintingPath = "" ; |
93 | if (optionalTreePrintingPath.has_value()) |
94 | treePrintingPath = optionalTreePrintingPath.value(); |
95 | mlirPassManagerEnableIRPrinting( |
96 | passManager.get(), printBeforeAll, printAfterAll, |
97 | printModuleScope, printAfterChange, printAfterFailure, flags, |
98 | mlirStringRefCreate(treePrintingPath.data(), |
99 | treePrintingPath.size())); |
100 | mlirOpPrintingFlagsDestroy(flags); |
101 | }, |
102 | "print_before_all"_a = false, "print_after_all"_a = true, |
103 | "print_module_scope"_a = false, "print_after_change"_a = false, |
104 | "print_after_failure"_a = false, |
105 | "large_elements_limit"_a .none() = nb::none(), |
106 | "enable_debug_info"_a = false, "print_generic_op_form"_a = false, |
107 | "tree_printing_dir_path"_a .none() = nb::none(), |
108 | "Enable IR printing, default as mlir-print-ir-after-all." ) |
109 | .def( |
110 | "enable_verifier" , |
111 | [](PyPassManager &passManager, bool enable) { |
112 | mlirPassManagerEnableVerifier(passManager.get(), enable); |
113 | }, |
114 | "enable"_a , "Enable / disable verify-each." ) |
115 | .def_static( |
116 | "parse" , |
117 | [](const std::string &pipeline, DefaultingPyMlirContext context) { |
118 | MlirPassManager passManager = mlirPassManagerCreate(context->get()); |
119 | PyPrintAccumulator errorMsg; |
120 | MlirLogicalResult status = mlirParsePassPipeline( |
121 | mlirPassManagerGetAsOpPassManager(passManager), |
122 | mlirStringRefCreate(pipeline.data(), pipeline.size()), |
123 | errorMsg.getCallback(), errorMsg.getUserData()); |
124 | if (mlirLogicalResultIsFailure(status)) |
125 | throw nb::value_error(errorMsg.join().c_str()); |
126 | return new PyPassManager(passManager); |
127 | }, |
128 | "pipeline"_a , "context"_a .none() = nb::none(), |
129 | "Parse a textual pass-pipeline and return a top-level PassManager " |
130 | "that can be applied on a Module. Throw a ValueError if the pipeline " |
131 | "can't be parsed" ) |
132 | .def( |
133 | "add" , |
134 | [](PyPassManager &passManager, const std::string &pipeline) { |
135 | PyPrintAccumulator errorMsg; |
136 | MlirLogicalResult status = mlirOpPassManagerAddPipeline( |
137 | mlirPassManagerGetAsOpPassManager(passManager.get()), |
138 | mlirStringRefCreate(pipeline.data(), pipeline.size()), |
139 | errorMsg.getCallback(), errorMsg.getUserData()); |
140 | if (mlirLogicalResultIsFailure(status)) |
141 | throw nb::value_error(errorMsg.join().c_str()); |
142 | }, |
143 | "pipeline"_a , |
144 | "Add textual pipeline elements to the pass manager. Throws a " |
145 | "ValueError if the pipeline can't be parsed." ) |
146 | .def( |
147 | "run" , |
148 | [](PyPassManager &passManager, PyOperationBase &op, |
149 | bool invalidateOps) { |
150 | if (invalidateOps) { |
151 | op.getOperation().getContext()->clearOperationsInside(op); |
152 | } |
153 | // Actually run the pass manager. |
154 | PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); |
155 | MlirLogicalResult status = mlirPassManagerRunOnOp( |
156 | passManager.get(), op.getOperation().get()); |
157 | if (mlirLogicalResultIsFailure(status)) |
158 | throw MLIRError("Failure while executing pass pipeline" , |
159 | errors.take()); |
160 | }, |
161 | "operation"_a , "invalidate_ops"_a = true, |
162 | "Run the pass manager on the provided operation, raising an " |
163 | "MLIRError on failure." ) |
164 | .def( |
165 | "__str__" , |
166 | [](PyPassManager &self) { |
167 | MlirPassManager passManager = self.get(); |
168 | PyPrintAccumulator printAccum; |
169 | mlirPrintPassPipeline( |
170 | mlirPassManagerGetAsOpPassManager(passManager), |
171 | printAccum.getCallback(), printAccum.getUserData()); |
172 | return printAccum.join(); |
173 | }, |
174 | "Print the textual representation for this PassManager, suitable to " |
175 | "be passed to `parse` for round-tripping." ); |
176 | } |
177 | |