| 1 | //===- TransformInterpreter.cpp -------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Pybind classes for the transform dialect interpreter. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "mlir-c/Dialect/Transform/Interpreter.h" |
| 14 | #include "mlir-c/IR.h" |
| 15 | #include "mlir-c/Support.h" |
| 16 | #include "mlir/Bindings/Python/Diagnostics.h" |
| 17 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
| 18 | #include "mlir/Bindings/Python/Nanobind.h" |
| 19 | |
| 20 | namespace nb = nanobind; |
| 21 | |
| 22 | namespace { |
| 23 | struct PyMlirTransformOptions { |
| 24 | PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; |
| 25 | PyMlirTransformOptions(PyMlirTransformOptions &&other) { |
| 26 | options = other.options; |
| 27 | other.options.ptr = nullptr; |
| 28 | } |
| 29 | PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; |
| 30 | |
| 31 | ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(transformOptions: options); } |
| 32 | |
| 33 | MlirTransformOptions options; |
| 34 | }; |
| 35 | } // namespace |
| 36 | |
| 37 | static void populateTransformInterpreterSubmodule(nb::module_ &m) { |
| 38 | nb::class_<PyMlirTransformOptions>(m, "TransformOptions" ) |
| 39 | .def(nb::init<>()) |
| 40 | .def_prop_rw( |
| 41 | "expensive_checks" , |
| 42 | [](const PyMlirTransformOptions &self) { |
| 43 | return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); |
| 44 | }, |
| 45 | [](PyMlirTransformOptions &self, bool value) { |
| 46 | mlirTransformOptionsEnableExpensiveChecks(self.options, value); |
| 47 | }) |
| 48 | .def_prop_rw( |
| 49 | "enforce_single_top_level_transform_op" , |
| 50 | [](const PyMlirTransformOptions &self) { |
| 51 | return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( |
| 52 | self.options); |
| 53 | }, |
| 54 | [](PyMlirTransformOptions &self, bool value) { |
| 55 | mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, |
| 56 | value); |
| 57 | }); |
| 58 | |
| 59 | m.def( |
| 60 | "apply_named_sequence" , |
| 61 | [](MlirOperation payloadRoot, MlirOperation transformRoot, |
| 62 | MlirOperation transformModule, const PyMlirTransformOptions &options) { |
| 63 | mlir::python::CollectDiagnosticsToStringScope scope( |
| 64 | mlirOperationGetContext(transformRoot)); |
| 65 | |
| 66 | // Calling back into Python to invalidate everything under the payload |
| 67 | // root. This is awkward, but we don't have access to PyMlirContext |
| 68 | // object here otherwise. |
| 69 | nb::object obj = nb::cast(payloadRoot); |
| 70 | obj.attr("context" ).attr("_clear_live_operations_inside" )(payloadRoot); |
| 71 | |
| 72 | MlirLogicalResult result = mlirTransformApplyNamedSequence( |
| 73 | payloadRoot, transformRoot, transformModule, options.options); |
| 74 | if (mlirLogicalResultIsSuccess(result)) |
| 75 | return; |
| 76 | |
| 77 | throw nb::value_error( |
| 78 | ("Failed to apply named transform sequence.\nDiagnostic message " + |
| 79 | scope.takeMessage()) |
| 80 | .c_str()); |
| 81 | }, |
| 82 | nb::arg("payload_root" ), nb::arg("transform_root" ), |
| 83 | nb::arg("transform_module" ), |
| 84 | nb::arg("transform_options" ) = PyMlirTransformOptions()); |
| 85 | |
| 86 | m.def( |
| 87 | "copy_symbols_and_merge_into" , |
| 88 | [](MlirOperation target, MlirOperation other) { |
| 89 | mlir::python::CollectDiagnosticsToStringScope scope( |
| 90 | mlirOperationGetContext(target)); |
| 91 | |
| 92 | MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); |
| 93 | if (mlirLogicalResultIsFailure(result)) { |
| 94 | throw nb::value_error( |
| 95 | ("Failed to merge symbols.\nDiagnostic message " + |
| 96 | scope.takeMessage()) |
| 97 | .c_str()); |
| 98 | } |
| 99 | }, |
| 100 | nb::arg("target" ), nb::arg("other" )); |
| 101 | } |
| 102 | |
| 103 | NB_MODULE(_mlirTransformInterpreter, m) { |
| 104 | m.doc() = "MLIR Transform dialect interpreter functionality." ; |
| 105 | populateTransformInterpreterSubmodule(m); |
| 106 | } |
| 107 | |