1//===- Pass.cpp - Pass Management -----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Pass.h"
10
11#include "IRModule.h"
12#include "mlir-c/Bindings/Python/Interop.h"
13#include "mlir-c/Pass.h"
14
15namespace py = pybind11;
16using namespace py::literals;
17using namespace mlir;
18using namespace mlir::python;
19
20namespace {
21
22/// Owning Wrapper around a PassManager.
23class PyPassManager {
24public:
25 PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26 PyPassManager(PyPassManager &&other) noexcept
27 : passManager(other.passManager) {
28 other.passManager.ptr = nullptr;
29 }
30 ~PyPassManager() {
31 if (!mlirPassManagerIsNull(passManager))
32 mlirPassManagerDestroy(passManager);
33 }
34 MlirPassManager get() { return passManager; }
35
36 void release() { passManager.ptr = nullptr; }
37 pybind11::object getCapsule() {
38 return py::reinterpret_steal<py::object>(
39 mlirPythonPassManagerToCapsule(get()));
40 }
41
42 static pybind11::object createFromCapsule(pybind11::object capsule) {
43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44 if (mlirPassManagerIsNull(rawPm))
45 throw py::error_already_set();
46 return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
47 }
48
49private:
50 MlirPassManager passManager;
51};
52
53} // namespace
54
55/// Create the `mlir.passmanager` here.
56void mlir::python::populatePassManagerSubmodule(py::module &m) {
57 //----------------------------------------------------------------------------
58 // Mapping of the top-level PassManager
59 //----------------------------------------------------------------------------
60 py::class_<PyPassManager>(m, "PassManager", py::module_local())
61 .def(py::init<>([](const std::string &anchorOp,
62 DefaultingPyMlirContext context) {
63 MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64 context->get(),
65 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66 return new PyPassManager(passManager);
67 }),
68 "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69 "Create a new PassManager for the current (or provided) Context.")
70 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71 &PyPassManager::getCapsule)
72 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73 .def("_testing_release", &PyPassManager::release,
74 "Releases (leaks) the backing pass manager (testing)")
75 .def(
76 "enable_ir_printing",
77 [](PyPassManager &passManager) {
78 mlirPassManagerEnableIRPrinting(passManager.get());
79 },
80 "Enable mlir-print-ir-after-all.")
81 .def(
82 "enable_verifier",
83 [](PyPassManager &passManager, bool enable) {
84 mlirPassManagerEnableVerifier(passManager.get(), enable);
85 },
86 "enable"_a, "Enable / disable verify-each.")
87 .def_static(
88 "parse",
89 [](const std::string &pipeline, DefaultingPyMlirContext context) {
90 MlirPassManager passManager = mlirPassManagerCreate(context->get());
91 PyPrintAccumulator errorMsg;
92 MlirLogicalResult status = mlirParsePassPipeline(
93 mlirPassManagerGetAsOpPassManager(passManager),
94 mlirStringRefCreate(pipeline.data(), pipeline.size()),
95 errorMsg.getCallback(), errorMsg.getUserData());
96 if (mlirLogicalResultIsFailure(status))
97 throw py::value_error(std::string(errorMsg.join()));
98 return new PyPassManager(passManager);
99 },
100 "pipeline"_a, "context"_a = py::none(),
101 "Parse a textual pass-pipeline and return a top-level PassManager "
102 "that can be applied on a Module. Throw a ValueError if the pipeline "
103 "can't be parsed")
104 .def(
105 "add",
106 [](PyPassManager &passManager, const std::string &pipeline) {
107 PyPrintAccumulator errorMsg;
108 MlirLogicalResult status = mlirOpPassManagerAddPipeline(
109 mlirPassManagerGetAsOpPassManager(passManager.get()),
110 mlirStringRefCreate(pipeline.data(), pipeline.size()),
111 errorMsg.getCallback(), errorMsg.getUserData());
112 if (mlirLogicalResultIsFailure(status))
113 throw py::value_error(std::string(errorMsg.join()));
114 },
115 "pipeline"_a,
116 "Add textual pipeline elements to the pass manager. Throws a "
117 "ValueError if the pipeline can't be parsed.")
118 .def(
119 "run",
120 [](PyPassManager &passManager, PyOperationBase &op,
121 bool invalidateOps) {
122 if (invalidateOps) {
123 op.getOperation().getContext()->clearOperationsInside(op);
124 }
125 // Actually run the pass manager.
126 PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
127 MlirLogicalResult status = mlirPassManagerRunOnOp(
128 passManager.get(), op.getOperation().get());
129 if (mlirLogicalResultIsFailure(status))
130 throw MLIRError("Failure while executing pass pipeline",
131 errors.take());
132 },
133 "operation"_a, "invalidate_ops"_a = true,
134 "Run the pass manager on the provided operation, raising an "
135 "MLIRError on failure.")
136 .def(
137 "__str__",
138 [](PyPassManager &self) {
139 MlirPassManager passManager = self.get();
140 PyPrintAccumulator printAccum;
141 mlirPrintPassPipeline(
142 mlirPassManagerGetAsOpPassManager(passManager),
143 printAccum.getCallback(), printAccum.getUserData());
144 return printAccum.join();
145 },
146 "Print the textual representation for this PassManager, suitable to "
147 "be passed to `parse` for round-tripping.");
148}
149

source code of mlir/lib/Bindings/Python/Pass.cpp