1 | //===- TransformInterpreter.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Pybind classes for the transform dialect interpreter. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir-c/Dialect/Transform/Interpreter.h" |
14 | #include "mlir-c/IR.h" |
15 | #include "mlir-c/Support.h" |
16 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
17 | |
18 | #include <pybind11/detail/common.h> |
19 | #include <pybind11/pybind11.h> |
20 | |
21 | namespace py = pybind11; |
22 | |
23 | namespace { |
24 | struct PyMlirTransformOptions { |
25 | PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; |
26 | PyMlirTransformOptions(PyMlirTransformOptions &&other) { |
27 | options = other.options; |
28 | other.options.ptr = nullptr; |
29 | } |
30 | PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; |
31 | |
32 | ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(transformOptions: options); } |
33 | |
34 | MlirTransformOptions options; |
35 | }; |
36 | } // namespace |
37 | |
38 | static void populateTransformInterpreterSubmodule(py::module &m) { |
39 | py::class_<PyMlirTransformOptions>(m, "TransformOptions" , py::module_local()) |
40 | .def(py::init()) |
41 | .def_property( |
42 | "expensive_checks" , |
43 | [](const PyMlirTransformOptions &self) { |
44 | return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); |
45 | }, |
46 | [](PyMlirTransformOptions &self, bool value) { |
47 | mlirTransformOptionsEnableExpensiveChecks(self.options, value); |
48 | }) |
49 | .def_property( |
50 | "enforce_single_top_level_transform_op" , |
51 | [](const PyMlirTransformOptions &self) { |
52 | return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( |
53 | self.options); |
54 | }, |
55 | [](PyMlirTransformOptions &self, bool value) { |
56 | mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, |
57 | value); |
58 | }); |
59 | |
60 | m.def( |
61 | "apply_named_sequence" , |
62 | [](MlirOperation payloadRoot, MlirOperation transformRoot, |
63 | MlirOperation transformModule, const PyMlirTransformOptions &options) { |
64 | mlir::python::CollectDiagnosticsToStringScope scope( |
65 | mlirOperationGetContext(transformRoot)); |
66 | |
67 | // Calling back into Python to invalidate everything under the payload |
68 | // root. This is awkward, but we don't have access to PyMlirContext |
69 | // object here otherwise. |
70 | py::object obj = py::cast(payloadRoot); |
71 | obj.attr("context" ).attr("_clear_live_operations_inside" )(payloadRoot); |
72 | |
73 | MlirLogicalResult result = mlirTransformApplyNamedSequence( |
74 | payloadRoot, transformRoot, transformModule, options.options); |
75 | if (mlirLogicalResultIsSuccess(result)) |
76 | return; |
77 | |
78 | throw py::value_error( |
79 | "Failed to apply named transform sequence.\nDiagnostic message " + |
80 | scope.takeMessage()); |
81 | }, |
82 | py::arg("payload_root" ), py::arg("transform_root" ), |
83 | py::arg("transform_module" ), |
84 | py::arg("transform_options" ) = PyMlirTransformOptions()); |
85 | |
86 | m.def( |
87 | "copy_symbols_and_merge_into" , |
88 | [](MlirOperation target, MlirOperation other) { |
89 | mlir::python::CollectDiagnosticsToStringScope scope( |
90 | mlirOperationGetContext(target)); |
91 | |
92 | MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); |
93 | if (mlirLogicalResultIsFailure(result)) { |
94 | throw py::value_error( |
95 | "Failed to merge symbols.\nDiagnostic message " + |
96 | scope.takeMessage()); |
97 | } |
98 | }, |
99 | py::arg("target" ), py::arg("other" )); |
100 | } |
101 | |
102 | PYBIND11_MODULE(_mlirTransformInterpreter, m) { |
103 | m.doc() = "MLIR Transform dialect interpreter functionality." ; |
104 | populateTransformInterpreterSubmodule(m); |
105 | } |
106 | |