1//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <cstdint>
10#include <optional>
11#include <string>
12#include <string_view>
13#include <utility>
14
15#include "IRModule.h"
16#include "NanobindUtils.h"
17#include "mlir-c/BuiltinAttributes.h"
18#include "mlir-c/BuiltinTypes.h"
19#include "mlir/Bindings/Python/NanobindAdaptors.h"
20#include "mlir/Bindings/Python/Nanobind.h"
21#include "llvm/ADT/ScopeExit.h"
22#include "llvm/Support/raw_ostream.h"
23
24namespace nb = nanobind;
25using namespace nanobind::literals;
26using namespace mlir;
27using namespace mlir::python;
28
29using llvm::SmallVector;
30
31//------------------------------------------------------------------------------
32// Docstrings (trivial, non-duplicated docstrings are included inline).
33//------------------------------------------------------------------------------
34
35static const char kDenseElementsAttrGetDocstring[] =
36 R"(Gets a DenseElementsAttr from a Python buffer or array.
37
38When `type` is not provided, then some limited type inferencing is done based
39on the buffer format. Support presently exists for 8/16/32/64 signed and
40unsigned integers and float16/float32/float64. DenseElementsAttrs of these
41types can also be converted back to a corresponding buffer.
42
43For conversions outside of these types, a `type=` must be explicitly provided
44and the buffer contents must be bit-castable to the MLIR internal
45representation:
46
47 * Integer types (except for i1): the buffer must be byte aligned to the
48 next byte boundary.
49 * Floating point types: Must be bit-castable to the given floating point
50 size.
51 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
52 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
53 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
54
55If a single element buffer is passed (or for i1, a single byte with value 0
56or 255), then a splat will be created.
57
58Args:
59 array: The array or buffer to convert.
60 signless: If inferring an appropriate MLIR type, use signless types for
61 integers (defaults True).
62 type: Skips inference of the MLIR element type and uses this instead. The
63 storage size must be consistent with the actual contents of the buffer.
64 shape: Overrides the shape of the buffer when constructing the MLIR
65 shaped type. This is needed when the physical and logical shape differ (as
66 for i1).
67 context: Explicit context, if not from context manager.
68
69Returns:
70 DenseElementsAttr on success.
71
72Raises:
73 ValueError: If the type of the buffer or array cannot be matched to an MLIR
74 type or if the buffer does not meet expectations.
75)";
76
77static const char kDenseElementsAttrGetFromListDocstring[] =
78 R"(Gets a DenseElementsAttr from a Python list of attributes.
79
80Note that it can be expensive to construct attributes individually.
81For a large number of elements, consider using a Python buffer or array instead.
82
83Args:
84 attrs: A list of attributes.
85 type: The desired shape and type of the resulting DenseElementsAttr.
86 If not provided, the element type is determined based on the type
87 of the 0th attribute and the shape is `[len(attrs)]`.
88 context: Explicit context, if not from context manager.
89
90Returns:
91 DenseElementsAttr on success.
92
93Raises:
94 ValueError: If the type of the attributes does not match the type
95 specified by `shaped_type`.
96)";
97
98static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
99 R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
100
101This function does minimal validation or massaging of the data, and it is
102up to the caller to ensure that the buffer meets the characteristics
103implied by the shape.
104
105The backing buffer and any user objects will be retained for the lifetime
106of the resource blob. This is typically bounded to the context but the
107resource can have a shorter lifespan depending on how it is used in
108subsequent processing.
109
110Args:
111 buffer: The array or buffer to convert.
112 name: Name to provide to the resource (may be changed upon collision).
113 type: The explicit ShapedType to construct the attribute with.
114 context: Explicit context, if not from context manager.
115
116Returns:
117 DenseResourceElementsAttr on success.
118
119Raises:
120 ValueError: If the type of the buffer or array cannot be matched to an MLIR
121 type or if the buffer does not meet expectations.
122)";
123
124namespace {
125
126struct nb_buffer_info {
127 void *ptr = nullptr;
128 ssize_t itemsize = 0;
129 ssize_t size = 0;
130 const char *format = nullptr;
131 ssize_t ndim = 0;
132 SmallVector<ssize_t, 4> shape;
133 SmallVector<ssize_t, 4> strides;
134 bool readonly = false;
135
136 nb_buffer_info(
137 void *ptr, ssize_t itemsize, const char *format, ssize_t ndim,
138 SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in,
139 bool readonly = false,
140 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in =
141 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr))
142 : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim),
143 shape(std::move(shape_in)), strides(std::move(strides_in)),
144 readonly(readonly), owned_view(std::move(owned_view_in)) {
145 size = 1;
146 for (ssize_t i = 0; i < ndim; ++i) {
147 size *= shape[i];
148 }
149 }
150
151 explicit nb_buffer_info(Py_buffer *view)
152 : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim,
153 {view->shape, view->shape + view->ndim},
154 // TODO(phawkins): check for null strides
155 {view->strides, view->strides + view->ndim},
156 view->readonly != 0,
157 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(
158 view, PyBuffer_Release)) {}
159
160 nb_buffer_info(const nb_buffer_info &) = delete;
161 nb_buffer_info(nb_buffer_info &&) = default;
162 nb_buffer_info &operator=(const nb_buffer_info &) = delete;
163 nb_buffer_info &operator=(nb_buffer_info &&) = default;
164
165private:
166 std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view;
167};
168
169class nb_buffer : public nb::object {
170 NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
171
172 nb_buffer_info request() const {
173 int flags = PyBUF_STRIDES | PyBUF_FORMAT;
174 auto *view = new Py_buffer();
175 if (PyObject_GetBuffer(ptr(), view, flags) != 0) {
176 delete view;
177 throw nb::python_error();
178 }
179 return nb_buffer_info(view);
180 }
181};
182
183template <typename T>
184struct nb_format_descriptor {};
185
186template <>
187struct nb_format_descriptor<bool> {
188 static const char *format() { return "?"; }
189};
190template <>
191struct nb_format_descriptor<int8_t> {
192 static const char *format() { return "b"; }
193};
194template <>
195struct nb_format_descriptor<uint8_t> {
196 static const char *format() { return "B"; }
197};
198template <>
199struct nb_format_descriptor<int16_t> {
200 static const char *format() { return "h"; }
201};
202template <>
203struct nb_format_descriptor<uint16_t> {
204 static const char *format() { return "H"; }
205};
206template <>
207struct nb_format_descriptor<int32_t> {
208 static const char *format() { return "i"; }
209};
210template <>
211struct nb_format_descriptor<uint32_t> {
212 static const char *format() { return "I"; }
213};
214template <>
215struct nb_format_descriptor<int64_t> {
216 static const char *format() { return "q"; }
217};
218template <>
219struct nb_format_descriptor<uint64_t> {
220 static const char *format() { return "Q"; }
221};
222template <>
223struct nb_format_descriptor<float> {
224 static const char *format() { return "f"; }
225};
226template <>
227struct nb_format_descriptor<double> {
228 static const char *format() { return "d"; }
229};
230
231static MlirStringRef toMlirStringRef(const std::string &s) {
232 return mlirStringRefCreate(s.data(), s.size());
233}
234
235static MlirStringRef toMlirStringRef(const nb::bytes &s) {
236 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
237}
238
239class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
240public:
241 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
242 static constexpr const char *pyClassName = "AffineMapAttr";
243 using PyConcreteAttribute::PyConcreteAttribute;
244 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
245 mlirAffineMapAttrGetTypeID;
246
247 static void bindDerived(ClassTy &c) {
248 c.def_static(
249 "get",
250 [](PyAffineMap &affineMap) {
251 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
252 return PyAffineMapAttribute(affineMap.getContext(), attr);
253 },
254 nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255 c.def_prop_ro("value", mlirAffineMapAttrGetValue,
256 "Returns the value of the AffineMap attribute");
257 }
258};
259
260class PyIntegerSetAttribute
261 : public PyConcreteAttribute<PyIntegerSetAttribute> {
262public:
263 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet;
264 static constexpr const char *pyClassName = "IntegerSetAttr";
265 using PyConcreteAttribute::PyConcreteAttribute;
266 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
267 mlirIntegerSetAttrGetTypeID;
268
269 static void bindDerived(ClassTy &c) {
270 c.def_static(
271 "get",
272 [](PyIntegerSet &integerSet) {
273 MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get());
274 return PyIntegerSetAttribute(integerSet.getContext(), attr);
275 },
276 nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet.");
277 }
278};
279
280template <typename T>
281static T pyTryCast(nb::handle object) {
282 try {
283 return nb::cast<T>(object);
284 } catch (nb::cast_error &err) {
285 std::string msg = std::string("Invalid attribute when attempting to "
286 "create an ArrayAttribute (") +
287 err.what() + ")";
288 throw std::runtime_error(msg.c_str());
289 } catch (std::runtime_error &err) {
290 std::string msg = std::string("Invalid attribute (None?) when attempting "
291 "to create an ArrayAttribute (") +
292 err.what() + ")";
293 throw std::runtime_error(msg.c_str());
294 }
295}
296
297/// A python-wrapped dense array attribute with an element type and a derived
298/// implementation class.
299template <typename EltTy, typename DerivedT>
300class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
301public:
302 using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
303
304 /// Iterator over the integer elements of a dense array.
305 class PyDenseArrayIterator {
306 public:
307 PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
308
309 /// Return a copy of the iterator.
310 PyDenseArrayIterator dunderIter() { return *this; }
311
312 /// Return the next element.
313 EltTy dunderNext() {
314 // Throw if the index has reached the end.
315 if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
316 throw nb::stop_iteration();
317 return DerivedT::getElement(attr.get(), nextIndex++);
318 }
319
320 /// Bind the iterator class.
321 static void bind(nb::module_ &m) {
322 nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName)
323 .def("__iter__", &PyDenseArrayIterator::dunderIter)
324 .def("__next__", &PyDenseArrayIterator::dunderNext);
325 }
326
327 private:
328 /// The referenced dense array attribute.
329 PyAttribute attr;
330 /// The next index to read.
331 int nextIndex = 0;
332 };
333
334 /// Get the element at the given index.
335 EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
336
337 /// Bind the attribute class.
338 static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
339 // Bind the constructor.
340 if constexpr (std::is_same_v<EltTy, bool>) {
341 c.def_static(
342 "get",
343 [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) {
344 std::vector<bool> values;
345 for (nb::handle py_value : py_values) {
346 int is_true = PyObject_IsTrue(py_value.ptr());
347 if (is_true < 0) {
348 throw nb::python_error();
349 }
350 values.push_back(is_true);
351 }
352 return getAttribute(values, ctx->getRef());
353 },
354 nb::arg("values"), nb::arg("context").none() = nb::none(),
355 "Gets a uniqued dense array attribute");
356 } else {
357 c.def_static(
358 "get",
359 [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
360 return getAttribute(values, ctx->getRef());
361 },
362 nb::arg("values"), nb::arg("context").none() = nb::none(),
363 "Gets a uniqued dense array attribute");
364 }
365 // Bind the array methods.
366 c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
367 if (i >= mlirDenseArrayGetNumElements(arr))
368 throw nb::index_error("DenseArray index out of range");
369 return arr.getItem(i);
370 });
371 c.def("__len__", [](const DerivedT &arr) {
372 return mlirDenseArrayGetNumElements(arr);
373 });
374 c.def("__iter__",
375 [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
376 c.def("__add__", [](DerivedT &arr, const nb::list &extras) {
377 std::vector<EltTy> values;
378 intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
379 values.reserve(numOldElements + nb::len(extras));
380 for (intptr_t i = 0; i < numOldElements; ++i)
381 values.push_back(arr.getItem(i));
382 for (nb::handle attr : extras)
383 values.push_back(pyTryCast<EltTy>(attr));
384 return getAttribute(values, ctx: arr.getContext());
385 });
386 }
387
388private:
389 static DerivedT getAttribute(const std::vector<EltTy> &values,
390 PyMlirContextRef ctx) {
391 if constexpr (std::is_same_v<EltTy, bool>) {
392 std::vector<int> intValues(values.begin(), values.end());
393 MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
394 intValues.data());
395 return DerivedT(ctx, attr);
396 } else {
397 MlirAttribute attr =
398 DerivedT::getAttribute(ctx->get(), values.size(), values.data());
399 return DerivedT(ctx, attr);
400 }
401 }
402};
403
404/// Instantiate the python dense array classes.
405struct PyDenseBoolArrayAttribute
406 : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
407 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
408 static constexpr auto getAttribute = mlirDenseBoolArrayGet;
409 static constexpr auto getElement = mlirDenseBoolArrayGetElement;
410 static constexpr const char *pyClassName = "DenseBoolArrayAttr";
411 static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
412 using PyDenseArrayAttribute::PyDenseArrayAttribute;
413};
414struct PyDenseI8ArrayAttribute
415 : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
416 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
417 static constexpr auto getAttribute = mlirDenseI8ArrayGet;
418 static constexpr auto getElement = mlirDenseI8ArrayGetElement;
419 static constexpr const char *pyClassName = "DenseI8ArrayAttr";
420 static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
421 using PyDenseArrayAttribute::PyDenseArrayAttribute;
422};
423struct PyDenseI16ArrayAttribute
424 : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
425 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
426 static constexpr auto getAttribute = mlirDenseI16ArrayGet;
427 static constexpr auto getElement = mlirDenseI16ArrayGetElement;
428 static constexpr const char *pyClassName = "DenseI16ArrayAttr";
429 static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
430 using PyDenseArrayAttribute::PyDenseArrayAttribute;
431};
432struct PyDenseI32ArrayAttribute
433 : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
434 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
435 static constexpr auto getAttribute = mlirDenseI32ArrayGet;
436 static constexpr auto getElement = mlirDenseI32ArrayGetElement;
437 static constexpr const char *pyClassName = "DenseI32ArrayAttr";
438 static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
439 using PyDenseArrayAttribute::PyDenseArrayAttribute;
440};
441struct PyDenseI64ArrayAttribute
442 : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
443 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
444 static constexpr auto getAttribute = mlirDenseI64ArrayGet;
445 static constexpr auto getElement = mlirDenseI64ArrayGetElement;
446 static constexpr const char *pyClassName = "DenseI64ArrayAttr";
447 static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
448 using PyDenseArrayAttribute::PyDenseArrayAttribute;
449};
450struct PyDenseF32ArrayAttribute
451 : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
452 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
453 static constexpr auto getAttribute = mlirDenseF32ArrayGet;
454 static constexpr auto getElement = mlirDenseF32ArrayGetElement;
455 static constexpr const char *pyClassName = "DenseF32ArrayAttr";
456 static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
457 using PyDenseArrayAttribute::PyDenseArrayAttribute;
458};
459struct PyDenseF64ArrayAttribute
460 : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
461 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
462 static constexpr auto getAttribute = mlirDenseF64ArrayGet;
463 static constexpr auto getElement = mlirDenseF64ArrayGetElement;
464 static constexpr const char *pyClassName = "DenseF64ArrayAttr";
465 static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
466 using PyDenseArrayAttribute::PyDenseArrayAttribute;
467};
468
469class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
470public:
471 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
472 static constexpr const char *pyClassName = "ArrayAttr";
473 using PyConcreteAttribute::PyConcreteAttribute;
474 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
475 mlirArrayAttrGetTypeID;
476
477 class PyArrayAttributeIterator {
478 public:
479 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
480
481 PyArrayAttributeIterator &dunderIter() { return *this; }
482
483 MlirAttribute dunderNext() {
484 // TODO: Throw is an inefficient way to stop iteration.
485 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
486 throw nb::stop_iteration();
487 return mlirArrayAttrGetElement(attr.get(), nextIndex++);
488 }
489
490 static void bind(nb::module_ &m) {
491 nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator")
492 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
493 .def("__next__", &PyArrayAttributeIterator::dunderNext);
494 }
495
496 private:
497 PyAttribute attr;
498 int nextIndex = 0;
499 };
500
501 MlirAttribute getItem(intptr_t i) {
502 return mlirArrayAttrGetElement(*this, i);
503 }
504
505 static void bindDerived(ClassTy &c) {
506 c.def_static(
507 "get",
508 [](nb::list attributes, DefaultingPyMlirContext context) {
509 SmallVector<MlirAttribute> mlirAttributes;
510 mlirAttributes.reserve(nb::len(attributes));
511 for (auto attribute : attributes) {
512 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
513 }
514 MlirAttribute attr = mlirArrayAttrGet(
515 context->get(), mlirAttributes.size(), mlirAttributes.data());
516 return PyArrayAttribute(context->getRef(), attr);
517 },
518 nb::arg("attributes"), nb::arg("context").none() = nb::none(),
519 "Gets a uniqued Array attribute");
520 c.def("__getitem__",
521 [](PyArrayAttribute &arr, intptr_t i) {
522 if (i >= mlirArrayAttrGetNumElements(arr))
523 throw nb::index_error("ArrayAttribute index out of range");
524 return arr.getItem(i);
525 })
526 .def("__len__",
527 [](const PyArrayAttribute &arr) {
528 return mlirArrayAttrGetNumElements(arr);
529 })
530 .def("__iter__", [](const PyArrayAttribute &arr) {
531 return PyArrayAttributeIterator(arr);
532 });
533 c.def("__add__", [](PyArrayAttribute arr, nb::list extras) {
534 std::vector<MlirAttribute> attributes;
535 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
536 attributes.reserve(numOldElements + nb::len(extras));
537 for (intptr_t i = 0; i < numOldElements; ++i)
538 attributes.push_back(arr.getItem(i));
539 for (nb::handle attr : extras)
540 attributes.push_back(pyTryCast<PyAttribute>(attr));
541 MlirAttribute arrayAttr = mlirArrayAttrGet(
542 arr.getContext()->get(), attributes.size(), attributes.data());
543 return PyArrayAttribute(arr.getContext(), arrayAttr);
544 });
545 }
546};
547
548/// Float Point Attribute subclass - FloatAttr.
549class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
550public:
551 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
552 static constexpr const char *pyClassName = "FloatAttr";
553 using PyConcreteAttribute::PyConcreteAttribute;
554 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
555 mlirFloatAttrGetTypeID;
556
557 static void bindDerived(ClassTy &c) {
558 c.def_static(
559 "get",
560 [](PyType &type, double value, DefaultingPyLocation loc) {
561 PyMlirContext::ErrorCapture errors(loc->getContext());
562 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
563 if (mlirAttributeIsNull(attr))
564 throw MLIRError("Invalid attribute", errors.take());
565 return PyFloatAttribute(type.getContext(), attr);
566 },
567 nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(),
568 "Gets an uniqued float point attribute associated to a type");
569 c.def_static(
570 "get_f32",
571 [](double value, DefaultingPyMlirContext context) {
572 MlirAttribute attr = mlirFloatAttrDoubleGet(
573 context->get(), mlirF32TypeGet(context->get()), value);
574 return PyFloatAttribute(context->getRef(), attr);
575 },
576 nb::arg("value"), nb::arg("context").none() = nb::none(),
577 "Gets an uniqued float point attribute associated to a f32 type");
578 c.def_static(
579 "get_f64",
580 [](double value, DefaultingPyMlirContext context) {
581 MlirAttribute attr = mlirFloatAttrDoubleGet(
582 context->get(), mlirF64TypeGet(context->get()), value);
583 return PyFloatAttribute(context->getRef(), attr);
584 },
585 nb::arg("value"), nb::arg("context").none() = nb::none(),
586 "Gets an uniqued float point attribute associated to a f64 type");
587 c.def_prop_ro("value", mlirFloatAttrGetValueDouble,
588 "Returns the value of the float attribute");
589 c.def("__float__", mlirFloatAttrGetValueDouble,
590 "Converts the value of the float attribute to a Python float");
591 }
592};
593
594/// Integer Attribute subclass - IntegerAttr.
595class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
596public:
597 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
598 static constexpr const char *pyClassName = "IntegerAttr";
599 using PyConcreteAttribute::PyConcreteAttribute;
600
601 static void bindDerived(ClassTy &c) {
602 c.def_static(
603 "get",
604 [](PyType &type, int64_t value) {
605 MlirAttribute attr = mlirIntegerAttrGet(type, value);
606 return PyIntegerAttribute(type.getContext(), attr);
607 },
608 nb::arg("type"), nb::arg("value"),
609 "Gets an uniqued integer attribute associated to a type");
610 c.def_prop_ro("value", toPyInt,
611 "Returns the value of the integer attribute");
612 c.def("__int__", toPyInt,
613 "Converts the value of the integer attribute to a Python int");
614 c.def_prop_ro_static("static_typeid",
615 [](nb::object & /*class*/) -> MlirTypeID {
616 return mlirIntegerAttrGetTypeID();
617 });
618 }
619
620private:
621 static int64_t toPyInt(PyIntegerAttribute &self) {
622 MlirType type = mlirAttributeGetType(self);
623 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
624 return mlirIntegerAttrGetValueInt(self);
625 if (mlirIntegerTypeIsSigned(type))
626 return mlirIntegerAttrGetValueSInt(self);
627 return mlirIntegerAttrGetValueUInt(self);
628 }
629};
630
631/// Bool Attribute subclass - BoolAttr.
632class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
633public:
634 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
635 static constexpr const char *pyClassName = "BoolAttr";
636 using PyConcreteAttribute::PyConcreteAttribute;
637
638 static void bindDerived(ClassTy &c) {
639 c.def_static(
640 "get",
641 [](bool value, DefaultingPyMlirContext context) {
642 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
643 return PyBoolAttribute(context->getRef(), attr);
644 },
645 nb::arg("value"), nb::arg("context").none() = nb::none(),
646 "Gets an uniqued bool attribute");
647 c.def_prop_ro("value", mlirBoolAttrGetValue,
648 "Returns the value of the bool attribute");
649 c.def("__bool__", mlirBoolAttrGetValue,
650 "Converts the value of the bool attribute to a Python bool");
651 }
652};
653
654class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
655public:
656 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
657 static constexpr const char *pyClassName = "SymbolRefAttr";
658 using PyConcreteAttribute::PyConcreteAttribute;
659
660 static MlirAttribute fromList(const std::vector<std::string> &symbols,
661 PyMlirContext &context) {
662 if (symbols.empty())
663 throw std::runtime_error("SymbolRefAttr must be composed of at least "
664 "one symbol.");
665 MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
666 SmallVector<MlirAttribute, 3> referenceAttrs;
667 for (size_t i = 1; i < symbols.size(); ++i) {
668 referenceAttrs.push_back(
669 mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
670 }
671 return mlirSymbolRefAttrGet(context.get(), rootSymbol,
672 referenceAttrs.size(), referenceAttrs.data());
673 }
674
675 static void bindDerived(ClassTy &c) {
676 c.def_static(
677 "get",
678 [](const std::vector<std::string> &symbols,
679 DefaultingPyMlirContext context) {
680 return PySymbolRefAttribute::fromList(symbols, context.resolve());
681 },
682 nb::arg("symbols"), nb::arg("context").none() = nb::none(),
683 "Gets a uniqued SymbolRef attribute from a list of symbol names");
684 c.def_prop_ro(
685 "value",
686 [](PySymbolRefAttribute &self) {
687 std::vector<std::string> symbols = {
688 unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
689 for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
690 ++i)
691 symbols.push_back(
692 unwrap(mlirSymbolRefAttrGetRootReference(
693 mlirSymbolRefAttrGetNestedReference(self, i)))
694 .str());
695 return symbols;
696 },
697 "Returns the value of the SymbolRef attribute as a list[str]");
698 }
699};
700
701class PyFlatSymbolRefAttribute
702 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
703public:
704 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
705 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
706 using PyConcreteAttribute::PyConcreteAttribute;
707
708 static void bindDerived(ClassTy &c) {
709 c.def_static(
710 "get",
711 [](std::string value, DefaultingPyMlirContext context) {
712 MlirAttribute attr =
713 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
714 return PyFlatSymbolRefAttribute(context->getRef(), attr);
715 },
716 nb::arg("value"), nb::arg("context").none() = nb::none(),
717 "Gets a uniqued FlatSymbolRef attribute");
718 c.def_prop_ro(
719 "value",
720 [](PyFlatSymbolRefAttribute &self) {
721 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
722 return nb::str(stringRef.data, stringRef.length);
723 },
724 "Returns the value of the FlatSymbolRef attribute as a string");
725 }
726};
727
728class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
729public:
730 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
731 static constexpr const char *pyClassName = "OpaqueAttr";
732 using PyConcreteAttribute::PyConcreteAttribute;
733 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
734 mlirOpaqueAttrGetTypeID;
735
736 static void bindDerived(ClassTy &c) {
737 c.def_static(
738 "get",
739 [](std::string dialectNamespace, nb_buffer buffer, PyType &type,
740 DefaultingPyMlirContext context) {
741 const nb_buffer_info bufferInfo = buffer.request();
742 intptr_t bufferSize = bufferInfo.size;
743 MlirAttribute attr = mlirOpaqueAttrGet(
744 context->get(), toMlirStringRef(dialectNamespace), bufferSize,
745 static_cast<char *>(bufferInfo.ptr), type);
746 return PyOpaqueAttribute(context->getRef(), attr);
747 },
748 nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
749 nb::arg("context").none() = nb::none(), "Gets an Opaque attribute.");
750 c.def_prop_ro(
751 "dialect_namespace",
752 [](PyOpaqueAttribute &self) {
753 MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
754 return nb::str(stringRef.data, stringRef.length);
755 },
756 "Returns the dialect namespace for the Opaque attribute as a string");
757 c.def_prop_ro(
758 "data",
759 [](PyOpaqueAttribute &self) {
760 MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
761 return nb::bytes(stringRef.data, stringRef.length);
762 },
763 "Returns the data for the Opaqued attributes as `bytes`");
764 }
765};
766
767class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
768public:
769 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770 static constexpr const char *pyClassName = "StringAttr";
771 using PyConcreteAttribute::PyConcreteAttribute;
772 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
773 mlirStringAttrGetTypeID;
774
775 static void bindDerived(ClassTy &c) {
776 c.def_static(
777 "get",
778 [](std::string value, DefaultingPyMlirContext context) {
779 MlirAttribute attr =
780 mlirStringAttrGet(context->get(), toMlirStringRef(value));
781 return PyStringAttribute(context->getRef(), attr);
782 },
783 nb::arg("value"), nb::arg("context").none() = nb::none(),
784 "Gets a uniqued string attribute");
785 c.def_static(
786 "get",
787 [](nb::bytes value, DefaultingPyMlirContext context) {
788 MlirAttribute attr =
789 mlirStringAttrGet(context->get(), toMlirStringRef(value));
790 return PyStringAttribute(context->getRef(), attr);
791 },
792 nb::arg("value"), nb::arg("context").none() = nb::none(),
793 "Gets a uniqued string attribute");
794 c.def_static(
795 "get_typed",
796 [](PyType &type, std::string value) {
797 MlirAttribute attr =
798 mlirStringAttrTypedGet(type, toMlirStringRef(value));
799 return PyStringAttribute(type.getContext(), attr);
800 },
801 nb::arg("type"), nb::arg("value"),
802 "Gets a uniqued string attribute associated to a type");
803 c.def_prop_ro(
804 "value",
805 [](PyStringAttribute &self) {
806 MlirStringRef stringRef = mlirStringAttrGetValue(self);
807 return nb::str(stringRef.data, stringRef.length);
808 },
809 "Returns the value of the string attribute");
810 c.def_prop_ro(
811 "value_bytes",
812 [](PyStringAttribute &self) {
813 MlirStringRef stringRef = mlirStringAttrGetValue(self);
814 return nb::bytes(stringRef.data, stringRef.length);
815 },
816 "Returns the value of the string attribute as `bytes`");
817 }
818};
819
820// TODO: Support construction of string elements.
821class PyDenseElementsAttribute
822 : public PyConcreteAttribute<PyDenseElementsAttribute> {
823public:
824 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
825 static constexpr const char *pyClassName = "DenseElementsAttr";
826 using PyConcreteAttribute::PyConcreteAttribute;
827
828 static PyDenseElementsAttribute
829 getFromList(nb::list attributes, std::optional<PyType> explicitType,
830 DefaultingPyMlirContext contextWrapper) {
831 const size_t numAttributes = nb::len(attributes);
832 if (numAttributes == 0)
833 throw nb::value_error("Attributes list must be non-empty.");
834
835 MlirType shapedType;
836 if (explicitType) {
837 if ((!mlirTypeIsAShaped(*explicitType) ||
838 !mlirShapedTypeHasStaticShape(*explicitType))) {
839
840 std::string message;
841 llvm::raw_string_ostream os(message);
842 os << "Expected a static ShapedType for the shaped_type parameter: "
843 << nb::cast<std::string>(nb::repr(nb::cast(*explicitType)));
844 throw nb::value_error(message.c_str());
845 }
846 shapedType = *explicitType;
847 } else {
848 SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)};
849 shapedType = mlirRankedTensorTypeGet(
850 shape.size(), shape.data(),
851 mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])),
852 mlirAttributeGetNull());
853 }
854
855 SmallVector<MlirAttribute> mlirAttributes;
856 mlirAttributes.reserve(numAttributes);
857 for (const nb::handle &attribute : attributes) {
858 MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute);
859 MlirType attrType = mlirAttributeGetType(mlirAttribute);
860 mlirAttributes.push_back(mlirAttribute);
861
862 if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) {
863 std::string message;
864 llvm::raw_string_ostream os(message);
865 os << "All attributes must be of the same type and match "
866 << "the type parameter: expected="
867 << nb::cast<std::string>(nb::repr(nb::cast(shapedType)))
868 << ", but got="
869 << nb::cast<std::string>(nb::repr(nb::cast(attrType)));
870 throw nb::value_error(message.c_str());
871 }
872 }
873
874 MlirAttribute elements = mlirDenseElementsAttrGet(
875 shapedType, mlirAttributes.size(), mlirAttributes.data());
876
877 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
878 }
879
880 static PyDenseElementsAttribute
881 getFromBuffer(nb_buffer array, bool signless,
882 std::optional<PyType> explicitType,
883 std::optional<std::vector<int64_t>> explicitShape,
884 DefaultingPyMlirContext contextWrapper) {
885 // Request a contiguous view. In exotic cases, this will cause a copy.
886 int flags = PyBUF_ND;
887 if (!explicitType) {
888 flags |= PyBUF_FORMAT;
889 }
890 Py_buffer view;
891 if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
892 throw nb::python_error();
893 }
894 auto freeBuffer = llvm::make_scope_exit(F: [&]() { PyBuffer_Release(&view); });
895
896 MlirContext context = contextWrapper->get();
897 MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
898 explicitShape, context);
899 if (mlirAttributeIsNull(attr)) {
900 throw std::invalid_argument(
901 "DenseElementsAttr could not be constructed from the given buffer. "
902 "This may mean that the Python buffer layout does not match that "
903 "MLIR expected layout and is a bug.");
904 }
905 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
906 }
907
908 static PyDenseElementsAttribute getSplat(const PyType &shapedType,
909 PyAttribute &elementAttr) {
910 auto contextWrapper =
911 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
912 if (!mlirAttributeIsAInteger(elementAttr) &&
913 !mlirAttributeIsAFloat(elementAttr)) {
914 std::string message = "Illegal element type for DenseElementsAttr: ";
915 message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
916 throw nb::value_error(message.c_str());
917 }
918 if (!mlirTypeIsAShaped(shapedType) ||
919 !mlirShapedTypeHasStaticShape(shapedType)) {
920 std::string message =
921 "Expected a static ShapedType for the shaped_type parameter: ";
922 message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
923 throw nb::value_error(message.c_str());
924 }
925 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
926 MlirType attrType = mlirAttributeGetType(elementAttr);
927 if (!mlirTypeEqual(shapedElementType, attrType)) {
928 std::string message =
929 "Shaped element type and attribute type must be equal: shaped=";
930 message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType))));
931 message.append(s: ", element=");
932 message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr))));
933 throw nb::value_error(message.c_str());
934 }
935
936 MlirAttribute elements =
937 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
938 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
939 }
940
941 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
942
943 std::unique_ptr<nb_buffer_info> accessBuffer() {
944 MlirType shapedType = mlirAttributeGetType(*this);
945 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
946 std::string format;
947
948 if (mlirTypeIsAF32(elementType)) {
949 // f32
950 return bufferInfo<float>(shapedType);
951 }
952 if (mlirTypeIsAF64(elementType)) {
953 // f64
954 return bufferInfo<double>(shapedType);
955 }
956 if (mlirTypeIsAF16(elementType)) {
957 // f16
958 return bufferInfo<uint16_t>(shapedType, "e");
959 }
960 if (mlirTypeIsAIndex(elementType)) {
961 // Same as IndexType::kInternalStorageBitWidth
962 return bufferInfo<int64_t>(shapedType);
963 }
964 if (mlirTypeIsAInteger(elementType) &&
965 mlirIntegerTypeGetWidth(elementType) == 32) {
966 if (mlirIntegerTypeIsSignless(elementType) ||
967 mlirIntegerTypeIsSigned(elementType)) {
968 // i32
969 return bufferInfo<int32_t>(shapedType);
970 }
971 if (mlirIntegerTypeIsUnsigned(elementType)) {
972 // unsigned i32
973 return bufferInfo<uint32_t>(shapedType);
974 }
975 } else if (mlirTypeIsAInteger(elementType) &&
976 mlirIntegerTypeGetWidth(elementType) == 64) {
977 if (mlirIntegerTypeIsSignless(elementType) ||
978 mlirIntegerTypeIsSigned(elementType)) {
979 // i64
980 return bufferInfo<int64_t>(shapedType);
981 }
982 if (mlirIntegerTypeIsUnsigned(elementType)) {
983 // unsigned i64
984 return bufferInfo<uint64_t>(shapedType);
985 }
986 } else if (mlirTypeIsAInteger(elementType) &&
987 mlirIntegerTypeGetWidth(elementType) == 8) {
988 if (mlirIntegerTypeIsSignless(elementType) ||
989 mlirIntegerTypeIsSigned(elementType)) {
990 // i8
991 return bufferInfo<int8_t>(shapedType);
992 }
993 if (mlirIntegerTypeIsUnsigned(elementType)) {
994 // unsigned i8
995 return bufferInfo<uint8_t>(shapedType);
996 }
997 } else if (mlirTypeIsAInteger(elementType) &&
998 mlirIntegerTypeGetWidth(elementType) == 16) {
999 if (mlirIntegerTypeIsSignless(elementType) ||
1000 mlirIntegerTypeIsSigned(elementType)) {
1001 // i16
1002 return bufferInfo<int16_t>(shapedType);
1003 }
1004 if (mlirIntegerTypeIsUnsigned(elementType)) {
1005 // unsigned i16
1006 return bufferInfo<uint16_t>(shapedType);
1007 }
1008 } else if (mlirTypeIsAInteger(elementType) &&
1009 mlirIntegerTypeGetWidth(elementType) == 1) {
1010 // i1 / bool
1011 // We can not send the buffer directly back to Python, because the i1
1012 // values are bitpacked within MLIR. We call numpy's unpackbits function
1013 // to convert the bytes.
1014 return getBooleanBufferFromBitpackedAttribute();
1015 }
1016
1017 // TODO: Currently crashes the program.
1018 // Reported as https://github.com/pybind/pybind11/issues/3336
1019 throw std::invalid_argument(
1020 "unsupported data type for conversion to Python buffer");
1021 }
1022
1023 static void bindDerived(ClassTy &c) {
1024#if PY_VERSION_HEX < 0x03090000
1025 PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr());
1026 tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer;
1027 tp->tp_as_buffer->bf_releasebuffer =
1028 PyDenseElementsAttribute::bf_releasebuffer;
1029#endif
1030 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1031 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1032 nb::arg("array"), nb::arg("signless") = true,
1033 nb::arg("type").none() = nb::none(),
1034 nb::arg("shape").none() = nb::none(),
1035 nb::arg("context").none() = nb::none(),
1036 kDenseElementsAttrGetDocstring)
1037 .def_static("get", PyDenseElementsAttribute::getFromList,
1038 nb::arg("attrs"), nb::arg("type").none() = nb::none(),
1039 nb::arg("context").none() = nb::none(),
1040 kDenseElementsAttrGetFromListDocstring)
1041 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1042 nb::arg("shaped_type"), nb::arg("element_attr"),
1043 "Gets a DenseElementsAttr where all values are the same")
1044 .def_prop_ro("is_splat",
1045 [](PyDenseElementsAttribute &self) -> bool {
1046 return mlirDenseElementsAttrIsSplat(self);
1047 })
1048 .def("get_splat_value", [](PyDenseElementsAttribute &self) {
1049 if (!mlirDenseElementsAttrIsSplat(self))
1050 throw nb::value_error(
1051 "get_splat_value called on a non-splat attribute");
1052 return mlirDenseElementsAttrGetSplatValue(self);
1053 });
1054 }
1055
1056 static PyType_Slot slots[];
1057
1058private:
1059 static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
1060 static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
1061
1062 static bool isUnsignedIntegerFormat(std::string_view format) {
1063 if (format.empty())
1064 return false;
1065 char code = format[0];
1066 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1067 code == 'Q';
1068 }
1069
1070 static bool isSignedIntegerFormat(std::string_view format) {
1071 if (format.empty())
1072 return false;
1073 char code = format[0];
1074 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1075 code == 'q';
1076 }
1077
1078 static MlirType
1079 getShapedType(std::optional<MlirType> bulkLoadElementType,
1080 std::optional<std::vector<int64_t>> explicitShape,
1081 Py_buffer &view) {
1082 SmallVector<int64_t> shape;
1083 if (explicitShape) {
1084 shape.append(in_start: explicitShape->begin(), in_end: explicitShape->end());
1085 } else {
1086 shape.append(view.shape, view.shape + view.ndim);
1087 }
1088
1089 if (mlirTypeIsAShaped(*bulkLoadElementType)) {
1090 if (explicitShape) {
1091 throw std::invalid_argument("Shape can only be specified explicitly "
1092 "when the type is not a shaped type.");
1093 }
1094 return *bulkLoadElementType;
1095 } else {
1096 MlirAttribute encodingAttr = mlirAttributeGetNull();
1097 return mlirRankedTensorTypeGet(shape.size(), shape.data(),
1098 *bulkLoadElementType, encodingAttr);
1099 }
1100 }
1101
1102 static MlirAttribute getAttributeFromBuffer(
1103 Py_buffer &view, bool signless, std::optional<PyType> explicitType,
1104 std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
1105 // Detect format codes that are suitable for bulk loading. This includes
1106 // all byte aligned integer and floating point types up to 8 bytes.
1107 // Notably, this excludes exotics types which do not have a direct
1108 // representation in the buffer protocol (i.e. complex, etc).
1109 std::optional<MlirType> bulkLoadElementType;
1110 if (explicitType) {
1111 bulkLoadElementType = *explicitType;
1112 } else {
1113 std::string_view format(view.format);
1114 if (format == "f") {
1115 // f32
1116 assert(view.itemsize == 4 && "mismatched array itemsize");
1117 bulkLoadElementType = mlirF32TypeGet(context);
1118 } else if (format == "d") {
1119 // f64
1120 assert(view.itemsize == 8 && "mismatched array itemsize");
1121 bulkLoadElementType = mlirF64TypeGet(context);
1122 } else if (format == "e") {
1123 // f16
1124 assert(view.itemsize == 2 && "mismatched array itemsize");
1125 bulkLoadElementType = mlirF16TypeGet(context);
1126 } else if (format == "?") {
1127 // i1
1128 // The i1 type needs to be bit-packed, so we will handle it seperately
1129 return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
1130 context);
1131 } else if (isSignedIntegerFormat(format)) {
1132 if (view.itemsize == 4) {
1133 // i32
1134 bulkLoadElementType = signless
1135 ? mlirIntegerTypeGet(context, 32)
1136 : mlirIntegerTypeSignedGet(context, 32);
1137 } else if (view.itemsize == 8) {
1138 // i64
1139 bulkLoadElementType = signless
1140 ? mlirIntegerTypeGet(context, 64)
1141 : mlirIntegerTypeSignedGet(context, 64);
1142 } else if (view.itemsize == 1) {
1143 // i8
1144 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
1145 : mlirIntegerTypeSignedGet(context, 8);
1146 } else if (view.itemsize == 2) {
1147 // i16
1148 bulkLoadElementType = signless
1149 ? mlirIntegerTypeGet(context, 16)
1150 : mlirIntegerTypeSignedGet(context, 16);
1151 }
1152 } else if (isUnsignedIntegerFormat(format)) {
1153 if (view.itemsize == 4) {
1154 // unsigned i32
1155 bulkLoadElementType = signless
1156 ? mlirIntegerTypeGet(context, 32)
1157 : mlirIntegerTypeUnsignedGet(context, 32);
1158 } else if (view.itemsize == 8) {
1159 // unsigned i64
1160 bulkLoadElementType = signless
1161 ? mlirIntegerTypeGet(context, 64)
1162 : mlirIntegerTypeUnsignedGet(context, 64);
1163 } else if (view.itemsize == 1) {
1164 // i8
1165 bulkLoadElementType = signless
1166 ? mlirIntegerTypeGet(context, 8)
1167 : mlirIntegerTypeUnsignedGet(context, 8);
1168 } else if (view.itemsize == 2) {
1169 // i16
1170 bulkLoadElementType = signless
1171 ? mlirIntegerTypeGet(context, 16)
1172 : mlirIntegerTypeUnsignedGet(context, 16);
1173 }
1174 }
1175 if (!bulkLoadElementType) {
1176 throw std::invalid_argument(
1177 std::string("unimplemented array format conversion from format: ") +
1178 std::string(format));
1179 }
1180 }
1181
1182 MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
1183 return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
1184 }
1185
1186 // There is a complication for boolean numpy arrays, as numpy represents
1187 // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8
1188 // booleans per byte.
1189 static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
1190 Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
1191 MlirContext &context) {
1192 if (llvm::endianness::native != llvm::endianness::little) {
1193 // Given we have no good way of testing the behavior on big-endian
1194 // systems we will throw
1195 throw nb::type_error("Constructing a bit-packed MLIR attribute is "
1196 "unsupported on big-endian systems");
1197 }
1198 nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray(
1199 /*data=*/static_cast<uint8_t *>(view.buf),
1200 /*shape=*/{static_cast<size_t>(view.len)});
1201
1202 nb::module_ numpy = nb::module_::import_("numpy");
1203 nb::object packbitsFunc = numpy.attr("packbits");
1204 nb::object packedBooleans =
1205 packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little");
1206 nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request();
1207
1208 MlirType bitpackedType =
1209 getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1210 assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1211 // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1212 // packedBooleans, hence the MlirAttribute will remain valid even when
1213 // packedBooleans get reclaimed by the end of the function.
1214 return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
1215 pythonBuffer.ptr);
1216 }
1217
1218 // This does the opposite transformation of
1219 // `getBitpackedAttributeFromBooleanBuffer`
1220 std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() {
1221 if (llvm::endianness::native != llvm::endianness::little) {
1222 // Given we have no good way of testing the behavior on big-endian
1223 // systems we will throw
1224 throw nb::type_error("Constructing a numpy array from a MLIR attribute "
1225 "is unsupported on big-endian systems");
1226 }
1227
1228 int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
1229 int64_t numBitpackedBytes = llvm::divideCeil(Numerator: numBooleans, Denominator: 8);
1230 uint8_t *bitpackedData = static_cast<uint8_t *>(
1231 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1232 nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray(
1233 /*data=*/bitpackedData,
1234 /*shape=*/{static_cast<size_t>(numBitpackedBytes)});
1235
1236 nb::module_ numpy = nb::module_::import_("numpy");
1237 nb::object unpackbitsFunc = numpy.attr("unpackbits");
1238 nb::object equalFunc = numpy.attr("equal");
1239 nb::object reshapeFunc = numpy.attr("reshape");
1240 nb::object unpackedBooleans =
1241 unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little");
1242
1243 // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1244 // We need to:
1245 // 1. Slice away the padded bits
1246 // 2. Make the boolean array have the correct shape
1247 // 3. Convert the array to a boolean array
1248 unpackedBooleans = unpackedBooleans[nb::slice(
1249 nb::int_(0), nb::int_(numBooleans), nb::int_(1))];
1250 unpackedBooleans = equalFunc(unpackedBooleans, 1);
1251
1252 MlirType shapedType = mlirAttributeGetType(*this);
1253 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1254 std::vector<intptr_t> shape(rank);
1255 for (intptr_t i = 0; i < rank; ++i) {
1256 shape[i] = mlirShapedTypeGetDimSize(shapedType, i);
1257 }
1258 unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
1259
1260 // Make sure the returned nb::buffer_view claims ownership of the data in
1261 // `pythonBuffer` so it remains valid when Python reads it
1262 nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans);
1263 return std::make_unique<nb_buffer_info>(args: pythonBuffer.request());
1264 }
1265
1266 template <typename Type>
1267 std::unique_ptr<nb_buffer_info>
1268 bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) {
1269 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1270 // Prepare the data for the buffer_info.
1271 // Buffer is configured for read-only access below.
1272 Type *data = static_cast<Type *>(
1273 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1274 // Prepare the shape for the buffer_info.
1275 SmallVector<intptr_t, 4> shape;
1276 for (intptr_t i = 0; i < rank; ++i)
1277 shape.push_back(Elt: mlirShapedTypeGetDimSize(shapedType, i));
1278 // Prepare the strides for the buffer_info.
1279 SmallVector<intptr_t, 4> strides;
1280 if (mlirDenseElementsAttrIsSplat(*this)) {
1281 // Splats are special, only the single value is stored.
1282 strides.assign(NumElts: rank, Elt: 0);
1283 } else {
1284 for (intptr_t i = 1; i < rank; ++i) {
1285 intptr_t strideFactor = 1;
1286 for (intptr_t j = i; j < rank; ++j)
1287 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1288 strides.push_back(Elt: sizeof(Type) * strideFactor);
1289 }
1290 strides.push_back(Elt: sizeof(Type));
1291 }
1292 const char *format;
1293 if (explicitFormat) {
1294 format = explicitFormat;
1295 } else {
1296 format = nb_format_descriptor<Type>::format();
1297 }
1298 return std::make_unique<nb_buffer_info>(
1299 data, sizeof(Type), format, rank, std::move(shape), std::move(strides),
1300 /*readonly=*/true);
1301 }
1302}; // namespace
1303
1304PyType_Slot PyDenseElementsAttribute::slots[] = {
1305// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec.
1306#if PY_VERSION_HEX >= 0x03090000
1307 {Py_bf_getbuffer,
1308 reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)},
1309 {Py_bf_releasebuffer,
1310 reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)},
1311#endif
1312 {0, nullptr},
1313};
1314
1315/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj,
1316 Py_buffer *view,
1317 int flags) {
1318 view->obj = nullptr;
1319 std::unique_ptr<nb_buffer_info> info;
1320 try {
1321 auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj));
1322 info = attr->accessBuffer();
1323 } catch (nb::python_error &e) {
1324 e.restore();
1325 nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer");
1326 return -1;
1327 }
1328 view->obj = obj;
1329 view->ndim = 1;
1330 view->buf = info->ptr;
1331 view->itemsize = info->itemsize;
1332 view->len = info->itemsize;
1333 for (auto s : info->shape) {
1334 view->len *= s;
1335 }
1336 view->readonly = info->readonly;
1337 if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) {
1338 view->format = const_cast<char *>(info->format);
1339 }
1340 if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) {
1341 view->ndim = static_cast<int>(info->ndim);
1342 view->strides = info->strides.data();
1343 view->shape = info->shape.data();
1344 }
1345 view->suboffsets = nullptr;
1346 view->internal = info.release();
1347 Py_INCREF(obj);
1348 return 0;
1349}
1350
1351/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *,
1352 Py_buffer *view) {
1353 delete reinterpret_cast<nb_buffer_info *>(view->internal);
1354}
1355
1356/// Refinement of the PyDenseElementsAttribute for attributes containing
1357/// integer (and boolean) values. Supports element access.
1358class PyDenseIntElementsAttribute
1359 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1360 PyDenseElementsAttribute> {
1361public:
1362 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1363 static constexpr const char *pyClassName = "DenseIntElementsAttr";
1364 using PyConcreteAttribute::PyConcreteAttribute;
1365
1366 /// Returns the element at the given linear position. Asserts if the index
1367 /// is out of range.
1368 nb::object dunderGetItem(intptr_t pos) {
1369 if (pos < 0 || pos >= dunderLen()) {
1370 throw nb::index_error("attempt to access out of bounds element");
1371 }
1372
1373 MlirType type = mlirAttributeGetType(*this);
1374 type = mlirShapedTypeGetElementType(type);
1375 // Index type can also appear as a DenseIntElementsAttr and therefore can be
1376 // casted to integer.
1377 assert(mlirTypeIsAInteger(type) ||
1378 mlirTypeIsAIndex(type) && "expected integer/index element type in "
1379 "dense int elements attribute");
1380 // Dispatch element extraction to an appropriate C function based on the
1381 // elemental type of the attribute. nb::int_ is implicitly constructible
1382 // from any C++ integral type and handles bitwidth correctly.
1383 // TODO: consider caching the type properties in the constructor to avoid
1384 // querying them on each element access.
1385 if (mlirTypeIsAIndex(type)) {
1386 return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos));
1387 }
1388 unsigned width = mlirIntegerTypeGetWidth(type);
1389 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1390 if (isUnsigned) {
1391 if (width == 1) {
1392 return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1393 }
1394 if (width == 8) {
1395 return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos));
1396 }
1397 if (width == 16) {
1398 return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos));
1399 }
1400 if (width == 32) {
1401 return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos));
1402 }
1403 if (width == 64) {
1404 return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos));
1405 }
1406 } else {
1407 if (width == 1) {
1408 return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos)));
1409 }
1410 if (width == 8) {
1411 return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos));
1412 }
1413 if (width == 16) {
1414 return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos));
1415 }
1416 if (width == 32) {
1417 return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos));
1418 }
1419 if (width == 64) {
1420 return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos));
1421 }
1422 }
1423 throw nb::type_error("Unsupported integer type");
1424 }
1425
1426 static void bindDerived(ClassTy &c) {
1427 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1428 }
1429};
1430
1431class PyDenseResourceElementsAttribute
1432 : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1433public:
1434 static constexpr IsAFunctionTy isaFunction =
1435 mlirAttributeIsADenseResourceElements;
1436 static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1437 using PyConcreteAttribute::PyConcreteAttribute;
1438
1439 static PyDenseResourceElementsAttribute
1440 getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type,
1441 std::optional<size_t> alignment, bool isMutable,
1442 DefaultingPyMlirContext contextWrapper) {
1443 if (!mlirTypeIsAShaped(type)) {
1444 throw std::invalid_argument(
1445 "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1446 }
1447
1448 // Do not request any conversions as we must ensure to use caller
1449 // managed memory.
1450 int flags = PyBUF_STRIDES;
1451 std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1452 if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1453 throw nb::python_error();
1454 }
1455
1456 // This scope releaser will only release if we haven't yet transferred
1457 // ownership.
1458 auto freeBuffer = llvm::make_scope_exit(F: [&]() {
1459 if (view)
1460 PyBuffer_Release(view.get());
1461 });
1462
1463 if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1464 throw std::invalid_argument("Contiguous buffer is required.");
1465 }
1466
1467 // Infer alignment to be the stride of one element if not explicit.
1468 size_t inferredAlignment;
1469 if (alignment)
1470 inferredAlignment = *alignment;
1471 else
1472 inferredAlignment = view->strides[view->ndim - 1];
1473
1474 // The userData is a Py_buffer* that the deleter owns.
1475 auto deleter = [](void *userData, const void *data, size_t size,
1476 size_t align) {
1477 if (!Py_IsInitialized())
1478 Py_Initialize();
1479 Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1480 nb::gil_scoped_acquire gil;
1481 PyBuffer_Release(ownedView);
1482 delete ownedView;
1483 };
1484
1485 size_t rawBufferSize = view->len;
1486 MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1487 type, toMlirStringRef(name), view->buf, rawBufferSize,
1488 inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1489 if (mlirAttributeIsNull(attr)) {
1490 throw std::invalid_argument(
1491 "DenseResourceElementsAttr could not be constructed from the given "
1492 "buffer. "
1493 "This may mean that the Python buffer layout does not match that "
1494 "MLIR expected layout and is a bug.");
1495 }
1496 view.release();
1497 return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1498 }
1499
1500 static void bindDerived(ClassTy &c) {
1501 c.def_static(
1502 "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer,
1503 nb::arg("array"), nb::arg("name"), nb::arg("type"),
1504 nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false,
1505 nb::arg("context").none() = nb::none(),
1506 kDenseResourceElementsAttrGetFromBufferDocstring);
1507 }
1508};
1509
1510class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1511public:
1512 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1513 static constexpr const char *pyClassName = "DictAttr";
1514 using PyConcreteAttribute::PyConcreteAttribute;
1515 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1516 mlirDictionaryAttrGetTypeID;
1517
1518 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1519
1520 bool dunderContains(const std::string &name) {
1521 return !mlirAttributeIsNull(
1522 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
1523 }
1524
1525 static void bindDerived(ClassTy &c) {
1526 c.def("__contains__", &PyDictAttribute::dunderContains);
1527 c.def("__len__", &PyDictAttribute::dunderLen);
1528 c.def_static(
1529 "get",
1530 [](nb::dict attributes, DefaultingPyMlirContext context) {
1531 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1532 mlirNamedAttributes.reserve(attributes.size());
1533 for (std::pair<nb::handle, nb::handle> it : attributes) {
1534 auto &mlirAttr = nb::cast<PyAttribute &>(it.second);
1535 auto name = nb::cast<std::string>(it.first);
1536 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1537 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1538 toMlirStringRef(name)),
1539 mlirAttr));
1540 }
1541 MlirAttribute attr =
1542 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1543 mlirNamedAttributes.data());
1544 return PyDictAttribute(context->getRef(), attr);
1545 },
1546 nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(),
1547 "Gets an uniqued dict attribute");
1548 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1549 MlirAttribute attr =
1550 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1551 if (mlirAttributeIsNull(attr))
1552 throw nb::key_error("attempt to access a non-existent attribute");
1553 return attr;
1554 });
1555 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1556 if (index < 0 || index >= self.dunderLen()) {
1557 throw nb::index_error("attempt to access out of bounds attribute");
1558 }
1559 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1560 return PyNamedAttribute(
1561 namedAttr.attribute,
1562 std::string(mlirIdentifierStr(namedAttr.name).data));
1563 });
1564 }
1565};
1566
1567/// Refinement of PyDenseElementsAttribute for attributes containing
1568/// floating-point values. Supports element access.
1569class PyDenseFPElementsAttribute
1570 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1571 PyDenseElementsAttribute> {
1572public:
1573 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1574 static constexpr const char *pyClassName = "DenseFPElementsAttr";
1575 using PyConcreteAttribute::PyConcreteAttribute;
1576
1577 nb::float_ dunderGetItem(intptr_t pos) {
1578 if (pos < 0 || pos >= dunderLen()) {
1579 throw nb::index_error("attempt to access out of bounds element");
1580 }
1581
1582 MlirType type = mlirAttributeGetType(*this);
1583 type = mlirShapedTypeGetElementType(type);
1584 // Dispatch element extraction to an appropriate C function based on the
1585 // elemental type of the attribute. nb::float_ is implicitly constructible
1586 // from float and double.
1587 // TODO: consider caching the type properties in the constructor to avoid
1588 // querying them on each element access.
1589 if (mlirTypeIsAF32(type)) {
1590 return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos));
1591 }
1592 if (mlirTypeIsAF64(type)) {
1593 return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos));
1594 }
1595 throw nb::type_error("Unsupported floating-point type");
1596 }
1597
1598 static void bindDerived(ClassTy &c) {
1599 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1600 }
1601};
1602
1603class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1604public:
1605 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1606 static constexpr const char *pyClassName = "TypeAttr";
1607 using PyConcreteAttribute::PyConcreteAttribute;
1608 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1609 mlirTypeAttrGetTypeID;
1610
1611 static void bindDerived(ClassTy &c) {
1612 c.def_static(
1613 "get",
1614 [](PyType value, DefaultingPyMlirContext context) {
1615 MlirAttribute attr = mlirTypeAttrGet(value.get());
1616 return PyTypeAttribute(context->getRef(), attr);
1617 },
1618 nb::arg("value"), nb::arg("context").none() = nb::none(),
1619 "Gets a uniqued Type attribute");
1620 c.def_prop_ro("value", [](PyTypeAttribute &self) {
1621 return mlirTypeAttrGetValue(self.get());
1622 });
1623 }
1624};
1625
1626/// Unit Attribute subclass. Unit attributes don't have values.
1627class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1628public:
1629 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1630 static constexpr const char *pyClassName = "UnitAttr";
1631 using PyConcreteAttribute::PyConcreteAttribute;
1632 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1633 mlirUnitAttrGetTypeID;
1634
1635 static void bindDerived(ClassTy &c) {
1636 c.def_static(
1637 "get",
1638 [](DefaultingPyMlirContext context) {
1639 return PyUnitAttribute(context->getRef(),
1640 mlirUnitAttrGet(context->get()));
1641 },
1642 nb::arg("context").none() = nb::none(), "Create a Unit attribute.");
1643 }
1644};
1645
1646/// Strided layout attribute subclass.
1647class PyStridedLayoutAttribute
1648 : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1649public:
1650 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1651 static constexpr const char *pyClassName = "StridedLayoutAttr";
1652 using PyConcreteAttribute::PyConcreteAttribute;
1653 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1654 mlirStridedLayoutAttrGetTypeID;
1655
1656 static void bindDerived(ClassTy &c) {
1657 c.def_static(
1658 "get",
1659 [](int64_t offset, const std::vector<int64_t> strides,
1660 DefaultingPyMlirContext ctx) {
1661 MlirAttribute attr = mlirStridedLayoutAttrGet(
1662 ctx->get(), offset, strides.size(), strides.data());
1663 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1664 },
1665 nb::arg("offset"), nb::arg("strides"),
1666 nb::arg("context").none() = nb::none(),
1667 "Gets a strided layout attribute.");
1668 c.def_static(
1669 "get_fully_dynamic",
1670 [](int64_t rank, DefaultingPyMlirContext ctx) {
1671 auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1672 std::vector<int64_t> strides(rank);
1673 std::fill(strides.begin(), strides.end(), dynamic);
1674 MlirAttribute attr = mlirStridedLayoutAttrGet(
1675 ctx->get(), dynamic, strides.size(), strides.data());
1676 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1677 },
1678 nb::arg("rank"), nb::arg("context").none() = nb::none(),
1679 "Gets a strided layout attribute with dynamic offset and strides of "
1680 "a "
1681 "given rank.");
1682 c.def_prop_ro(
1683 "offset",
1684 [](PyStridedLayoutAttribute &self) {
1685 return mlirStridedLayoutAttrGetOffset(self);
1686 },
1687 "Returns the value of the float point attribute");
1688 c.def_prop_ro(
1689 "strides",
1690 [](PyStridedLayoutAttribute &self) {
1691 intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1692 std::vector<int64_t> strides(size);
1693 for (intptr_t i = 0; i < size; i++) {
1694 strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1695 }
1696 return strides;
1697 },
1698 "Returns the value of the float point attribute");
1699 }
1700};
1701
1702nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1703 if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1704 return nb::cast(PyDenseBoolArrayAttribute(pyAttribute));
1705 if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1706 return nb::cast(PyDenseI8ArrayAttribute(pyAttribute));
1707 if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1708 return nb::cast(PyDenseI16ArrayAttribute(pyAttribute));
1709 if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1710 return nb::cast(PyDenseI32ArrayAttribute(pyAttribute));
1711 if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1712 return nb::cast(PyDenseI64ArrayAttribute(pyAttribute));
1713 if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1714 return nb::cast(PyDenseF32ArrayAttribute(pyAttribute));
1715 if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1716 return nb::cast(PyDenseF64ArrayAttribute(pyAttribute));
1717 std::string msg =
1718 std::string("Can't cast unknown element type DenseArrayAttr (") +
1719 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1720 throw nb::type_error(msg.c_str());
1721}
1722
1723nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1724 if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1725 return nb::cast(PyDenseFPElementsAttribute(pyAttribute));
1726 if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1727 return nb::cast(PyDenseIntElementsAttribute(pyAttribute));
1728 std::string msg =
1729 std::string(
1730 "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1731 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1732 throw nb::type_error(msg.c_str());
1733}
1734
1735nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1736 if (PyBoolAttribute::isaFunction(pyAttribute))
1737 return nb::cast(PyBoolAttribute(pyAttribute));
1738 if (PyIntegerAttribute::isaFunction(pyAttribute))
1739 return nb::cast(PyIntegerAttribute(pyAttribute));
1740 std::string msg =
1741 std::string("Can't cast unknown element type DenseArrayAttr (") +
1742 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")";
1743 throw nb::type_error(msg.c_str());
1744}
1745
1746nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1747 if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1748 return nb::cast(PyFlatSymbolRefAttribute(pyAttribute));
1749 if (PySymbolRefAttribute::isaFunction(pyAttribute))
1750 return nb::cast(PySymbolRefAttribute(pyAttribute));
1751 std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1752 nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) +
1753 ")";
1754 throw nb::type_error(msg.c_str());
1755}
1756
1757} // namespace
1758
1759void mlir::python::populateIRAttributes(nb::module_ &m) {
1760 PyAffineMapAttribute::bind(m);
1761 PyDenseBoolArrayAttribute::bind(m);
1762 PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1763 PyDenseI8ArrayAttribute::bind(m);
1764 PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1765 PyDenseI16ArrayAttribute::bind(m);
1766 PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1767 PyDenseI32ArrayAttribute::bind(m);
1768 PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1769 PyDenseI64ArrayAttribute::bind(m);
1770 PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1771 PyDenseF32ArrayAttribute::bind(m);
1772 PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1773 PyDenseF64ArrayAttribute::bind(m);
1774 PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1775 PyGlobals::get().registerTypeCaster(
1776 mlirDenseArrayAttrGetTypeID(),
1777 nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster)));
1778
1779 PyArrayAttribute::bind(m);
1780 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1781 PyBoolAttribute::bind(m);
1782 PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots);
1783 PyDenseFPElementsAttribute::bind(m);
1784 PyDenseIntElementsAttribute::bind(m);
1785 PyGlobals::get().registerTypeCaster(
1786 mlirDenseIntOrFPElementsAttrGetTypeID(),
1787 nb::cast<nb::callable>(
1788 nb::cpp_function(denseIntOrFPElementsAttributeCaster)));
1789 PyDenseResourceElementsAttribute::bind(m);
1790
1791 PyDictAttribute::bind(m);
1792 PySymbolRefAttribute::bind(m);
1793 PyGlobals::get().registerTypeCaster(
1794 mlirSymbolRefAttrGetTypeID(),
1795 nb::cast<nb::callable>(
1796 nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)));
1797
1798 PyFlatSymbolRefAttribute::bind(m);
1799 PyOpaqueAttribute::bind(m);
1800 PyFloatAttribute::bind(m);
1801 PyIntegerAttribute::bind(m);
1802 PyIntegerSetAttribute::bind(m);
1803 PyStringAttribute::bind(m);
1804 PyTypeAttribute::bind(m);
1805 PyGlobals::get().registerTypeCaster(
1806 mlirIntegerAttrGetTypeID(),
1807 nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster)));
1808 PyUnitAttribute::bind(m);
1809
1810 PyStridedLayoutAttribute::bind(m);
1811}
1812

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Bindings/Python/IRAttributes.cpp