1//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PythonTestCAPI.h"
10#include "mlir-c/BuiltinAttributes.h"
11#include "mlir-c/BuiltinTypes.h"
12#include "mlir-c/IR.h"
13#include "mlir/Bindings/Python/PybindAdaptors.h"
14
15namespace py = pybind11;
16using namespace mlir::python::adaptors;
17using namespace pybind11::literals;
18
19static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
20 return mlirTypeIsARankedTensor(type: t) &&
21 mlirTypeIsAInteger(type: mlirShapedTypeGetElementType(type: t));
22}
23
24PYBIND11_MODULE(_mlirPythonTest, m) {
25 m.def(
26 "register_python_test_dialect",
27 [](MlirContext context, bool load) {
28 MlirDialectHandle pythonTestDialect =
29 mlirGetDialectHandle__python_test__();
30 mlirDialectHandleRegisterDialect(pythonTestDialect, context);
31 if (load) {
32 mlirDialectHandleLoadDialect(pythonTestDialect, context);
33 }
34 },
35 py::arg("context"), py::arg("load") = true);
36
37 m.def(
38 "register_dialect",
39 [](MlirDialectRegistry registry) {
40 MlirDialectHandle pythonTestDialect =
41 mlirGetDialectHandle__python_test__();
42 mlirDialectHandleInsertDialect(pythonTestDialect, registry);
43 },
44 py::arg("registry"));
45
46 mlir_attribute_subclass(m, "TestAttr",
47 mlirAttributeIsAPythonTestTestAttribute)
48 .def_classmethod(
49 "get",
50 [](py::object cls, MlirContext ctx) {
51 return cls(mlirPythonTestTestAttributeGet(ctx));
52 },
53 py::arg("cls"), py::arg("context") = py::none());
54
55 mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
56 mlirPythonTestTestTypeGetTypeID)
57 .def_classmethod(
58 "get",
59 [](py::object cls, MlirContext ctx) {
60 return cls(mlirPythonTestTestTypeGet(ctx));
61 },
62 py::arg("cls"), py::arg("context") = py::none());
63
64 auto typeCls =
65 mlir_type_subclass(m, "TestIntegerRankedTensorType",
66 mlirTypeIsARankedIntegerTensor,
67 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
68 .attr("RankedTensorType"))
69 .def_classmethod(
70 "get",
71 [](const py::object &cls, std::vector<int64_t> shape,
72 unsigned width, MlirContext ctx) {
73 MlirAttribute encoding = mlirAttributeGetNull();
74 return cls(mlirRankedTensorTypeGet(
75 shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
76 encoding));
77 },
78 "cls"_a, "shape"_a, "width"_a, "context"_a = py::none());
79
80 assert(py::hasattr(typeCls.get_class(), "static_typeid") &&
81 "TestIntegerRankedTensorType has no static_typeid");
82
83 MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
84
85 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
86 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID,
87 "replace"_a = true)(
88 pybind11::cpp_function([typeCls](const py::object &mlirType) {
89 return typeCls.get_class()(mlirType);
90 }));
91
92 auto valueCls = mlir_value_subclass(m, "TestTensorValue",
93 mlirTypeIsAPythonTestTestTensorValue)
94 .def("is_null", [](MlirValue &self) {
95 return mlirValueIsNull(self);
96 });
97
98 py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
99 .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
100 mlirRankedTensorTypeID)(
101 pybind11::cpp_function([valueCls](const py::object &valueObj) {
102 py::object capsule = mlirApiObjectToCapsule(valueObj);
103 MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
104 MlirType t = mlirValueGetType(v);
105 // This is hyper-specific in order to exercise/test registering a
106 // value caster from cpp (but only for a single test case; see
107 // testTensorValue python_test.py).
108 if (mlirShapedTypeHasStaticShape(t) &&
109 mlirShapedTypeGetDimSize(t, 0) == 1 &&
110 mlirShapedTypeGetDimSize(t, 1) == 2 &&
111 mlirShapedTypeGetDimSize(t, 2) == 3)
112 return valueCls.get_class()(valueObj);
113 return valueObj;
114 }));
115}
116

source code of mlir/test/python/lib/PythonTestModule.cpp