| 1 | //===- IRTypes.cpp - Exports builtin and standard types -------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | // clang-format off |
| 10 | #include "IRModule.h" |
| 11 | #include "mlir/Bindings/Python/IRTypes.h" |
| 12 | // clang-format on |
| 13 | |
| 14 | #include <optional> |
| 15 | |
| 16 | #include "IRModule.h" |
| 17 | #include "NanobindUtils.h" |
| 18 | #include "mlir-c/BuiltinAttributes.h" |
| 19 | #include "mlir-c/BuiltinTypes.h" |
| 20 | #include "mlir-c/Support.h" |
| 21 | |
| 22 | namespace nb = nanobind; |
| 23 | using namespace mlir; |
| 24 | using namespace mlir::python; |
| 25 | |
| 26 | using llvm::SmallVector; |
| 27 | using llvm::Twine; |
| 28 | |
| 29 | namespace { |
| 30 | |
| 31 | /// Checks whether the given type is an integer or float type. |
| 32 | static int mlirTypeIsAIntegerOrFloat(MlirType type) { |
| 33 | return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || |
| 34 | mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); |
| 35 | } |
| 36 | |
| 37 | class PyIntegerType : public PyConcreteType<PyIntegerType> { |
| 38 | public: |
| 39 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; |
| 40 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 41 | mlirIntegerTypeGetTypeID; |
| 42 | static constexpr const char *pyClassName = "IntegerType" ; |
| 43 | using PyConcreteType::PyConcreteType; |
| 44 | |
| 45 | static void bindDerived(ClassTy &c) { |
| 46 | c.def_static( |
| 47 | "get_signless" , |
| 48 | [](unsigned width, DefaultingPyMlirContext context) { |
| 49 | MlirType t = mlirIntegerTypeGet(context->get(), width); |
| 50 | return PyIntegerType(context->getRef(), t); |
| 51 | }, |
| 52 | nb::arg("width" ), nb::arg("context" ).none() = nb::none(), |
| 53 | "Create a signless integer type" ); |
| 54 | c.def_static( |
| 55 | "get_signed" , |
| 56 | [](unsigned width, DefaultingPyMlirContext context) { |
| 57 | MlirType t = mlirIntegerTypeSignedGet(context->get(), width); |
| 58 | return PyIntegerType(context->getRef(), t); |
| 59 | }, |
| 60 | nb::arg("width" ), nb::arg("context" ).none() = nb::none(), |
| 61 | "Create a signed integer type" ); |
| 62 | c.def_static( |
| 63 | "get_unsigned" , |
| 64 | [](unsigned width, DefaultingPyMlirContext context) { |
| 65 | MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); |
| 66 | return PyIntegerType(context->getRef(), t); |
| 67 | }, |
| 68 | nb::arg("width" ), nb::arg("context" ).none() = nb::none(), |
| 69 | "Create an unsigned integer type" ); |
| 70 | c.def_prop_ro( |
| 71 | "width" , |
| 72 | [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, |
| 73 | "Returns the width of the integer type" ); |
| 74 | c.def_prop_ro( |
| 75 | "is_signless" , |
| 76 | [](PyIntegerType &self) -> bool { |
| 77 | return mlirIntegerTypeIsSignless(self); |
| 78 | }, |
| 79 | "Returns whether this is a signless integer" ); |
| 80 | c.def_prop_ro( |
| 81 | "is_signed" , |
| 82 | [](PyIntegerType &self) -> bool { |
| 83 | return mlirIntegerTypeIsSigned(self); |
| 84 | }, |
| 85 | "Returns whether this is a signed integer" ); |
| 86 | c.def_prop_ro( |
| 87 | "is_unsigned" , |
| 88 | [](PyIntegerType &self) -> bool { |
| 89 | return mlirIntegerTypeIsUnsigned(self); |
| 90 | }, |
| 91 | "Returns whether this is an unsigned integer" ); |
| 92 | } |
| 93 | }; |
| 94 | |
| 95 | /// Index Type subclass - IndexType. |
| 96 | class PyIndexType : public PyConcreteType<PyIndexType> { |
| 97 | public: |
| 98 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; |
| 99 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 100 | mlirIndexTypeGetTypeID; |
| 101 | static constexpr const char *pyClassName = "IndexType" ; |
| 102 | using PyConcreteType::PyConcreteType; |
| 103 | |
| 104 | static void bindDerived(ClassTy &c) { |
| 105 | c.def_static( |
| 106 | "get" , |
| 107 | [](DefaultingPyMlirContext context) { |
| 108 | MlirType t = mlirIndexTypeGet(context->get()); |
| 109 | return PyIndexType(context->getRef(), t); |
| 110 | }, |
| 111 | nb::arg("context" ).none() = nb::none(), "Create a index type." ); |
| 112 | } |
| 113 | }; |
| 114 | |
| 115 | class PyFloatType : public PyConcreteType<PyFloatType> { |
| 116 | public: |
| 117 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; |
| 118 | static constexpr const char *pyClassName = "FloatType" ; |
| 119 | using PyConcreteType::PyConcreteType; |
| 120 | |
| 121 | static void bindDerived(ClassTy &c) { |
| 122 | c.def_prop_ro( |
| 123 | "width" , [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, |
| 124 | "Returns the width of the floating-point type" ); |
| 125 | } |
| 126 | }; |
| 127 | |
| 128 | /// Floating Point Type subclass - Float4E2M1FNType. |
| 129 | class PyFloat4E2M1FNType |
| 130 | : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> { |
| 131 | public: |
| 132 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; |
| 133 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 134 | mlirFloat4E2M1FNTypeGetTypeID; |
| 135 | static constexpr const char *pyClassName = "Float4E2M1FNType" ; |
| 136 | using PyConcreteType::PyConcreteType; |
| 137 | |
| 138 | static void bindDerived(ClassTy &c) { |
| 139 | c.def_static( |
| 140 | "get" , |
| 141 | [](DefaultingPyMlirContext context) { |
| 142 | MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); |
| 143 | return PyFloat4E2M1FNType(context->getRef(), t); |
| 144 | }, |
| 145 | nb::arg("context" ).none() = nb::none(), "Create a float4_e2m1fn type." ); |
| 146 | } |
| 147 | }; |
| 148 | |
| 149 | /// Floating Point Type subclass - Float6E2M3FNType. |
| 150 | class PyFloat6E2M3FNType |
| 151 | : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> { |
| 152 | public: |
| 153 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; |
| 154 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 155 | mlirFloat6E2M3FNTypeGetTypeID; |
| 156 | static constexpr const char *pyClassName = "Float6E2M3FNType" ; |
| 157 | using PyConcreteType::PyConcreteType; |
| 158 | |
| 159 | static void bindDerived(ClassTy &c) { |
| 160 | c.def_static( |
| 161 | "get" , |
| 162 | [](DefaultingPyMlirContext context) { |
| 163 | MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); |
| 164 | return PyFloat6E2M3FNType(context->getRef(), t); |
| 165 | }, |
| 166 | nb::arg("context" ).none() = nb::none(), "Create a float6_e2m3fn type." ); |
| 167 | } |
| 168 | }; |
| 169 | |
| 170 | /// Floating Point Type subclass - Float6E3M2FNType. |
| 171 | class PyFloat6E3M2FNType |
| 172 | : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> { |
| 173 | public: |
| 174 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; |
| 175 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 176 | mlirFloat6E3M2FNTypeGetTypeID; |
| 177 | static constexpr const char *pyClassName = "Float6E3M2FNType" ; |
| 178 | using PyConcreteType::PyConcreteType; |
| 179 | |
| 180 | static void bindDerived(ClassTy &c) { |
| 181 | c.def_static( |
| 182 | "get" , |
| 183 | [](DefaultingPyMlirContext context) { |
| 184 | MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); |
| 185 | return PyFloat6E3M2FNType(context->getRef(), t); |
| 186 | }, |
| 187 | nb::arg("context" ).none() = nb::none(), "Create a float6_e3m2fn type." ); |
| 188 | } |
| 189 | }; |
| 190 | |
| 191 | /// Floating Point Type subclass - Float8E4M3FNType. |
| 192 | class PyFloat8E4M3FNType |
| 193 | : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { |
| 194 | public: |
| 195 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; |
| 196 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 197 | mlirFloat8E4M3FNTypeGetTypeID; |
| 198 | static constexpr const char *pyClassName = "Float8E4M3FNType" ; |
| 199 | using PyConcreteType::PyConcreteType; |
| 200 | |
| 201 | static void bindDerived(ClassTy &c) { |
| 202 | c.def_static( |
| 203 | "get" , |
| 204 | [](DefaultingPyMlirContext context) { |
| 205 | MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); |
| 206 | return PyFloat8E4M3FNType(context->getRef(), t); |
| 207 | }, |
| 208 | nb::arg("context" ).none() = nb::none(), "Create a float8_e4m3fn type." ); |
| 209 | } |
| 210 | }; |
| 211 | |
| 212 | /// Floating Point Type subclass - Float8E5M2Type. |
| 213 | class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { |
| 214 | public: |
| 215 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; |
| 216 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 217 | mlirFloat8E5M2TypeGetTypeID; |
| 218 | static constexpr const char *pyClassName = "Float8E5M2Type" ; |
| 219 | using PyConcreteType::PyConcreteType; |
| 220 | |
| 221 | static void bindDerived(ClassTy &c) { |
| 222 | c.def_static( |
| 223 | "get" , |
| 224 | [](DefaultingPyMlirContext context) { |
| 225 | MlirType t = mlirFloat8E5M2TypeGet(context->get()); |
| 226 | return PyFloat8E5M2Type(context->getRef(), t); |
| 227 | }, |
| 228 | nb::arg("context" ).none() = nb::none(), "Create a float8_e5m2 type." ); |
| 229 | } |
| 230 | }; |
| 231 | |
| 232 | /// Floating Point Type subclass - Float8E4M3Type. |
| 233 | class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> { |
| 234 | public: |
| 235 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; |
| 236 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 237 | mlirFloat8E4M3TypeGetTypeID; |
| 238 | static constexpr const char *pyClassName = "Float8E4M3Type" ; |
| 239 | using PyConcreteType::PyConcreteType; |
| 240 | |
| 241 | static void bindDerived(ClassTy &c) { |
| 242 | c.def_static( |
| 243 | "get" , |
| 244 | [](DefaultingPyMlirContext context) { |
| 245 | MlirType t = mlirFloat8E4M3TypeGet(context->get()); |
| 246 | return PyFloat8E4M3Type(context->getRef(), t); |
| 247 | }, |
| 248 | nb::arg("context" ).none() = nb::none(), "Create a float8_e4m3 type." ); |
| 249 | } |
| 250 | }; |
| 251 | |
| 252 | /// Floating Point Type subclass - Float8E4M3FNUZ. |
| 253 | class PyFloat8E4M3FNUZType |
| 254 | : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { |
| 255 | public: |
| 256 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; |
| 257 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 258 | mlirFloat8E4M3FNUZTypeGetTypeID; |
| 259 | static constexpr const char *pyClassName = "Float8E4M3FNUZType" ; |
| 260 | using PyConcreteType::PyConcreteType; |
| 261 | |
| 262 | static void bindDerived(ClassTy &c) { |
| 263 | c.def_static( |
| 264 | "get" , |
| 265 | [](DefaultingPyMlirContext context) { |
| 266 | MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); |
| 267 | return PyFloat8E4M3FNUZType(context->getRef(), t); |
| 268 | }, |
| 269 | nb::arg("context" ).none() = nb::none(), |
| 270 | "Create a float8_e4m3fnuz type." ); |
| 271 | } |
| 272 | }; |
| 273 | |
| 274 | /// Floating Point Type subclass - Float8E4M3B11FNUZ. |
| 275 | class PyFloat8E4M3B11FNUZType |
| 276 | : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { |
| 277 | public: |
| 278 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; |
| 279 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 280 | mlirFloat8E4M3B11FNUZTypeGetTypeID; |
| 281 | static constexpr const char *pyClassName = "Float8E4M3B11FNUZType" ; |
| 282 | using PyConcreteType::PyConcreteType; |
| 283 | |
| 284 | static void bindDerived(ClassTy &c) { |
| 285 | c.def_static( |
| 286 | "get" , |
| 287 | [](DefaultingPyMlirContext context) { |
| 288 | MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); |
| 289 | return PyFloat8E4M3B11FNUZType(context->getRef(), t); |
| 290 | }, |
| 291 | nb::arg("context" ).none() = nb::none(), |
| 292 | "Create a float8_e4m3b11fnuz type." ); |
| 293 | } |
| 294 | }; |
| 295 | |
| 296 | /// Floating Point Type subclass - Float8E5M2FNUZ. |
| 297 | class PyFloat8E5M2FNUZType |
| 298 | : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { |
| 299 | public: |
| 300 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; |
| 301 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 302 | mlirFloat8E5M2FNUZTypeGetTypeID; |
| 303 | static constexpr const char *pyClassName = "Float8E5M2FNUZType" ; |
| 304 | using PyConcreteType::PyConcreteType; |
| 305 | |
| 306 | static void bindDerived(ClassTy &c) { |
| 307 | c.def_static( |
| 308 | "get" , |
| 309 | [](DefaultingPyMlirContext context) { |
| 310 | MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); |
| 311 | return PyFloat8E5M2FNUZType(context->getRef(), t); |
| 312 | }, |
| 313 | nb::arg("context" ).none() = nb::none(), |
| 314 | "Create a float8_e5m2fnuz type." ); |
| 315 | } |
| 316 | }; |
| 317 | |
| 318 | /// Floating Point Type subclass - Float8E3M4Type. |
| 319 | class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> { |
| 320 | public: |
| 321 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; |
| 322 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 323 | mlirFloat8E3M4TypeGetTypeID; |
| 324 | static constexpr const char *pyClassName = "Float8E3M4Type" ; |
| 325 | using PyConcreteType::PyConcreteType; |
| 326 | |
| 327 | static void bindDerived(ClassTy &c) { |
| 328 | c.def_static( |
| 329 | "get" , |
| 330 | [](DefaultingPyMlirContext context) { |
| 331 | MlirType t = mlirFloat8E3M4TypeGet(context->get()); |
| 332 | return PyFloat8E3M4Type(context->getRef(), t); |
| 333 | }, |
| 334 | nb::arg("context" ).none() = nb::none(), "Create a float8_e3m4 type." ); |
| 335 | } |
| 336 | }; |
| 337 | |
| 338 | /// Floating Point Type subclass - Float8E8M0FNUType. |
| 339 | class PyFloat8E8M0FNUType |
| 340 | : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> { |
| 341 | public: |
| 342 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; |
| 343 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 344 | mlirFloat8E8M0FNUTypeGetTypeID; |
| 345 | static constexpr const char *pyClassName = "Float8E8M0FNUType" ; |
| 346 | using PyConcreteType::PyConcreteType; |
| 347 | |
| 348 | static void bindDerived(ClassTy &c) { |
| 349 | c.def_static( |
| 350 | "get" , |
| 351 | [](DefaultingPyMlirContext context) { |
| 352 | MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); |
| 353 | return PyFloat8E8M0FNUType(context->getRef(), t); |
| 354 | }, |
| 355 | nb::arg("context" ).none() = nb::none(), |
| 356 | "Create a float8_e8m0fnu type." ); |
| 357 | } |
| 358 | }; |
| 359 | |
| 360 | /// Floating Point Type subclass - BF16Type. |
| 361 | class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> { |
| 362 | public: |
| 363 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; |
| 364 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 365 | mlirBFloat16TypeGetTypeID; |
| 366 | static constexpr const char *pyClassName = "BF16Type" ; |
| 367 | using PyConcreteType::PyConcreteType; |
| 368 | |
| 369 | static void bindDerived(ClassTy &c) { |
| 370 | c.def_static( |
| 371 | "get" , |
| 372 | [](DefaultingPyMlirContext context) { |
| 373 | MlirType t = mlirBF16TypeGet(context->get()); |
| 374 | return PyBF16Type(context->getRef(), t); |
| 375 | }, |
| 376 | nb::arg("context" ).none() = nb::none(), "Create a bf16 type." ); |
| 377 | } |
| 378 | }; |
| 379 | |
| 380 | /// Floating Point Type subclass - F16Type. |
| 381 | class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> { |
| 382 | public: |
| 383 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; |
| 384 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 385 | mlirFloat16TypeGetTypeID; |
| 386 | static constexpr const char *pyClassName = "F16Type" ; |
| 387 | using PyConcreteType::PyConcreteType; |
| 388 | |
| 389 | static void bindDerived(ClassTy &c) { |
| 390 | c.def_static( |
| 391 | "get" , |
| 392 | [](DefaultingPyMlirContext context) { |
| 393 | MlirType t = mlirF16TypeGet(context->get()); |
| 394 | return PyF16Type(context->getRef(), t); |
| 395 | }, |
| 396 | nb::arg("context" ).none() = nb::none(), "Create a f16 type." ); |
| 397 | } |
| 398 | }; |
| 399 | |
| 400 | /// Floating Point Type subclass - TF32Type. |
| 401 | class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> { |
| 402 | public: |
| 403 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; |
| 404 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 405 | mlirFloatTF32TypeGetTypeID; |
| 406 | static constexpr const char *pyClassName = "FloatTF32Type" ; |
| 407 | using PyConcreteType::PyConcreteType; |
| 408 | |
| 409 | static void bindDerived(ClassTy &c) { |
| 410 | c.def_static( |
| 411 | "get" , |
| 412 | [](DefaultingPyMlirContext context) { |
| 413 | MlirType t = mlirTF32TypeGet(context->get()); |
| 414 | return PyTF32Type(context->getRef(), t); |
| 415 | }, |
| 416 | nb::arg("context" ).none() = nb::none(), "Create a tf32 type." ); |
| 417 | } |
| 418 | }; |
| 419 | |
| 420 | /// Floating Point Type subclass - F32Type. |
| 421 | class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> { |
| 422 | public: |
| 423 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; |
| 424 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 425 | mlirFloat32TypeGetTypeID; |
| 426 | static constexpr const char *pyClassName = "F32Type" ; |
| 427 | using PyConcreteType::PyConcreteType; |
| 428 | |
| 429 | static void bindDerived(ClassTy &c) { |
| 430 | c.def_static( |
| 431 | "get" , |
| 432 | [](DefaultingPyMlirContext context) { |
| 433 | MlirType t = mlirF32TypeGet(context->get()); |
| 434 | return PyF32Type(context->getRef(), t); |
| 435 | }, |
| 436 | nb::arg("context" ).none() = nb::none(), "Create a f32 type." ); |
| 437 | } |
| 438 | }; |
| 439 | |
| 440 | /// Floating Point Type subclass - F64Type. |
| 441 | class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> { |
| 442 | public: |
| 443 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; |
| 444 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 445 | mlirFloat64TypeGetTypeID; |
| 446 | static constexpr const char *pyClassName = "F64Type" ; |
| 447 | using PyConcreteType::PyConcreteType; |
| 448 | |
| 449 | static void bindDerived(ClassTy &c) { |
| 450 | c.def_static( |
| 451 | "get" , |
| 452 | [](DefaultingPyMlirContext context) { |
| 453 | MlirType t = mlirF64TypeGet(context->get()); |
| 454 | return PyF64Type(context->getRef(), t); |
| 455 | }, |
| 456 | nb::arg("context" ).none() = nb::none(), "Create a f64 type." ); |
| 457 | } |
| 458 | }; |
| 459 | |
| 460 | /// None Type subclass - NoneType. |
| 461 | class PyNoneType : public PyConcreteType<PyNoneType> { |
| 462 | public: |
| 463 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; |
| 464 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 465 | mlirNoneTypeGetTypeID; |
| 466 | static constexpr const char *pyClassName = "NoneType" ; |
| 467 | using PyConcreteType::PyConcreteType; |
| 468 | |
| 469 | static void bindDerived(ClassTy &c) { |
| 470 | c.def_static( |
| 471 | "get" , |
| 472 | [](DefaultingPyMlirContext context) { |
| 473 | MlirType t = mlirNoneTypeGet(context->get()); |
| 474 | return PyNoneType(context->getRef(), t); |
| 475 | }, |
| 476 | nb::arg("context" ).none() = nb::none(), "Create a none type." ); |
| 477 | } |
| 478 | }; |
| 479 | |
| 480 | /// Complex Type subclass - ComplexType. |
| 481 | class PyComplexType : public PyConcreteType<PyComplexType> { |
| 482 | public: |
| 483 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; |
| 484 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 485 | mlirComplexTypeGetTypeID; |
| 486 | static constexpr const char *pyClassName = "ComplexType" ; |
| 487 | using PyConcreteType::PyConcreteType; |
| 488 | |
| 489 | static void bindDerived(ClassTy &c) { |
| 490 | c.def_static( |
| 491 | "get" , |
| 492 | [](PyType &elementType) { |
| 493 | // The element must be a floating point or integer scalar type. |
| 494 | if (mlirTypeIsAIntegerOrFloat(elementType)) { |
| 495 | MlirType t = mlirComplexTypeGet(elementType); |
| 496 | return PyComplexType(elementType.getContext(), t); |
| 497 | } |
| 498 | throw nb::value_error( |
| 499 | (Twine("invalid '" ) + |
| 500 | nb::cast<std::string>(nb::repr(nb::cast(elementType))) + |
| 501 | "' and expected floating point or integer type." ) |
| 502 | .str() |
| 503 | .c_str()); |
| 504 | }, |
| 505 | "Create a complex type" ); |
| 506 | c.def_prop_ro( |
| 507 | "element_type" , |
| 508 | [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, |
| 509 | "Returns element type." ); |
| 510 | } |
| 511 | }; |
| 512 | |
| 513 | } // namespace |
| 514 | |
| 515 | // Shaped Type Interface - ShapedType |
| 516 | void mlir::PyShapedType::bindDerived(ClassTy &c) { |
| 517 | c.def_prop_ro( |
| 518 | "element_type" , |
| 519 | [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, |
| 520 | "Returns the element type of the shaped type." ); |
| 521 | c.def_prop_ro( |
| 522 | "has_rank" , |
| 523 | [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, |
| 524 | "Returns whether the given shaped type is ranked." ); |
| 525 | c.def_prop_ro( |
| 526 | "rank" , |
| 527 | [](PyShapedType &self) { |
| 528 | self.requireHasRank(); |
| 529 | return mlirShapedTypeGetRank(self); |
| 530 | }, |
| 531 | "Returns the rank of the given ranked shaped type." ); |
| 532 | c.def_prop_ro( |
| 533 | "has_static_shape" , |
| 534 | [](PyShapedType &self) -> bool { |
| 535 | return mlirShapedTypeHasStaticShape(self); |
| 536 | }, |
| 537 | "Returns whether the given shaped type has a static shape." ); |
| 538 | c.def( |
| 539 | "is_dynamic_dim" , |
| 540 | [](PyShapedType &self, intptr_t dim) -> bool { |
| 541 | self.requireHasRank(); |
| 542 | return mlirShapedTypeIsDynamicDim(self, dim); |
| 543 | }, |
| 544 | nb::arg("dim" ), |
| 545 | "Returns whether the dim-th dimension of the given shaped type is " |
| 546 | "dynamic." ); |
| 547 | c.def( |
| 548 | "is_static_dim" , |
| 549 | [](PyShapedType &self, intptr_t dim) -> bool { |
| 550 | self.requireHasRank(); |
| 551 | return mlirShapedTypeIsStaticDim(self, dim); |
| 552 | }, |
| 553 | nb::arg("dim" ), |
| 554 | "Returns whether the dim-th dimension of the given shaped type is " |
| 555 | "static." ); |
| 556 | c.def( |
| 557 | "get_dim_size" , |
| 558 | [](PyShapedType &self, intptr_t dim) { |
| 559 | self.requireHasRank(); |
| 560 | return mlirShapedTypeGetDimSize(self, dim); |
| 561 | }, |
| 562 | nb::arg("dim" ), |
| 563 | "Returns the dim-th dimension of the given ranked shaped type." ); |
| 564 | c.def_static( |
| 565 | "is_dynamic_size" , |
| 566 | [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, |
| 567 | nb::arg("dim_size" ), |
| 568 | "Returns whether the given dimension size indicates a dynamic " |
| 569 | "dimension." ); |
| 570 | c.def_static( |
| 571 | "is_static_size" , |
| 572 | [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); }, |
| 573 | nb::arg("dim_size" ), |
| 574 | "Returns whether the given dimension size indicates a static " |
| 575 | "dimension." ); |
| 576 | c.def( |
| 577 | "is_dynamic_stride_or_offset" , |
| 578 | [](PyShapedType &self, int64_t val) -> bool { |
| 579 | self.requireHasRank(); |
| 580 | return mlirShapedTypeIsDynamicStrideOrOffset(val); |
| 581 | }, |
| 582 | nb::arg("dim_size" ), |
| 583 | "Returns whether the given value is used as a placeholder for dynamic " |
| 584 | "strides and offsets in shaped types." ); |
| 585 | c.def( |
| 586 | "is_static_stride_or_offset" , |
| 587 | [](PyShapedType &self, int64_t val) -> bool { |
| 588 | self.requireHasRank(); |
| 589 | return mlirShapedTypeIsStaticStrideOrOffset(val); |
| 590 | }, |
| 591 | nb::arg("dim_size" ), |
| 592 | "Returns whether the given shaped type stride or offset value is " |
| 593 | "statically-sized." ); |
| 594 | c.def_prop_ro( |
| 595 | "shape" , |
| 596 | [](PyShapedType &self) { |
| 597 | self.requireHasRank(); |
| 598 | |
| 599 | std::vector<int64_t> shape; |
| 600 | int64_t rank = mlirShapedTypeGetRank(self); |
| 601 | shape.reserve(rank); |
| 602 | for (int64_t i = 0; i < rank; ++i) |
| 603 | shape.push_back(mlirShapedTypeGetDimSize(self, i)); |
| 604 | return shape; |
| 605 | }, |
| 606 | "Returns the shape of the ranked shaped type as a list of integers." ); |
| 607 | c.def_static( |
| 608 | "get_dynamic_size" , []() { return mlirShapedTypeGetDynamicSize(); }, |
| 609 | "Returns the value used to indicate dynamic dimensions in shaped " |
| 610 | "types." ); |
| 611 | c.def_static( |
| 612 | "get_dynamic_stride_or_offset" , |
| 613 | []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, |
| 614 | "Returns the value used to indicate dynamic strides or offsets in " |
| 615 | "shaped types." ); |
| 616 | } |
| 617 | |
| 618 | void mlir::PyShapedType::requireHasRank() { |
| 619 | if (!mlirShapedTypeHasRank(*this)) { |
| 620 | throw nb::value_error( |
| 621 | "calling this method requires that the type has a rank." ); |
| 622 | } |
| 623 | } |
| 624 | |
| 625 | const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = |
| 626 | mlirTypeIsAShaped; |
| 627 | |
| 628 | namespace { |
| 629 | |
| 630 | /// Vector Type subclass - VectorType. |
| 631 | class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { |
| 632 | public: |
| 633 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; |
| 634 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 635 | mlirVectorTypeGetTypeID; |
| 636 | static constexpr const char *pyClassName = "VectorType" ; |
| 637 | using PyConcreteType::PyConcreteType; |
| 638 | |
| 639 | static void bindDerived(ClassTy &c) { |
| 640 | c.def_static("get" , &PyVectorType::get, nb::arg("shape" ), |
| 641 | nb::arg("element_type" ), nb::kw_only(), |
| 642 | nb::arg("scalable" ).none() = nb::none(), |
| 643 | nb::arg("scalable_dims" ).none() = nb::none(), |
| 644 | nb::arg("loc" ).none() = nb::none(), "Create a vector type" ) |
| 645 | .def_prop_ro( |
| 646 | "scalable" , |
| 647 | [](MlirType self) { return mlirVectorTypeIsScalable(self); }) |
| 648 | .def_prop_ro("scalable_dims" , [](MlirType self) { |
| 649 | std::vector<bool> scalableDims; |
| 650 | size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self)); |
| 651 | scalableDims.reserve(rank); |
| 652 | for (size_t i = 0; i < rank; ++i) |
| 653 | scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); |
| 654 | return scalableDims; |
| 655 | }); |
| 656 | } |
| 657 | |
| 658 | private: |
| 659 | static PyVectorType get(std::vector<int64_t> shape, PyType &elementType, |
| 660 | std::optional<nb::list> scalable, |
| 661 | std::optional<std::vector<int64_t>> scalableDims, |
| 662 | DefaultingPyLocation loc) { |
| 663 | if (scalable && scalableDims) { |
| 664 | throw nb::value_error("'scalable' and 'scalable_dims' kwargs " |
| 665 | "are mutually exclusive." ); |
| 666 | } |
| 667 | |
| 668 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 669 | MlirType type; |
| 670 | if (scalable) { |
| 671 | if (scalable->size() != shape.size()) |
| 672 | throw nb::value_error("Expected len(scalable) == len(shape)." ); |
| 673 | |
| 674 | SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range( |
| 675 | *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); })); |
| 676 | type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), |
| 677 | scalableDimFlags.data(), |
| 678 | elementType); |
| 679 | } else if (scalableDims) { |
| 680 | SmallVector<bool> scalableDimFlags(shape.size(), false); |
| 681 | for (int64_t dim : *scalableDims) { |
| 682 | if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) |
| 683 | throw nb::value_error("Scalable dimension index out of bounds." ); |
| 684 | scalableDimFlags[dim] = true; |
| 685 | } |
| 686 | type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), |
| 687 | scalableDimFlags.data(), |
| 688 | elementType); |
| 689 | } else { |
| 690 | type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), |
| 691 | elementType); |
| 692 | } |
| 693 | if (mlirTypeIsNull(type)) |
| 694 | throw MLIRError("Invalid type" , errors.take()); |
| 695 | return PyVectorType(elementType.getContext(), type); |
| 696 | } |
| 697 | }; |
| 698 | |
| 699 | /// Ranked Tensor Type subclass - RankedTensorType. |
| 700 | class PyRankedTensorType |
| 701 | : public PyConcreteType<PyRankedTensorType, PyShapedType> { |
| 702 | public: |
| 703 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; |
| 704 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 705 | mlirRankedTensorTypeGetTypeID; |
| 706 | static constexpr const char *pyClassName = "RankedTensorType" ; |
| 707 | using PyConcreteType::PyConcreteType; |
| 708 | |
| 709 | static void bindDerived(ClassTy &c) { |
| 710 | c.def_static( |
| 711 | "get" , |
| 712 | [](std::vector<int64_t> shape, PyType &elementType, |
| 713 | std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) { |
| 714 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 715 | MlirType t = mlirRankedTensorTypeGetChecked( |
| 716 | loc, shape.size(), shape.data(), elementType, |
| 717 | encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); |
| 718 | if (mlirTypeIsNull(t)) |
| 719 | throw MLIRError("Invalid type" , errors.take()); |
| 720 | return PyRankedTensorType(elementType.getContext(), t); |
| 721 | }, |
| 722 | nb::arg("shape" ), nb::arg("element_type" ), |
| 723 | nb::arg("encoding" ).none() = nb::none(), |
| 724 | nb::arg("loc" ).none() = nb::none(), "Create a ranked tensor type" ); |
| 725 | c.def_prop_ro("encoding" , |
| 726 | [](PyRankedTensorType &self) -> std::optional<MlirAttribute> { |
| 727 | MlirAttribute encoding = |
| 728 | mlirRankedTensorTypeGetEncoding(self.get()); |
| 729 | if (mlirAttributeIsNull(encoding)) |
| 730 | return std::nullopt; |
| 731 | return encoding; |
| 732 | }); |
| 733 | } |
| 734 | }; |
| 735 | |
| 736 | /// Unranked Tensor Type subclass - UnrankedTensorType. |
| 737 | class PyUnrankedTensorType |
| 738 | : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { |
| 739 | public: |
| 740 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; |
| 741 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 742 | mlirUnrankedTensorTypeGetTypeID; |
| 743 | static constexpr const char *pyClassName = "UnrankedTensorType" ; |
| 744 | using PyConcreteType::PyConcreteType; |
| 745 | |
| 746 | static void bindDerived(ClassTy &c) { |
| 747 | c.def_static( |
| 748 | "get" , |
| 749 | [](PyType &elementType, DefaultingPyLocation loc) { |
| 750 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 751 | MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); |
| 752 | if (mlirTypeIsNull(t)) |
| 753 | throw MLIRError("Invalid type" , errors.take()); |
| 754 | return PyUnrankedTensorType(elementType.getContext(), t); |
| 755 | }, |
| 756 | nb::arg("element_type" ), nb::arg("loc" ).none() = nb::none(), |
| 757 | "Create a unranked tensor type" ); |
| 758 | } |
| 759 | }; |
| 760 | |
| 761 | /// Ranked MemRef Type subclass - MemRefType. |
| 762 | class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { |
| 763 | public: |
| 764 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; |
| 765 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 766 | mlirMemRefTypeGetTypeID; |
| 767 | static constexpr const char *pyClassName = "MemRefType" ; |
| 768 | using PyConcreteType::PyConcreteType; |
| 769 | |
| 770 | static void bindDerived(ClassTy &c) { |
| 771 | c.def_static( |
| 772 | "get" , |
| 773 | [](std::vector<int64_t> shape, PyType &elementType, |
| 774 | PyAttribute *layout, PyAttribute *memorySpace, |
| 775 | DefaultingPyLocation loc) { |
| 776 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 777 | MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); |
| 778 | MlirAttribute memSpaceAttr = |
| 779 | memorySpace ? *memorySpace : mlirAttributeGetNull(); |
| 780 | MlirType t = |
| 781 | mlirMemRefTypeGetChecked(loc, elementType, shape.size(), |
| 782 | shape.data(), layoutAttr, memSpaceAttr); |
| 783 | if (mlirTypeIsNull(t)) |
| 784 | throw MLIRError("Invalid type" , errors.take()); |
| 785 | return PyMemRefType(elementType.getContext(), t); |
| 786 | }, |
| 787 | nb::arg("shape" ), nb::arg("element_type" ), |
| 788 | nb::arg("layout" ).none() = nb::none(), |
| 789 | nb::arg("memory_space" ).none() = nb::none(), |
| 790 | nb::arg("loc" ).none() = nb::none(), "Create a memref type" ) |
| 791 | .def_prop_ro( |
| 792 | "layout" , |
| 793 | [](PyMemRefType &self) -> MlirAttribute { |
| 794 | return mlirMemRefTypeGetLayout(self); |
| 795 | }, |
| 796 | "The layout of the MemRef type." ) |
| 797 | .def( |
| 798 | "get_strides_and_offset" , |
| 799 | [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> { |
| 800 | std::vector<int64_t> strides(mlirShapedTypeGetRank(self)); |
| 801 | int64_t offset; |
| 802 | if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( |
| 803 | self, strides.data(), &offset))) |
| 804 | throw std::runtime_error( |
| 805 | "Failed to extract strides and offset from memref." ); |
| 806 | return {strides, offset}; |
| 807 | }, |
| 808 | "The strides and offset of the MemRef type." ) |
| 809 | .def_prop_ro( |
| 810 | "affine_map" , |
| 811 | [](PyMemRefType &self) -> PyAffineMap { |
| 812 | MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); |
| 813 | return PyAffineMap(self.getContext(), map); |
| 814 | }, |
| 815 | "The layout of the MemRef type as an affine map." ) |
| 816 | .def_prop_ro( |
| 817 | "memory_space" , |
| 818 | [](PyMemRefType &self) -> std::optional<MlirAttribute> { |
| 819 | MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); |
| 820 | if (mlirAttributeIsNull(a)) |
| 821 | return std::nullopt; |
| 822 | return a; |
| 823 | }, |
| 824 | "Returns the memory space of the given MemRef type." ); |
| 825 | } |
| 826 | }; |
| 827 | |
| 828 | /// Unranked MemRef Type subclass - UnrankedMemRefType. |
| 829 | class PyUnrankedMemRefType |
| 830 | : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { |
| 831 | public: |
| 832 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; |
| 833 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 834 | mlirUnrankedMemRefTypeGetTypeID; |
| 835 | static constexpr const char *pyClassName = "UnrankedMemRefType" ; |
| 836 | using PyConcreteType::PyConcreteType; |
| 837 | |
| 838 | static void bindDerived(ClassTy &c) { |
| 839 | c.def_static( |
| 840 | "get" , |
| 841 | [](PyType &elementType, PyAttribute *memorySpace, |
| 842 | DefaultingPyLocation loc) { |
| 843 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 844 | MlirAttribute memSpaceAttr = {}; |
| 845 | if (memorySpace) |
| 846 | memSpaceAttr = *memorySpace; |
| 847 | |
| 848 | MlirType t = |
| 849 | mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); |
| 850 | if (mlirTypeIsNull(t)) |
| 851 | throw MLIRError("Invalid type" , errors.take()); |
| 852 | return PyUnrankedMemRefType(elementType.getContext(), t); |
| 853 | }, |
| 854 | nb::arg("element_type" ), nb::arg("memory_space" ).none(), |
| 855 | nb::arg("loc" ).none() = nb::none(), "Create a unranked memref type" ) |
| 856 | .def_prop_ro( |
| 857 | "memory_space" , |
| 858 | [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> { |
| 859 | MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); |
| 860 | if (mlirAttributeIsNull(a)) |
| 861 | return std::nullopt; |
| 862 | return a; |
| 863 | }, |
| 864 | "Returns the memory space of the given Unranked MemRef type." ); |
| 865 | } |
| 866 | }; |
| 867 | |
| 868 | /// Tuple Type subclass - TupleType. |
| 869 | class PyTupleType : public PyConcreteType<PyTupleType> { |
| 870 | public: |
| 871 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; |
| 872 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 873 | mlirTupleTypeGetTypeID; |
| 874 | static constexpr const char *pyClassName = "TupleType" ; |
| 875 | using PyConcreteType::PyConcreteType; |
| 876 | |
| 877 | static void bindDerived(ClassTy &c) { |
| 878 | c.def_static( |
| 879 | "get_tuple" , |
| 880 | [](std::vector<MlirType> elements, DefaultingPyMlirContext context) { |
| 881 | MlirType t = mlirTupleTypeGet(context->get(), elements.size(), |
| 882 | elements.data()); |
| 883 | return PyTupleType(context->getRef(), t); |
| 884 | }, |
| 885 | nb::arg("elements" ), nb::arg("context" ).none() = nb::none(), |
| 886 | "Create a tuple type" ); |
| 887 | c.def( |
| 888 | "get_type" , |
| 889 | [](PyTupleType &self, intptr_t pos) { |
| 890 | return mlirTupleTypeGetType(self, pos); |
| 891 | }, |
| 892 | nb::arg("pos" ), "Returns the pos-th type in the tuple type." ); |
| 893 | c.def_prop_ro( |
| 894 | "num_types" , |
| 895 | [](PyTupleType &self) -> intptr_t { |
| 896 | return mlirTupleTypeGetNumTypes(self); |
| 897 | }, |
| 898 | "Returns the number of types contained in a tuple." ); |
| 899 | } |
| 900 | }; |
| 901 | |
| 902 | /// Function type. |
| 903 | class PyFunctionType : public PyConcreteType<PyFunctionType> { |
| 904 | public: |
| 905 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; |
| 906 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 907 | mlirFunctionTypeGetTypeID; |
| 908 | static constexpr const char *pyClassName = "FunctionType" ; |
| 909 | using PyConcreteType::PyConcreteType; |
| 910 | |
| 911 | static void bindDerived(ClassTy &c) { |
| 912 | c.def_static( |
| 913 | "get" , |
| 914 | [](std::vector<MlirType> inputs, std::vector<MlirType> results, |
| 915 | DefaultingPyMlirContext context) { |
| 916 | MlirType t = |
| 917 | mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), |
| 918 | results.size(), results.data()); |
| 919 | return PyFunctionType(context->getRef(), t); |
| 920 | }, |
| 921 | nb::arg("inputs" ), nb::arg("results" ), |
| 922 | nb::arg("context" ).none() = nb::none(), |
| 923 | "Gets a FunctionType from a list of input and result types" ); |
| 924 | c.def_prop_ro( |
| 925 | "inputs" , |
| 926 | [](PyFunctionType &self) { |
| 927 | MlirType t = self; |
| 928 | nb::list types; |
| 929 | for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; |
| 930 | ++i) { |
| 931 | types.append(mlirFunctionTypeGetInput(t, i)); |
| 932 | } |
| 933 | return types; |
| 934 | }, |
| 935 | "Returns the list of input types in the FunctionType." ); |
| 936 | c.def_prop_ro( |
| 937 | "results" , |
| 938 | [](PyFunctionType &self) { |
| 939 | nb::list types; |
| 940 | for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; |
| 941 | ++i) { |
| 942 | types.append(mlirFunctionTypeGetResult(self, i)); |
| 943 | } |
| 944 | return types; |
| 945 | }, |
| 946 | "Returns the list of result types in the FunctionType." ); |
| 947 | } |
| 948 | }; |
| 949 | |
| 950 | static MlirStringRef toMlirStringRef(const std::string &s) { |
| 951 | return mlirStringRefCreate(s.data(), s.size()); |
| 952 | } |
| 953 | |
| 954 | /// Opaque Type subclass - OpaqueType. |
| 955 | class PyOpaqueType : public PyConcreteType<PyOpaqueType> { |
| 956 | public: |
| 957 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; |
| 958 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 959 | mlirOpaqueTypeGetTypeID; |
| 960 | static constexpr const char *pyClassName = "OpaqueType" ; |
| 961 | using PyConcreteType::PyConcreteType; |
| 962 | |
| 963 | static void bindDerived(ClassTy &c) { |
| 964 | c.def_static( |
| 965 | "get" , |
| 966 | [](std::string dialectNamespace, std::string typeData, |
| 967 | DefaultingPyMlirContext context) { |
| 968 | MlirType type = mlirOpaqueTypeGet(context->get(), |
| 969 | toMlirStringRef(dialectNamespace), |
| 970 | toMlirStringRef(typeData)); |
| 971 | return PyOpaqueType(context->getRef(), type); |
| 972 | }, |
| 973 | nb::arg("dialect_namespace" ), nb::arg("buffer" ), |
| 974 | nb::arg("context" ).none() = nb::none(), |
| 975 | "Create an unregistered (opaque) dialect type." ); |
| 976 | c.def_prop_ro( |
| 977 | "dialect_namespace" , |
| 978 | [](PyOpaqueType &self) { |
| 979 | MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); |
| 980 | return nb::str(stringRef.data, stringRef.length); |
| 981 | }, |
| 982 | "Returns the dialect namespace for the Opaque type as a string." ); |
| 983 | c.def_prop_ro( |
| 984 | "data" , |
| 985 | [](PyOpaqueType &self) { |
| 986 | MlirStringRef stringRef = mlirOpaqueTypeGetData(self); |
| 987 | return nb::str(stringRef.data, stringRef.length); |
| 988 | }, |
| 989 | "Returns the data for the Opaque type as a string." ); |
| 990 | } |
| 991 | }; |
| 992 | |
| 993 | } // namespace |
| 994 | |
| 995 | void mlir::python::populateIRTypes(nb::module_ &m) { |
| 996 | PyIntegerType::bind(m); |
| 997 | PyFloatType::bind(m); |
| 998 | PyIndexType::bind(m); |
| 999 | PyFloat4E2M1FNType::bind(m); |
| 1000 | PyFloat6E2M3FNType::bind(m); |
| 1001 | PyFloat6E3M2FNType::bind(m); |
| 1002 | PyFloat8E4M3FNType::bind(m); |
| 1003 | PyFloat8E5M2Type::bind(m); |
| 1004 | PyFloat8E4M3Type::bind(m); |
| 1005 | PyFloat8E4M3FNUZType::bind(m); |
| 1006 | PyFloat8E4M3B11FNUZType::bind(m); |
| 1007 | PyFloat8E5M2FNUZType::bind(m); |
| 1008 | PyFloat8E3M4Type::bind(m); |
| 1009 | PyFloat8E8M0FNUType::bind(m); |
| 1010 | PyBF16Type::bind(m); |
| 1011 | PyF16Type::bind(m); |
| 1012 | PyTF32Type::bind(m); |
| 1013 | PyF32Type::bind(m); |
| 1014 | PyF64Type::bind(m); |
| 1015 | PyNoneType::bind(m); |
| 1016 | PyComplexType::bind(m); |
| 1017 | PyShapedType::bind(m); |
| 1018 | PyVectorType::bind(m); |
| 1019 | PyRankedTensorType::bind(m); |
| 1020 | PyUnrankedTensorType::bind(m); |
| 1021 | PyMemRefType::bind(m); |
| 1022 | PyUnrankedMemRefType::bind(m); |
| 1023 | PyTupleType::bind(m); |
| 1024 | PyFunctionType::bind(m); |
| 1025 | PyOpaqueType::bind(m); |
| 1026 | } |
| 1027 | |