1//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9// clang-format off
10#include "IRModule.h"
11#include "mlir/Bindings/Python/IRTypes.h"
12// clang-format on
13
14#include <optional>
15
16#include "IRModule.h"
17#include "NanobindUtils.h"
18#include "mlir-c/BuiltinAttributes.h"
19#include "mlir-c/BuiltinTypes.h"
20#include "mlir-c/Support.h"
21
22namespace nb = nanobind;
23using namespace mlir;
24using namespace mlir::python;
25
26using llvm::SmallVector;
27using llvm::Twine;
28
29namespace {
30
31/// Checks whether the given type is an integer or float type.
32static int mlirTypeIsAIntegerOrFloat(MlirType type) {
33 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
34 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
35}
36
37class PyIntegerType : public PyConcreteType<PyIntegerType> {
38public:
39 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
40 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
41 mlirIntegerTypeGetTypeID;
42 static constexpr const char *pyClassName = "IntegerType";
43 using PyConcreteType::PyConcreteType;
44
45 static void bindDerived(ClassTy &c) {
46 c.def_static(
47 "get_signless",
48 [](unsigned width, DefaultingPyMlirContext context) {
49 MlirType t = mlirIntegerTypeGet(context->get(), width);
50 return PyIntegerType(context->getRef(), t);
51 },
52 nb::arg("width"), nb::arg("context").none() = nb::none(),
53 "Create a signless integer type");
54 c.def_static(
55 "get_signed",
56 [](unsigned width, DefaultingPyMlirContext context) {
57 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
58 return PyIntegerType(context->getRef(), t);
59 },
60 nb::arg("width"), nb::arg("context").none() = nb::none(),
61 "Create a signed integer type");
62 c.def_static(
63 "get_unsigned",
64 [](unsigned width, DefaultingPyMlirContext context) {
65 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
66 return PyIntegerType(context->getRef(), t);
67 },
68 nb::arg("width"), nb::arg("context").none() = nb::none(),
69 "Create an unsigned integer type");
70 c.def_prop_ro(
71 "width",
72 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
73 "Returns the width of the integer type");
74 c.def_prop_ro(
75 "is_signless",
76 [](PyIntegerType &self) -> bool {
77 return mlirIntegerTypeIsSignless(self);
78 },
79 "Returns whether this is a signless integer");
80 c.def_prop_ro(
81 "is_signed",
82 [](PyIntegerType &self) -> bool {
83 return mlirIntegerTypeIsSigned(self);
84 },
85 "Returns whether this is a signed integer");
86 c.def_prop_ro(
87 "is_unsigned",
88 [](PyIntegerType &self) -> bool {
89 return mlirIntegerTypeIsUnsigned(self);
90 },
91 "Returns whether this is an unsigned integer");
92 }
93};
94
95/// Index Type subclass - IndexType.
96class PyIndexType : public PyConcreteType<PyIndexType> {
97public:
98 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
99 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
100 mlirIndexTypeGetTypeID;
101 static constexpr const char *pyClassName = "IndexType";
102 using PyConcreteType::PyConcreteType;
103
104 static void bindDerived(ClassTy &c) {
105 c.def_static(
106 "get",
107 [](DefaultingPyMlirContext context) {
108 MlirType t = mlirIndexTypeGet(context->get());
109 return PyIndexType(context->getRef(), t);
110 },
111 nb::arg("context").none() = nb::none(), "Create a index type.");
112 }
113};
114
115class PyFloatType : public PyConcreteType<PyFloatType> {
116public:
117 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
118 static constexpr const char *pyClassName = "FloatType";
119 using PyConcreteType::PyConcreteType;
120
121 static void bindDerived(ClassTy &c) {
122 c.def_prop_ro(
123 "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
124 "Returns the width of the floating-point type");
125 }
126};
127
128/// Floating Point Type subclass - Float4E2M1FNType.
129class PyFloat4E2M1FNType
130 : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
131public:
132 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
133 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
134 mlirFloat4E2M1FNTypeGetTypeID;
135 static constexpr const char *pyClassName = "Float4E2M1FNType";
136 using PyConcreteType::PyConcreteType;
137
138 static void bindDerived(ClassTy &c) {
139 c.def_static(
140 "get",
141 [](DefaultingPyMlirContext context) {
142 MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
143 return PyFloat4E2M1FNType(context->getRef(), t);
144 },
145 nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
146 }
147};
148
149/// Floating Point Type subclass - Float6E2M3FNType.
150class PyFloat6E2M3FNType
151 : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
152public:
153 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
154 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
155 mlirFloat6E2M3FNTypeGetTypeID;
156 static constexpr const char *pyClassName = "Float6E2M3FNType";
157 using PyConcreteType::PyConcreteType;
158
159 static void bindDerived(ClassTy &c) {
160 c.def_static(
161 "get",
162 [](DefaultingPyMlirContext context) {
163 MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
164 return PyFloat6E2M3FNType(context->getRef(), t);
165 },
166 nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
167 }
168};
169
170/// Floating Point Type subclass - Float6E3M2FNType.
171class PyFloat6E3M2FNType
172 : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
173public:
174 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
175 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
176 mlirFloat6E3M2FNTypeGetTypeID;
177 static constexpr const char *pyClassName = "Float6E3M2FNType";
178 using PyConcreteType::PyConcreteType;
179
180 static void bindDerived(ClassTy &c) {
181 c.def_static(
182 "get",
183 [](DefaultingPyMlirContext context) {
184 MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
185 return PyFloat6E3M2FNType(context->getRef(), t);
186 },
187 nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
188 }
189};
190
191/// Floating Point Type subclass - Float8E4M3FNType.
192class PyFloat8E4M3FNType
193 : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
194public:
195 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
196 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
197 mlirFloat8E4M3FNTypeGetTypeID;
198 static constexpr const char *pyClassName = "Float8E4M3FNType";
199 using PyConcreteType::PyConcreteType;
200
201 static void bindDerived(ClassTy &c) {
202 c.def_static(
203 "get",
204 [](DefaultingPyMlirContext context) {
205 MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
206 return PyFloat8E4M3FNType(context->getRef(), t);
207 },
208 nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
209 }
210};
211
212/// Floating Point Type subclass - Float8E5M2Type.
213class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
214public:
215 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
216 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
217 mlirFloat8E5M2TypeGetTypeID;
218 static constexpr const char *pyClassName = "Float8E5M2Type";
219 using PyConcreteType::PyConcreteType;
220
221 static void bindDerived(ClassTy &c) {
222 c.def_static(
223 "get",
224 [](DefaultingPyMlirContext context) {
225 MlirType t = mlirFloat8E5M2TypeGet(context->get());
226 return PyFloat8E5M2Type(context->getRef(), t);
227 },
228 nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
229 }
230};
231
232/// Floating Point Type subclass - Float8E4M3Type.
233class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
234public:
235 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
236 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
237 mlirFloat8E4M3TypeGetTypeID;
238 static constexpr const char *pyClassName = "Float8E4M3Type";
239 using PyConcreteType::PyConcreteType;
240
241 static void bindDerived(ClassTy &c) {
242 c.def_static(
243 "get",
244 [](DefaultingPyMlirContext context) {
245 MlirType t = mlirFloat8E4M3TypeGet(context->get());
246 return PyFloat8E4M3Type(context->getRef(), t);
247 },
248 nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
249 }
250};
251
252/// Floating Point Type subclass - Float8E4M3FNUZ.
253class PyFloat8E4M3FNUZType
254 : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
255public:
256 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
257 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
258 mlirFloat8E4M3FNUZTypeGetTypeID;
259 static constexpr const char *pyClassName = "Float8E4M3FNUZType";
260 using PyConcreteType::PyConcreteType;
261
262 static void bindDerived(ClassTy &c) {
263 c.def_static(
264 "get",
265 [](DefaultingPyMlirContext context) {
266 MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
267 return PyFloat8E4M3FNUZType(context->getRef(), t);
268 },
269 nb::arg("context").none() = nb::none(),
270 "Create a float8_e4m3fnuz type.");
271 }
272};
273
274/// Floating Point Type subclass - Float8E4M3B11FNUZ.
275class PyFloat8E4M3B11FNUZType
276 : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
277public:
278 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
279 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
280 mlirFloat8E4M3B11FNUZTypeGetTypeID;
281 static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
282 using PyConcreteType::PyConcreteType;
283
284 static void bindDerived(ClassTy &c) {
285 c.def_static(
286 "get",
287 [](DefaultingPyMlirContext context) {
288 MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
289 return PyFloat8E4M3B11FNUZType(context->getRef(), t);
290 },
291 nb::arg("context").none() = nb::none(),
292 "Create a float8_e4m3b11fnuz type.");
293 }
294};
295
296/// Floating Point Type subclass - Float8E5M2FNUZ.
297class PyFloat8E5M2FNUZType
298 : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
299public:
300 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
301 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
302 mlirFloat8E5M2FNUZTypeGetTypeID;
303 static constexpr const char *pyClassName = "Float8E5M2FNUZType";
304 using PyConcreteType::PyConcreteType;
305
306 static void bindDerived(ClassTy &c) {
307 c.def_static(
308 "get",
309 [](DefaultingPyMlirContext context) {
310 MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
311 return PyFloat8E5M2FNUZType(context->getRef(), t);
312 },
313 nb::arg("context").none() = nb::none(),
314 "Create a float8_e5m2fnuz type.");
315 }
316};
317
318/// Floating Point Type subclass - Float8E3M4Type.
319class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
320public:
321 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
322 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
323 mlirFloat8E3M4TypeGetTypeID;
324 static constexpr const char *pyClassName = "Float8E3M4Type";
325 using PyConcreteType::PyConcreteType;
326
327 static void bindDerived(ClassTy &c) {
328 c.def_static(
329 "get",
330 [](DefaultingPyMlirContext context) {
331 MlirType t = mlirFloat8E3M4TypeGet(context->get());
332 return PyFloat8E3M4Type(context->getRef(), t);
333 },
334 nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
335 }
336};
337
338/// Floating Point Type subclass - Float8E8M0FNUType.
339class PyFloat8E8M0FNUType
340 : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
341public:
342 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
343 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
344 mlirFloat8E8M0FNUTypeGetTypeID;
345 static constexpr const char *pyClassName = "Float8E8M0FNUType";
346 using PyConcreteType::PyConcreteType;
347
348 static void bindDerived(ClassTy &c) {
349 c.def_static(
350 "get",
351 [](DefaultingPyMlirContext context) {
352 MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
353 return PyFloat8E8M0FNUType(context->getRef(), t);
354 },
355 nb::arg("context").none() = nb::none(),
356 "Create a float8_e8m0fnu type.");
357 }
358};
359
360/// Floating Point Type subclass - BF16Type.
361class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
362public:
363 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
364 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
365 mlirBFloat16TypeGetTypeID;
366 static constexpr const char *pyClassName = "BF16Type";
367 using PyConcreteType::PyConcreteType;
368
369 static void bindDerived(ClassTy &c) {
370 c.def_static(
371 "get",
372 [](DefaultingPyMlirContext context) {
373 MlirType t = mlirBF16TypeGet(context->get());
374 return PyBF16Type(context->getRef(), t);
375 },
376 nb::arg("context").none() = nb::none(), "Create a bf16 type.");
377 }
378};
379
380/// Floating Point Type subclass - F16Type.
381class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
382public:
383 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
384 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
385 mlirFloat16TypeGetTypeID;
386 static constexpr const char *pyClassName = "F16Type";
387 using PyConcreteType::PyConcreteType;
388
389 static void bindDerived(ClassTy &c) {
390 c.def_static(
391 "get",
392 [](DefaultingPyMlirContext context) {
393 MlirType t = mlirF16TypeGet(context->get());
394 return PyF16Type(context->getRef(), t);
395 },
396 nb::arg("context").none() = nb::none(), "Create a f16 type.");
397 }
398};
399
400/// Floating Point Type subclass - TF32Type.
401class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
402public:
403 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
404 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
405 mlirFloatTF32TypeGetTypeID;
406 static constexpr const char *pyClassName = "FloatTF32Type";
407 using PyConcreteType::PyConcreteType;
408
409 static void bindDerived(ClassTy &c) {
410 c.def_static(
411 "get",
412 [](DefaultingPyMlirContext context) {
413 MlirType t = mlirTF32TypeGet(context->get());
414 return PyTF32Type(context->getRef(), t);
415 },
416 nb::arg("context").none() = nb::none(), "Create a tf32 type.");
417 }
418};
419
420/// Floating Point Type subclass - F32Type.
421class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
422public:
423 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
424 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
425 mlirFloat32TypeGetTypeID;
426 static constexpr const char *pyClassName = "F32Type";
427 using PyConcreteType::PyConcreteType;
428
429 static void bindDerived(ClassTy &c) {
430 c.def_static(
431 "get",
432 [](DefaultingPyMlirContext context) {
433 MlirType t = mlirF32TypeGet(context->get());
434 return PyF32Type(context->getRef(), t);
435 },
436 nb::arg("context").none() = nb::none(), "Create a f32 type.");
437 }
438};
439
440/// Floating Point Type subclass - F64Type.
441class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
442public:
443 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
444 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
445 mlirFloat64TypeGetTypeID;
446 static constexpr const char *pyClassName = "F64Type";
447 using PyConcreteType::PyConcreteType;
448
449 static void bindDerived(ClassTy &c) {
450 c.def_static(
451 "get",
452 [](DefaultingPyMlirContext context) {
453 MlirType t = mlirF64TypeGet(context->get());
454 return PyF64Type(context->getRef(), t);
455 },
456 nb::arg("context").none() = nb::none(), "Create a f64 type.");
457 }
458};
459
460/// None Type subclass - NoneType.
461class PyNoneType : public PyConcreteType<PyNoneType> {
462public:
463 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
464 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
465 mlirNoneTypeGetTypeID;
466 static constexpr const char *pyClassName = "NoneType";
467 using PyConcreteType::PyConcreteType;
468
469 static void bindDerived(ClassTy &c) {
470 c.def_static(
471 "get",
472 [](DefaultingPyMlirContext context) {
473 MlirType t = mlirNoneTypeGet(context->get());
474 return PyNoneType(context->getRef(), t);
475 },
476 nb::arg("context").none() = nb::none(), "Create a none type.");
477 }
478};
479
480/// Complex Type subclass - ComplexType.
481class PyComplexType : public PyConcreteType<PyComplexType> {
482public:
483 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
484 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
485 mlirComplexTypeGetTypeID;
486 static constexpr const char *pyClassName = "ComplexType";
487 using PyConcreteType::PyConcreteType;
488
489 static void bindDerived(ClassTy &c) {
490 c.def_static(
491 "get",
492 [](PyType &elementType) {
493 // The element must be a floating point or integer scalar type.
494 if (mlirTypeIsAIntegerOrFloat(elementType)) {
495 MlirType t = mlirComplexTypeGet(elementType);
496 return PyComplexType(elementType.getContext(), t);
497 }
498 throw nb::value_error(
499 (Twine("invalid '") +
500 nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
501 "' and expected floating point or integer type.")
502 .str()
503 .c_str());
504 },
505 "Create a complex type");
506 c.def_prop_ro(
507 "element_type",
508 [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
509 "Returns element type.");
510 }
511};
512
513} // namespace
514
515// Shaped Type Interface - ShapedType
516void mlir::PyShapedType::bindDerived(ClassTy &c) {
517 c.def_prop_ro(
518 "element_type",
519 [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
520 "Returns the element type of the shaped type.");
521 c.def_prop_ro(
522 "has_rank",
523 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
524 "Returns whether the given shaped type is ranked.");
525 c.def_prop_ro(
526 "rank",
527 [](PyShapedType &self) {
528 self.requireHasRank();
529 return mlirShapedTypeGetRank(self);
530 },
531 "Returns the rank of the given ranked shaped type.");
532 c.def_prop_ro(
533 "has_static_shape",
534 [](PyShapedType &self) -> bool {
535 return mlirShapedTypeHasStaticShape(self);
536 },
537 "Returns whether the given shaped type has a static shape.");
538 c.def(
539 "is_dynamic_dim",
540 [](PyShapedType &self, intptr_t dim) -> bool {
541 self.requireHasRank();
542 return mlirShapedTypeIsDynamicDim(self, dim);
543 },
544 nb::arg("dim"),
545 "Returns whether the dim-th dimension of the given shaped type is "
546 "dynamic.");
547 c.def(
548 "get_dim_size",
549 [](PyShapedType &self, intptr_t dim) {
550 self.requireHasRank();
551 return mlirShapedTypeGetDimSize(self, dim);
552 },
553 nb::arg("dim"),
554 "Returns the dim-th dimension of the given ranked shaped type.");
555 c.def_static(
556 "is_dynamic_size",
557 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
558 nb::arg("dim_size"),
559 "Returns whether the given dimension size indicates a dynamic "
560 "dimension.");
561 c.def(
562 "is_dynamic_stride_or_offset",
563 [](PyShapedType &self, int64_t val) -> bool {
564 self.requireHasRank();
565 return mlirShapedTypeIsDynamicStrideOrOffset(val);
566 },
567 nb::arg("dim_size"),
568 "Returns whether the given value is used as a placeholder for dynamic "
569 "strides and offsets in shaped types.");
570 c.def_prop_ro(
571 "shape",
572 [](PyShapedType &self) {
573 self.requireHasRank();
574
575 std::vector<int64_t> shape;
576 int64_t rank = mlirShapedTypeGetRank(self);
577 shape.reserve(rank);
578 for (int64_t i = 0; i < rank; ++i)
579 shape.push_back(mlirShapedTypeGetDimSize(self, i));
580 return shape;
581 },
582 "Returns the shape of the ranked shaped type as a list of integers.");
583 c.def_static(
584 "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
585 "Returns the value used to indicate dynamic dimensions in shaped "
586 "types.");
587 c.def_static(
588 "get_dynamic_stride_or_offset",
589 []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
590 "Returns the value used to indicate dynamic strides or offsets in "
591 "shaped types.");
592}
593
594void mlir::PyShapedType::requireHasRank() {
595 if (!mlirShapedTypeHasRank(*this)) {
596 throw nb::value_error(
597 "calling this method requires that the type has a rank.");
598 }
599}
600
601const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
602 mlirTypeIsAShaped;
603
604namespace {
605
606/// Vector Type subclass - VectorType.
607class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
608public:
609 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
610 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
611 mlirVectorTypeGetTypeID;
612 static constexpr const char *pyClassName = "VectorType";
613 using PyConcreteType::PyConcreteType;
614
615 static void bindDerived(ClassTy &c) {
616 c.def_static("get", &PyVectorType::get, nb::arg("shape"),
617 nb::arg("element_type"), nb::kw_only(),
618 nb::arg("scalable").none() = nb::none(),
619 nb::arg("scalable_dims").none() = nb::none(),
620 nb::arg("loc").none() = nb::none(), "Create a vector type")
621 .def_prop_ro(
622 "scalable",
623 [](MlirType self) { return mlirVectorTypeIsScalable(self); })
624 .def_prop_ro("scalable_dims", [](MlirType self) {
625 std::vector<bool> scalableDims;
626 size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
627 scalableDims.reserve(rank);
628 for (size_t i = 0; i < rank; ++i)
629 scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
630 return scalableDims;
631 });
632 }
633
634private:
635 static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
636 std::optional<nb::list> scalable,
637 std::optional<std::vector<int64_t>> scalableDims,
638 DefaultingPyLocation loc) {
639 if (scalable && scalableDims) {
640 throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
641 "are mutually exclusive.");
642 }
643
644 PyMlirContext::ErrorCapture errors(loc->getContext());
645 MlirType type;
646 if (scalable) {
647 if (scalable->size() != shape.size())
648 throw nb::value_error("Expected len(scalable) == len(shape).");
649
650 SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
651 *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
652 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
653 scalableDimFlags.data(),
654 elementType);
655 } else if (scalableDims) {
656 SmallVector<bool> scalableDimFlags(shape.size(), false);
657 for (int64_t dim : *scalableDims) {
658 if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
659 throw nb::value_error("Scalable dimension index out of bounds.");
660 scalableDimFlags[dim] = true;
661 }
662 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
663 scalableDimFlags.data(),
664 elementType);
665 } else {
666 type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
667 elementType);
668 }
669 if (mlirTypeIsNull(type))
670 throw MLIRError("Invalid type", errors.take());
671 return PyVectorType(elementType.getContext(), type);
672 }
673};
674
675/// Ranked Tensor Type subclass - RankedTensorType.
676class PyRankedTensorType
677 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
678public:
679 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
680 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
681 mlirRankedTensorTypeGetTypeID;
682 static constexpr const char *pyClassName = "RankedTensorType";
683 using PyConcreteType::PyConcreteType;
684
685 static void bindDerived(ClassTy &c) {
686 c.def_static(
687 "get",
688 [](std::vector<int64_t> shape, PyType &elementType,
689 std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
690 PyMlirContext::ErrorCapture errors(loc->getContext());
691 MlirType t = mlirRankedTensorTypeGetChecked(
692 loc, shape.size(), shape.data(), elementType,
693 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
694 if (mlirTypeIsNull(t))
695 throw MLIRError("Invalid type", errors.take());
696 return PyRankedTensorType(elementType.getContext(), t);
697 },
698 nb::arg("shape"), nb::arg("element_type"),
699 nb::arg("encoding").none() = nb::none(),
700 nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
701 c.def_prop_ro("encoding",
702 [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
703 MlirAttribute encoding =
704 mlirRankedTensorTypeGetEncoding(self.get());
705 if (mlirAttributeIsNull(encoding))
706 return std::nullopt;
707 return encoding;
708 });
709 }
710};
711
712/// Unranked Tensor Type subclass - UnrankedTensorType.
713class PyUnrankedTensorType
714 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
715public:
716 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
717 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
718 mlirUnrankedTensorTypeGetTypeID;
719 static constexpr const char *pyClassName = "UnrankedTensorType";
720 using PyConcreteType::PyConcreteType;
721
722 static void bindDerived(ClassTy &c) {
723 c.def_static(
724 "get",
725 [](PyType &elementType, DefaultingPyLocation loc) {
726 PyMlirContext::ErrorCapture errors(loc->getContext());
727 MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
728 if (mlirTypeIsNull(t))
729 throw MLIRError("Invalid type", errors.take());
730 return PyUnrankedTensorType(elementType.getContext(), t);
731 },
732 nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
733 "Create a unranked tensor type");
734 }
735};
736
737/// Ranked MemRef Type subclass - MemRefType.
738class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
739public:
740 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
741 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
742 mlirMemRefTypeGetTypeID;
743 static constexpr const char *pyClassName = "MemRefType";
744 using PyConcreteType::PyConcreteType;
745
746 static void bindDerived(ClassTy &c) {
747 c.def_static(
748 "get",
749 [](std::vector<int64_t> shape, PyType &elementType,
750 PyAttribute *layout, PyAttribute *memorySpace,
751 DefaultingPyLocation loc) {
752 PyMlirContext::ErrorCapture errors(loc->getContext());
753 MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
754 MlirAttribute memSpaceAttr =
755 memorySpace ? *memorySpace : mlirAttributeGetNull();
756 MlirType t =
757 mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
758 shape.data(), layoutAttr, memSpaceAttr);
759 if (mlirTypeIsNull(t))
760 throw MLIRError("Invalid type", errors.take());
761 return PyMemRefType(elementType.getContext(), t);
762 },
763 nb::arg("shape"), nb::arg("element_type"),
764 nb::arg("layout").none() = nb::none(),
765 nb::arg("memory_space").none() = nb::none(),
766 nb::arg("loc").none() = nb::none(), "Create a memref type")
767 .def_prop_ro(
768 "layout",
769 [](PyMemRefType &self) -> MlirAttribute {
770 return mlirMemRefTypeGetLayout(self);
771 },
772 "The layout of the MemRef type.")
773 .def(
774 "get_strides_and_offset",
775 [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
776 std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
777 int64_t offset;
778 if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
779 self, strides.data(), &offset)))
780 throw std::runtime_error(
781 "Failed to extract strides and offset from memref.");
782 return {strides, offset};
783 },
784 "The strides and offset of the MemRef type.")
785 .def_prop_ro(
786 "affine_map",
787 [](PyMemRefType &self) -> PyAffineMap {
788 MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
789 return PyAffineMap(self.getContext(), map);
790 },
791 "The layout of the MemRef type as an affine map.")
792 .def_prop_ro(
793 "memory_space",
794 [](PyMemRefType &self) -> std::optional<MlirAttribute> {
795 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
796 if (mlirAttributeIsNull(a))
797 return std::nullopt;
798 return a;
799 },
800 "Returns the memory space of the given MemRef type.");
801 }
802};
803
804/// Unranked MemRef Type subclass - UnrankedMemRefType.
805class PyUnrankedMemRefType
806 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
807public:
808 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
809 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
810 mlirUnrankedMemRefTypeGetTypeID;
811 static constexpr const char *pyClassName = "UnrankedMemRefType";
812 using PyConcreteType::PyConcreteType;
813
814 static void bindDerived(ClassTy &c) {
815 c.def_static(
816 "get",
817 [](PyType &elementType, PyAttribute *memorySpace,
818 DefaultingPyLocation loc) {
819 PyMlirContext::ErrorCapture errors(loc->getContext());
820 MlirAttribute memSpaceAttr = {};
821 if (memorySpace)
822 memSpaceAttr = *memorySpace;
823
824 MlirType t =
825 mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
826 if (mlirTypeIsNull(t))
827 throw MLIRError("Invalid type", errors.take());
828 return PyUnrankedMemRefType(elementType.getContext(), t);
829 },
830 nb::arg("element_type"), nb::arg("memory_space").none(),
831 nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
832 .def_prop_ro(
833 "memory_space",
834 [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
835 MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
836 if (mlirAttributeIsNull(a))
837 return std::nullopt;
838 return a;
839 },
840 "Returns the memory space of the given Unranked MemRef type.");
841 }
842};
843
844/// Tuple Type subclass - TupleType.
845class PyTupleType : public PyConcreteType<PyTupleType> {
846public:
847 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
848 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
849 mlirTupleTypeGetTypeID;
850 static constexpr const char *pyClassName = "TupleType";
851 using PyConcreteType::PyConcreteType;
852
853 static void bindDerived(ClassTy &c) {
854 c.def_static(
855 "get_tuple",
856 [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
857 MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
858 elements.data());
859 return PyTupleType(context->getRef(), t);
860 },
861 nb::arg("elements"), nb::arg("context").none() = nb::none(),
862 "Create a tuple type");
863 c.def(
864 "get_type",
865 [](PyTupleType &self, intptr_t pos) {
866 return mlirTupleTypeGetType(self, pos);
867 },
868 nb::arg("pos"), "Returns the pos-th type in the tuple type.");
869 c.def_prop_ro(
870 "num_types",
871 [](PyTupleType &self) -> intptr_t {
872 return mlirTupleTypeGetNumTypes(self);
873 },
874 "Returns the number of types contained in a tuple.");
875 }
876};
877
878/// Function type.
879class PyFunctionType : public PyConcreteType<PyFunctionType> {
880public:
881 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
882 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
883 mlirFunctionTypeGetTypeID;
884 static constexpr const char *pyClassName = "FunctionType";
885 using PyConcreteType::PyConcreteType;
886
887 static void bindDerived(ClassTy &c) {
888 c.def_static(
889 "get",
890 [](std::vector<MlirType> inputs, std::vector<MlirType> results,
891 DefaultingPyMlirContext context) {
892 MlirType t =
893 mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
894 results.size(), results.data());
895 return PyFunctionType(context->getRef(), t);
896 },
897 nb::arg("inputs"), nb::arg("results"),
898 nb::arg("context").none() = nb::none(),
899 "Gets a FunctionType from a list of input and result types");
900 c.def_prop_ro(
901 "inputs",
902 [](PyFunctionType &self) {
903 MlirType t = self;
904 nb::list types;
905 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
906 ++i) {
907 types.append(mlirFunctionTypeGetInput(t, i));
908 }
909 return types;
910 },
911 "Returns the list of input types in the FunctionType.");
912 c.def_prop_ro(
913 "results",
914 [](PyFunctionType &self) {
915 nb::list types;
916 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
917 ++i) {
918 types.append(mlirFunctionTypeGetResult(self, i));
919 }
920 return types;
921 },
922 "Returns the list of result types in the FunctionType.");
923 }
924};
925
926static MlirStringRef toMlirStringRef(const std::string &s) {
927 return mlirStringRefCreate(s.data(), s.size());
928}
929
930/// Opaque Type subclass - OpaqueType.
931class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
932public:
933 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
934 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
935 mlirOpaqueTypeGetTypeID;
936 static constexpr const char *pyClassName = "OpaqueType";
937 using PyConcreteType::PyConcreteType;
938
939 static void bindDerived(ClassTy &c) {
940 c.def_static(
941 "get",
942 [](std::string dialectNamespace, std::string typeData,
943 DefaultingPyMlirContext context) {
944 MlirType type = mlirOpaqueTypeGet(context->get(),
945 toMlirStringRef(dialectNamespace),
946 toMlirStringRef(typeData));
947 return PyOpaqueType(context->getRef(), type);
948 },
949 nb::arg("dialect_namespace"), nb::arg("buffer"),
950 nb::arg("context").none() = nb::none(),
951 "Create an unregistered (opaque) dialect type.");
952 c.def_prop_ro(
953 "dialect_namespace",
954 [](PyOpaqueType &self) {
955 MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
956 return nb::str(stringRef.data, stringRef.length);
957 },
958 "Returns the dialect namespace for the Opaque type as a string.");
959 c.def_prop_ro(
960 "data",
961 [](PyOpaqueType &self) {
962 MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
963 return nb::str(stringRef.data, stringRef.length);
964 },
965 "Returns the data for the Opaque type as a string.");
966 }
967};
968
969} // namespace
970
971void mlir::python::populateIRTypes(nb::module_ &m) {
972 PyIntegerType::bind(m);
973 PyFloatType::bind(m);
974 PyIndexType::bind(m);
975 PyFloat4E2M1FNType::bind(m);
976 PyFloat6E2M3FNType::bind(m);
977 PyFloat6E3M2FNType::bind(m);
978 PyFloat8E4M3FNType::bind(m);
979 PyFloat8E5M2Type::bind(m);
980 PyFloat8E4M3Type::bind(m);
981 PyFloat8E4M3FNUZType::bind(m);
982 PyFloat8E4M3B11FNUZType::bind(m);
983 PyFloat8E5M2FNUZType::bind(m);
984 PyFloat8E3M4Type::bind(m);
985 PyFloat8E8M0FNUType::bind(m);
986 PyBF16Type::bind(m);
987 PyF16Type::bind(m);
988 PyTF32Type::bind(m);
989 PyF32Type::bind(m);
990 PyF64Type::bind(m);
991 PyNoneType::bind(m);
992 PyComplexType::bind(m);
993 PyShapedType::bind(m);
994 PyVectorType::bind(m);
995 PyRankedTensorType::bind(m);
996 PyUnrankedTensorType::bind(m);
997 PyMemRefType::bind(m);
998 PyUnrankedMemRefType::bind(m);
999 PyTupleType::bind(m);
1000 PyFunctionType::bind(m);
1001 PyOpaqueType::bind(m);
1002}
1003

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Bindings/Python/IRTypes.cpp