1//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <string>
10
11#include "mlir-c/Dialect/Transform.h"
12#include "mlir-c/IR.h"
13#include "mlir-c/Support.h"
14#include "mlir/Bindings/Python/NanobindAdaptors.h"
15#include "mlir/Bindings/Python/Nanobind.h"
16
17namespace nb = nanobind;
18using namespace mlir;
19using namespace mlir::python;
20using namespace mlir::python::nanobind_adaptors;
21
22void populateDialectTransformSubmodule(const nb::module_ &m) {
23 //===-------------------------------------------------------------------===//
24 // AnyOpType
25 //===-------------------------------------------------------------------===//
26
27 auto anyOpType =
28 mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
29 mlirTransformAnyOpTypeGetTypeID);
30 anyOpType.def_classmethod(
31 "get",
32 [](nb::object cls, MlirContext ctx) {
33 return cls(mlirTransformAnyOpTypeGet(ctx));
34 },
35 "Get an instance of AnyOpType in the given context.", nb::arg("cls"),
36 nb::arg("context").none() = nb::none());
37
38 //===-------------------------------------------------------------------===//
39 // AnyParamType
40 //===-------------------------------------------------------------------===//
41
42 auto anyParamType =
43 mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
44 mlirTransformAnyParamTypeGetTypeID);
45 anyParamType.def_classmethod(
46 "get",
47 [](nb::object cls, MlirContext ctx) {
48 return cls(mlirTransformAnyParamTypeGet(ctx));
49 },
50 "Get an instance of AnyParamType in the given context.", nb::arg("cls"),
51 nb::arg("context").none() = nb::none());
52
53 //===-------------------------------------------------------------------===//
54 // AnyValueType
55 //===-------------------------------------------------------------------===//
56
57 auto anyValueType =
58 mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
59 mlirTransformAnyValueTypeGetTypeID);
60 anyValueType.def_classmethod(
61 "get",
62 [](nb::object cls, MlirContext ctx) {
63 return cls(mlirTransformAnyValueTypeGet(ctx));
64 },
65 "Get an instance of AnyValueType in the given context.", nb::arg("cls"),
66 nb::arg("context").none() = nb::none());
67
68 //===-------------------------------------------------------------------===//
69 // OperationType
70 //===-------------------------------------------------------------------===//
71
72 auto operationType =
73 mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
74 mlirTransformOperationTypeGetTypeID);
75 operationType.def_classmethod(
76 "get",
77 [](nb::object cls, const std::string &operationName, MlirContext ctx) {
78 MlirStringRef cOperationName =
79 mlirStringRefCreate(operationName.data(), operationName.size());
80 return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
81 },
82 "Get an instance of OperationType for the given kind in the given "
83 "context",
84 nb::arg("cls"), nb::arg("operation_name"),
85 nb::arg("context").none() = nb::none());
86 operationType.def_property_readonly(
87 "operation_name",
88 [](MlirType type) {
89 MlirStringRef operationName =
90 mlirTransformOperationTypeGetOperationName(type);
91 return nb::str(operationName.data, operationName.length);
92 },
93 "Get the name of the payload operation accepted by the handle.");
94
95 //===-------------------------------------------------------------------===//
96 // ParamType
97 //===-------------------------------------------------------------------===//
98
99 auto paramType =
100 mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
101 mlirTransformParamTypeGetTypeID);
102 paramType.def_classmethod(
103 "get",
104 [](nb::object cls, MlirType type, MlirContext ctx) {
105 return cls(mlirTransformParamTypeGet(ctx, type));
106 },
107 "Get an instance of ParamType for the given type in the given context.",
108 nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none());
109 paramType.def_property_readonly(
110 "type",
111 [](MlirType type) {
112 MlirType paramType = mlirTransformParamTypeGetType(type);
113 return paramType;
114 },
115 "Get the type this ParamType is associated with.");
116}
117
118NB_MODULE(_mlirDialectsTransform, m) {
119 m.doc() = "MLIR Transform dialect.";
120 populateDialectTransformSubmodule(m);
121}
122

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Bindings/Python/DialectTransform.cpp