1//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
10#include "IRModule.h"
11#include "mlir/Bindings/Python/IRTypes.h"
12// clang-format on
13
14#include <optional>
15
16#include "IRModule.h"
17#include "NanobindUtils.h"
18#include "mlir-c/BuiltinAttributes.h"
19#include "mlir-c/BuiltinTypes.h"
20#include "mlir-c/Support.h"
21
22namespace nb = nanobind;
23using namespace mlir;
24using namespace mlir::python;
25
26using llvm::SmallVector;
27using llvm::Twine;
28
29namespace {
30
31/// Checks whether the given type is an integer or float type.
32static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35}
36
37class PyIntegerType : public PyConcreteType<PyIntegerType> {
38public:
39 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
41 mlirIntegerTypeGetTypeID;
42 static constexpr const char *pyClassName = "IntegerType";
43 using PyConcreteType::PyConcreteType;
44
45 static void bindDerived(ClassTy &c) {
46 c.def_static(
47 "get_signless",
48 [](unsigned width, DefaultingPyMlirContext context) {
49 MlirType t = mlirIntegerTypeGet(context->get(), width);
50 return PyIntegerType(context->getRef(), t);
51 },
52 nb::arg("width"), nb::arg("context").none() = nb::none(),
53 "Create a signless integer type");
54 c.def_static(
55 "get_signed",
56 [](unsigned width, DefaultingPyMlirContext context) {
57 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58 return PyIntegerType(context->getRef(), t);
59 },
60 nb::arg("width"), nb::arg("context").none() = nb::none(),
61 "Create a signed integer type");
62 c.def_static(
63 "get_unsigned",
64 [](unsigned width, DefaultingPyMlirContext context) {
65 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66 return PyIntegerType(context->getRef(), t);
67 },
68 nb::arg("width"), nb::arg("context").none() = nb::none(),
69 "Create an unsigned integer type");
70 c.def_prop_ro(
71 "width",
72 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73 "Returns the width of the integer type");
74 c.def_prop_ro(
75 "is_signless",
76 [](PyIntegerType &self) -> bool {
77 return mlirIntegerTypeIsSignless(self);
78 },
79 "Returns whether this is a signless integer");
80 c.def_prop_ro(
81 "is_signed",
82 [](PyIntegerType &self) -> bool {
83 return mlirIntegerTypeIsSigned(self);
84 },
85 "Returns whether this is a signed integer");
86 c.def_prop_ro(
87 "is_unsigned",
88 [](PyIntegerType &self) -> bool {
89 return mlirIntegerTypeIsUnsigned(self);
90 },
91 "Returns whether this is an unsigned integer");
92 }
93};
94
95/// Index Type subclass - IndexType.
96class PyIndexType : public PyConcreteType<PyIndexType> {
97public:
98 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
100 mlirIndexTypeGetTypeID;
101 static constexpr const char *pyClassName = "IndexType";
102 using PyConcreteType::PyConcreteType;
103
104 static void bindDerived(ClassTy &c) {
105 c.def_static(
106 "get",
107 [](DefaultingPyMlirContext context) {
108 MlirType t = mlirIndexTypeGet(context->get());
109 return PyIndexType(context->getRef(), t);
110 },
111 nb::arg("context").none() = nb::none(), "Create a index type.");
112 }
113};
114
115class PyFloatType : public PyConcreteType<PyFloatType> {
116public:
117 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118 static constexpr const char *pyClassName = "FloatType";
119 using PyConcreteType::PyConcreteType;
120
121 static void bindDerived(ClassTy &c) {
122 c.def_prop_ro(
123 "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124 "Returns the width of the floating-point type");
125 }
126};
127
128/// Floating Point Type subclass - Float4E2M1FNType.
129class PyFloat4E2M1FNType
130 : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131public:
132 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
134 mlirFloat4E2M1FNTypeGetTypeID;
135 static constexpr const char *pyClassName = "Float4E2M1FNType";
136 using PyConcreteType::PyConcreteType;
137
138 static void bindDerived(ClassTy &c) {
139 c.def_static(
140 "get",
141 [](DefaultingPyMlirContext context) {
142 MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143 return PyFloat4E2M1FNType(context->getRef(), t);
144 },
145 nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
146 }
147};
148
149/// Floating Point Type subclass - Float6E2M3FNType.
150class PyFloat6E2M3FNType
151 : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152public:
153 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
155 mlirFloat6E2M3FNTypeGetTypeID;
156 static constexpr const char *pyClassName = "Float6E2M3FNType";
157 using PyConcreteType::PyConcreteType;
158
159 static void bindDerived(ClassTy &c) {
160 c.def_static(
161 "get",
162 [](DefaultingPyMlirContext context) {
163 MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164 return PyFloat6E2M3FNType(context->getRef(), t);
165 },
166 nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
167 }
168};
169
170/// Floating Point Type subclass - Float6E3M2FNType.
171class PyFloat6E3M2FNType
172 : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173public:
174 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
176 mlirFloat6E3M2FNTypeGetTypeID;
177 static constexpr const char *pyClassName = "Float6E3M2FNType";
178 using PyConcreteType::PyConcreteType;
179
180 static void bindDerived(ClassTy &c) {
181 c.def_static(
182 "get",
183 [](DefaultingPyMlirContext context) {
184 MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185 return PyFloat6E3M2FNType(context->getRef(), t);
186 },
187 nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
188 }
189};
190
191/// Floating Point Type subclass - Float8E4M3FNType.
192class PyFloat8E4M3FNType
193 : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194public:
195 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
197 mlirFloat8E4M3FNTypeGetTypeID;
198 static constexpr const char *pyClassName = "Float8E4M3FNType";
199 using PyConcreteType::PyConcreteType;
200
201 static void bindDerived(ClassTy &c) {
202 c.def_static(
203 "get",
204 [](DefaultingPyMlirContext context) {
205 MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206 return PyFloat8E4M3FNType(context->getRef(), t);
207 },
208 nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
209 }
210};
211
212/// Floating Point Type subclass - Float8E5M2Type.
213class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214public:
215 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
217 mlirFloat8E5M2TypeGetTypeID;
218 static constexpr const char *pyClassName = "Float8E5M2Type";
219 using PyConcreteType::PyConcreteType;
220
221 static void bindDerived(ClassTy &c) {
222 c.def_static(
223 "get",
224 [](DefaultingPyMlirContext context) {
225 MlirType t = mlirFloat8E5M2TypeGet(context->get());
226 return PyFloat8E5M2Type(context->getRef(), t);
227 },
228 nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
229 }
230};
231
232/// Floating Point Type subclass - Float8E4M3Type.
233class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234public:
235 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
237 mlirFloat8E4M3TypeGetTypeID;
238 static constexpr const char *pyClassName = "Float8E4M3Type";
239 using PyConcreteType::PyConcreteType;
240
241 static void bindDerived(ClassTy &c) {
242 c.def_static(
243 "get",
244 [](DefaultingPyMlirContext context) {
245 MlirType t = mlirFloat8E4M3TypeGet(context->get());
246 return PyFloat8E4M3Type(context->getRef(), t);
247 },
248 nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
249 }
250};
251
252/// Floating Point Type subclass - Float8E4M3FNUZ.
253class PyFloat8E4M3FNUZType
254 : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255public:
256 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
258 mlirFloat8E4M3FNUZTypeGetTypeID;
259 static constexpr const char *pyClassName = "Float8E4M3FNUZType";
260 using PyConcreteType::PyConcreteType;
261
262 static void bindDerived(ClassTy &c) {
263 c.def_static(
264 "get",
265 [](DefaultingPyMlirContext context) {
266 MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267 return PyFloat8E4M3FNUZType(context->getRef(), t);
268 },
269 nb::arg("context").none() = nb::none(),
270 "Create a float8_e4m3fnuz type.");
271 }
272};
273
274/// Floating Point Type subclass - Float8E4M3B11FNUZ.
275class PyFloat8E4M3B11FNUZType
276 : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
277public:
278 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
279 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
280 mlirFloat8E4M3B11FNUZTypeGetTypeID;
281 static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
282 using PyConcreteType::PyConcreteType;
283
284 static void bindDerived(ClassTy &c) {
285 c.def_static(
286 "get",
287 [](DefaultingPyMlirContext context) {
288 MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
289 return PyFloat8E4M3B11FNUZType(context->getRef(), t);
290 },
291 nb::arg("context").none() = nb::none(),
292 "Create a float8_e4m3b11fnuz type.");
293 }
294};
295
296/// Floating Point Type subclass - Float8E5M2FNUZ.
297class PyFloat8E5M2FNUZType
298 : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
299public:
300 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
301 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
302 mlirFloat8E5M2FNUZTypeGetTypeID;
303 static constexpr const char *pyClassName = "Float8E5M2FNUZType";
304 using PyConcreteType::PyConcreteType;
305
306 static void bindDerived(ClassTy &c) {
307 c.def_static(
308 "get",
309 [](DefaultingPyMlirContext context) {
310 MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
311 return PyFloat8E5M2FNUZType(context->getRef(), t);
312 },
313 nb::arg("context").none() = nb::none(),
314 "Create a float8_e5m2fnuz type.");
315 }
316};
317
318/// Floating Point Type subclass - Float8E3M4Type.
319class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
320public:
321 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
322 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
323 mlirFloat8E3M4TypeGetTypeID;
324 static constexpr const char *pyClassName = "Float8E3M4Type";
325 using PyConcreteType::PyConcreteType;
326
327 static void bindDerived(ClassTy &c) {
328 c.def_static(
329 "get",
330 [](DefaultingPyMlirContext context) {
331 MlirType t = mlirFloat8E3M4TypeGet(context->get());
332 return PyFloat8E3M4Type(context->getRef(), t);
333 },
334 nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
335 }
336};
337
338/// Floating Point Type subclass - Float8E8M0FNUType.
339class PyFloat8E8M0FNUType
340 : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
341public:
342 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
343 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
344 mlirFloat8E8M0FNUTypeGetTypeID;
345 static constexpr const char *pyClassName = "Float8E8M0FNUType";
346 using PyConcreteType::PyConcreteType;
347
348 static void bindDerived(ClassTy &c) {
349 c.def_static(
350 "get",
351 [](DefaultingPyMlirContext context) {
352 MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
353 return PyFloat8E8M0FNUType(context->getRef(), t);
354 },
355 nb::arg("context").none() = nb::none(),
356 "Create a float8_e8m0fnu type.");
357 }
358};
359
360/// Floating Point Type subclass - BF16Type.
361class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
362public:
363 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
364 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
365 mlirBFloat16TypeGetTypeID;
366 static constexpr const char *pyClassName = "BF16Type";
367 using PyConcreteType::PyConcreteType;
368
369 static void bindDerived(ClassTy &c) {
370 c.def_static(
371 "get",
372 [](DefaultingPyMlirContext context) {
373 MlirType t = mlirBF16TypeGet(context->get());
374 return PyBF16Type(context->getRef(), t);
375 },
376 nb::arg("context").none() = nb::none(), "Create a bf16 type.");
377 }
378};
379
380/// Floating Point Type subclass - F16Type.
381class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
382public:
383 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
384 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
385 mlirFloat16TypeGetTypeID;
386 static constexpr const char *pyClassName = "F16Type";
387 using PyConcreteType::PyConcreteType;
388
389 static void bindDerived(ClassTy &c) {
390 c.def_static(
391 "get",
392 [](DefaultingPyMlirContext context) {
393 MlirType t = mlirF16TypeGet(context->get());
394 return PyF16Type(context->getRef(), t);
395 },
396 nb::arg("context").none() = nb::none(), "Create a f16 type.");
397 }
398};
399
400/// Floating Point Type subclass - TF32Type.
401class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
402public:
403 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
404 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
405 mlirFloatTF32TypeGetTypeID;
406 static constexpr const char *pyClassName = "FloatTF32Type";
407 using PyConcreteType::PyConcreteType;
408
409 static void bindDerived(ClassTy &c) {
410 c.def_static(
411 "get",
412 [](DefaultingPyMlirContext context) {
413 MlirType t = mlirTF32TypeGet(context->get());
414 return PyTF32Type(context->getRef(), t);
415 },
416 nb::arg("context").none() = nb::none(), "Create a tf32 type.");
417 }
418};
419
420/// Floating Point Type subclass - F32Type.
421class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
422public:
423 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
424 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
425 mlirFloat32TypeGetTypeID;
426 static constexpr const char *pyClassName = "F32Type";
427 using PyConcreteType::PyConcreteType;
428
429 static void bindDerived(ClassTy &c) {
430 c.def_static(
431 "get",
432 [](DefaultingPyMlirContext context) {
433 MlirType t = mlirF32TypeGet(context->get());
434 return PyF32Type(context->getRef(), t);
435 },
436 nb::arg("context").none() = nb::none(), "Create a f32 type.");
437 }
438};
439
440/// Floating Point Type subclass - F64Type.
441class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
442public:
443 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
444 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
445 mlirFloat64TypeGetTypeID;
446 static constexpr const char *pyClassName = "F64Type";
447 using PyConcreteType::PyConcreteType;
448
449 static void bindDerived(ClassTy &c) {
450 c.def_static(
451 "get",
452 [](DefaultingPyMlirContext context) {
453 MlirType t = mlirF64TypeGet(context->get());
454 return PyF64Type(context->getRef(), t);
455 },
456 nb::arg("context").none() = nb::none(), "Create a f64 type.");
457 }
458};
459
460/// None Type subclass - NoneType.
461class PyNoneType : public PyConcreteType<PyNoneType> {
462public:
463 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
464 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
465 mlirNoneTypeGetTypeID;
466 static constexpr const char *pyClassName = "NoneType";
467 using PyConcreteType::PyConcreteType;
468
469 static void bindDerived(ClassTy &c) {
470 c.def_static(
471 "get",
472 [](DefaultingPyMlirContext context) {
473 MlirType t = mlirNoneTypeGet(context->get());
474 return PyNoneType(context->getRef(), t);
475 },
476 nb::arg("context").none() = nb::none(), "Create a none type.");
477 }
478};
479
480/// Complex Type subclass - ComplexType.
481class PyComplexType : public PyConcreteType<PyComplexType> {
482public:
483 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
484 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
485 mlirComplexTypeGetTypeID;
486 static constexpr const char *pyClassName = "ComplexType";
487 using PyConcreteType::PyConcreteType;
488
489 static void bindDerived(ClassTy &c) {
490 c.def_static(
491 "get",
492 [](PyType &elementType) {
493 // The element must be a floating point or integer scalar type.
494 if (mlirTypeIsAIntegerOrFloat(elementType)) {
495 MlirType t = mlirComplexTypeGet(elementType);
496 return PyComplexType(elementType.getContext(), t);
497 }
498 throw nb::value_error(
499 (Twine("invalid '") +
500 nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
501 "' and expected floating point or integer type.")
502 .str()
503 .c_str());
504 },
505 "Create a complex type");
506 c.def_prop_ro(
507 "element_type",
508 [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
509 "Returns element type.");
510 }
511};
512
513} // namespace
514
515// Shaped Type Interface - ShapedType
516void mlir::PyShapedType::bindDerived(ClassTy &c) {
517 c.def_prop_ro(
518 "element_type",
519 [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
520 "Returns the element type of the shaped type.");
521 c.def_prop_ro(
522 "has_rank",
523 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
524 "Returns whether the given shaped type is ranked.");
525 c.def_prop_ro(
526 "rank",
527 [](PyShapedType &self) {
528 self.requireHasRank();
529 return mlirShapedTypeGetRank(self);
530 },
531 "Returns the rank of the given ranked shaped type.");
532 c.def_prop_ro(
533 "has_static_shape",
534 [](PyShapedType &self) -> bool {
535 return mlirShapedTypeHasStaticShape(self);
536 },
537 "Returns whether the given shaped type has a static shape.");
538 c.def(
539 "is_dynamic_dim",
540 [](PyShapedType &self, intptr_t dim) -> bool {
541 self.requireHasRank();
542 return mlirShapedTypeIsDynamicDim(self, dim);
543 },
544 nb::arg("dim"),
545 "Returns whether the dim-th dimension of the given shaped type is "
546 "dynamic.");
547 c.def(
548 "is_static_dim",
549 [](PyShapedType &self, intptr_t dim) -> bool {
550 self.requireHasRank();
551 return mlirShapedTypeIsStaticDim(self, dim);
552 },
553 nb::arg("dim"),
554 "Returns whether the dim-th dimension of the given shaped type is "
555 "static.");
556 c.def(
557 "get_dim_size",
558 [](PyShapedType &self, intptr_t dim) {
559 self.requireHasRank();
560 return mlirShapedTypeGetDimSize(self, dim);
561 },
562 nb::arg("dim"),
563 "Returns the dim-th dimension of the given ranked shaped type.");
564 c.def_static(
565 "is_dynamic_size",
566 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
567 nb::arg("dim_size"),
568 "Returns whether the given dimension size indicates a dynamic "
569 "dimension.");
570 c.def_static(
571 "is_static_size",
572 [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); },
573 nb::arg("dim_size"),
574 "Returns whether the given dimension size indicates a static "
575 "dimension.");
576 c.def(
577 "is_dynamic_stride_or_offset",
578 [](PyShapedType &self, int64_t val) -> bool {
579 self.requireHasRank();
580 return mlirShapedTypeIsDynamicStrideOrOffset(val);
581 },
582 nb::arg("dim_size"),
583 "Returns whether the given value is used as a placeholder for dynamic "
584 "strides and offsets in shaped types.");
585 c.def(
586 "is_static_stride_or_offset",
587 [](PyShapedType &self, int64_t val) -> bool {
588 self.requireHasRank();
589 return mlirShapedTypeIsStaticStrideOrOffset(val);
590 },
591 nb::arg("dim_size"),
592 "Returns whether the given shaped type stride or offset value is "
593 "statically-sized.");
594 c.def_prop_ro(
595 "shape",
596 [](PyShapedType &self) {
597 self.requireHasRank();
598
599 std::vector<int64_t> shape;
600 int64_t rank = mlirShapedTypeGetRank(self);
601 shape.reserve(rank);
602 for (int64_t i = 0; i < rank; ++i)
603 shape.push_back(mlirShapedTypeGetDimSize(self, i));
604 return shape;
605 },
606 "Returns the shape of the ranked shaped type as a list of integers.");
607 c.def_static(
608 "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
609 "Returns the value used to indicate dynamic dimensions in shaped "
610 "types.");
611 c.def_static(
612 "get_dynamic_stride_or_offset",
613 []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
614 "Returns the value used to indicate dynamic strides or offsets in "
615 "shaped types.");
616}
617
618void mlir::PyShapedType::requireHasRank() {
619 if (!mlirShapedTypeHasRank(*this)) {
620 throw nb::value_error(
621 "calling this method requires that the type has a rank.");
622 }
623}
624
625const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
626 mlirTypeIsAShaped;
627
628namespace {
629
630/// Vector Type subclass - VectorType.
631class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
632public:
633 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
634 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
635 mlirVectorTypeGetTypeID;
636 static constexpr const char *pyClassName = "VectorType";
637 using PyConcreteType::PyConcreteType;
638
639 static void bindDerived(ClassTy &c) {
640 c.def_static("get", &PyVectorType::get, nb::arg("shape"),
641 nb::arg("element_type"), nb::kw_only(),
642 nb::arg("scalable").none() = nb::none(),
643 nb::arg("scalable_dims").none() = nb::none(),
644 nb::arg("loc").none() = nb::none(), "Create a vector type")
645 .def_prop_ro(
646 "scalable",
647 [](MlirType self) { return mlirVectorTypeIsScalable(self); })
648 .def_prop_ro("scalable_dims", [](MlirType self) {
649 std::vector<bool> scalableDims;
650 size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
651 scalableDims.reserve(rank);
652 for (size_t i = 0; i < rank; ++i)
653 scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
654 return scalableDims;
655 });
656 }
657
658private:
659 static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
660 std::optional<nb::list> scalable,
661 std::optional<std::vector<int64_t>> scalableDims,
662 DefaultingPyLocation loc) {
663 if (scalable && scalableDims) {
664 throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
665 "are mutually exclusive.");
666 }
667
668 PyMlirContext::ErrorCapture errors(loc->getContext());
669 MlirType type;
670 if (scalable) {
671 if (scalable->size() != shape.size())
672 throw nb::value_error("Expected len(scalable) == len(shape).");
673
674 SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
675 *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
676 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
677 scalableDimFlags.data(),
678 elementType);
679 } else if (scalableDims) {
680 SmallVector<bool> scalableDimFlags(shape.size(), false);
681 for (int64_t dim : *scalableDims) {
682 if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
683 throw nb::value_error("Scalable dimension index out of bounds.");
684 scalableDimFlags[dim] = true;
685 }
686 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
687 scalableDimFlags.data(),
688 elementType);
689 } else {
690 type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
691 elementType);
692 }
693 if (mlirTypeIsNull(type))
694 throw MLIRError("Invalid type", errors.take());
695 return PyVectorType(elementType.getContext(), type);
696 }
697};
698
699/// Ranked Tensor Type subclass - RankedTensorType.
700class PyRankedTensorType
701 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
702public:
703 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
704 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
705 mlirRankedTensorTypeGetTypeID;
706 static constexpr const char *pyClassName = "RankedTensorType";
707 using PyConcreteType::PyConcreteType;
708
709 static void bindDerived(ClassTy &c) {
710 c.def_static(
711 "get",
712 [](std::vector<int64_t> shape, PyType &elementType,
713 std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
714 PyMlirContext::ErrorCapture errors(loc->getContext());
715 MlirType t = mlirRankedTensorTypeGetChecked(
716 loc, shape.size(), shape.data(), elementType,
717 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
718 if (mlirTypeIsNull(t))
719 throw MLIRError("Invalid type", errors.take());
720 return PyRankedTensorType(elementType.getContext(), t);
721 },
722 nb::arg("shape"), nb::arg("element_type"),
723 nb::arg("encoding").none() = nb::none(),
724 nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
725 c.def_prop_ro("encoding",
726 [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
727 MlirAttribute encoding =
728 mlirRankedTensorTypeGetEncoding(self.get());
729 if (mlirAttributeIsNull(encoding))
730 return std::nullopt;
731 return encoding;
732 });
733 }
734};
735
736/// Unranked Tensor Type subclass - UnrankedTensorType.
737class PyUnrankedTensorType
738 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
739public:
740 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
741 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
742 mlirUnrankedTensorTypeGetTypeID;
743 static constexpr const char *pyClassName = "UnrankedTensorType";
744 using PyConcreteType::PyConcreteType;
745
746 static void bindDerived(ClassTy &c) {
747 c.def_static(
748 "get",
749 [](PyType &elementType, DefaultingPyLocation loc) {
750 PyMlirContext::ErrorCapture errors(loc->getContext());
751 MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
752 if (mlirTypeIsNull(t))
753 throw MLIRError("Invalid type", errors.take());
754 return PyUnrankedTensorType(elementType.getContext(), t);
755 },
756 nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
757 "Create a unranked tensor type");
758 }
759};
760
761/// Ranked MemRef Type subclass - MemRefType.
762class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
763public:
764 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
765 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
766 mlirMemRefTypeGetTypeID;
767 static constexpr const char *pyClassName = "MemRefType";
768 using PyConcreteType::PyConcreteType;
769
770 static void bindDerived(ClassTy &c) {
771 c.def_static(
772 "get",
773 [](std::vector<int64_t> shape, PyType &elementType,
774 PyAttribute *layout, PyAttribute *memorySpace,
775 DefaultingPyLocation loc) {
776 PyMlirContext::ErrorCapture errors(loc->getContext());
777 MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
778 MlirAttribute memSpaceAttr =
779 memorySpace ? *memorySpace : mlirAttributeGetNull();
780 MlirType t =
781 mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
782 shape.data(), layoutAttr, memSpaceAttr);
783 if (mlirTypeIsNull(t))
784 throw MLIRError("Invalid type", errors.take());
785 return PyMemRefType(elementType.getContext(), t);
786 },
787 nb::arg("shape"), nb::arg("element_type"),
788 nb::arg("layout").none() = nb::none(),
789 nb::arg("memory_space").none() = nb::none(),
790 nb::arg("loc").none() = nb::none(), "Create a memref type")
791 .def_prop_ro(
792 "layout",
793 [](PyMemRefType &self) -> MlirAttribute {
794 return mlirMemRefTypeGetLayout(self);
795 },
796 "The layout of the MemRef type.")
797 .def(
798 "get_strides_and_offset",
799 [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
800 std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
801 int64_t offset;
802 if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
803 self, strides.data(), &offset)))
804 throw std::runtime_error(
805 "Failed to extract strides and offset from memref.");
806 return {strides, offset};
807 },
808 "The strides and offset of the MemRef type.")
809 .def_prop_ro(
810 "affine_map",
811 [](PyMemRefType &self) -> PyAffineMap {
812 MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
813 return PyAffineMap(self.getContext(), map);
814 },
815 "The layout of the MemRef type as an affine map.")
816 .def_prop_ro(
817 "memory_space",
818 [](PyMemRefType &self) -> std::optional<MlirAttribute> {
819 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
820 if (mlirAttributeIsNull(a))
821 return std::nullopt;
822 return a;
823 },
824 "Returns the memory space of the given MemRef type.");
825 }
826};
827
828/// Unranked MemRef Type subclass - UnrankedMemRefType.
829class PyUnrankedMemRefType
830 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
831public:
832 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
833 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
834 mlirUnrankedMemRefTypeGetTypeID;
835 static constexpr const char *pyClassName = "UnrankedMemRefType";
836 using PyConcreteType::PyConcreteType;
837
838 static void bindDerived(ClassTy &c) {
839 c.def_static(
840 "get",
841 [](PyType &elementType, PyAttribute *memorySpace,
842 DefaultingPyLocation loc) {
843 PyMlirContext::ErrorCapture errors(loc->getContext());
844 MlirAttribute memSpaceAttr = {};
845 if (memorySpace)
846 memSpaceAttr = *memorySpace;
847
848 MlirType t =
849 mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
850 if (mlirTypeIsNull(t))
851 throw MLIRError("Invalid type", errors.take());
852 return PyUnrankedMemRefType(elementType.getContext(), t);
853 },
854 nb::arg("element_type"), nb::arg("memory_space").none(),
855 nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
856 .def_prop_ro(
857 "memory_space",
858 [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
859 MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
860 if (mlirAttributeIsNull(a))
861 return std::nullopt;
862 return a;
863 },
864 "Returns the memory space of the given Unranked MemRef type.");
865 }
866};
867
868/// Tuple Type subclass - TupleType.
869class PyTupleType : public PyConcreteType<PyTupleType> {
870public:
871 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
872 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
873 mlirTupleTypeGetTypeID;
874 static constexpr const char *pyClassName = "TupleType";
875 using PyConcreteType::PyConcreteType;
876
877 static void bindDerived(ClassTy &c) {
878 c.def_static(
879 "get_tuple",
880 [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
881 MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
882 elements.data());
883 return PyTupleType(context->getRef(), t);
884 },
885 nb::arg("elements"), nb::arg("context").none() = nb::none(),
886 "Create a tuple type");
887 c.def(
888 "get_type",
889 [](PyTupleType &self, intptr_t pos) {
890 return mlirTupleTypeGetType(self, pos);
891 },
892 nb::arg("pos"), "Returns the pos-th type in the tuple type.");
893 c.def_prop_ro(
894 "num_types",
895 [](PyTupleType &self) -> intptr_t {
896 return mlirTupleTypeGetNumTypes(self);
897 },
898 "Returns the number of types contained in a tuple.");
899 }
900};
901
902/// Function type.
903class PyFunctionType : public PyConcreteType<PyFunctionType> {
904public:
905 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
906 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
907 mlirFunctionTypeGetTypeID;
908 static constexpr const char *pyClassName = "FunctionType";
909 using PyConcreteType::PyConcreteType;
910
911 static void bindDerived(ClassTy &c) {
912 c.def_static(
913 "get",
914 [](std::vector<MlirType> inputs, std::vector<MlirType> results,
915 DefaultingPyMlirContext context) {
916 MlirType t =
917 mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
918 results.size(), results.data());
919 return PyFunctionType(context->getRef(), t);
920 },
921 nb::arg("inputs"), nb::arg("results"),
922 nb::arg("context").none() = nb::none(),
923 "Gets a FunctionType from a list of input and result types");
924 c.def_prop_ro(
925 "inputs",
926 [](PyFunctionType &self) {
927 MlirType t = self;
928 nb::list types;
929 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
930 ++i) {
931 types.append(mlirFunctionTypeGetInput(t, i));
932 }
933 return types;
934 },
935 "Returns the list of input types in the FunctionType.");
936 c.def_prop_ro(
937 "results",
938 [](PyFunctionType &self) {
939 nb::list types;
940 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
941 ++i) {
942 types.append(mlirFunctionTypeGetResult(self, i));
943 }
944 return types;
945 },
946 "Returns the list of result types in the FunctionType.");
947 }
948};
949
950static MlirStringRef toMlirStringRef(const std::string &s) {
951 return mlirStringRefCreate(s.data(), s.size());
952}
953
954/// Opaque Type subclass - OpaqueType.
955class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
956public:
957 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
958 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
959 mlirOpaqueTypeGetTypeID;
960 static constexpr const char *pyClassName = "OpaqueType";
961 using PyConcreteType::PyConcreteType;
962
963 static void bindDerived(ClassTy &c) {
964 c.def_static(
965 "get",
966 [](std::string dialectNamespace, std::string typeData,
967 DefaultingPyMlirContext context) {
968 MlirType type = mlirOpaqueTypeGet(context->get(),
969 toMlirStringRef(dialectNamespace),
970 toMlirStringRef(typeData));
971 return PyOpaqueType(context->getRef(), type);
972 },
973 nb::arg("dialect_namespace"), nb::arg("buffer"),
974 nb::arg("context").none() = nb::none(),
975 "Create an unregistered (opaque) dialect type.");
976 c.def_prop_ro(
977 "dialect_namespace",
978 [](PyOpaqueType &self) {
979 MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
980 return nb::str(stringRef.data, stringRef.length);
981 },
982 "Returns the dialect namespace for the Opaque type as a string.");
983 c.def_prop_ro(
984 "data",
985 [](PyOpaqueType &self) {
986 MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
987 return nb::str(stringRef.data, stringRef.length);
988 },
989 "Returns the data for the Opaque type as a string.");
990 }
991};
992
993} // namespace
994
995void mlir::python::populateIRTypes(nb::module_ &m) {
996 PyIntegerType::bind(m);
997 PyFloatType::bind(m);
998 PyIndexType::bind(m);
999 PyFloat4E2M1FNType::bind(m);
1000 PyFloat6E2M3FNType::bind(m);
1001 PyFloat6E3M2FNType::bind(m);
1002 PyFloat8E4M3FNType::bind(m);
1003 PyFloat8E5M2Type::bind(m);
1004 PyFloat8E4M3Type::bind(m);
1005 PyFloat8E4M3FNUZType::bind(m);
1006 PyFloat8E4M3B11FNUZType::bind(m);
1007 PyFloat8E5M2FNUZType::bind(m);
1008 PyFloat8E3M4Type::bind(m);
1009 PyFloat8E8M0FNUType::bind(m);
1010 PyBF16Type::bind(m);
1011 PyF16Type::bind(m);
1012 PyTF32Type::bind(m);
1013 PyF32Type::bind(m);
1014 PyF64Type::bind(m);
1015 PyNoneType::bind(m);
1016 PyComplexType::bind(m);
1017 PyShapedType::bind(m);
1018 PyVectorType::bind(m);
1019 PyRankedTensorType::bind(m);
1020 PyUnrankedTensorType::bind(m);
1021 PyMemRefType::bind(m);
1022 PyUnrankedMemRefType::bind(m);
1023 PyTupleType::bind(m);
1024 PyFunctionType::bind(m);
1025 PyOpaqueType::bind(m);
1026}
1027

source code of mlir/lib/Bindings/Python/IRTypes.cpp