1 | //===- TransformInterpreter.cpp -------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Pybind classes for the transform dialect interpreter. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir-c/Dialect/Transform/Interpreter.h" |
14 | #include "mlir-c/IR.h" |
15 | #include "mlir-c/Support.h" |
16 | #include "mlir/Bindings/Python/Diagnostics.h" |
17 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
18 | #include "mlir/Bindings/Python/Nanobind.h" |
19 | |
20 | namespace nb = nanobind; |
21 | |
22 | namespace { |
23 | struct PyMlirTransformOptions { |
24 | PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; |
25 | PyMlirTransformOptions(PyMlirTransformOptions &&other) { |
26 | options = other.options; |
27 | other.options.ptr = nullptr; |
28 | } |
29 | PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; |
30 | |
31 | ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(transformOptions: options); } |
32 | |
33 | MlirTransformOptions options; |
34 | }; |
35 | } // namespace |
36 | |
37 | static void populateTransformInterpreterSubmodule(nb::module_ &m) { |
38 | nb::class_<PyMlirTransformOptions>(m, "TransformOptions" ) |
39 | .def(nb::init<>()) |
40 | .def_prop_rw( |
41 | "expensive_checks" , |
42 | [](const PyMlirTransformOptions &self) { |
43 | return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); |
44 | }, |
45 | [](PyMlirTransformOptions &self, bool value) { |
46 | mlirTransformOptionsEnableExpensiveChecks(self.options, value); |
47 | }) |
48 | .def_prop_rw( |
49 | "enforce_single_top_level_transform_op" , |
50 | [](const PyMlirTransformOptions &self) { |
51 | return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( |
52 | self.options); |
53 | }, |
54 | [](PyMlirTransformOptions &self, bool value) { |
55 | mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, |
56 | value); |
57 | }); |
58 | |
59 | m.def( |
60 | "apply_named_sequence" , |
61 | [](MlirOperation payloadRoot, MlirOperation transformRoot, |
62 | MlirOperation transformModule, const PyMlirTransformOptions &options) { |
63 | mlir::python::CollectDiagnosticsToStringScope scope( |
64 | mlirOperationGetContext(transformRoot)); |
65 | |
66 | // Calling back into Python to invalidate everything under the payload |
67 | // root. This is awkward, but we don't have access to PyMlirContext |
68 | // object here otherwise. |
69 | nb::object obj = nb::cast(payloadRoot); |
70 | obj.attr("context" ).attr("_clear_live_operations_inside" )(payloadRoot); |
71 | |
72 | MlirLogicalResult result = mlirTransformApplyNamedSequence( |
73 | payloadRoot, transformRoot, transformModule, options.options); |
74 | if (mlirLogicalResultIsSuccess(result)) |
75 | return; |
76 | |
77 | throw nb::value_error( |
78 | ("Failed to apply named transform sequence.\nDiagnostic message " + |
79 | scope.takeMessage()) |
80 | .c_str()); |
81 | }, |
82 | nb::arg("payload_root" ), nb::arg("transform_root" ), |
83 | nb::arg("transform_module" ), |
84 | nb::arg("transform_options" ) = PyMlirTransformOptions()); |
85 | |
86 | m.def( |
87 | "copy_symbols_and_merge_into" , |
88 | [](MlirOperation target, MlirOperation other) { |
89 | mlir::python::CollectDiagnosticsToStringScope scope( |
90 | mlirOperationGetContext(target)); |
91 | |
92 | MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); |
93 | if (mlirLogicalResultIsFailure(result)) { |
94 | throw nb::value_error( |
95 | ("Failed to merge symbols.\nDiagnostic message " + |
96 | scope.takeMessage()) |
97 | .c_str()); |
98 | } |
99 | }, |
100 | nb::arg("target" ), nb::arg("other" )); |
101 | } |
102 | |
103 | NB_MODULE(_mlirTransformInterpreter, m) { |
104 | m.doc() = "MLIR Transform dialect interpreter functionality." ; |
105 | populateTransformInterpreterSubmodule(m); |
106 | } |
107 | |