1//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Rewrite.h"
10
11#include "IRModule.h"
12#include "mlir-c/Rewrite.h"
13#include "mlir/Bindings/Python/Nanobind.h"
14#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
15#include "mlir/Config/mlir-config.h"
16
17namespace nb = nanobind;
18using namespace mlir;
19using namespace nb::literals;
20using namespace mlir::python;
21
22namespace {
23
24#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
25/// Owning Wrapper around a PDLPatternModule.
26class PyPDLPatternModule {
27public:
28 PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
29 PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
30 : module(other.module) {
31 other.module.ptr = nullptr;
32 }
33 ~PyPDLPatternModule() {
34 if (module.ptr != nullptr)
35 mlirPDLPatternModuleDestroy(op: module);
36 }
37 MlirPDLPatternModule get() { return module; }
38
39private:
40 MlirPDLPatternModule module;
41};
42#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
43
44/// Owning Wrapper around a FrozenRewritePatternSet.
45class PyFrozenRewritePatternSet {
46public:
47 PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
48 PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
49 : set(other.set) {
50 other.set.ptr = nullptr;
51 }
52 ~PyFrozenRewritePatternSet() {
53 if (set.ptr != nullptr)
54 mlirFrozenRewritePatternSetDestroy(op: set);
55 }
56 MlirFrozenRewritePatternSet get() { return set; }
57
58 nb::object getCapsule() {
59 return nb::steal<nb::object>(
60 mlirPythonFrozenRewritePatternSetToCapsule(get()));
61 }
62
63 static nb::object createFromCapsule(nb::object capsule) {
64 MlirFrozenRewritePatternSet rawPm =
65 mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
66 if (rawPm.ptr == nullptr)
67 throw nb::python_error();
68 return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
69 }
70
71private:
72 MlirFrozenRewritePatternSet set;
73};
74
75} // namespace
76
77/// Create the `mlir.rewrite` here.
78void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
79 //----------------------------------------------------------------------------
80 // Mapping of the top-level PassManager
81 //----------------------------------------------------------------------------
82#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
83 nb::class_<PyPDLPatternModule>(m, "PDLModule")
84 .def(
85 "__init__",
86 [](PyPDLPatternModule &self, MlirModule module) {
87 new (&self)
88 PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
89 },
90 "module"_a, "Create a PDL module from the given module.")
91 .def("freeze", [](PyPDLPatternModule &self) {
92 return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
93 mlirRewritePatternSetFromPDLPatternModule(self.get())));
94 });
95#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
96 nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
97 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
98 &PyFrozenRewritePatternSet::getCapsule)
99 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100 &PyFrozenRewritePatternSet::createFromCapsule);
101 m.def(
102 "apply_patterns_and_fold_greedily",
103 [](MlirModule module, MlirFrozenRewritePatternSet set) {
104 auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105 if (mlirLogicalResultIsFailure(status))
106 // FIXME: Not sure this is the right error to throw here.
107 throw nb::value_error("pattern application failed to converge");
108 },
109 "module"_a, "set"_a,
110 "Applys the given patterns to the given module greedily while folding "
111 "results.");
112}
113

source code of mlir/lib/Bindings/Python/Rewrite.cpp