1//===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This is the nanobind edition of the PythonTest dialect module.
9//===----------------------------------------------------------------------===//
10
11#include "PythonTestCAPI.h"
12#include "mlir-c/BuiltinAttributes.h"
13#include "mlir-c/BuiltinTypes.h"
14#include "mlir-c/Diagnostics.h"
15#include "mlir-c/IR.h"
16#include "mlir/Bindings/Python/Diagnostics.h"
17#include "mlir/Bindings/Python/Nanobind.h"
18#include "mlir/Bindings/Python/NanobindAdaptors.h"
19#include "nanobind/nanobind.h"
20
21namespace nb = nanobind;
22using namespace mlir::python::nanobind_adaptors;
23
24static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
25 return mlirTypeIsARankedTensor(type: t) &&
26 mlirTypeIsAInteger(type: mlirShapedTypeGetElementType(type: t));
27}
28
29NB_MODULE(_mlirPythonTestNanobind, m) {
30 m.def(
31 "register_python_test_dialect",
32 [](MlirContext context, bool load) {
33 MlirDialectHandle pythonTestDialect =
34 mlirGetDialectHandle__python_test__();
35 mlirDialectHandleRegisterDialect(pythonTestDialect, context);
36 if (load) {
37 mlirDialectHandleLoadDialect(pythonTestDialect, context);
38 }
39 },
40 nb::arg("context"), nb::arg("load") = true);
41
42 m.def(
43 "register_dialect",
44 [](MlirDialectRegistry registry) {
45 MlirDialectHandle pythonTestDialect =
46 mlirGetDialectHandle__python_test__();
47 mlirDialectHandleInsertDialect(pythonTestDialect, registry);
48 },
49 nb::arg("registry"));
50
51 m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
52 mlir::python::CollectDiagnosticsToStringScope handler(ctx);
53
54 mlirPythonTestEmitDiagnosticWithNote(ctx);
55 throw nb::value_error(handler.takeMessage().c_str());
56 });
57
58 mlir_attribute_subclass(m, "TestAttr",
59 mlirAttributeIsAPythonTestTestAttribute,
60 mlirPythonTestTestAttributeGetTypeID)
61 .def_classmethod(
62 "get",
63 [](const nb::object &cls, MlirContext ctx) {
64 return cls(mlirPythonTestTestAttributeGet(ctx));
65 },
66 nb::arg("cls"), nb::arg("context").none() = nb::none());
67
68 mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType,
69 mlirPythonTestTestTypeGetTypeID)
70 .def_classmethod(
71 "get",
72 [](const nb::object &cls, MlirContext ctx) {
73 return cls(mlirPythonTestTestTypeGet(ctx));
74 },
75 nb::arg("cls"), nb::arg("context").none() = nb::none());
76
77 auto typeCls =
78 mlir_type_subclass(m, "TestIntegerRankedTensorType",
79 mlirTypeIsARankedIntegerTensor,
80 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
81 .attr("RankedTensorType"))
82 .def_classmethod(
83 "get",
84 [](const nb::object &cls, std::vector<int64_t> shape,
85 unsigned width, MlirContext ctx) {
86 MlirAttribute encoding = mlirAttributeGetNull();
87 return cls(mlirRankedTensorTypeGet(
88 shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
89 encoding));
90 },
91 nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
92 nb::arg("context").none() = nb::none());
93
94 assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
95 "TestIntegerRankedTensorType has no static_typeid");
96
97 MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
98
99 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
100 .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
101 mlirRankedTensorTypeID, nb::arg("replace") = true)(
102 nanobind::cpp_function([typeCls](const nb::object &mlirType) {
103 return typeCls.get_class()(mlirType);
104 }));
105
106 auto valueCls = mlir_value_subclass(m, "TestTensorValue",
107 mlirTypeIsAPythonTestTestTensorValue)
108 .def("is_null", [](MlirValue &self) {
109 return mlirValueIsNull(self);
110 });
111
112 nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
113 .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
114 mlirRankedTensorTypeID)(
115 nanobind::cpp_function([valueCls](const nb::object &valueObj) {
116 nb::object capsule = mlirApiObjectToCapsule(valueObj);
117 MlirValue v = mlirPythonCapsuleToValue(capsule.ptr());
118 MlirType t = mlirValueGetType(v);
119 // This is hyper-specific in order to exercise/test registering a
120 // value caster from cpp (but only for a single test case; see
121 // testTensorValue python_test.py).
122 if (mlirShapedTypeHasStaticShape(t) &&
123 mlirShapedTypeGetDimSize(t, 0) == 1 &&
124 mlirShapedTypeGetDimSize(t, 1) == 2 &&
125 mlirShapedTypeGetDimSize(t, 2) == 3)
126 return valueCls.get_class()(valueObj);
127 return valueObj;
128 }));
129}
130

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/test/python/lib/PythonTestModuleNanobind.cpp