1 | //===- Rewrite.cpp - Rewrite ----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "Rewrite.h" |
10 | |
11 | #include "IRModule.h" |
12 | #include "mlir-c/Rewrite.h" |
13 | #include "mlir/Bindings/Python/Nanobind.h" |
14 | #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. |
15 | #include "mlir/Config/mlir-config.h" |
16 | |
17 | namespace nb = nanobind; |
18 | using namespace mlir; |
19 | using namespace nb::literals; |
20 | using namespace mlir::python; |
21 | |
22 | namespace { |
23 | |
24 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
25 | /// Owning Wrapper around a PDLPatternModule. |
26 | class PyPDLPatternModule { |
27 | public: |
28 | PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} |
29 | PyPDLPatternModule(PyPDLPatternModule &&other) noexcept |
30 | : module(other.module) { |
31 | other.module.ptr = nullptr; |
32 | } |
33 | ~PyPDLPatternModule() { |
34 | if (module.ptr != nullptr) |
35 | mlirPDLPatternModuleDestroy(op: module); |
36 | } |
37 | MlirPDLPatternModule get() { return module; } |
38 | |
39 | private: |
40 | MlirPDLPatternModule module; |
41 | }; |
42 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
43 | |
44 | /// Owning Wrapper around a FrozenRewritePatternSet. |
45 | class PyFrozenRewritePatternSet { |
46 | public: |
47 | PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} |
48 | PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept |
49 | : set(other.set) { |
50 | other.set.ptr = nullptr; |
51 | } |
52 | ~PyFrozenRewritePatternSet() { |
53 | if (set.ptr != nullptr) |
54 | mlirFrozenRewritePatternSetDestroy(op: set); |
55 | } |
56 | MlirFrozenRewritePatternSet get() { return set; } |
57 | |
58 | nb::object getCapsule() { |
59 | return nb::steal<nb::object>( |
60 | mlirPythonFrozenRewritePatternSetToCapsule(get())); |
61 | } |
62 | |
63 | static nb::object createFromCapsule(nb::object capsule) { |
64 | MlirFrozenRewritePatternSet rawPm = |
65 | mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); |
66 | if (rawPm.ptr == nullptr) |
67 | throw nb::python_error(); |
68 | return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); |
69 | } |
70 | |
71 | private: |
72 | MlirFrozenRewritePatternSet set; |
73 | }; |
74 | |
75 | } // namespace |
76 | |
77 | /// Create the `mlir.rewrite` here. |
78 | void mlir::python::populateRewriteSubmodule(nb::module_ &m) { |
79 | //---------------------------------------------------------------------------- |
80 | // Mapping of the top-level PassManager |
81 | //---------------------------------------------------------------------------- |
82 | #if MLIR_ENABLE_PDL_IN_PATTERNMATCH |
83 | nb::class_<PyPDLPatternModule>(m, "PDLModule" ) |
84 | .def( |
85 | "__init__" , |
86 | [](PyPDLPatternModule &self, MlirModule module) { |
87 | new (&self) |
88 | PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); |
89 | }, |
90 | "module"_a , "Create a PDL module from the given module." ) |
91 | .def("freeze" , [](PyPDLPatternModule &self) { |
92 | return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( |
93 | mlirRewritePatternSetFromPDLPatternModule(self.get()))); |
94 | }); |
95 | #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH |
96 | nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet" ) |
97 | .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, |
98 | &PyFrozenRewritePatternSet::getCapsule) |
99 | .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, |
100 | &PyFrozenRewritePatternSet::createFromCapsule); |
101 | m.def( |
102 | "apply_patterns_and_fold_greedily" , |
103 | [](MlirModule module, MlirFrozenRewritePatternSet set) { |
104 | auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); |
105 | if (mlirLogicalResultIsFailure(status)) |
106 | // FIXME: Not sure this is the right error to throw here. |
107 | throw nb::value_error("pattern application failed to converge" ); |
108 | }, |
109 | "module"_a , "set"_a , |
110 | "Applys the given patterns to the given module greedily while folding " |
111 | "results." ); |
112 | } |
113 | |