1//===- TransformInterpreter.cpp -------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pybind classes for the transform dialect interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir-c/Dialect/Transform/Interpreter.h"
14#include "mlir-c/IR.h"
15#include "mlir-c/Support.h"
16#include "mlir/Bindings/Python/PybindAdaptors.h"
17
18#include <pybind11/detail/common.h>
19#include <pybind11/pybind11.h>
20
21namespace py = pybind11;
22
23namespace {
24struct PyMlirTransformOptions {
25 PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
26 PyMlirTransformOptions(PyMlirTransformOptions &&other) {
27 options = other.options;
28 other.options.ptr = nullptr;
29 }
30 PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
31
32 ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(transformOptions: options); }
33
34 MlirTransformOptions options;
35};
36} // namespace
37
38static void populateTransformInterpreterSubmodule(py::module &m) {
39 py::class_<PyMlirTransformOptions>(m, "TransformOptions", py::module_local())
40 .def(py::init())
41 .def_property(
42 "expensive_checks",
43 [](const PyMlirTransformOptions &self) {
44 return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
45 },
46 [](PyMlirTransformOptions &self, bool value) {
47 mlirTransformOptionsEnableExpensiveChecks(self.options, value);
48 })
49 .def_property(
50 "enforce_single_top_level_transform_op",
51 [](const PyMlirTransformOptions &self) {
52 return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
53 self.options);
54 },
55 [](PyMlirTransformOptions &self, bool value) {
56 mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
57 value);
58 });
59
60 m.def(
61 "apply_named_sequence",
62 [](MlirOperation payloadRoot, MlirOperation transformRoot,
63 MlirOperation transformModule, const PyMlirTransformOptions &options) {
64 mlir::python::CollectDiagnosticsToStringScope scope(
65 mlirOperationGetContext(transformRoot));
66
67 // Calling back into Python to invalidate everything under the payload
68 // root. This is awkward, but we don't have access to PyMlirContext
69 // object here otherwise.
70 py::object obj = py::cast(payloadRoot);
71 obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
72
73 MlirLogicalResult result = mlirTransformApplyNamedSequence(
74 payloadRoot, transformRoot, transformModule, options.options);
75 if (mlirLogicalResultIsSuccess(result))
76 return;
77
78 throw py::value_error(
79 "Failed to apply named transform sequence.\nDiagnostic message " +
80 scope.takeMessage());
81 },
82 py::arg("payload_root"), py::arg("transform_root"),
83 py::arg("transform_module"),
84 py::arg("transform_options") = PyMlirTransformOptions());
85
86 m.def(
87 "copy_symbols_and_merge_into",
88 [](MlirOperation target, MlirOperation other) {
89 mlir::python::CollectDiagnosticsToStringScope scope(
90 mlirOperationGetContext(target));
91
92 MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
93 if (mlirLogicalResultIsFailure(result)) {
94 throw py::value_error(
95 "Failed to merge symbols.\nDiagnostic message " +
96 scope.takeMessage());
97 }
98 },
99 py::arg("target"), py::arg("other"));
100}
101
102PYBIND11_MODULE(_mlirTransformInterpreter, m) {
103 m.doc() = "MLIR Transform dialect interpreter functionality.";
104 populateTransformInterpreterSubmodule(m);
105}
106

source code of mlir/lib/Bindings/Python/TransformInterpreter.cpp