1//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IRModule.h"
10
11#include "PybindUtils.h"
12
13#include "mlir-c/BuiltinAttributes.h"
14#include "mlir-c/BuiltinTypes.h"
15#include "mlir-c/Support.h"
16
17#include <optional>
18
19namespace py = pybind11;
20using namespace mlir;
21using namespace mlir::python;
22
23using llvm::SmallVector;
24using llvm::Twine;
25
26namespace {
27
28/// Checks whether the given type is an integer or float type.
29static int mlirTypeIsAIntegerOrFloat(MlirType type) {
30 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
31 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
32}
33
34class PyIntegerType : public PyConcreteType<PyIntegerType> {
35public:
36 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
37 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
38 mlirIntegerTypeGetTypeID;
39 static constexpr const char *pyClassName = "IntegerType";
40 using PyConcreteType::PyConcreteType;
41
42 static void bindDerived(ClassTy &c) {
43 c.def_static(
44 "get_signless",
45 [](unsigned width, DefaultingPyMlirContext context) {
46 MlirType t = mlirIntegerTypeGet(context->get(), width);
47 return PyIntegerType(context->getRef(), t);
48 },
49 py::arg("width"), py::arg("context") = py::none(),
50 "Create a signless integer type");
51 c.def_static(
52 "get_signed",
53 [](unsigned width, DefaultingPyMlirContext context) {
54 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
55 return PyIntegerType(context->getRef(), t);
56 },
57 py::arg("width"), py::arg("context") = py::none(),
58 "Create a signed integer type");
59 c.def_static(
60 "get_unsigned",
61 [](unsigned width, DefaultingPyMlirContext context) {
62 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
63 return PyIntegerType(context->getRef(), t);
64 },
65 py::arg("width"), py::arg("context") = py::none(),
66 "Create an unsigned integer type");
67 c.def_property_readonly(
68 "width",
69 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
70 "Returns the width of the integer type");
71 c.def_property_readonly(
72 "is_signless",
73 [](PyIntegerType &self) -> bool {
74 return mlirIntegerTypeIsSignless(self);
75 },
76 "Returns whether this is a signless integer");
77 c.def_property_readonly(
78 "is_signed",
79 [](PyIntegerType &self) -> bool {
80 return mlirIntegerTypeIsSigned(self);
81 },
82 "Returns whether this is a signed integer");
83 c.def_property_readonly(
84 "is_unsigned",
85 [](PyIntegerType &self) -> bool {
86 return mlirIntegerTypeIsUnsigned(self);
87 },
88 "Returns whether this is an unsigned integer");
89 }
90};
91
92/// Index Type subclass - IndexType.
93class PyIndexType : public PyConcreteType<PyIndexType> {
94public:
95 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
96 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
97 mlirIndexTypeGetTypeID;
98 static constexpr const char *pyClassName = "IndexType";
99 using PyConcreteType::PyConcreteType;
100
101 static void bindDerived(ClassTy &c) {
102 c.def_static(
103 "get",
104 [](DefaultingPyMlirContext context) {
105 MlirType t = mlirIndexTypeGet(context->get());
106 return PyIndexType(context->getRef(), t);
107 },
108 py::arg("context") = py::none(), "Create a index type.");
109 }
110};
111
112class PyFloatType : public PyConcreteType<PyFloatType> {
113public:
114 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
115 static constexpr const char *pyClassName = "FloatType";
116 using PyConcreteType::PyConcreteType;
117
118 static void bindDerived(ClassTy &c) {
119 c.def_property_readonly(
120 "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
121 "Returns the width of the floating-point type");
122 }
123};
124
125/// Floating Point Type subclass - Float8E4M3FNType.
126class PyFloat8E4M3FNType
127 : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
128public:
129 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
130 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
131 mlirFloat8E4M3FNTypeGetTypeID;
132 static constexpr const char *pyClassName = "Float8E4M3FNType";
133 using PyConcreteType::PyConcreteType;
134
135 static void bindDerived(ClassTy &c) {
136 c.def_static(
137 "get",
138 [](DefaultingPyMlirContext context) {
139 MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
140 return PyFloat8E4M3FNType(context->getRef(), t);
141 },
142 py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
143 }
144};
145
146/// Floating Point Type subclass - Float8M5E2Type.
147class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
148public:
149 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
150 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
151 mlirFloat8E5M2TypeGetTypeID;
152 static constexpr const char *pyClassName = "Float8E5M2Type";
153 using PyConcreteType::PyConcreteType;
154
155 static void bindDerived(ClassTy &c) {
156 c.def_static(
157 "get",
158 [](DefaultingPyMlirContext context) {
159 MlirType t = mlirFloat8E5M2TypeGet(context->get());
160 return PyFloat8E5M2Type(context->getRef(), t);
161 },
162 py::arg("context") = py::none(), "Create a float8_e5m2 type.");
163 }
164};
165
166/// Floating Point Type subclass - Float8E4M3FNUZ.
167class PyFloat8E4M3FNUZType
168 : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
169public:
170 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
171 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
172 mlirFloat8E4M3FNUZTypeGetTypeID;
173 static constexpr const char *pyClassName = "Float8E4M3FNUZType";
174 using PyConcreteType::PyConcreteType;
175
176 static void bindDerived(ClassTy &c) {
177 c.def_static(
178 "get",
179 [](DefaultingPyMlirContext context) {
180 MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
181 return PyFloat8E4M3FNUZType(context->getRef(), t);
182 },
183 py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
184 }
185};
186
187/// Floating Point Type subclass - Float8E4M3B11FNUZ.
188class PyFloat8E4M3B11FNUZType
189 : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
190public:
191 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
192 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
193 mlirFloat8E4M3B11FNUZTypeGetTypeID;
194 static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
195 using PyConcreteType::PyConcreteType;
196
197 static void bindDerived(ClassTy &c) {
198 c.def_static(
199 "get",
200 [](DefaultingPyMlirContext context) {
201 MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
202 return PyFloat8E4M3B11FNUZType(context->getRef(), t);
203 },
204 py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
205 }
206};
207
208/// Floating Point Type subclass - Float8E5M2FNUZ.
209class PyFloat8E5M2FNUZType
210 : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
211public:
212 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
213 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
214 mlirFloat8E5M2FNUZTypeGetTypeID;
215 static constexpr const char *pyClassName = "Float8E5M2FNUZType";
216 using PyConcreteType::PyConcreteType;
217
218 static void bindDerived(ClassTy &c) {
219 c.def_static(
220 "get",
221 [](DefaultingPyMlirContext context) {
222 MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
223 return PyFloat8E5M2FNUZType(context->getRef(), t);
224 },
225 py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
226 }
227};
228
229/// Floating Point Type subclass - BF16Type.
230class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
231public:
232 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
233 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
234 mlirBFloat16TypeGetTypeID;
235 static constexpr const char *pyClassName = "BF16Type";
236 using PyConcreteType::PyConcreteType;
237
238 static void bindDerived(ClassTy &c) {
239 c.def_static(
240 "get",
241 [](DefaultingPyMlirContext context) {
242 MlirType t = mlirBF16TypeGet(context->get());
243 return PyBF16Type(context->getRef(), t);
244 },
245 py::arg("context") = py::none(), "Create a bf16 type.");
246 }
247};
248
249/// Floating Point Type subclass - F16Type.
250class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
251public:
252 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
253 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
254 mlirFloat16TypeGetTypeID;
255 static constexpr const char *pyClassName = "F16Type";
256 using PyConcreteType::PyConcreteType;
257
258 static void bindDerived(ClassTy &c) {
259 c.def_static(
260 "get",
261 [](DefaultingPyMlirContext context) {
262 MlirType t = mlirF16TypeGet(context->get());
263 return PyF16Type(context->getRef(), t);
264 },
265 py::arg("context") = py::none(), "Create a f16 type.");
266 }
267};
268
269/// Floating Point Type subclass - TF32Type.
270class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
271public:
272 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
273 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
274 mlirFloatTF32TypeGetTypeID;
275 static constexpr const char *pyClassName = "FloatTF32Type";
276 using PyConcreteType::PyConcreteType;
277
278 static void bindDerived(ClassTy &c) {
279 c.def_static(
280 "get",
281 [](DefaultingPyMlirContext context) {
282 MlirType t = mlirTF32TypeGet(context->get());
283 return PyTF32Type(context->getRef(), t);
284 },
285 py::arg("context") = py::none(), "Create a tf32 type.");
286 }
287};
288
289/// Floating Point Type subclass - F32Type.
290class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
291public:
292 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
293 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
294 mlirFloat32TypeGetTypeID;
295 static constexpr const char *pyClassName = "F32Type";
296 using PyConcreteType::PyConcreteType;
297
298 static void bindDerived(ClassTy &c) {
299 c.def_static(
300 "get",
301 [](DefaultingPyMlirContext context) {
302 MlirType t = mlirF32TypeGet(context->get());
303 return PyF32Type(context->getRef(), t);
304 },
305 py::arg("context") = py::none(), "Create a f32 type.");
306 }
307};
308
309/// Floating Point Type subclass - F64Type.
310class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
311public:
312 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
313 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
314 mlirFloat64TypeGetTypeID;
315 static constexpr const char *pyClassName = "F64Type";
316 using PyConcreteType::PyConcreteType;
317
318 static void bindDerived(ClassTy &c) {
319 c.def_static(
320 "get",
321 [](DefaultingPyMlirContext context) {
322 MlirType t = mlirF64TypeGet(context->get());
323 return PyF64Type(context->getRef(), t);
324 },
325 py::arg("context") = py::none(), "Create a f64 type.");
326 }
327};
328
329/// None Type subclass - NoneType.
330class PyNoneType : public PyConcreteType<PyNoneType> {
331public:
332 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
333 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
334 mlirNoneTypeGetTypeID;
335 static constexpr const char *pyClassName = "NoneType";
336 using PyConcreteType::PyConcreteType;
337
338 static void bindDerived(ClassTy &c) {
339 c.def_static(
340 "get",
341 [](DefaultingPyMlirContext context) {
342 MlirType t = mlirNoneTypeGet(context->get());
343 return PyNoneType(context->getRef(), t);
344 },
345 py::arg("context") = py::none(), "Create a none type.");
346 }
347};
348
349/// Complex Type subclass - ComplexType.
350class PyComplexType : public PyConcreteType<PyComplexType> {
351public:
352 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
353 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
354 mlirComplexTypeGetTypeID;
355 static constexpr const char *pyClassName = "ComplexType";
356 using PyConcreteType::PyConcreteType;
357
358 static void bindDerived(ClassTy &c) {
359 c.def_static(
360 "get",
361 [](PyType &elementType) {
362 // The element must be a floating point or integer scalar type.
363 if (mlirTypeIsAIntegerOrFloat(elementType)) {
364 MlirType t = mlirComplexTypeGet(elementType);
365 return PyComplexType(elementType.getContext(), t);
366 }
367 throw py::value_error(
368 (Twine("invalid '") +
369 py::repr(py::cast(elementType)).cast<std::string>() +
370 "' and expected floating point or integer type.")
371 .str());
372 },
373 "Create a complex type");
374 c.def_property_readonly(
375 "element_type",
376 [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
377 "Returns element type.");
378 }
379};
380
381class PyShapedType : public PyConcreteType<PyShapedType> {
382public:
383 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
384 static constexpr const char *pyClassName = "ShapedType";
385 using PyConcreteType::PyConcreteType;
386
387 static void bindDerived(ClassTy &c) {
388 c.def_property_readonly(
389 "element_type",
390 [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
391 "Returns the element type of the shaped type.");
392 c.def_property_readonly(
393 "has_rank",
394 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
395 "Returns whether the given shaped type is ranked.");
396 c.def_property_readonly(
397 "rank",
398 [](PyShapedType &self) {
399 self.requireHasRank();
400 return mlirShapedTypeGetRank(self);
401 },
402 "Returns the rank of the given ranked shaped type.");
403 c.def_property_readonly(
404 "has_static_shape",
405 [](PyShapedType &self) -> bool {
406 return mlirShapedTypeHasStaticShape(self);
407 },
408 "Returns whether the given shaped type has a static shape.");
409 c.def(
410 "is_dynamic_dim",
411 [](PyShapedType &self, intptr_t dim) -> bool {
412 self.requireHasRank();
413 return mlirShapedTypeIsDynamicDim(self, dim);
414 },
415 py::arg("dim"),
416 "Returns whether the dim-th dimension of the given shaped type is "
417 "dynamic.");
418 c.def(
419 "get_dim_size",
420 [](PyShapedType &self, intptr_t dim) {
421 self.requireHasRank();
422 return mlirShapedTypeGetDimSize(self, dim);
423 },
424 py::arg("dim"),
425 "Returns the dim-th dimension of the given ranked shaped type.");
426 c.def_static(
427 "is_dynamic_size",
428 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
429 py::arg("dim_size"),
430 "Returns whether the given dimension size indicates a dynamic "
431 "dimension.");
432 c.def(
433 "is_dynamic_stride_or_offset",
434 [](PyShapedType &self, int64_t val) -> bool {
435 self.requireHasRank();
436 return mlirShapedTypeIsDynamicStrideOrOffset(val);
437 },
438 py::arg("dim_size"),
439 "Returns whether the given value is used as a placeholder for dynamic "
440 "strides and offsets in shaped types.");
441 c.def_property_readonly(
442 "shape",
443 [](PyShapedType &self) {
444 self.requireHasRank();
445
446 std::vector<int64_t> shape;
447 int64_t rank = mlirShapedTypeGetRank(self);
448 shape.reserve(n: rank);
449 for (int64_t i = 0; i < rank; ++i)
450 shape.push_back(mlirShapedTypeGetDimSize(self, i));
451 return shape;
452 },
453 "Returns the shape of the ranked shaped type as a list of integers.");
454 c.def_static(
455 "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
456 "Returns the value used to indicate dynamic dimensions in shaped "
457 "types.");
458 c.def_static(
459 "get_dynamic_stride_or_offset",
460 []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
461 "Returns the value used to indicate dynamic strides or offsets in "
462 "shaped types.");
463 }
464
465private:
466 void requireHasRank() {
467 if (!mlirShapedTypeHasRank(*this)) {
468 throw py::value_error(
469 "calling this method requires that the type has a rank.");
470 }
471 }
472};
473
474/// Vector Type subclass - VectorType.
475class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
476public:
477 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
478 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
479 mlirVectorTypeGetTypeID;
480 static constexpr const char *pyClassName = "VectorType";
481 using PyConcreteType::PyConcreteType;
482
483 static void bindDerived(ClassTy &c) {
484 c.def_static("get", &PyVectorType::get, py::arg("shape"),
485 py::arg("element_type"), py::kw_only(),
486 py::arg("scalable") = py::none(),
487 py::arg("scalable_dims") = py::none(),
488 py::arg("loc") = py::none(), "Create a vector type")
489 .def_property_readonly(
490 "scalable",
491 [](MlirType self) { return mlirVectorTypeIsScalable(self); })
492 .def_property_readonly("scalable_dims", [](MlirType self) {
493 std::vector<bool> scalableDims;
494 size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
495 scalableDims.reserve(rank);
496 for (size_t i = 0; i < rank; ++i)
497 scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
498 return scalableDims;
499 });
500 }
501
502private:
503 static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
504 std::optional<py::list> scalable,
505 std::optional<std::vector<int64_t>> scalableDims,
506 DefaultingPyLocation loc) {
507 if (scalable && scalableDims) {
508 throw py::value_error("'scalable' and 'scalable_dims' kwargs "
509 "are mutually exclusive.");
510 }
511
512 PyMlirContext::ErrorCapture errors(loc->getContext());
513 MlirType type;
514 if (scalable) {
515 if (scalable->size() != shape.size())
516 throw py::value_error("Expected len(scalable) == len(shape).");
517
518 SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
519 *scalable, [](const py::handle &h) { return h.cast<bool>(); }));
520 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
521 scalableDimFlags.data(),
522 elementType);
523 } else if (scalableDims) {
524 SmallVector<bool> scalableDimFlags(shape.size(), false);
525 for (int64_t dim : *scalableDims) {
526 if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
527 throw py::value_error("Scalable dimension index out of bounds.");
528 scalableDimFlags[dim] = true;
529 }
530 type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
531 scalableDimFlags.data(),
532 elementType);
533 } else {
534 type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
535 elementType);
536 }
537 if (mlirTypeIsNull(type))
538 throw MLIRError("Invalid type", errors.take());
539 return PyVectorType(elementType.getContext(), type);
540 }
541};
542
543/// Ranked Tensor Type subclass - RankedTensorType.
544class PyRankedTensorType
545 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
546public:
547 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
548 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
549 mlirRankedTensorTypeGetTypeID;
550 static constexpr const char *pyClassName = "RankedTensorType";
551 using PyConcreteType::PyConcreteType;
552
553 static void bindDerived(ClassTy &c) {
554 c.def_static(
555 "get",
556 [](std::vector<int64_t> shape, PyType &elementType,
557 std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
558 PyMlirContext::ErrorCapture errors(loc->getContext());
559 MlirType t = mlirRankedTensorTypeGetChecked(
560 loc, shape.size(), shape.data(), elementType,
561 encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
562 if (mlirTypeIsNull(t))
563 throw MLIRError("Invalid type", errors.take());
564 return PyRankedTensorType(elementType.getContext(), t);
565 },
566 py::arg("shape"), py::arg("element_type"),
567 py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
568 "Create a ranked tensor type");
569 c.def_property_readonly(
570 "encoding",
571 [](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
572 MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
573 if (mlirAttributeIsNull(encoding))
574 return std::nullopt;
575 return encoding;
576 });
577 }
578};
579
580/// Unranked Tensor Type subclass - UnrankedTensorType.
581class PyUnrankedTensorType
582 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
583public:
584 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
585 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
586 mlirUnrankedTensorTypeGetTypeID;
587 static constexpr const char *pyClassName = "UnrankedTensorType";
588 using PyConcreteType::PyConcreteType;
589
590 static void bindDerived(ClassTy &c) {
591 c.def_static(
592 "get",
593 [](PyType &elementType, DefaultingPyLocation loc) {
594 PyMlirContext::ErrorCapture errors(loc->getContext());
595 MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
596 if (mlirTypeIsNull(t))
597 throw MLIRError("Invalid type", errors.take());
598 return PyUnrankedTensorType(elementType.getContext(), t);
599 },
600 py::arg("element_type"), py::arg("loc") = py::none(),
601 "Create a unranked tensor type");
602 }
603};
604
605/// Ranked MemRef Type subclass - MemRefType.
606class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
607public:
608 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
609 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
610 mlirMemRefTypeGetTypeID;
611 static constexpr const char *pyClassName = "MemRefType";
612 using PyConcreteType::PyConcreteType;
613
614 static void bindDerived(ClassTy &c) {
615 c.def_static(
616 "get",
617 [](std::vector<int64_t> shape, PyType &elementType,
618 PyAttribute *layout, PyAttribute *memorySpace,
619 DefaultingPyLocation loc) {
620 PyMlirContext::ErrorCapture errors(loc->getContext());
621 MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
622 MlirAttribute memSpaceAttr =
623 memorySpace ? *memorySpace : mlirAttributeGetNull();
624 MlirType t =
625 mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
626 shape.data(), layoutAttr, memSpaceAttr);
627 if (mlirTypeIsNull(t))
628 throw MLIRError("Invalid type", errors.take());
629 return PyMemRefType(elementType.getContext(), t);
630 },
631 py::arg("shape"), py::arg("element_type"),
632 py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
633 py::arg("loc") = py::none(), "Create a memref type")
634 .def_property_readonly(
635 "layout",
636 [](PyMemRefType &self) -> MlirAttribute {
637 return mlirMemRefTypeGetLayout(self);
638 },
639 "The layout of the MemRef type.")
640 .def(
641 "get_strides_and_offset",
642 [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
643 std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
644 int64_t offset;
645 if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
646 self, strides.data(), &offset)))
647 throw std::runtime_error(
648 "Failed to extract strides and offset from memref.");
649 return {strides, offset};
650 },
651 "The strides and offset of the MemRef type.")
652 .def_property_readonly(
653 "affine_map",
654 [](PyMemRefType &self) -> PyAffineMap {
655 MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
656 return PyAffineMap(self.getContext(), map);
657 },
658 "The layout of the MemRef type as an affine map.")
659 .def_property_readonly(
660 "memory_space",
661 [](PyMemRefType &self) -> std::optional<MlirAttribute> {
662 MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
663 if (mlirAttributeIsNull(a))
664 return std::nullopt;
665 return a;
666 },
667 "Returns the memory space of the given MemRef type.");
668 }
669};
670
671/// Unranked MemRef Type subclass - UnrankedMemRefType.
672class PyUnrankedMemRefType
673 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
674public:
675 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
676 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
677 mlirUnrankedMemRefTypeGetTypeID;
678 static constexpr const char *pyClassName = "UnrankedMemRefType";
679 using PyConcreteType::PyConcreteType;
680
681 static void bindDerived(ClassTy &c) {
682 c.def_static(
683 "get",
684 [](PyType &elementType, PyAttribute *memorySpace,
685 DefaultingPyLocation loc) {
686 PyMlirContext::ErrorCapture errors(loc->getContext());
687 MlirAttribute memSpaceAttr = {};
688 if (memorySpace)
689 memSpaceAttr = *memorySpace;
690
691 MlirType t =
692 mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
693 if (mlirTypeIsNull(t))
694 throw MLIRError("Invalid type", errors.take());
695 return PyUnrankedMemRefType(elementType.getContext(), t);
696 },
697 py::arg("element_type"), py::arg("memory_space"),
698 py::arg("loc") = py::none(), "Create a unranked memref type")
699 .def_property_readonly(
700 "memory_space",
701 [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
702 MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
703 if (mlirAttributeIsNull(a))
704 return std::nullopt;
705 return a;
706 },
707 "Returns the memory space of the given Unranked MemRef type.");
708 }
709};
710
711/// Tuple Type subclass - TupleType.
712class PyTupleType : public PyConcreteType<PyTupleType> {
713public:
714 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
715 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
716 mlirTupleTypeGetTypeID;
717 static constexpr const char *pyClassName = "TupleType";
718 using PyConcreteType::PyConcreteType;
719
720 static void bindDerived(ClassTy &c) {
721 c.def_static(
722 "get_tuple",
723 [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
724 MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
725 elements.data());
726 return PyTupleType(context->getRef(), t);
727 },
728 py::arg("elements"), py::arg("context") = py::none(),
729 "Create a tuple type");
730 c.def(
731 "get_type",
732 [](PyTupleType &self, intptr_t pos) {
733 return mlirTupleTypeGetType(self, pos);
734 },
735 py::arg("pos"), "Returns the pos-th type in the tuple type.");
736 c.def_property_readonly(
737 "num_types",
738 [](PyTupleType &self) -> intptr_t {
739 return mlirTupleTypeGetNumTypes(self);
740 },
741 "Returns the number of types contained in a tuple.");
742 }
743};
744
745/// Function type.
746class PyFunctionType : public PyConcreteType<PyFunctionType> {
747public:
748 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
749 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
750 mlirFunctionTypeGetTypeID;
751 static constexpr const char *pyClassName = "FunctionType";
752 using PyConcreteType::PyConcreteType;
753
754 static void bindDerived(ClassTy &c) {
755 c.def_static(
756 "get",
757 [](std::vector<MlirType> inputs, std::vector<MlirType> results,
758 DefaultingPyMlirContext context) {
759 MlirType t =
760 mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
761 results.size(), results.data());
762 return PyFunctionType(context->getRef(), t);
763 },
764 py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
765 "Gets a FunctionType from a list of input and result types");
766 c.def_property_readonly(
767 "inputs",
768 [](PyFunctionType &self) {
769 MlirType t = self;
770 py::list types;
771 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
772 ++i) {
773 types.append(mlirFunctionTypeGetInput(t, i));
774 }
775 return types;
776 },
777 "Returns the list of input types in the FunctionType.");
778 c.def_property_readonly(
779 "results",
780 [](PyFunctionType &self) {
781 py::list types;
782 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
783 ++i) {
784 types.append(mlirFunctionTypeGetResult(self, i));
785 }
786 return types;
787 },
788 "Returns the list of result types in the FunctionType.");
789 }
790};
791
792static MlirStringRef toMlirStringRef(const std::string &s) {
793 return mlirStringRefCreate(s.data(), s.size());
794}
795
796/// Opaque Type subclass - OpaqueType.
797class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
798public:
799 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
800 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
801 mlirOpaqueTypeGetTypeID;
802 static constexpr const char *pyClassName = "OpaqueType";
803 using PyConcreteType::PyConcreteType;
804
805 static void bindDerived(ClassTy &c) {
806 c.def_static(
807 "get",
808 [](std::string dialectNamespace, std::string typeData,
809 DefaultingPyMlirContext context) {
810 MlirType type = mlirOpaqueTypeGet(context->get(),
811 toMlirStringRef(dialectNamespace),
812 toMlirStringRef(typeData));
813 return PyOpaqueType(context->getRef(), type);
814 },
815 py::arg("dialect_namespace"), py::arg("buffer"),
816 py::arg("context") = py::none(),
817 "Create an unregistered (opaque) dialect type.");
818 c.def_property_readonly(
819 "dialect_namespace",
820 [](PyOpaqueType &self) {
821 MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
822 return py::str(stringRef.data, stringRef.length);
823 },
824 "Returns the dialect namespace for the Opaque type as a string.");
825 c.def_property_readonly(
826 "data",
827 [](PyOpaqueType &self) {
828 MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
829 return py::str(stringRef.data, stringRef.length);
830 },
831 "Returns the data for the Opaque type as a string.");
832 }
833};
834
835} // namespace
836
837void mlir::python::populateIRTypes(py::module &m) {
838 PyIntegerType::bind(m);
839 PyFloatType::bind(m);
840 PyIndexType::bind(m);
841 PyFloat8E4M3FNType::bind(m);
842 PyFloat8E5M2Type::bind(m);
843 PyFloat8E4M3FNUZType::bind(m);
844 PyFloat8E4M3B11FNUZType::bind(m);
845 PyFloat8E5M2FNUZType::bind(m);
846 PyBF16Type::bind(m);
847 PyF16Type::bind(m);
848 PyTF32Type::bind(m);
849 PyF32Type::bind(m);
850 PyF64Type::bind(m);
851 PyNoneType::bind(m);
852 PyComplexType::bind(m);
853 PyShapedType::bind(m);
854 PyVectorType::bind(m);
855 PyRankedTensorType::bind(m);
856 PyUnrankedTensorType::bind(m);
857 PyMemRefType::bind(m);
858 PyUnrankedMemRefType::bind(m);
859 PyTupleType::bind(m);
860 PyFunctionType::bind(m);
861 PyOpaqueType::bind(m);
862}
863

source code of mlir/lib/Bindings/Python/IRTypes.cpp