| 1 | //===- IRTypes.cpp - Exports builtin and standard types -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | // clang-format off |
| 10 | #include "IRModule.h" |
| 11 | #include "mlir/Bindings/Python/IRTypes.h" |
| 12 | // clang-format on |
| 13 | |
| 14 | #include <optional> |
| 15 | |
| 16 | #include "IRModule.h" |
| 17 | #include "NanobindUtils.h" |
| 18 | #include "mlir-c/BuiltinAttributes.h" |
| 19 | #include "mlir-c/BuiltinTypes.h" |
| 20 | #include "mlir-c/Support.h" |
| 21 | |
| 22 | namespace nb = nanobind; |
| 23 | using namespace mlir; |
| 24 | using namespace mlir::python; |
| 25 | |
| 26 | using llvm::SmallVector; |
| 27 | using llvm::Twine; |
| 28 | |
| 29 | namespace { |
| 30 | |
| 31 | /// Checks whether the given type is an integer or float type. |
| 32 | static int mlirTypeIsAIntegerOrFloat(MlirType type) { |
| 33 | return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || |
| 34 | mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); |
| 35 | } |
| 36 | |
| 37 | class PyIntegerType : public PyConcreteType<PyIntegerType> { |
| 38 | public: |
| 39 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; |
| 40 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 41 | mlirIntegerTypeGetTypeID; |
| 42 | static constexpr const char *pyClassName = "IntegerType" ; |
| 43 | using PyConcreteType::PyConcreteType; |
| 44 | |
| 45 | static void bindDerived(ClassTy &c) { |
| 46 | c.def_static( |
| 47 | "get_signless" , |
| 48 | [](unsigned width, DefaultingPyMlirContext context) { |
| 49 | MlirType t = mlirIntegerTypeGet(context->get(), width); |
| 50 | return PyIntegerType(context->getRef(), t); |
| 51 | }, |
| 52 | nb::arg("width" ), nb::arg("context" ).none() = nb::none(), |
| 53 | "Create a signless integer type" ); |
| 54 | c.def_static( |
| 55 | "get_signed" , |
| 56 | [](unsigned width, DefaultingPyMlirContext context) { |
| 57 | MlirType t = mlirIntegerTypeSignedGet(context->get(), width); |
| 58 | return PyIntegerType(context->getRef(), t); |
| 59 | }, |
| 60 | nb::arg("width" ), nb::arg("context" ).none() = nb::none(), |
| 61 | "Create a signed integer type" ); |
| 62 | c.def_static( |
| 63 | "get_unsigned" , |
| 64 | [](unsigned width, DefaultingPyMlirContext context) { |
| 65 | MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); |
| 66 | return PyIntegerType(context->getRef(), t); |
| 67 | }, |
| 68 | nb::arg("width" ), nb::arg("context" ).none() = nb::none(), |
| 69 | "Create an unsigned integer type" ); |
| 70 | c.def_prop_ro( |
| 71 | "width" , |
| 72 | [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, |
| 73 | "Returns the width of the integer type" ); |
| 74 | c.def_prop_ro( |
| 75 | "is_signless" , |
| 76 | [](PyIntegerType &self) -> bool { |
| 77 | return mlirIntegerTypeIsSignless(self); |
| 78 | }, |
| 79 | "Returns whether this is a signless integer" ); |
| 80 | c.def_prop_ro( |
| 81 | "is_signed" , |
| 82 | [](PyIntegerType &self) -> bool { |
| 83 | return mlirIntegerTypeIsSigned(self); |
| 84 | }, |
| 85 | "Returns whether this is a signed integer" ); |
| 86 | c.def_prop_ro( |
| 87 | "is_unsigned" , |
| 88 | [](PyIntegerType &self) -> bool { |
| 89 | return mlirIntegerTypeIsUnsigned(self); |
| 90 | }, |
| 91 | "Returns whether this is an unsigned integer" ); |
| 92 | } |
| 93 | }; |
| 94 | |
| 95 | /// Index Type subclass - IndexType. |
| 96 | class PyIndexType : public PyConcreteType<PyIndexType> { |
| 97 | public: |
| 98 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; |
| 99 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 100 | mlirIndexTypeGetTypeID; |
| 101 | static constexpr const char *pyClassName = "IndexType" ; |
| 102 | using PyConcreteType::PyConcreteType; |
| 103 | |
| 104 | static void bindDerived(ClassTy &c) { |
| 105 | c.def_static( |
| 106 | "get" , |
| 107 | [](DefaultingPyMlirContext context) { |
| 108 | MlirType t = mlirIndexTypeGet(context->get()); |
| 109 | return PyIndexType(context->getRef(), t); |
| 110 | }, |
| 111 | nb::arg("context" ).none() = nb::none(), "Create a index type." ); |
| 112 | } |
| 113 | }; |
| 114 | |
| 115 | class PyFloatType : public PyConcreteType<PyFloatType> { |
| 116 | public: |
| 117 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; |
| 118 | static constexpr const char *pyClassName = "FloatType" ; |
| 119 | using PyConcreteType::PyConcreteType; |
| 120 | |
| 121 | static void bindDerived(ClassTy &c) { |
| 122 | c.def_prop_ro( |
| 123 | "width" , [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, |
| 124 | "Returns the width of the floating-point type" ); |
| 125 | } |
| 126 | }; |
| 127 | |
| 128 | /// Floating Point Type subclass - Float4E2M1FNType. |
| 129 | class PyFloat4E2M1FNType |
| 130 | : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> { |
| 131 | public: |
| 132 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; |
| 133 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 134 | mlirFloat4E2M1FNTypeGetTypeID; |
| 135 | static constexpr const char *pyClassName = "Float4E2M1FNType" ; |
| 136 | using PyConcreteType::PyConcreteType; |
| 137 | |
| 138 | static void bindDerived(ClassTy &c) { |
| 139 | c.def_static( |
| 140 | "get" , |
| 141 | [](DefaultingPyMlirContext context) { |
| 142 | MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); |
| 143 | return PyFloat4E2M1FNType(context->getRef(), t); |
| 144 | }, |
| 145 | nb::arg("context" ).none() = nb::none(), "Create a float4_e2m1fn type." ); |
| 146 | } |
| 147 | }; |
| 148 | |
| 149 | /// Floating Point Type subclass - Float6E2M3FNType. |
| 150 | class PyFloat6E2M3FNType |
| 151 | : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> { |
| 152 | public: |
| 153 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; |
| 154 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 155 | mlirFloat6E2M3FNTypeGetTypeID; |
| 156 | static constexpr const char *pyClassName = "Float6E2M3FNType" ; |
| 157 | using PyConcreteType::PyConcreteType; |
| 158 | |
| 159 | static void bindDerived(ClassTy &c) { |
| 160 | c.def_static( |
| 161 | "get" , |
| 162 | [](DefaultingPyMlirContext context) { |
| 163 | MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); |
| 164 | return PyFloat6E2M3FNType(context->getRef(), t); |
| 165 | }, |
| 166 | nb::arg("context" ).none() = nb::none(), "Create a float6_e2m3fn type." ); |
| 167 | } |
| 168 | }; |
| 169 | |
| 170 | /// Floating Point Type subclass - Float6E3M2FNType. |
| 171 | class PyFloat6E3M2FNType |
| 172 | : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> { |
| 173 | public: |
| 174 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; |
| 175 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 176 | mlirFloat6E3M2FNTypeGetTypeID; |
| 177 | static constexpr const char *pyClassName = "Float6E3M2FNType" ; |
| 178 | using PyConcreteType::PyConcreteType; |
| 179 | |
| 180 | static void bindDerived(ClassTy &c) { |
| 181 | c.def_static( |
| 182 | "get" , |
| 183 | [](DefaultingPyMlirContext context) { |
| 184 | MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); |
| 185 | return PyFloat6E3M2FNType(context->getRef(), t); |
| 186 | }, |
| 187 | nb::arg("context" ).none() = nb::none(), "Create a float6_e3m2fn type." ); |
| 188 | } |
| 189 | }; |
| 190 | |
| 191 | /// Floating Point Type subclass - Float8E4M3FNType. |
| 192 | class PyFloat8E4M3FNType |
| 193 | : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { |
| 194 | public: |
| 195 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; |
| 196 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 197 | mlirFloat8E4M3FNTypeGetTypeID; |
| 198 | static constexpr const char *pyClassName = "Float8E4M3FNType" ; |
| 199 | using PyConcreteType::PyConcreteType; |
| 200 | |
| 201 | static void bindDerived(ClassTy &c) { |
| 202 | c.def_static( |
| 203 | "get" , |
| 204 | [](DefaultingPyMlirContext context) { |
| 205 | MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); |
| 206 | return PyFloat8E4M3FNType(context->getRef(), t); |
| 207 | }, |
| 208 | nb::arg("context" ).none() = nb::none(), "Create a float8_e4m3fn type." ); |
| 209 | } |
| 210 | }; |
| 211 | |
| 212 | /// Floating Point Type subclass - Float8E5M2Type. |
| 213 | class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { |
| 214 | public: |
| 215 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; |
| 216 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 217 | mlirFloat8E5M2TypeGetTypeID; |
| 218 | static constexpr const char *pyClassName = "Float8E5M2Type" ; |
| 219 | using PyConcreteType::PyConcreteType; |
| 220 | |
| 221 | static void bindDerived(ClassTy &c) { |
| 222 | c.def_static( |
| 223 | "get" , |
| 224 | [](DefaultingPyMlirContext context) { |
| 225 | MlirType t = mlirFloat8E5M2TypeGet(context->get()); |
| 226 | return PyFloat8E5M2Type(context->getRef(), t); |
| 227 | }, |
| 228 | nb::arg("context" ).none() = nb::none(), "Create a float8_e5m2 type." ); |
| 229 | } |
| 230 | }; |
| 231 | |
| 232 | /// Floating Point Type subclass - Float8E4M3Type. |
| 233 | class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> { |
| 234 | public: |
| 235 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; |
| 236 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 237 | mlirFloat8E4M3TypeGetTypeID; |
| 238 | static constexpr const char *pyClassName = "Float8E4M3Type" ; |
| 239 | using PyConcreteType::PyConcreteType; |
| 240 | |
| 241 | static void bindDerived(ClassTy &c) { |
| 242 | c.def_static( |
| 243 | "get" , |
| 244 | [](DefaultingPyMlirContext context) { |
| 245 | MlirType t = mlirFloat8E4M3TypeGet(context->get()); |
| 246 | return PyFloat8E4M3Type(context->getRef(), t); |
| 247 | }, |
| 248 | nb::arg("context" ).none() = nb::none(), "Create a float8_e4m3 type." ); |
| 249 | } |
| 250 | }; |
| 251 | |
| 252 | /// Floating Point Type subclass - Float8E4M3FNUZ. |
| 253 | class PyFloat8E4M3FNUZType |
| 254 | : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { |
| 255 | public: |
| 256 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; |
| 257 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 258 | mlirFloat8E4M3FNUZTypeGetTypeID; |
| 259 | static constexpr const char *pyClassName = "Float8E4M3FNUZType" ; |
| 260 | using PyConcreteType::PyConcreteType; |
| 261 | |
| 262 | static void bindDerived(ClassTy &c) { |
| 263 | c.def_static( |
| 264 | "get" , |
| 265 | [](DefaultingPyMlirContext context) { |
| 266 | MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); |
| 267 | return PyFloat8E4M3FNUZType(context->getRef(), t); |
| 268 | }, |
| 269 | nb::arg("context" ).none() = nb::none(), |
| 270 | "Create a float8_e4m3fnuz type." ); |
| 271 | } |
| 272 | }; |
| 273 | |
| 274 | /// Floating Point Type subclass - Float8E4M3B11FNUZ. |
| 275 | class PyFloat8E4M3B11FNUZType |
| 276 | : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { |
| 277 | public: |
| 278 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; |
| 279 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 280 | mlirFloat8E4M3B11FNUZTypeGetTypeID; |
| 281 | static constexpr const char *pyClassName = "Float8E4M3B11FNUZType" ; |
| 282 | using PyConcreteType::PyConcreteType; |
| 283 | |
| 284 | static void bindDerived(ClassTy &c) { |
| 285 | c.def_static( |
| 286 | "get" , |
| 287 | [](DefaultingPyMlirContext context) { |
| 288 | MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); |
| 289 | return PyFloat8E4M3B11FNUZType(context->getRef(), t); |
| 290 | }, |
| 291 | nb::arg("context" ).none() = nb::none(), |
| 292 | "Create a float8_e4m3b11fnuz type." ); |
| 293 | } |
| 294 | }; |
| 295 | |
| 296 | /// Floating Point Type subclass - Float8E5M2FNUZ. |
| 297 | class PyFloat8E5M2FNUZType |
| 298 | : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { |
| 299 | public: |
| 300 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; |
| 301 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 302 | mlirFloat8E5M2FNUZTypeGetTypeID; |
| 303 | static constexpr const char *pyClassName = "Float8E5M2FNUZType" ; |
| 304 | using PyConcreteType::PyConcreteType; |
| 305 | |
| 306 | static void bindDerived(ClassTy &c) { |
| 307 | c.def_static( |
| 308 | "get" , |
| 309 | [](DefaultingPyMlirContext context) { |
| 310 | MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); |
| 311 | return PyFloat8E5M2FNUZType(context->getRef(), t); |
| 312 | }, |
| 313 | nb::arg("context" ).none() = nb::none(), |
| 314 | "Create a float8_e5m2fnuz type." ); |
| 315 | } |
| 316 | }; |
| 317 | |
| 318 | /// Floating Point Type subclass - Float8E3M4Type. |
| 319 | class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> { |
| 320 | public: |
| 321 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; |
| 322 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 323 | mlirFloat8E3M4TypeGetTypeID; |
| 324 | static constexpr const char *pyClassName = "Float8E3M4Type" ; |
| 325 | using PyConcreteType::PyConcreteType; |
| 326 | |
| 327 | static void bindDerived(ClassTy &c) { |
| 328 | c.def_static( |
| 329 | "get" , |
| 330 | [](DefaultingPyMlirContext context) { |
| 331 | MlirType t = mlirFloat8E3M4TypeGet(context->get()); |
| 332 | return PyFloat8E3M4Type(context->getRef(), t); |
| 333 | }, |
| 334 | nb::arg("context" ).none() = nb::none(), "Create a float8_e3m4 type." ); |
| 335 | } |
| 336 | }; |
| 337 | |
| 338 | /// Floating Point Type subclass - Float8E8M0FNUType. |
| 339 | class PyFloat8E8M0FNUType |
| 340 | : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> { |
| 341 | public: |
| 342 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; |
| 343 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 344 | mlirFloat8E8M0FNUTypeGetTypeID; |
| 345 | static constexpr const char *pyClassName = "Float8E8M0FNUType" ; |
| 346 | using PyConcreteType::PyConcreteType; |
| 347 | |
| 348 | static void bindDerived(ClassTy &c) { |
| 349 | c.def_static( |
| 350 | "get" , |
| 351 | [](DefaultingPyMlirContext context) { |
| 352 | MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); |
| 353 | return PyFloat8E8M0FNUType(context->getRef(), t); |
| 354 | }, |
| 355 | nb::arg("context" ).none() = nb::none(), |
| 356 | "Create a float8_e8m0fnu type." ); |
| 357 | } |
| 358 | }; |
| 359 | |
| 360 | /// Floating Point Type subclass - BF16Type. |
| 361 | class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> { |
| 362 | public: |
| 363 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; |
| 364 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 365 | mlirBFloat16TypeGetTypeID; |
| 366 | static constexpr const char *pyClassName = "BF16Type" ; |
| 367 | using PyConcreteType::PyConcreteType; |
| 368 | |
| 369 | static void bindDerived(ClassTy &c) { |
| 370 | c.def_static( |
| 371 | "get" , |
| 372 | [](DefaultingPyMlirContext context) { |
| 373 | MlirType t = mlirBF16TypeGet(context->get()); |
| 374 | return PyBF16Type(context->getRef(), t); |
| 375 | }, |
| 376 | nb::arg("context" ).none() = nb::none(), "Create a bf16 type." ); |
| 377 | } |
| 378 | }; |
| 379 | |
| 380 | /// Floating Point Type subclass - F16Type. |
| 381 | class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> { |
| 382 | public: |
| 383 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; |
| 384 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 385 | mlirFloat16TypeGetTypeID; |
| 386 | static constexpr const char *pyClassName = "F16Type" ; |
| 387 | using PyConcreteType::PyConcreteType; |
| 388 | |
| 389 | static void bindDerived(ClassTy &c) { |
| 390 | c.def_static( |
| 391 | "get" , |
| 392 | [](DefaultingPyMlirContext context) { |
| 393 | MlirType t = mlirF16TypeGet(context->get()); |
| 394 | return PyF16Type(context->getRef(), t); |
| 395 | }, |
| 396 | nb::arg("context" ).none() = nb::none(), "Create a f16 type." ); |
| 397 | } |
| 398 | }; |
| 399 | |
| 400 | /// Floating Point Type subclass - TF32Type. |
| 401 | class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> { |
| 402 | public: |
| 403 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; |
| 404 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 405 | mlirFloatTF32TypeGetTypeID; |
| 406 | static constexpr const char *pyClassName = "FloatTF32Type" ; |
| 407 | using PyConcreteType::PyConcreteType; |
| 408 | |
| 409 | static void bindDerived(ClassTy &c) { |
| 410 | c.def_static( |
| 411 | "get" , |
| 412 | [](DefaultingPyMlirContext context) { |
| 413 | MlirType t = mlirTF32TypeGet(context->get()); |
| 414 | return PyTF32Type(context->getRef(), t); |
| 415 | }, |
| 416 | nb::arg("context" ).none() = nb::none(), "Create a tf32 type." ); |
| 417 | } |
| 418 | }; |
| 419 | |
| 420 | /// Floating Point Type subclass - F32Type. |
| 421 | class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> { |
| 422 | public: |
| 423 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; |
| 424 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 425 | mlirFloat32TypeGetTypeID; |
| 426 | static constexpr const char *pyClassName = "F32Type" ; |
| 427 | using PyConcreteType::PyConcreteType; |
| 428 | |
| 429 | static void bindDerived(ClassTy &c) { |
| 430 | c.def_static( |
| 431 | "get" , |
| 432 | [](DefaultingPyMlirContext context) { |
| 433 | MlirType t = mlirF32TypeGet(context->get()); |
| 434 | return PyF32Type(context->getRef(), t); |
| 435 | }, |
| 436 | nb::arg("context" ).none() = nb::none(), "Create a f32 type." ); |
| 437 | } |
| 438 | }; |
| 439 | |
| 440 | /// Floating Point Type subclass - F64Type. |
| 441 | class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> { |
| 442 | public: |
| 443 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; |
| 444 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 445 | mlirFloat64TypeGetTypeID; |
| 446 | static constexpr const char *pyClassName = "F64Type" ; |
| 447 | using PyConcreteType::PyConcreteType; |
| 448 | |
| 449 | static void bindDerived(ClassTy &c) { |
| 450 | c.def_static( |
| 451 | "get" , |
| 452 | [](DefaultingPyMlirContext context) { |
| 453 | MlirType t = mlirF64TypeGet(context->get()); |
| 454 | return PyF64Type(context->getRef(), t); |
| 455 | }, |
| 456 | nb::arg("context" ).none() = nb::none(), "Create a f64 type." ); |
| 457 | } |
| 458 | }; |
| 459 | |
| 460 | /// None Type subclass - NoneType. |
| 461 | class PyNoneType : public PyConcreteType<PyNoneType> { |
| 462 | public: |
| 463 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; |
| 464 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 465 | mlirNoneTypeGetTypeID; |
| 466 | static constexpr const char *pyClassName = "NoneType" ; |
| 467 | using PyConcreteType::PyConcreteType; |
| 468 | |
| 469 | static void bindDerived(ClassTy &c) { |
| 470 | c.def_static( |
| 471 | "get" , |
| 472 | [](DefaultingPyMlirContext context) { |
| 473 | MlirType t = mlirNoneTypeGet(context->get()); |
| 474 | return PyNoneType(context->getRef(), t); |
| 475 | }, |
| 476 | nb::arg("context" ).none() = nb::none(), "Create a none type." ); |
| 477 | } |
| 478 | }; |
| 479 | |
| 480 | /// Complex Type subclass - ComplexType. |
| 481 | class PyComplexType : public PyConcreteType<PyComplexType> { |
| 482 | public: |
| 483 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; |
| 484 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 485 | mlirComplexTypeGetTypeID; |
| 486 | static constexpr const char *pyClassName = "ComplexType" ; |
| 487 | using PyConcreteType::PyConcreteType; |
| 488 | |
| 489 | static void bindDerived(ClassTy &c) { |
| 490 | c.def_static( |
| 491 | "get" , |
| 492 | [](PyType &elementType) { |
| 493 | // The element must be a floating point or integer scalar type. |
| 494 | if (mlirTypeIsAIntegerOrFloat(elementType)) { |
| 495 | MlirType t = mlirComplexTypeGet(elementType); |
| 496 | return PyComplexType(elementType.getContext(), t); |
| 497 | } |
| 498 | throw nb::value_error( |
| 499 | (Twine("invalid '" ) + |
| 500 | nb::cast<std::string>(nb::repr(nb::cast(elementType))) + |
| 501 | "' and expected floating point or integer type." ) |
| 502 | .str() |
| 503 | .c_str()); |
| 504 | }, |
| 505 | "Create a complex type" ); |
| 506 | c.def_prop_ro( |
| 507 | "element_type" , |
| 508 | [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, |
| 509 | "Returns element type." ); |
| 510 | } |
| 511 | }; |
| 512 | |
| 513 | } // namespace |
| 514 | |
| 515 | // Shaped Type Interface - ShapedType |
| 516 | void mlir::PyShapedType::bindDerived(ClassTy &c) { |
| 517 | c.def_prop_ro( |
| 518 | "element_type" , |
| 519 | [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, |
| 520 | "Returns the element type of the shaped type." ); |
| 521 | c.def_prop_ro( |
| 522 | "has_rank" , |
| 523 | [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, |
| 524 | "Returns whether the given shaped type is ranked." ); |
| 525 | c.def_prop_ro( |
| 526 | "rank" , |
| 527 | [](PyShapedType &self) { |
| 528 | self.requireHasRank(); |
| 529 | return mlirShapedTypeGetRank(self); |
| 530 | }, |
| 531 | "Returns the rank of the given ranked shaped type." ); |
| 532 | c.def_prop_ro( |
| 533 | "has_static_shape" , |
| 534 | [](PyShapedType &self) -> bool { |
| 535 | return mlirShapedTypeHasStaticShape(self); |
| 536 | }, |
| 537 | "Returns whether the given shaped type has a static shape." ); |
| 538 | c.def( |
| 539 | "is_dynamic_dim" , |
| 540 | [](PyShapedType &self, intptr_t dim) -> bool { |
| 541 | self.requireHasRank(); |
| 542 | return mlirShapedTypeIsDynamicDim(self, dim); |
| 543 | }, |
| 544 | nb::arg("dim" ), |
| 545 | "Returns whether the dim-th dimension of the given shaped type is " |
| 546 | "dynamic." ); |
| 547 | c.def( |
| 548 | "get_dim_size" , |
| 549 | [](PyShapedType &self, intptr_t dim) { |
| 550 | self.requireHasRank(); |
| 551 | return mlirShapedTypeGetDimSize(self, dim); |
| 552 | }, |
| 553 | nb::arg("dim" ), |
| 554 | "Returns the dim-th dimension of the given ranked shaped type." ); |
| 555 | c.def_static( |
| 556 | "is_dynamic_size" , |
| 557 | [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, |
| 558 | nb::arg("dim_size" ), |
| 559 | "Returns whether the given dimension size indicates a dynamic " |
| 560 | "dimension." ); |
| 561 | c.def( |
| 562 | "is_dynamic_stride_or_offset" , |
| 563 | [](PyShapedType &self, int64_t val) -> bool { |
| 564 | self.requireHasRank(); |
| 565 | return mlirShapedTypeIsDynamicStrideOrOffset(val); |
| 566 | }, |
| 567 | nb::arg("dim_size" ), |
| 568 | "Returns whether the given value is used as a placeholder for dynamic " |
| 569 | "strides and offsets in shaped types." ); |
| 570 | c.def_prop_ro( |
| 571 | "shape" , |
| 572 | [](PyShapedType &self) { |
| 573 | self.requireHasRank(); |
| 574 | |
| 575 | std::vector<int64_t> shape; |
| 576 | int64_t rank = mlirShapedTypeGetRank(self); |
| 577 | shape.reserve(rank); |
| 578 | for (int64_t i = 0; i < rank; ++i) |
| 579 | shape.push_back(mlirShapedTypeGetDimSize(self, i)); |
| 580 | return shape; |
| 581 | }, |
| 582 | "Returns the shape of the ranked shaped type as a list of integers." ); |
| 583 | c.def_static( |
| 584 | "get_dynamic_size" , []() { return mlirShapedTypeGetDynamicSize(); }, |
| 585 | "Returns the value used to indicate dynamic dimensions in shaped " |
| 586 | "types." ); |
| 587 | c.def_static( |
| 588 | "get_dynamic_stride_or_offset" , |
| 589 | []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, |
| 590 | "Returns the value used to indicate dynamic strides or offsets in " |
| 591 | "shaped types." ); |
| 592 | } |
| 593 | |
| 594 | void mlir::PyShapedType::requireHasRank() { |
| 595 | if (!mlirShapedTypeHasRank(*this)) { |
| 596 | throw nb::value_error( |
| 597 | "calling this method requires that the type has a rank." ); |
| 598 | } |
| 599 | } |
| 600 | |
| 601 | const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = |
| 602 | mlirTypeIsAShaped; |
| 603 | |
| 604 | namespace { |
| 605 | |
| 606 | /// Vector Type subclass - VectorType. |
| 607 | class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { |
| 608 | public: |
| 609 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; |
| 610 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 611 | mlirVectorTypeGetTypeID; |
| 612 | static constexpr const char *pyClassName = "VectorType" ; |
| 613 | using PyConcreteType::PyConcreteType; |
| 614 | |
| 615 | static void bindDerived(ClassTy &c) { |
| 616 | c.def_static("get" , &PyVectorType::get, nb::arg("shape" ), |
| 617 | nb::arg("element_type" ), nb::kw_only(), |
| 618 | nb::arg("scalable" ).none() = nb::none(), |
| 619 | nb::arg("scalable_dims" ).none() = nb::none(), |
| 620 | nb::arg("loc" ).none() = nb::none(), "Create a vector type" ) |
| 621 | .def_prop_ro( |
| 622 | "scalable" , |
| 623 | [](MlirType self) { return mlirVectorTypeIsScalable(self); }) |
| 624 | .def_prop_ro("scalable_dims" , [](MlirType self) { |
| 625 | std::vector<bool> scalableDims; |
| 626 | size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self)); |
| 627 | scalableDims.reserve(rank); |
| 628 | for (size_t i = 0; i < rank; ++i) |
| 629 | scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); |
| 630 | return scalableDims; |
| 631 | }); |
| 632 | } |
| 633 | |
| 634 | private: |
| 635 | static PyVectorType get(std::vector<int64_t> shape, PyType &elementType, |
| 636 | std::optional<nb::list> scalable, |
| 637 | std::optional<std::vector<int64_t>> scalableDims, |
| 638 | DefaultingPyLocation loc) { |
| 639 | if (scalable && scalableDims) { |
| 640 | throw nb::value_error("'scalable' and 'scalable_dims' kwargs " |
| 641 | "are mutually exclusive." ); |
| 642 | } |
| 643 | |
| 644 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 645 | MlirType type; |
| 646 | if (scalable) { |
| 647 | if (scalable->size() != shape.size()) |
| 648 | throw nb::value_error("Expected len(scalable) == len(shape)." ); |
| 649 | |
| 650 | SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range( |
| 651 | *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); })); |
| 652 | type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), |
| 653 | scalableDimFlags.data(), |
| 654 | elementType); |
| 655 | } else if (scalableDims) { |
| 656 | SmallVector<bool> scalableDimFlags(shape.size(), false); |
| 657 | for (int64_t dim : *scalableDims) { |
| 658 | if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) |
| 659 | throw nb::value_error("Scalable dimension index out of bounds." ); |
| 660 | scalableDimFlags[dim] = true; |
| 661 | } |
| 662 | type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), |
| 663 | scalableDimFlags.data(), |
| 664 | elementType); |
| 665 | } else { |
| 666 | type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), |
| 667 | elementType); |
| 668 | } |
| 669 | if (mlirTypeIsNull(type)) |
| 670 | throw MLIRError("Invalid type" , errors.take()); |
| 671 | return PyVectorType(elementType.getContext(), type); |
| 672 | } |
| 673 | }; |
| 674 | |
| 675 | /// Ranked Tensor Type subclass - RankedTensorType. |
| 676 | class PyRankedTensorType |
| 677 | : public PyConcreteType<PyRankedTensorType, PyShapedType> { |
| 678 | public: |
| 679 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; |
| 680 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 681 | mlirRankedTensorTypeGetTypeID; |
| 682 | static constexpr const char *pyClassName = "RankedTensorType" ; |
| 683 | using PyConcreteType::PyConcreteType; |
| 684 | |
| 685 | static void bindDerived(ClassTy &c) { |
| 686 | c.def_static( |
| 687 | "get" , |
| 688 | [](std::vector<int64_t> shape, PyType &elementType, |
| 689 | std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) { |
| 690 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 691 | MlirType t = mlirRankedTensorTypeGetChecked( |
| 692 | loc, shape.size(), shape.data(), elementType, |
| 693 | encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); |
| 694 | if (mlirTypeIsNull(t)) |
| 695 | throw MLIRError("Invalid type" , errors.take()); |
| 696 | return PyRankedTensorType(elementType.getContext(), t); |
| 697 | }, |
| 698 | nb::arg("shape" ), nb::arg("element_type" ), |
| 699 | nb::arg("encoding" ).none() = nb::none(), |
| 700 | nb::arg("loc" ).none() = nb::none(), "Create a ranked tensor type" ); |
| 701 | c.def_prop_ro("encoding" , |
| 702 | [](PyRankedTensorType &self) -> std::optional<MlirAttribute> { |
| 703 | MlirAttribute encoding = |
| 704 | mlirRankedTensorTypeGetEncoding(self.get()); |
| 705 | if (mlirAttributeIsNull(encoding)) |
| 706 | return std::nullopt; |
| 707 | return encoding; |
| 708 | }); |
| 709 | } |
| 710 | }; |
| 711 | |
| 712 | /// Unranked Tensor Type subclass - UnrankedTensorType. |
| 713 | class PyUnrankedTensorType |
| 714 | : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { |
| 715 | public: |
| 716 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; |
| 717 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 718 | mlirUnrankedTensorTypeGetTypeID; |
| 719 | static constexpr const char *pyClassName = "UnrankedTensorType" ; |
| 720 | using PyConcreteType::PyConcreteType; |
| 721 | |
| 722 | static void bindDerived(ClassTy &c) { |
| 723 | c.def_static( |
| 724 | "get" , |
| 725 | [](PyType &elementType, DefaultingPyLocation loc) { |
| 726 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 727 | MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); |
| 728 | if (mlirTypeIsNull(t)) |
| 729 | throw MLIRError("Invalid type" , errors.take()); |
| 730 | return PyUnrankedTensorType(elementType.getContext(), t); |
| 731 | }, |
| 732 | nb::arg("element_type" ), nb::arg("loc" ).none() = nb::none(), |
| 733 | "Create a unranked tensor type" ); |
| 734 | } |
| 735 | }; |
| 736 | |
| 737 | /// Ranked MemRef Type subclass - MemRefType. |
| 738 | class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { |
| 739 | public: |
| 740 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; |
| 741 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 742 | mlirMemRefTypeGetTypeID; |
| 743 | static constexpr const char *pyClassName = "MemRefType" ; |
| 744 | using PyConcreteType::PyConcreteType; |
| 745 | |
| 746 | static void bindDerived(ClassTy &c) { |
| 747 | c.def_static( |
| 748 | "get" , |
| 749 | [](std::vector<int64_t> shape, PyType &elementType, |
| 750 | PyAttribute *layout, PyAttribute *memorySpace, |
| 751 | DefaultingPyLocation loc) { |
| 752 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 753 | MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); |
| 754 | MlirAttribute memSpaceAttr = |
| 755 | memorySpace ? *memorySpace : mlirAttributeGetNull(); |
| 756 | MlirType t = |
| 757 | mlirMemRefTypeGetChecked(loc, elementType, shape.size(), |
| 758 | shape.data(), layoutAttr, memSpaceAttr); |
| 759 | if (mlirTypeIsNull(t)) |
| 760 | throw MLIRError("Invalid type" , errors.take()); |
| 761 | return PyMemRefType(elementType.getContext(), t); |
| 762 | }, |
| 763 | nb::arg("shape" ), nb::arg("element_type" ), |
| 764 | nb::arg("layout" ).none() = nb::none(), |
| 765 | nb::arg("memory_space" ).none() = nb::none(), |
| 766 | nb::arg("loc" ).none() = nb::none(), "Create a memref type" ) |
| 767 | .def_prop_ro( |
| 768 | "layout" , |
| 769 | [](PyMemRefType &self) -> MlirAttribute { |
| 770 | return mlirMemRefTypeGetLayout(self); |
| 771 | }, |
| 772 | "The layout of the MemRef type." ) |
| 773 | .def( |
| 774 | "get_strides_and_offset" , |
| 775 | [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> { |
| 776 | std::vector<int64_t> strides(mlirShapedTypeGetRank(self)); |
| 777 | int64_t offset; |
| 778 | if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( |
| 779 | self, strides.data(), &offset))) |
| 780 | throw std::runtime_error( |
| 781 | "Failed to extract strides and offset from memref." ); |
| 782 | return {strides, offset}; |
| 783 | }, |
| 784 | "The strides and offset of the MemRef type." ) |
| 785 | .def_prop_ro( |
| 786 | "affine_map" , |
| 787 | [](PyMemRefType &self) -> PyAffineMap { |
| 788 | MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); |
| 789 | return PyAffineMap(self.getContext(), map); |
| 790 | }, |
| 791 | "The layout of the MemRef type as an affine map." ) |
| 792 | .def_prop_ro( |
| 793 | "memory_space" , |
| 794 | [](PyMemRefType &self) -> std::optional<MlirAttribute> { |
| 795 | MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); |
| 796 | if (mlirAttributeIsNull(a)) |
| 797 | return std::nullopt; |
| 798 | return a; |
| 799 | }, |
| 800 | "Returns the memory space of the given MemRef type." ); |
| 801 | } |
| 802 | }; |
| 803 | |
| 804 | /// Unranked MemRef Type subclass - UnrankedMemRefType. |
| 805 | class PyUnrankedMemRefType |
| 806 | : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { |
| 807 | public: |
| 808 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; |
| 809 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 810 | mlirUnrankedMemRefTypeGetTypeID; |
| 811 | static constexpr const char *pyClassName = "UnrankedMemRefType" ; |
| 812 | using PyConcreteType::PyConcreteType; |
| 813 | |
| 814 | static void bindDerived(ClassTy &c) { |
| 815 | c.def_static( |
| 816 | "get" , |
| 817 | [](PyType &elementType, PyAttribute *memorySpace, |
| 818 | DefaultingPyLocation loc) { |
| 819 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 820 | MlirAttribute memSpaceAttr = {}; |
| 821 | if (memorySpace) |
| 822 | memSpaceAttr = *memorySpace; |
| 823 | |
| 824 | MlirType t = |
| 825 | mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); |
| 826 | if (mlirTypeIsNull(t)) |
| 827 | throw MLIRError("Invalid type" , errors.take()); |
| 828 | return PyUnrankedMemRefType(elementType.getContext(), t); |
| 829 | }, |
| 830 | nb::arg("element_type" ), nb::arg("memory_space" ).none(), |
| 831 | nb::arg("loc" ).none() = nb::none(), "Create a unranked memref type" ) |
| 832 | .def_prop_ro( |
| 833 | "memory_space" , |
| 834 | [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> { |
| 835 | MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); |
| 836 | if (mlirAttributeIsNull(a)) |
| 837 | return std::nullopt; |
| 838 | return a; |
| 839 | }, |
| 840 | "Returns the memory space of the given Unranked MemRef type." ); |
| 841 | } |
| 842 | }; |
| 843 | |
| 844 | /// Tuple Type subclass - TupleType. |
| 845 | class PyTupleType : public PyConcreteType<PyTupleType> { |
| 846 | public: |
| 847 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; |
| 848 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 849 | mlirTupleTypeGetTypeID; |
| 850 | static constexpr const char *pyClassName = "TupleType" ; |
| 851 | using PyConcreteType::PyConcreteType; |
| 852 | |
| 853 | static void bindDerived(ClassTy &c) { |
| 854 | c.def_static( |
| 855 | "get_tuple" , |
| 856 | [](std::vector<MlirType> elements, DefaultingPyMlirContext context) { |
| 857 | MlirType t = mlirTupleTypeGet(context->get(), elements.size(), |
| 858 | elements.data()); |
| 859 | return PyTupleType(context->getRef(), t); |
| 860 | }, |
| 861 | nb::arg("elements" ), nb::arg("context" ).none() = nb::none(), |
| 862 | "Create a tuple type" ); |
| 863 | c.def( |
| 864 | "get_type" , |
| 865 | [](PyTupleType &self, intptr_t pos) { |
| 866 | return mlirTupleTypeGetType(self, pos); |
| 867 | }, |
| 868 | nb::arg("pos" ), "Returns the pos-th type in the tuple type." ); |
| 869 | c.def_prop_ro( |
| 870 | "num_types" , |
| 871 | [](PyTupleType &self) -> intptr_t { |
| 872 | return mlirTupleTypeGetNumTypes(self); |
| 873 | }, |
| 874 | "Returns the number of types contained in a tuple." ); |
| 875 | } |
| 876 | }; |
| 877 | |
| 878 | /// Function type. |
| 879 | class PyFunctionType : public PyConcreteType<PyFunctionType> { |
| 880 | public: |
| 881 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; |
| 882 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 883 | mlirFunctionTypeGetTypeID; |
| 884 | static constexpr const char *pyClassName = "FunctionType" ; |
| 885 | using PyConcreteType::PyConcreteType; |
| 886 | |
| 887 | static void bindDerived(ClassTy &c) { |
| 888 | c.def_static( |
| 889 | "get" , |
| 890 | [](std::vector<MlirType> inputs, std::vector<MlirType> results, |
| 891 | DefaultingPyMlirContext context) { |
| 892 | MlirType t = |
| 893 | mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), |
| 894 | results.size(), results.data()); |
| 895 | return PyFunctionType(context->getRef(), t); |
| 896 | }, |
| 897 | nb::arg("inputs" ), nb::arg("results" ), |
| 898 | nb::arg("context" ).none() = nb::none(), |
| 899 | "Gets a FunctionType from a list of input and result types" ); |
| 900 | c.def_prop_ro( |
| 901 | "inputs" , |
| 902 | [](PyFunctionType &self) { |
| 903 | MlirType t = self; |
| 904 | nb::list types; |
| 905 | for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; |
| 906 | ++i) { |
| 907 | types.append(mlirFunctionTypeGetInput(t, i)); |
| 908 | } |
| 909 | return types; |
| 910 | }, |
| 911 | "Returns the list of input types in the FunctionType." ); |
| 912 | c.def_prop_ro( |
| 913 | "results" , |
| 914 | [](PyFunctionType &self) { |
| 915 | nb::list types; |
| 916 | for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; |
| 917 | ++i) { |
| 918 | types.append(mlirFunctionTypeGetResult(self, i)); |
| 919 | } |
| 920 | return types; |
| 921 | }, |
| 922 | "Returns the list of result types in the FunctionType." ); |
| 923 | } |
| 924 | }; |
| 925 | |
| 926 | static MlirStringRef toMlirStringRef(const std::string &s) { |
| 927 | return mlirStringRefCreate(s.data(), s.size()); |
| 928 | } |
| 929 | |
| 930 | /// Opaque Type subclass - OpaqueType. |
| 931 | class PyOpaqueType : public PyConcreteType<PyOpaqueType> { |
| 932 | public: |
| 933 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; |
| 934 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 935 | mlirOpaqueTypeGetTypeID; |
| 936 | static constexpr const char *pyClassName = "OpaqueType" ; |
| 937 | using PyConcreteType::PyConcreteType; |
| 938 | |
| 939 | static void bindDerived(ClassTy &c) { |
| 940 | c.def_static( |
| 941 | "get" , |
| 942 | [](std::string dialectNamespace, std::string typeData, |
| 943 | DefaultingPyMlirContext context) { |
| 944 | MlirType type = mlirOpaqueTypeGet(context->get(), |
| 945 | toMlirStringRef(dialectNamespace), |
| 946 | toMlirStringRef(typeData)); |
| 947 | return PyOpaqueType(context->getRef(), type); |
| 948 | }, |
| 949 | nb::arg("dialect_namespace" ), nb::arg("buffer" ), |
| 950 | nb::arg("context" ).none() = nb::none(), |
| 951 | "Create an unregistered (opaque) dialect type." ); |
| 952 | c.def_prop_ro( |
| 953 | "dialect_namespace" , |
| 954 | [](PyOpaqueType &self) { |
| 955 | MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); |
| 956 | return nb::str(stringRef.data, stringRef.length); |
| 957 | }, |
| 958 | "Returns the dialect namespace for the Opaque type as a string." ); |
| 959 | c.def_prop_ro( |
| 960 | "data" , |
| 961 | [](PyOpaqueType &self) { |
| 962 | MlirStringRef stringRef = mlirOpaqueTypeGetData(self); |
| 963 | return nb::str(stringRef.data, stringRef.length); |
| 964 | }, |
| 965 | "Returns the data for the Opaque type as a string." ); |
| 966 | } |
| 967 | }; |
| 968 | |
| 969 | } // namespace |
| 970 | |
| 971 | void mlir::python::populateIRTypes(nb::module_ &m) { |
| 972 | PyIntegerType::bind(m); |
| 973 | PyFloatType::bind(m); |
| 974 | PyIndexType::bind(m); |
| 975 | PyFloat4E2M1FNType::bind(m); |
| 976 | PyFloat6E2M3FNType::bind(m); |
| 977 | PyFloat6E3M2FNType::bind(m); |
| 978 | PyFloat8E4M3FNType::bind(m); |
| 979 | PyFloat8E5M2Type::bind(m); |
| 980 | PyFloat8E4M3Type::bind(m); |
| 981 | PyFloat8E4M3FNUZType::bind(m); |
| 982 | PyFloat8E4M3B11FNUZType::bind(m); |
| 983 | PyFloat8E5M2FNUZType::bind(m); |
| 984 | PyFloat8E3M4Type::bind(m); |
| 985 | PyFloat8E8M0FNUType::bind(m); |
| 986 | PyBF16Type::bind(m); |
| 987 | PyF16Type::bind(m); |
| 988 | PyTF32Type::bind(m); |
| 989 | PyF32Type::bind(m); |
| 990 | PyF64Type::bind(m); |
| 991 | PyNoneType::bind(m); |
| 992 | PyComplexType::bind(m); |
| 993 | PyShapedType::bind(m); |
| 994 | PyVectorType::bind(m); |
| 995 | PyRankedTensorType::bind(m); |
| 996 | PyUnrankedTensorType::bind(m); |
| 997 | PyMemRefType::bind(m); |
| 998 | PyUnrankedMemRefType::bind(m); |
| 999 | PyTupleType::bind(m); |
| 1000 | PyFunctionType::bind(m); |
| 1001 | PyOpaqueType::bind(m); |
| 1002 | } |
| 1003 | |