1//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/Dialect/Transform.h"
10#include "mlir-c/IR.h"
11#include "mlir-c/Support.h"
12#include "mlir/Bindings/Python/PybindAdaptors.h"
13#include <pybind11/cast.h>
14#include <pybind11/detail/common.h>
15#include <pybind11/pybind11.h>
16#include <pybind11/pytypes.h>
17#include <string>
18
19namespace py = pybind11;
20using namespace mlir;
21using namespace mlir::python;
22using namespace mlir::python::adaptors;
23
24void populateDialectTransformSubmodule(const pybind11::module &m) {
25 //===-------------------------------------------------------------------===//
26 // AnyOpType
27 //===-------------------------------------------------------------------===//
28
29 auto anyOpType =
30 mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
31 mlirTransformAnyOpTypeGetTypeID);
32 anyOpType.def_classmethod(
33 "get",
34 [](py::object cls, MlirContext ctx) {
35 return cls(mlirTransformAnyOpTypeGet(ctx));
36 },
37 "Get an instance of AnyOpType in the given context.", py::arg("cls"),
38 py::arg("context") = py::none());
39
40 //===-------------------------------------------------------------------===//
41 // AnyParamType
42 //===-------------------------------------------------------------------===//
43
44 auto anyParamType =
45 mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
46 mlirTransformAnyParamTypeGetTypeID);
47 anyParamType.def_classmethod(
48 "get",
49 [](py::object cls, MlirContext ctx) {
50 return cls(mlirTransformAnyParamTypeGet(ctx));
51 },
52 "Get an instance of AnyParamType in the given context.", py::arg("cls"),
53 py::arg("context") = py::none());
54
55 //===-------------------------------------------------------------------===//
56 // AnyValueType
57 //===-------------------------------------------------------------------===//
58
59 auto anyValueType =
60 mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
61 mlirTransformAnyValueTypeGetTypeID);
62 anyValueType.def_classmethod(
63 "get",
64 [](py::object cls, MlirContext ctx) {
65 return cls(mlirTransformAnyValueTypeGet(ctx));
66 },
67 "Get an instance of AnyValueType in the given context.", py::arg("cls"),
68 py::arg("context") = py::none());
69
70 //===-------------------------------------------------------------------===//
71 // OperationType
72 //===-------------------------------------------------------------------===//
73
74 auto operationType =
75 mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
76 mlirTransformOperationTypeGetTypeID);
77 operationType.def_classmethod(
78 "get",
79 [](py::object cls, const std::string &operationName, MlirContext ctx) {
80 MlirStringRef cOperationName =
81 mlirStringRefCreate(operationName.data(), operationName.size());
82 return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
83 },
84 "Get an instance of OperationType for the given kind in the given "
85 "context",
86 py::arg("cls"), py::arg("operation_name"),
87 py::arg("context") = py::none());
88 operationType.def_property_readonly(
89 "operation_name",
90 [](MlirType type) {
91 MlirStringRef operationName =
92 mlirTransformOperationTypeGetOperationName(type);
93 return py::str(operationName.data, operationName.length);
94 },
95 "Get the name of the payload operation accepted by the handle.");
96
97 //===-------------------------------------------------------------------===//
98 // ParamType
99 //===-------------------------------------------------------------------===//
100
101 auto paramType =
102 mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
103 mlirTransformParamTypeGetTypeID);
104 paramType.def_classmethod(
105 "get",
106 [](py::object cls, MlirType type, MlirContext ctx) {
107 return cls(mlirTransformParamTypeGet(ctx, type));
108 },
109 "Get an instance of ParamType for the given type in the given context.",
110 py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
111 paramType.def_property_readonly(
112 "type",
113 [](MlirType type) {
114 MlirType paramType = mlirTransformParamTypeGetType(type);
115 return paramType;
116 },
117 "Get the type this ParamType is associated with.");
118}
119
120PYBIND11_MODULE(_mlirDialectsTransform, m) {
121 m.doc() = "MLIR Transform dialect.";
122 populateDialectTransformSubmodule(m);
123}
124

source code of mlir/lib/Bindings/Python/DialectTransform.cpp