1 | //===- IRTypes.cpp - Exports builtin and standard types -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "IRModule.h" |
10 | |
11 | #include "PybindUtils.h" |
12 | |
13 | #include "mlir-c/BuiltinAttributes.h" |
14 | #include "mlir-c/BuiltinTypes.h" |
15 | #include "mlir-c/Support.h" |
16 | |
17 | #include <optional> |
18 | |
19 | namespace py = pybind11; |
20 | using namespace mlir; |
21 | using namespace mlir::python; |
22 | |
23 | using llvm::SmallVector; |
24 | using llvm::Twine; |
25 | |
26 | namespace { |
27 | |
28 | /// Checks whether the given type is an integer or float type. |
29 | static int mlirTypeIsAIntegerOrFloat(MlirType type) { |
30 | return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || |
31 | mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); |
32 | } |
33 | |
34 | class PyIntegerType : public PyConcreteType<PyIntegerType> { |
35 | public: |
36 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; |
37 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
38 | mlirIntegerTypeGetTypeID; |
39 | static constexpr const char *pyClassName = "IntegerType" ; |
40 | using PyConcreteType::PyConcreteType; |
41 | |
42 | static void bindDerived(ClassTy &c) { |
43 | c.def_static( |
44 | "get_signless" , |
45 | [](unsigned width, DefaultingPyMlirContext context) { |
46 | MlirType t = mlirIntegerTypeGet(context->get(), width); |
47 | return PyIntegerType(context->getRef(), t); |
48 | }, |
49 | py::arg("width" ), py::arg("context" ) = py::none(), |
50 | "Create a signless integer type" ); |
51 | c.def_static( |
52 | "get_signed" , |
53 | [](unsigned width, DefaultingPyMlirContext context) { |
54 | MlirType t = mlirIntegerTypeSignedGet(context->get(), width); |
55 | return PyIntegerType(context->getRef(), t); |
56 | }, |
57 | py::arg("width" ), py::arg("context" ) = py::none(), |
58 | "Create a signed integer type" ); |
59 | c.def_static( |
60 | "get_unsigned" , |
61 | [](unsigned width, DefaultingPyMlirContext context) { |
62 | MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); |
63 | return PyIntegerType(context->getRef(), t); |
64 | }, |
65 | py::arg("width" ), py::arg("context" ) = py::none(), |
66 | "Create an unsigned integer type" ); |
67 | c.def_property_readonly( |
68 | "width" , |
69 | [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, |
70 | "Returns the width of the integer type" ); |
71 | c.def_property_readonly( |
72 | "is_signless" , |
73 | [](PyIntegerType &self) -> bool { |
74 | return mlirIntegerTypeIsSignless(self); |
75 | }, |
76 | "Returns whether this is a signless integer" ); |
77 | c.def_property_readonly( |
78 | "is_signed" , |
79 | [](PyIntegerType &self) -> bool { |
80 | return mlirIntegerTypeIsSigned(self); |
81 | }, |
82 | "Returns whether this is a signed integer" ); |
83 | c.def_property_readonly( |
84 | "is_unsigned" , |
85 | [](PyIntegerType &self) -> bool { |
86 | return mlirIntegerTypeIsUnsigned(self); |
87 | }, |
88 | "Returns whether this is an unsigned integer" ); |
89 | } |
90 | }; |
91 | |
92 | /// Index Type subclass - IndexType. |
93 | class PyIndexType : public PyConcreteType<PyIndexType> { |
94 | public: |
95 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; |
96 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
97 | mlirIndexTypeGetTypeID; |
98 | static constexpr const char *pyClassName = "IndexType" ; |
99 | using PyConcreteType::PyConcreteType; |
100 | |
101 | static void bindDerived(ClassTy &c) { |
102 | c.def_static( |
103 | "get" , |
104 | [](DefaultingPyMlirContext context) { |
105 | MlirType t = mlirIndexTypeGet(context->get()); |
106 | return PyIndexType(context->getRef(), t); |
107 | }, |
108 | py::arg("context" ) = py::none(), "Create a index type." ); |
109 | } |
110 | }; |
111 | |
112 | class PyFloatType : public PyConcreteType<PyFloatType> { |
113 | public: |
114 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; |
115 | static constexpr const char *pyClassName = "FloatType" ; |
116 | using PyConcreteType::PyConcreteType; |
117 | |
118 | static void bindDerived(ClassTy &c) { |
119 | c.def_property_readonly( |
120 | "width" , [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, |
121 | "Returns the width of the floating-point type" ); |
122 | } |
123 | }; |
124 | |
125 | /// Floating Point Type subclass - Float8E4M3FNType. |
126 | class PyFloat8E4M3FNType |
127 | : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { |
128 | public: |
129 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; |
130 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
131 | mlirFloat8E4M3FNTypeGetTypeID; |
132 | static constexpr const char *pyClassName = "Float8E4M3FNType" ; |
133 | using PyConcreteType::PyConcreteType; |
134 | |
135 | static void bindDerived(ClassTy &c) { |
136 | c.def_static( |
137 | "get" , |
138 | [](DefaultingPyMlirContext context) { |
139 | MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); |
140 | return PyFloat8E4M3FNType(context->getRef(), t); |
141 | }, |
142 | py::arg("context" ) = py::none(), "Create a float8_e4m3fn type." ); |
143 | } |
144 | }; |
145 | |
146 | /// Floating Point Type subclass - Float8M5E2Type. |
147 | class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { |
148 | public: |
149 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; |
150 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
151 | mlirFloat8E5M2TypeGetTypeID; |
152 | static constexpr const char *pyClassName = "Float8E5M2Type" ; |
153 | using PyConcreteType::PyConcreteType; |
154 | |
155 | static void bindDerived(ClassTy &c) { |
156 | c.def_static( |
157 | "get" , |
158 | [](DefaultingPyMlirContext context) { |
159 | MlirType t = mlirFloat8E5M2TypeGet(context->get()); |
160 | return PyFloat8E5M2Type(context->getRef(), t); |
161 | }, |
162 | py::arg("context" ) = py::none(), "Create a float8_e5m2 type." ); |
163 | } |
164 | }; |
165 | |
166 | /// Floating Point Type subclass - Float8E4M3FNUZ. |
167 | class PyFloat8E4M3FNUZType |
168 | : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { |
169 | public: |
170 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; |
171 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
172 | mlirFloat8E4M3FNUZTypeGetTypeID; |
173 | static constexpr const char *pyClassName = "Float8E4M3FNUZType" ; |
174 | using PyConcreteType::PyConcreteType; |
175 | |
176 | static void bindDerived(ClassTy &c) { |
177 | c.def_static( |
178 | "get" , |
179 | [](DefaultingPyMlirContext context) { |
180 | MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); |
181 | return PyFloat8E4M3FNUZType(context->getRef(), t); |
182 | }, |
183 | py::arg("context" ) = py::none(), "Create a float8_e4m3fnuz type." ); |
184 | } |
185 | }; |
186 | |
187 | /// Floating Point Type subclass - Float8E4M3B11FNUZ. |
188 | class PyFloat8E4M3B11FNUZType |
189 | : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { |
190 | public: |
191 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; |
192 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
193 | mlirFloat8E4M3B11FNUZTypeGetTypeID; |
194 | static constexpr const char *pyClassName = "Float8E4M3B11FNUZType" ; |
195 | using PyConcreteType::PyConcreteType; |
196 | |
197 | static void bindDerived(ClassTy &c) { |
198 | c.def_static( |
199 | "get" , |
200 | [](DefaultingPyMlirContext context) { |
201 | MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); |
202 | return PyFloat8E4M3B11FNUZType(context->getRef(), t); |
203 | }, |
204 | py::arg("context" ) = py::none(), "Create a float8_e4m3b11fnuz type." ); |
205 | } |
206 | }; |
207 | |
208 | /// Floating Point Type subclass - Float8E5M2FNUZ. |
209 | class PyFloat8E5M2FNUZType |
210 | : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { |
211 | public: |
212 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; |
213 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
214 | mlirFloat8E5M2FNUZTypeGetTypeID; |
215 | static constexpr const char *pyClassName = "Float8E5M2FNUZType" ; |
216 | using PyConcreteType::PyConcreteType; |
217 | |
218 | static void bindDerived(ClassTy &c) { |
219 | c.def_static( |
220 | "get" , |
221 | [](DefaultingPyMlirContext context) { |
222 | MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); |
223 | return PyFloat8E5M2FNUZType(context->getRef(), t); |
224 | }, |
225 | py::arg("context" ) = py::none(), "Create a float8_e5m2fnuz type." ); |
226 | } |
227 | }; |
228 | |
229 | /// Floating Point Type subclass - BF16Type. |
230 | class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> { |
231 | public: |
232 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; |
233 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
234 | mlirBFloat16TypeGetTypeID; |
235 | static constexpr const char *pyClassName = "BF16Type" ; |
236 | using PyConcreteType::PyConcreteType; |
237 | |
238 | static void bindDerived(ClassTy &c) { |
239 | c.def_static( |
240 | "get" , |
241 | [](DefaultingPyMlirContext context) { |
242 | MlirType t = mlirBF16TypeGet(context->get()); |
243 | return PyBF16Type(context->getRef(), t); |
244 | }, |
245 | py::arg("context" ) = py::none(), "Create a bf16 type." ); |
246 | } |
247 | }; |
248 | |
249 | /// Floating Point Type subclass - F16Type. |
250 | class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> { |
251 | public: |
252 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; |
253 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
254 | mlirFloat16TypeGetTypeID; |
255 | static constexpr const char *pyClassName = "F16Type" ; |
256 | using PyConcreteType::PyConcreteType; |
257 | |
258 | static void bindDerived(ClassTy &c) { |
259 | c.def_static( |
260 | "get" , |
261 | [](DefaultingPyMlirContext context) { |
262 | MlirType t = mlirF16TypeGet(context->get()); |
263 | return PyF16Type(context->getRef(), t); |
264 | }, |
265 | py::arg("context" ) = py::none(), "Create a f16 type." ); |
266 | } |
267 | }; |
268 | |
269 | /// Floating Point Type subclass - TF32Type. |
270 | class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> { |
271 | public: |
272 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; |
273 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
274 | mlirFloatTF32TypeGetTypeID; |
275 | static constexpr const char *pyClassName = "FloatTF32Type" ; |
276 | using PyConcreteType::PyConcreteType; |
277 | |
278 | static void bindDerived(ClassTy &c) { |
279 | c.def_static( |
280 | "get" , |
281 | [](DefaultingPyMlirContext context) { |
282 | MlirType t = mlirTF32TypeGet(context->get()); |
283 | return PyTF32Type(context->getRef(), t); |
284 | }, |
285 | py::arg("context" ) = py::none(), "Create a tf32 type." ); |
286 | } |
287 | }; |
288 | |
289 | /// Floating Point Type subclass - F32Type. |
290 | class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> { |
291 | public: |
292 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; |
293 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
294 | mlirFloat32TypeGetTypeID; |
295 | static constexpr const char *pyClassName = "F32Type" ; |
296 | using PyConcreteType::PyConcreteType; |
297 | |
298 | static void bindDerived(ClassTy &c) { |
299 | c.def_static( |
300 | "get" , |
301 | [](DefaultingPyMlirContext context) { |
302 | MlirType t = mlirF32TypeGet(context->get()); |
303 | return PyF32Type(context->getRef(), t); |
304 | }, |
305 | py::arg("context" ) = py::none(), "Create a f32 type." ); |
306 | } |
307 | }; |
308 | |
309 | /// Floating Point Type subclass - F64Type. |
310 | class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> { |
311 | public: |
312 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; |
313 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
314 | mlirFloat64TypeGetTypeID; |
315 | static constexpr const char *pyClassName = "F64Type" ; |
316 | using PyConcreteType::PyConcreteType; |
317 | |
318 | static void bindDerived(ClassTy &c) { |
319 | c.def_static( |
320 | "get" , |
321 | [](DefaultingPyMlirContext context) { |
322 | MlirType t = mlirF64TypeGet(context->get()); |
323 | return PyF64Type(context->getRef(), t); |
324 | }, |
325 | py::arg("context" ) = py::none(), "Create a f64 type." ); |
326 | } |
327 | }; |
328 | |
329 | /// None Type subclass - NoneType. |
330 | class PyNoneType : public PyConcreteType<PyNoneType> { |
331 | public: |
332 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; |
333 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
334 | mlirNoneTypeGetTypeID; |
335 | static constexpr const char *pyClassName = "NoneType" ; |
336 | using PyConcreteType::PyConcreteType; |
337 | |
338 | static void bindDerived(ClassTy &c) { |
339 | c.def_static( |
340 | "get" , |
341 | [](DefaultingPyMlirContext context) { |
342 | MlirType t = mlirNoneTypeGet(context->get()); |
343 | return PyNoneType(context->getRef(), t); |
344 | }, |
345 | py::arg("context" ) = py::none(), "Create a none type." ); |
346 | } |
347 | }; |
348 | |
349 | /// Complex Type subclass - ComplexType. |
350 | class PyComplexType : public PyConcreteType<PyComplexType> { |
351 | public: |
352 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; |
353 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
354 | mlirComplexTypeGetTypeID; |
355 | static constexpr const char *pyClassName = "ComplexType" ; |
356 | using PyConcreteType::PyConcreteType; |
357 | |
358 | static void bindDerived(ClassTy &c) { |
359 | c.def_static( |
360 | "get" , |
361 | [](PyType &elementType) { |
362 | // The element must be a floating point or integer scalar type. |
363 | if (mlirTypeIsAIntegerOrFloat(elementType)) { |
364 | MlirType t = mlirComplexTypeGet(elementType); |
365 | return PyComplexType(elementType.getContext(), t); |
366 | } |
367 | throw py::value_error( |
368 | (Twine("invalid '" ) + |
369 | py::repr(py::cast(elementType)).cast<std::string>() + |
370 | "' and expected floating point or integer type." ) |
371 | .str()); |
372 | }, |
373 | "Create a complex type" ); |
374 | c.def_property_readonly( |
375 | "element_type" , |
376 | [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, |
377 | "Returns element type." ); |
378 | } |
379 | }; |
380 | |
381 | class PyShapedType : public PyConcreteType<PyShapedType> { |
382 | public: |
383 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; |
384 | static constexpr const char *pyClassName = "ShapedType" ; |
385 | using PyConcreteType::PyConcreteType; |
386 | |
387 | static void bindDerived(ClassTy &c) { |
388 | c.def_property_readonly( |
389 | "element_type" , |
390 | [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, |
391 | "Returns the element type of the shaped type." ); |
392 | c.def_property_readonly( |
393 | "has_rank" , |
394 | [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, |
395 | "Returns whether the given shaped type is ranked." ); |
396 | c.def_property_readonly( |
397 | "rank" , |
398 | [](PyShapedType &self) { |
399 | self.requireHasRank(); |
400 | return mlirShapedTypeGetRank(self); |
401 | }, |
402 | "Returns the rank of the given ranked shaped type." ); |
403 | c.def_property_readonly( |
404 | "has_static_shape" , |
405 | [](PyShapedType &self) -> bool { |
406 | return mlirShapedTypeHasStaticShape(self); |
407 | }, |
408 | "Returns whether the given shaped type has a static shape." ); |
409 | c.def( |
410 | "is_dynamic_dim" , |
411 | [](PyShapedType &self, intptr_t dim) -> bool { |
412 | self.requireHasRank(); |
413 | return mlirShapedTypeIsDynamicDim(self, dim); |
414 | }, |
415 | py::arg("dim" ), |
416 | "Returns whether the dim-th dimension of the given shaped type is " |
417 | "dynamic." ); |
418 | c.def( |
419 | "get_dim_size" , |
420 | [](PyShapedType &self, intptr_t dim) { |
421 | self.requireHasRank(); |
422 | return mlirShapedTypeGetDimSize(self, dim); |
423 | }, |
424 | py::arg("dim" ), |
425 | "Returns the dim-th dimension of the given ranked shaped type." ); |
426 | c.def_static( |
427 | "is_dynamic_size" , |
428 | [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, |
429 | py::arg("dim_size" ), |
430 | "Returns whether the given dimension size indicates a dynamic " |
431 | "dimension." ); |
432 | c.def( |
433 | "is_dynamic_stride_or_offset" , |
434 | [](PyShapedType &self, int64_t val) -> bool { |
435 | self.requireHasRank(); |
436 | return mlirShapedTypeIsDynamicStrideOrOffset(val); |
437 | }, |
438 | py::arg("dim_size" ), |
439 | "Returns whether the given value is used as a placeholder for dynamic " |
440 | "strides and offsets in shaped types." ); |
441 | c.def_property_readonly( |
442 | "shape" , |
443 | [](PyShapedType &self) { |
444 | self.requireHasRank(); |
445 | |
446 | std::vector<int64_t> shape; |
447 | int64_t rank = mlirShapedTypeGetRank(self); |
448 | shape.reserve(n: rank); |
449 | for (int64_t i = 0; i < rank; ++i) |
450 | shape.push_back(mlirShapedTypeGetDimSize(self, i)); |
451 | return shape; |
452 | }, |
453 | "Returns the shape of the ranked shaped type as a list of integers." ); |
454 | c.def_static( |
455 | "get_dynamic_size" , []() { return mlirShapedTypeGetDynamicSize(); }, |
456 | "Returns the value used to indicate dynamic dimensions in shaped " |
457 | "types." ); |
458 | c.def_static( |
459 | "get_dynamic_stride_or_offset" , |
460 | []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, |
461 | "Returns the value used to indicate dynamic strides or offsets in " |
462 | "shaped types." ); |
463 | } |
464 | |
465 | private: |
466 | void requireHasRank() { |
467 | if (!mlirShapedTypeHasRank(*this)) { |
468 | throw py::value_error( |
469 | "calling this method requires that the type has a rank." ); |
470 | } |
471 | } |
472 | }; |
473 | |
474 | /// Vector Type subclass - VectorType. |
475 | class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { |
476 | public: |
477 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; |
478 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
479 | mlirVectorTypeGetTypeID; |
480 | static constexpr const char *pyClassName = "VectorType" ; |
481 | using PyConcreteType::PyConcreteType; |
482 | |
483 | static void bindDerived(ClassTy &c) { |
484 | c.def_static("get" , &PyVectorType::get, py::arg("shape" ), |
485 | py::arg("element_type" ), py::kw_only(), |
486 | py::arg("scalable" ) = py::none(), |
487 | py::arg("scalable_dims" ) = py::none(), |
488 | py::arg("loc" ) = py::none(), "Create a vector type" ) |
489 | .def_property_readonly( |
490 | "scalable" , |
491 | [](MlirType self) { return mlirVectorTypeIsScalable(self); }) |
492 | .def_property_readonly("scalable_dims" , [](MlirType self) { |
493 | std::vector<bool> scalableDims; |
494 | size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self)); |
495 | scalableDims.reserve(rank); |
496 | for (size_t i = 0; i < rank; ++i) |
497 | scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); |
498 | return scalableDims; |
499 | }); |
500 | } |
501 | |
502 | private: |
503 | static PyVectorType get(std::vector<int64_t> shape, PyType &elementType, |
504 | std::optional<py::list> scalable, |
505 | std::optional<std::vector<int64_t>> scalableDims, |
506 | DefaultingPyLocation loc) { |
507 | if (scalable && scalableDims) { |
508 | throw py::value_error("'scalable' and 'scalable_dims' kwargs " |
509 | "are mutually exclusive." ); |
510 | } |
511 | |
512 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
513 | MlirType type; |
514 | if (scalable) { |
515 | if (scalable->size() != shape.size()) |
516 | throw py::value_error("Expected len(scalable) == len(shape)." ); |
517 | |
518 | SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range( |
519 | *scalable, [](const py::handle &h) { return h.cast<bool>(); })); |
520 | type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), |
521 | scalableDimFlags.data(), |
522 | elementType); |
523 | } else if (scalableDims) { |
524 | SmallVector<bool> scalableDimFlags(shape.size(), false); |
525 | for (int64_t dim : *scalableDims) { |
526 | if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) |
527 | throw py::value_error("Scalable dimension index out of bounds." ); |
528 | scalableDimFlags[dim] = true; |
529 | } |
530 | type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), |
531 | scalableDimFlags.data(), |
532 | elementType); |
533 | } else { |
534 | type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), |
535 | elementType); |
536 | } |
537 | if (mlirTypeIsNull(type)) |
538 | throw MLIRError("Invalid type" , errors.take()); |
539 | return PyVectorType(elementType.getContext(), type); |
540 | } |
541 | }; |
542 | |
543 | /// Ranked Tensor Type subclass - RankedTensorType. |
544 | class PyRankedTensorType |
545 | : public PyConcreteType<PyRankedTensorType, PyShapedType> { |
546 | public: |
547 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; |
548 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
549 | mlirRankedTensorTypeGetTypeID; |
550 | static constexpr const char *pyClassName = "RankedTensorType" ; |
551 | using PyConcreteType::PyConcreteType; |
552 | |
553 | static void bindDerived(ClassTy &c) { |
554 | c.def_static( |
555 | "get" , |
556 | [](std::vector<int64_t> shape, PyType &elementType, |
557 | std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) { |
558 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
559 | MlirType t = mlirRankedTensorTypeGetChecked( |
560 | loc, shape.size(), shape.data(), elementType, |
561 | encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); |
562 | if (mlirTypeIsNull(t)) |
563 | throw MLIRError("Invalid type" , errors.take()); |
564 | return PyRankedTensorType(elementType.getContext(), t); |
565 | }, |
566 | py::arg("shape" ), py::arg("element_type" ), |
567 | py::arg("encoding" ) = py::none(), py::arg("loc" ) = py::none(), |
568 | "Create a ranked tensor type" ); |
569 | c.def_property_readonly( |
570 | "encoding" , |
571 | [](PyRankedTensorType &self) -> std::optional<MlirAttribute> { |
572 | MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); |
573 | if (mlirAttributeIsNull(encoding)) |
574 | return std::nullopt; |
575 | return encoding; |
576 | }); |
577 | } |
578 | }; |
579 | |
580 | /// Unranked Tensor Type subclass - UnrankedTensorType. |
581 | class PyUnrankedTensorType |
582 | : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { |
583 | public: |
584 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; |
585 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
586 | mlirUnrankedTensorTypeGetTypeID; |
587 | static constexpr const char *pyClassName = "UnrankedTensorType" ; |
588 | using PyConcreteType::PyConcreteType; |
589 | |
590 | static void bindDerived(ClassTy &c) { |
591 | c.def_static( |
592 | "get" , |
593 | [](PyType &elementType, DefaultingPyLocation loc) { |
594 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
595 | MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); |
596 | if (mlirTypeIsNull(t)) |
597 | throw MLIRError("Invalid type" , errors.take()); |
598 | return PyUnrankedTensorType(elementType.getContext(), t); |
599 | }, |
600 | py::arg("element_type" ), py::arg("loc" ) = py::none(), |
601 | "Create a unranked tensor type" ); |
602 | } |
603 | }; |
604 | |
605 | /// Ranked MemRef Type subclass - MemRefType. |
606 | class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { |
607 | public: |
608 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; |
609 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
610 | mlirMemRefTypeGetTypeID; |
611 | static constexpr const char *pyClassName = "MemRefType" ; |
612 | using PyConcreteType::PyConcreteType; |
613 | |
614 | static void bindDerived(ClassTy &c) { |
615 | c.def_static( |
616 | "get" , |
617 | [](std::vector<int64_t> shape, PyType &elementType, |
618 | PyAttribute *layout, PyAttribute *memorySpace, |
619 | DefaultingPyLocation loc) { |
620 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
621 | MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); |
622 | MlirAttribute memSpaceAttr = |
623 | memorySpace ? *memorySpace : mlirAttributeGetNull(); |
624 | MlirType t = |
625 | mlirMemRefTypeGetChecked(loc, elementType, shape.size(), |
626 | shape.data(), layoutAttr, memSpaceAttr); |
627 | if (mlirTypeIsNull(t)) |
628 | throw MLIRError("Invalid type" , errors.take()); |
629 | return PyMemRefType(elementType.getContext(), t); |
630 | }, |
631 | py::arg("shape" ), py::arg("element_type" ), |
632 | py::arg("layout" ) = py::none(), py::arg("memory_space" ) = py::none(), |
633 | py::arg("loc" ) = py::none(), "Create a memref type" ) |
634 | .def_property_readonly( |
635 | "layout" , |
636 | [](PyMemRefType &self) -> MlirAttribute { |
637 | return mlirMemRefTypeGetLayout(self); |
638 | }, |
639 | "The layout of the MemRef type." ) |
640 | .def( |
641 | "get_strides_and_offset" , |
642 | [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> { |
643 | std::vector<int64_t> strides(mlirShapedTypeGetRank(self)); |
644 | int64_t offset; |
645 | if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( |
646 | self, strides.data(), &offset))) |
647 | throw std::runtime_error( |
648 | "Failed to extract strides and offset from memref." ); |
649 | return {strides, offset}; |
650 | }, |
651 | "The strides and offset of the MemRef type." ) |
652 | .def_property_readonly( |
653 | "affine_map" , |
654 | [](PyMemRefType &self) -> PyAffineMap { |
655 | MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); |
656 | return PyAffineMap(self.getContext(), map); |
657 | }, |
658 | "The layout of the MemRef type as an affine map." ) |
659 | .def_property_readonly( |
660 | "memory_space" , |
661 | [](PyMemRefType &self) -> std::optional<MlirAttribute> { |
662 | MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); |
663 | if (mlirAttributeIsNull(a)) |
664 | return std::nullopt; |
665 | return a; |
666 | }, |
667 | "Returns the memory space of the given MemRef type." ); |
668 | } |
669 | }; |
670 | |
671 | /// Unranked MemRef Type subclass - UnrankedMemRefType. |
672 | class PyUnrankedMemRefType |
673 | : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { |
674 | public: |
675 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; |
676 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
677 | mlirUnrankedMemRefTypeGetTypeID; |
678 | static constexpr const char *pyClassName = "UnrankedMemRefType" ; |
679 | using PyConcreteType::PyConcreteType; |
680 | |
681 | static void bindDerived(ClassTy &c) { |
682 | c.def_static( |
683 | "get" , |
684 | [](PyType &elementType, PyAttribute *memorySpace, |
685 | DefaultingPyLocation loc) { |
686 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
687 | MlirAttribute memSpaceAttr = {}; |
688 | if (memorySpace) |
689 | memSpaceAttr = *memorySpace; |
690 | |
691 | MlirType t = |
692 | mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); |
693 | if (mlirTypeIsNull(t)) |
694 | throw MLIRError("Invalid type" , errors.take()); |
695 | return PyUnrankedMemRefType(elementType.getContext(), t); |
696 | }, |
697 | py::arg("element_type" ), py::arg("memory_space" ), |
698 | py::arg("loc" ) = py::none(), "Create a unranked memref type" ) |
699 | .def_property_readonly( |
700 | "memory_space" , |
701 | [](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> { |
702 | MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); |
703 | if (mlirAttributeIsNull(a)) |
704 | return std::nullopt; |
705 | return a; |
706 | }, |
707 | "Returns the memory space of the given Unranked MemRef type." ); |
708 | } |
709 | }; |
710 | |
711 | /// Tuple Type subclass - TupleType. |
712 | class PyTupleType : public PyConcreteType<PyTupleType> { |
713 | public: |
714 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; |
715 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
716 | mlirTupleTypeGetTypeID; |
717 | static constexpr const char *pyClassName = "TupleType" ; |
718 | using PyConcreteType::PyConcreteType; |
719 | |
720 | static void bindDerived(ClassTy &c) { |
721 | c.def_static( |
722 | "get_tuple" , |
723 | [](std::vector<MlirType> elements, DefaultingPyMlirContext context) { |
724 | MlirType t = mlirTupleTypeGet(context->get(), elements.size(), |
725 | elements.data()); |
726 | return PyTupleType(context->getRef(), t); |
727 | }, |
728 | py::arg("elements" ), py::arg("context" ) = py::none(), |
729 | "Create a tuple type" ); |
730 | c.def( |
731 | "get_type" , |
732 | [](PyTupleType &self, intptr_t pos) { |
733 | return mlirTupleTypeGetType(self, pos); |
734 | }, |
735 | py::arg("pos" ), "Returns the pos-th type in the tuple type." ); |
736 | c.def_property_readonly( |
737 | "num_types" , |
738 | [](PyTupleType &self) -> intptr_t { |
739 | return mlirTupleTypeGetNumTypes(self); |
740 | }, |
741 | "Returns the number of types contained in a tuple." ); |
742 | } |
743 | }; |
744 | |
745 | /// Function type. |
746 | class PyFunctionType : public PyConcreteType<PyFunctionType> { |
747 | public: |
748 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; |
749 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
750 | mlirFunctionTypeGetTypeID; |
751 | static constexpr const char *pyClassName = "FunctionType" ; |
752 | using PyConcreteType::PyConcreteType; |
753 | |
754 | static void bindDerived(ClassTy &c) { |
755 | c.def_static( |
756 | "get" , |
757 | [](std::vector<MlirType> inputs, std::vector<MlirType> results, |
758 | DefaultingPyMlirContext context) { |
759 | MlirType t = |
760 | mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), |
761 | results.size(), results.data()); |
762 | return PyFunctionType(context->getRef(), t); |
763 | }, |
764 | py::arg("inputs" ), py::arg("results" ), py::arg("context" ) = py::none(), |
765 | "Gets a FunctionType from a list of input and result types" ); |
766 | c.def_property_readonly( |
767 | "inputs" , |
768 | [](PyFunctionType &self) { |
769 | MlirType t = self; |
770 | py::list types; |
771 | for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; |
772 | ++i) { |
773 | types.append(mlirFunctionTypeGetInput(t, i)); |
774 | } |
775 | return types; |
776 | }, |
777 | "Returns the list of input types in the FunctionType." ); |
778 | c.def_property_readonly( |
779 | "results" , |
780 | [](PyFunctionType &self) { |
781 | py::list types; |
782 | for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; |
783 | ++i) { |
784 | types.append(mlirFunctionTypeGetResult(self, i)); |
785 | } |
786 | return types; |
787 | }, |
788 | "Returns the list of result types in the FunctionType." ); |
789 | } |
790 | }; |
791 | |
792 | static MlirStringRef toMlirStringRef(const std::string &s) { |
793 | return mlirStringRefCreate(s.data(), s.size()); |
794 | } |
795 | |
796 | /// Opaque Type subclass - OpaqueType. |
797 | class PyOpaqueType : public PyConcreteType<PyOpaqueType> { |
798 | public: |
799 | static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; |
800 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
801 | mlirOpaqueTypeGetTypeID; |
802 | static constexpr const char *pyClassName = "OpaqueType" ; |
803 | using PyConcreteType::PyConcreteType; |
804 | |
805 | static void bindDerived(ClassTy &c) { |
806 | c.def_static( |
807 | "get" , |
808 | [](std::string dialectNamespace, std::string typeData, |
809 | DefaultingPyMlirContext context) { |
810 | MlirType type = mlirOpaqueTypeGet(context->get(), |
811 | toMlirStringRef(dialectNamespace), |
812 | toMlirStringRef(typeData)); |
813 | return PyOpaqueType(context->getRef(), type); |
814 | }, |
815 | py::arg("dialect_namespace" ), py::arg("buffer" ), |
816 | py::arg("context" ) = py::none(), |
817 | "Create an unregistered (opaque) dialect type." ); |
818 | c.def_property_readonly( |
819 | "dialect_namespace" , |
820 | [](PyOpaqueType &self) { |
821 | MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); |
822 | return py::str(stringRef.data, stringRef.length); |
823 | }, |
824 | "Returns the dialect namespace for the Opaque type as a string." ); |
825 | c.def_property_readonly( |
826 | "data" , |
827 | [](PyOpaqueType &self) { |
828 | MlirStringRef stringRef = mlirOpaqueTypeGetData(self); |
829 | return py::str(stringRef.data, stringRef.length); |
830 | }, |
831 | "Returns the data for the Opaque type as a string." ); |
832 | } |
833 | }; |
834 | |
835 | } // namespace |
836 | |
837 | void mlir::python::populateIRTypes(py::module &m) { |
838 | PyIntegerType::bind(m); |
839 | PyFloatType::bind(m); |
840 | PyIndexType::bind(m); |
841 | PyFloat8E4M3FNType::bind(m); |
842 | PyFloat8E5M2Type::bind(m); |
843 | PyFloat8E4M3FNUZType::bind(m); |
844 | PyFloat8E4M3B11FNUZType::bind(m); |
845 | PyFloat8E5M2FNUZType::bind(m); |
846 | PyBF16Type::bind(m); |
847 | PyF16Type::bind(m); |
848 | PyTF32Type::bind(m); |
849 | PyF32Type::bind(m); |
850 | PyF64Type::bind(m); |
851 | PyNoneType::bind(m); |
852 | PyComplexType::bind(m); |
853 | PyShapedType::bind(m); |
854 | PyVectorType::bind(m); |
855 | PyRankedTensorType::bind(m); |
856 | PyUnrankedTensorType::bind(m); |
857 | PyMemRefType::bind(m); |
858 | PyUnrankedMemRefType::bind(m); |
859 | PyTupleType::bind(m); |
860 | PyFunctionType::bind(m); |
861 | PyOpaqueType::bind(m); |
862 | } |
863 | |