1 | //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/Transform.h" |
10 | #include "mlir-c/IR.h" |
11 | #include "mlir-c/Support.h" |
12 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
13 | #include <pybind11/cast.h> |
14 | #include <pybind11/detail/common.h> |
15 | #include <pybind11/pybind11.h> |
16 | #include <pybind11/pytypes.h> |
17 | #include <string> |
18 | |
19 | namespace py = pybind11; |
20 | using namespace mlir; |
21 | using namespace mlir::python; |
22 | using namespace mlir::python::adaptors; |
23 | |
24 | void populateDialectTransformSubmodule(const pybind11::module &m) { |
25 | //===-------------------------------------------------------------------===// |
26 | // AnyOpType |
27 | //===-------------------------------------------------------------------===// |
28 | |
29 | auto anyOpType = |
30 | mlir_type_subclass(m, "AnyOpType" , mlirTypeIsATransformAnyOpType, |
31 | mlirTransformAnyOpTypeGetTypeID); |
32 | anyOpType.def_classmethod( |
33 | "get" , |
34 | [](py::object cls, MlirContext ctx) { |
35 | return cls(mlirTransformAnyOpTypeGet(ctx)); |
36 | }, |
37 | "Get an instance of AnyOpType in the given context." , py::arg("cls" ), |
38 | py::arg("context" ) = py::none()); |
39 | |
40 | //===-------------------------------------------------------------------===// |
41 | // AnyParamType |
42 | //===-------------------------------------------------------------------===// |
43 | |
44 | auto anyParamType = |
45 | mlir_type_subclass(m, "AnyParamType" , mlirTypeIsATransformAnyParamType, |
46 | mlirTransformAnyParamTypeGetTypeID); |
47 | anyParamType.def_classmethod( |
48 | "get" , |
49 | [](py::object cls, MlirContext ctx) { |
50 | return cls(mlirTransformAnyParamTypeGet(ctx)); |
51 | }, |
52 | "Get an instance of AnyParamType in the given context." , py::arg("cls" ), |
53 | py::arg("context" ) = py::none()); |
54 | |
55 | //===-------------------------------------------------------------------===// |
56 | // AnyValueType |
57 | //===-------------------------------------------------------------------===// |
58 | |
59 | auto anyValueType = |
60 | mlir_type_subclass(m, "AnyValueType" , mlirTypeIsATransformAnyValueType, |
61 | mlirTransformAnyValueTypeGetTypeID); |
62 | anyValueType.def_classmethod( |
63 | "get" , |
64 | [](py::object cls, MlirContext ctx) { |
65 | return cls(mlirTransformAnyValueTypeGet(ctx)); |
66 | }, |
67 | "Get an instance of AnyValueType in the given context." , py::arg("cls" ), |
68 | py::arg("context" ) = py::none()); |
69 | |
70 | //===-------------------------------------------------------------------===// |
71 | // OperationType |
72 | //===-------------------------------------------------------------------===// |
73 | |
74 | auto operationType = |
75 | mlir_type_subclass(m, "OperationType" , mlirTypeIsATransformOperationType, |
76 | mlirTransformOperationTypeGetTypeID); |
77 | operationType.def_classmethod( |
78 | "get" , |
79 | [](py::object cls, const std::string &operationName, MlirContext ctx) { |
80 | MlirStringRef cOperationName = |
81 | mlirStringRefCreate(operationName.data(), operationName.size()); |
82 | return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); |
83 | }, |
84 | "Get an instance of OperationType for the given kind in the given " |
85 | "context" , |
86 | py::arg("cls" ), py::arg("operation_name" ), |
87 | py::arg("context" ) = py::none()); |
88 | operationType.def_property_readonly( |
89 | "operation_name" , |
90 | [](MlirType type) { |
91 | MlirStringRef operationName = |
92 | mlirTransformOperationTypeGetOperationName(type); |
93 | return py::str(operationName.data, operationName.length); |
94 | }, |
95 | "Get the name of the payload operation accepted by the handle." ); |
96 | |
97 | //===-------------------------------------------------------------------===// |
98 | // ParamType |
99 | //===-------------------------------------------------------------------===// |
100 | |
101 | auto paramType = |
102 | mlir_type_subclass(m, "ParamType" , mlirTypeIsATransformParamType, |
103 | mlirTransformParamTypeGetTypeID); |
104 | paramType.def_classmethod( |
105 | "get" , |
106 | [](py::object cls, MlirType type, MlirContext ctx) { |
107 | return cls(mlirTransformParamTypeGet(ctx, type)); |
108 | }, |
109 | "Get an instance of ParamType for the given type in the given context." , |
110 | py::arg("cls" ), py::arg("type" ), py::arg("context" ) = py::none()); |
111 | paramType.def_property_readonly( |
112 | "type" , |
113 | [](MlirType type) { |
114 | MlirType paramType = mlirTransformParamTypeGetType(type); |
115 | return paramType; |
116 | }, |
117 | "Get the type this ParamType is associated with." ); |
118 | } |
119 | |
120 | PYBIND11_MODULE(_mlirDialectsTransform, m) { |
121 | m.doc() = "MLIR Transform dialect." ; |
122 | populateDialectTransformSubmodule(m); |
123 | } |
124 | |