1//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <optional>
10#include <string_view>
11#include <utility>
12
13#include "IRModule.h"
14
15#include "PybindUtils.h"
16
17#include "llvm/ADT/ScopeExit.h"
18
19#include "mlir-c/BuiltinAttributes.h"
20#include "mlir-c/BuiltinTypes.h"
21#include "mlir/Bindings/Python/PybindAdaptors.h"
22
23namespace py = pybind11;
24using namespace mlir;
25using namespace mlir::python;
26
27using llvm::SmallVector;
28
29//------------------------------------------------------------------------------
30// Docstrings (trivial, non-duplicated docstrings are included inline).
31//------------------------------------------------------------------------------
32
33static const char kDenseElementsAttrGetDocstring[] =
34 R"(Gets a DenseElementsAttr from a Python buffer or array.
35
36When `type` is not provided, then some limited type inferencing is done based
37on the buffer format. Support presently exists for 8/16/32/64 signed and
38unsigned integers and float16/float32/float64. DenseElementsAttrs of these
39types can also be converted back to a corresponding buffer.
40
41For conversions outside of these types, a `type=` must be explicitly provided
42and the buffer contents must be bit-castable to the MLIR internal
43representation:
44
45 * Integer types (except for i1): the buffer must be byte aligned to the
46 next byte boundary.
47 * Floating point types: Must be bit-castable to the given floating point
48 size.
49 * i1 (bool): Bit packed into 8bit words where the bit pattern matches a
50 row major ordering. An arbitrary Numpy `bool_` array can be bit packed to
51 this specification with: `np.packbits(ary, axis=None, bitorder='little')`.
52
53If a single element buffer is passed (or for i1, a single byte with value 0
54or 255), then a splat will be created.
55
56Args:
57 array: The array or buffer to convert.
58 signless: If inferring an appropriate MLIR type, use signless types for
59 integers (defaults True).
60 type: Skips inference of the MLIR element type and uses this instead. The
61 storage size must be consistent with the actual contents of the buffer.
62 shape: Overrides the shape of the buffer when constructing the MLIR
63 shaped type. This is needed when the physical and logical shape differ (as
64 for i1).
65 context: Explicit context, if not from context manager.
66
67Returns:
68 DenseElementsAttr on success.
69
70Raises:
71 ValueError: If the type of the buffer or array cannot be matched to an MLIR
72 type or if the buffer does not meet expectations.
73)";
74
75static const char kDenseResourceElementsAttrGetFromBufferDocstring[] =
76 R"(Gets a DenseResourceElementsAttr from a Python buffer or array.
77
78This function does minimal validation or massaging of the data, and it is
79up to the caller to ensure that the buffer meets the characteristics
80implied by the shape.
81
82The backing buffer and any user objects will be retained for the lifetime
83of the resource blob. This is typically bounded to the context but the
84resource can have a shorter lifespan depending on how it is used in
85subsequent processing.
86
87Args:
88 buffer: The array or buffer to convert.
89 name: Name to provide to the resource (may be changed upon collision).
90 type: The explicit ShapedType to construct the attribute with.
91 context: Explicit context, if not from context manager.
92
93Returns:
94 DenseResourceElementsAttr on success.
95
96Raises:
97 ValueError: If the type of the buffer or array cannot be matched to an MLIR
98 type or if the buffer does not meet expectations.
99)";
100
101namespace {
102
103static MlirStringRef toMlirStringRef(const std::string &s) {
104 return mlirStringRefCreate(s.data(), s.size());
105}
106
107class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
108public:
109 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap;
110 static constexpr const char *pyClassName = "AffineMapAttr";
111 using PyConcreteAttribute::PyConcreteAttribute;
112 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
113 mlirAffineMapAttrGetTypeID;
114
115 static void bindDerived(ClassTy &c) {
116 c.def_static(
117 "get",
118 [](PyAffineMap &affineMap) {
119 MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get());
120 return PyAffineMapAttribute(affineMap.getContext(), attr);
121 },
122 py::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
123 }
124};
125
126template <typename T>
127static T pyTryCast(py::handle object) {
128 try {
129 return object.cast<T>();
130 } catch (py::cast_error &err) {
131 std::string msg =
132 std::string(
133 "Invalid attribute when attempting to create an ArrayAttribute (") +
134 err.what() + ")";
135 throw py::cast_error(msg);
136 } catch (py::reference_cast_error &err) {
137 std::string msg = std::string("Invalid attribute (None?) when attempting "
138 "to create an ArrayAttribute (") +
139 err.what() + ")";
140 throw py::cast_error(msg);
141 }
142}
143
144/// A python-wrapped dense array attribute with an element type and a derived
145/// implementation class.
146template <typename EltTy, typename DerivedT>
147class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> {
148public:
149 using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
150
151 /// Iterator over the integer elements of a dense array.
152 class PyDenseArrayIterator {
153 public:
154 PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {}
155
156 /// Return a copy of the iterator.
157 PyDenseArrayIterator dunderIter() { return *this; }
158
159 /// Return the next element.
160 EltTy dunderNext() {
161 // Throw if the index has reached the end.
162 if (nextIndex >= mlirDenseArrayGetNumElements(attr.get()))
163 throw py::stop_iteration();
164 return DerivedT::getElement(attr.get(), nextIndex++);
165 }
166
167 /// Bind the iterator class.
168 static void bind(py::module &m) {
169 py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName,
170 py::module_local())
171 .def("__iter__", &PyDenseArrayIterator::dunderIter)
172 .def("__next__", &PyDenseArrayIterator::dunderNext);
173 }
174
175 private:
176 /// The referenced dense array attribute.
177 PyAttribute attr;
178 /// The next index to read.
179 int nextIndex = 0;
180 };
181
182 /// Get the element at the given index.
183 EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); }
184
185 /// Bind the attribute class.
186 static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
187 // Bind the constructor.
188 c.def_static(
189 "get",
190 [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
191 return getAttribute(values, ctx->getRef());
192 },
193 py::arg("values"), py::arg("context") = py::none(),
194 "Gets a uniqued dense array attribute");
195 // Bind the array methods.
196 c.def("__getitem__", [](DerivedT &arr, intptr_t i) {
197 if (i >= mlirDenseArrayGetNumElements(arr))
198 throw py::index_error("DenseArray index out of range");
199 return arr.getItem(i);
200 });
201 c.def("__len__", [](const DerivedT &arr) {
202 return mlirDenseArrayGetNumElements(arr);
203 });
204 c.def("__iter__",
205 [](const DerivedT &arr) { return PyDenseArrayIterator(arr); });
206 c.def("__add__", [](DerivedT &arr, const py::list &extras) {
207 std::vector<EltTy> values;
208 intptr_t numOldElements = mlirDenseArrayGetNumElements(arr);
209 values.reserve(numOldElements + py::len(extras));
210 for (intptr_t i = 0; i < numOldElements; ++i)
211 values.push_back(arr.getItem(i));
212 for (py::handle attr : extras)
213 values.push_back(pyTryCast<EltTy>(attr));
214 return getAttribute(values, ctx: arr.getContext());
215 });
216 }
217
218private:
219 static DerivedT getAttribute(const std::vector<EltTy> &values,
220 PyMlirContextRef ctx) {
221 if constexpr (std::is_same_v<EltTy, bool>) {
222 std::vector<int> intValues(values.begin(), values.end());
223 MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(),
224 intValues.data());
225 return DerivedT(ctx, attr);
226 } else {
227 MlirAttribute attr =
228 DerivedT::getAttribute(ctx->get(), values.size(), values.data());
229 return DerivedT(ctx, attr);
230 }
231 }
232};
233
234/// Instantiate the python dense array classes.
235struct PyDenseBoolArrayAttribute
236 : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> {
237 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray;
238 static constexpr auto getAttribute = mlirDenseBoolArrayGet;
239 static constexpr auto getElement = mlirDenseBoolArrayGetElement;
240 static constexpr const char *pyClassName = "DenseBoolArrayAttr";
241 static constexpr const char *pyIteratorName = "DenseBoolArrayIterator";
242 using PyDenseArrayAttribute::PyDenseArrayAttribute;
243};
244struct PyDenseI8ArrayAttribute
245 : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> {
246 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array;
247 static constexpr auto getAttribute = mlirDenseI8ArrayGet;
248 static constexpr auto getElement = mlirDenseI8ArrayGetElement;
249 static constexpr const char *pyClassName = "DenseI8ArrayAttr";
250 static constexpr const char *pyIteratorName = "DenseI8ArrayIterator";
251 using PyDenseArrayAttribute::PyDenseArrayAttribute;
252};
253struct PyDenseI16ArrayAttribute
254 : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> {
255 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array;
256 static constexpr auto getAttribute = mlirDenseI16ArrayGet;
257 static constexpr auto getElement = mlirDenseI16ArrayGetElement;
258 static constexpr const char *pyClassName = "DenseI16ArrayAttr";
259 static constexpr const char *pyIteratorName = "DenseI16ArrayIterator";
260 using PyDenseArrayAttribute::PyDenseArrayAttribute;
261};
262struct PyDenseI32ArrayAttribute
263 : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> {
264 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array;
265 static constexpr auto getAttribute = mlirDenseI32ArrayGet;
266 static constexpr auto getElement = mlirDenseI32ArrayGetElement;
267 static constexpr const char *pyClassName = "DenseI32ArrayAttr";
268 static constexpr const char *pyIteratorName = "DenseI32ArrayIterator";
269 using PyDenseArrayAttribute::PyDenseArrayAttribute;
270};
271struct PyDenseI64ArrayAttribute
272 : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> {
273 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array;
274 static constexpr auto getAttribute = mlirDenseI64ArrayGet;
275 static constexpr auto getElement = mlirDenseI64ArrayGetElement;
276 static constexpr const char *pyClassName = "DenseI64ArrayAttr";
277 static constexpr const char *pyIteratorName = "DenseI64ArrayIterator";
278 using PyDenseArrayAttribute::PyDenseArrayAttribute;
279};
280struct PyDenseF32ArrayAttribute
281 : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> {
282 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array;
283 static constexpr auto getAttribute = mlirDenseF32ArrayGet;
284 static constexpr auto getElement = mlirDenseF32ArrayGetElement;
285 static constexpr const char *pyClassName = "DenseF32ArrayAttr";
286 static constexpr const char *pyIteratorName = "DenseF32ArrayIterator";
287 using PyDenseArrayAttribute::PyDenseArrayAttribute;
288};
289struct PyDenseF64ArrayAttribute
290 : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> {
291 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array;
292 static constexpr auto getAttribute = mlirDenseF64ArrayGet;
293 static constexpr auto getElement = mlirDenseF64ArrayGetElement;
294 static constexpr const char *pyClassName = "DenseF64ArrayAttr";
295 static constexpr const char *pyIteratorName = "DenseF64ArrayIterator";
296 using PyDenseArrayAttribute::PyDenseArrayAttribute;
297};
298
299class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
300public:
301 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray;
302 static constexpr const char *pyClassName = "ArrayAttr";
303 using PyConcreteAttribute::PyConcreteAttribute;
304 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
305 mlirArrayAttrGetTypeID;
306
307 class PyArrayAttributeIterator {
308 public:
309 PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {}
310
311 PyArrayAttributeIterator &dunderIter() { return *this; }
312
313 MlirAttribute dunderNext() {
314 // TODO: Throw is an inefficient way to stop iteration.
315 if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
316 throw py::stop_iteration();
317 return mlirArrayAttrGetElement(attr.get(), nextIndex++);
318 }
319
320 static void bind(py::module &m) {
321 py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator",
322 py::module_local())
323 .def("__iter__", &PyArrayAttributeIterator::dunderIter)
324 .def("__next__", &PyArrayAttributeIterator::dunderNext);
325 }
326
327 private:
328 PyAttribute attr;
329 int nextIndex = 0;
330 };
331
332 MlirAttribute getItem(intptr_t i) {
333 return mlirArrayAttrGetElement(*this, i);
334 }
335
336 static void bindDerived(ClassTy &c) {
337 c.def_static(
338 "get",
339 [](py::list attributes, DefaultingPyMlirContext context) {
340 SmallVector<MlirAttribute> mlirAttributes;
341 mlirAttributes.reserve(py::len(attributes));
342 for (auto attribute : attributes) {
343 mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute));
344 }
345 MlirAttribute attr = mlirArrayAttrGet(
346 context->get(), mlirAttributes.size(), mlirAttributes.data());
347 return PyArrayAttribute(context->getRef(), attr);
348 },
349 py::arg("attributes"), py::arg("context") = py::none(),
350 "Gets a uniqued Array attribute");
351 c.def("__getitem__",
352 [](PyArrayAttribute &arr, intptr_t i) {
353 if (i >= mlirArrayAttrGetNumElements(arr))
354 throw py::index_error("ArrayAttribute index out of range");
355 return arr.getItem(i);
356 })
357 .def("__len__",
358 [](const PyArrayAttribute &arr) {
359 return mlirArrayAttrGetNumElements(arr);
360 })
361 .def("__iter__", [](const PyArrayAttribute &arr) {
362 return PyArrayAttributeIterator(arr);
363 });
364 c.def("__add__", [](PyArrayAttribute arr, py::list extras) {
365 std::vector<MlirAttribute> attributes;
366 intptr_t numOldElements = mlirArrayAttrGetNumElements(arr);
367 attributes.reserve(numOldElements + py::len(extras));
368 for (intptr_t i = 0; i < numOldElements; ++i)
369 attributes.push_back(arr.getItem(i));
370 for (py::handle attr : extras)
371 attributes.push_back(pyTryCast<PyAttribute>(attr));
372 MlirAttribute arrayAttr = mlirArrayAttrGet(
373 arr.getContext()->get(), attributes.size(), attributes.data());
374 return PyArrayAttribute(arr.getContext(), arrayAttr);
375 });
376 }
377};
378
379/// Float Point Attribute subclass - FloatAttr.
380class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
381public:
382 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
383 static constexpr const char *pyClassName = "FloatAttr";
384 using PyConcreteAttribute::PyConcreteAttribute;
385 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
386 mlirFloatAttrGetTypeID;
387
388 static void bindDerived(ClassTy &c) {
389 c.def_static(
390 "get",
391 [](PyType &type, double value, DefaultingPyLocation loc) {
392 PyMlirContext::ErrorCapture errors(loc->getContext());
393 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
394 if (mlirAttributeIsNull(attr))
395 throw MLIRError("Invalid attribute", errors.take());
396 return PyFloatAttribute(type.getContext(), attr);
397 },
398 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
399 "Gets an uniqued float point attribute associated to a type");
400 c.def_static(
401 "get_f32",
402 [](double value, DefaultingPyMlirContext context) {
403 MlirAttribute attr = mlirFloatAttrDoubleGet(
404 context->get(), mlirF32TypeGet(context->get()), value);
405 return PyFloatAttribute(context->getRef(), attr);
406 },
407 py::arg("value"), py::arg("context") = py::none(),
408 "Gets an uniqued float point attribute associated to a f32 type");
409 c.def_static(
410 "get_f64",
411 [](double value, DefaultingPyMlirContext context) {
412 MlirAttribute attr = mlirFloatAttrDoubleGet(
413 context->get(), mlirF64TypeGet(context->get()), value);
414 return PyFloatAttribute(context->getRef(), attr);
415 },
416 py::arg("value"), py::arg("context") = py::none(),
417 "Gets an uniqued float point attribute associated to a f64 type");
418 c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
419 "Returns the value of the float attribute");
420 c.def("__float__", mlirFloatAttrGetValueDouble,
421 "Converts the value of the float attribute to a Python float");
422 }
423};
424
425/// Integer Attribute subclass - IntegerAttr.
426class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
427public:
428 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
429 static constexpr const char *pyClassName = "IntegerAttr";
430 using PyConcreteAttribute::PyConcreteAttribute;
431
432 static void bindDerived(ClassTy &c) {
433 c.def_static(
434 "get",
435 [](PyType &type, int64_t value) {
436 MlirAttribute attr = mlirIntegerAttrGet(type, value);
437 return PyIntegerAttribute(type.getContext(), attr);
438 },
439 py::arg("type"), py::arg("value"),
440 "Gets an uniqued integer attribute associated to a type");
441 c.def_property_readonly("value", toPyInt,
442 "Returns the value of the integer attribute");
443 c.def("__int__", toPyInt,
444 "Converts the value of the integer attribute to a Python int");
445 c.def_property_readonly_static("static_typeid",
446 [](py::object & /*class*/) -> MlirTypeID {
447 return mlirIntegerAttrGetTypeID();
448 });
449 }
450
451private:
452 static py::int_ toPyInt(PyIntegerAttribute &self) {
453 MlirType type = mlirAttributeGetType(self);
454 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
455 return mlirIntegerAttrGetValueInt(self);
456 if (mlirIntegerTypeIsSigned(type))
457 return mlirIntegerAttrGetValueSInt(self);
458 return mlirIntegerAttrGetValueUInt(self);
459 }
460};
461
462/// Bool Attribute subclass - BoolAttr.
463class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
464public:
465 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
466 static constexpr const char *pyClassName = "BoolAttr";
467 using PyConcreteAttribute::PyConcreteAttribute;
468
469 static void bindDerived(ClassTy &c) {
470 c.def_static(
471 "get",
472 [](bool value, DefaultingPyMlirContext context) {
473 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
474 return PyBoolAttribute(context->getRef(), attr);
475 },
476 py::arg("value"), py::arg("context") = py::none(),
477 "Gets an uniqued bool attribute");
478 c.def_property_readonly("value", mlirBoolAttrGetValue,
479 "Returns the value of the bool attribute");
480 c.def("__bool__", mlirBoolAttrGetValue,
481 "Converts the value of the bool attribute to a Python bool");
482 }
483};
484
485class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
486public:
487 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef;
488 static constexpr const char *pyClassName = "SymbolRefAttr";
489 using PyConcreteAttribute::PyConcreteAttribute;
490
491 static MlirAttribute fromList(const std::vector<std::string> &symbols,
492 PyMlirContext &context) {
493 if (symbols.empty())
494 throw std::runtime_error("SymbolRefAttr must be composed of at least "
495 "one symbol.");
496 MlirStringRef rootSymbol = toMlirStringRef(symbols[0]);
497 SmallVector<MlirAttribute, 3> referenceAttrs;
498 for (size_t i = 1; i < symbols.size(); ++i) {
499 referenceAttrs.push_back(
500 mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
501 }
502 return mlirSymbolRefAttrGet(context.get(), rootSymbol,
503 referenceAttrs.size(), referenceAttrs.data());
504 }
505
506 static void bindDerived(ClassTy &c) {
507 c.def_static(
508 "get",
509 [](const std::vector<std::string> &symbols,
510 DefaultingPyMlirContext context) {
511 return PySymbolRefAttribute::fromList(symbols, context.resolve());
512 },
513 py::arg("symbols"), py::arg("context") = py::none(),
514 "Gets a uniqued SymbolRef attribute from a list of symbol names");
515 c.def_property_readonly(
516 "value",
517 [](PySymbolRefAttribute &self) {
518 std::vector<std::string> symbols = {
519 unwrap(mlirSymbolRefAttrGetRootReference(self)).str()};
520 for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self);
521 ++i)
522 symbols.push_back(
523 unwrap(mlirSymbolRefAttrGetRootReference(
524 mlirSymbolRefAttrGetNestedReference(self, i)))
525 .str());
526 return symbols;
527 },
528 "Returns the value of the SymbolRef attribute as a list[str]");
529 }
530};
531
532class PyFlatSymbolRefAttribute
533 : public PyConcreteAttribute<PyFlatSymbolRefAttribute> {
534public:
535 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef;
536 static constexpr const char *pyClassName = "FlatSymbolRefAttr";
537 using PyConcreteAttribute::PyConcreteAttribute;
538
539 static void bindDerived(ClassTy &c) {
540 c.def_static(
541 "get",
542 [](std::string value, DefaultingPyMlirContext context) {
543 MlirAttribute attr =
544 mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value));
545 return PyFlatSymbolRefAttribute(context->getRef(), attr);
546 },
547 py::arg("value"), py::arg("context") = py::none(),
548 "Gets a uniqued FlatSymbolRef attribute");
549 c.def_property_readonly(
550 "value",
551 [](PyFlatSymbolRefAttribute &self) {
552 MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self);
553 return py::str(stringRef.data, stringRef.length);
554 },
555 "Returns the value of the FlatSymbolRef attribute as a string");
556 }
557};
558
559class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
560public:
561 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque;
562 static constexpr const char *pyClassName = "OpaqueAttr";
563 using PyConcreteAttribute::PyConcreteAttribute;
564 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
565 mlirOpaqueAttrGetTypeID;
566
567 static void bindDerived(ClassTy &c) {
568 c.def_static(
569 "get",
570 [](std::string dialectNamespace, py::buffer buffer, PyType &type,
571 DefaultingPyMlirContext context) {
572 const py::buffer_info bufferInfo = buffer.request();
573 intptr_t bufferSize = bufferInfo.size;
574 MlirAttribute attr = mlirOpaqueAttrGet(
575 context->get(), toMlirStringRef(dialectNamespace), bufferSize,
576 static_cast<char *>(bufferInfo.ptr), type);
577 return PyOpaqueAttribute(context->getRef(), attr);
578 },
579 py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"),
580 py::arg("context") = py::none(), "Gets an Opaque attribute.");
581 c.def_property_readonly(
582 "dialect_namespace",
583 [](PyOpaqueAttribute &self) {
584 MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self);
585 return py::str(stringRef.data, stringRef.length);
586 },
587 "Returns the dialect namespace for the Opaque attribute as a string");
588 c.def_property_readonly(
589 "data",
590 [](PyOpaqueAttribute &self) {
591 MlirStringRef stringRef = mlirOpaqueAttrGetData(self);
592 return py::bytes(stringRef.data, stringRef.length);
593 },
594 "Returns the data for the Opaqued attributes as `bytes`");
595 }
596};
597
598class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
599public:
600 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
601 static constexpr const char *pyClassName = "StringAttr";
602 using PyConcreteAttribute::PyConcreteAttribute;
603 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
604 mlirStringAttrGetTypeID;
605
606 static void bindDerived(ClassTy &c) {
607 c.def_static(
608 "get",
609 [](std::string value, DefaultingPyMlirContext context) {
610 MlirAttribute attr =
611 mlirStringAttrGet(context->get(), toMlirStringRef(value));
612 return PyStringAttribute(context->getRef(), attr);
613 },
614 py::arg("value"), py::arg("context") = py::none(),
615 "Gets a uniqued string attribute");
616 c.def_static(
617 "get_typed",
618 [](PyType &type, std::string value) {
619 MlirAttribute attr =
620 mlirStringAttrTypedGet(type, toMlirStringRef(value));
621 return PyStringAttribute(type.getContext(), attr);
622 },
623 py::arg("type"), py::arg("value"),
624 "Gets a uniqued string attribute associated to a type");
625 c.def_property_readonly(
626 "value",
627 [](PyStringAttribute &self) {
628 MlirStringRef stringRef = mlirStringAttrGetValue(self);
629 return py::str(stringRef.data, stringRef.length);
630 },
631 "Returns the value of the string attribute");
632 c.def_property_readonly(
633 "value_bytes",
634 [](PyStringAttribute &self) {
635 MlirStringRef stringRef = mlirStringAttrGetValue(self);
636 return py::bytes(stringRef.data, stringRef.length);
637 },
638 "Returns the value of the string attribute as `bytes`");
639 }
640};
641
642// TODO: Support construction of string elements.
643class PyDenseElementsAttribute
644 : public PyConcreteAttribute<PyDenseElementsAttribute> {
645public:
646 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
647 static constexpr const char *pyClassName = "DenseElementsAttr";
648 using PyConcreteAttribute::PyConcreteAttribute;
649
650 static PyDenseElementsAttribute
651 getFromBuffer(py::buffer array, bool signless,
652 std::optional<PyType> explicitType,
653 std::optional<std::vector<int64_t>> explicitShape,
654 DefaultingPyMlirContext contextWrapper) {
655 // Request a contiguous view. In exotic cases, this will cause a copy.
656 int flags = PyBUF_ND;
657 if (!explicitType) {
658 flags |= PyBUF_FORMAT;
659 }
660 Py_buffer view;
661 if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) {
662 throw py::error_already_set();
663 }
664 auto freeBuffer = llvm::make_scope_exit(F: [&]() { PyBuffer_Release(&view); });
665 SmallVector<int64_t> shape;
666 if (explicitShape) {
667 shape.append(in_start: explicitShape->begin(), in_end: explicitShape->end());
668 } else {
669 shape.append(view.shape, view.shape + view.ndim);
670 }
671
672 MlirAttribute encodingAttr = mlirAttributeGetNull();
673 MlirContext context = contextWrapper->get();
674
675 // Detect format codes that are suitable for bulk loading. This includes
676 // all byte aligned integer and floating point types up to 8 bytes.
677 // Notably, this excludes, bool (which needs to be bit-packed) and
678 // other exotics which do not have a direct representation in the buffer
679 // protocol (i.e. complex, etc).
680 std::optional<MlirType> bulkLoadElementType;
681 if (explicitType) {
682 bulkLoadElementType = *explicitType;
683 } else {
684 std::string_view format(view.format);
685 if (format == "f") {
686 // f32
687 assert(view.itemsize == 4 && "mismatched array itemsize");
688 bulkLoadElementType = mlirF32TypeGet(context);
689 } else if (format == "d") {
690 // f64
691 assert(view.itemsize == 8 && "mismatched array itemsize");
692 bulkLoadElementType = mlirF64TypeGet(context);
693 } else if (format == "e") {
694 // f16
695 assert(view.itemsize == 2 && "mismatched array itemsize");
696 bulkLoadElementType = mlirF16TypeGet(context);
697 } else if (isSignedIntegerFormat(format)) {
698 if (view.itemsize == 4) {
699 // i32
700 bulkLoadElementType = signless
701 ? mlirIntegerTypeGet(context, 32)
702 : mlirIntegerTypeSignedGet(context, 32);
703 } else if (view.itemsize == 8) {
704 // i64
705 bulkLoadElementType = signless
706 ? mlirIntegerTypeGet(context, 64)
707 : mlirIntegerTypeSignedGet(context, 64);
708 } else if (view.itemsize == 1) {
709 // i8
710 bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
711 : mlirIntegerTypeSignedGet(context, 8);
712 } else if (view.itemsize == 2) {
713 // i16
714 bulkLoadElementType = signless
715 ? mlirIntegerTypeGet(context, 16)
716 : mlirIntegerTypeSignedGet(context, 16);
717 }
718 } else if (isUnsignedIntegerFormat(format)) {
719 if (view.itemsize == 4) {
720 // unsigned i32
721 bulkLoadElementType = signless
722 ? mlirIntegerTypeGet(context, 32)
723 : mlirIntegerTypeUnsignedGet(context, 32);
724 } else if (view.itemsize == 8) {
725 // unsigned i64
726 bulkLoadElementType = signless
727 ? mlirIntegerTypeGet(context, 64)
728 : mlirIntegerTypeUnsignedGet(context, 64);
729 } else if (view.itemsize == 1) {
730 // i8
731 bulkLoadElementType = signless
732 ? mlirIntegerTypeGet(context, 8)
733 : mlirIntegerTypeUnsignedGet(context, 8);
734 } else if (view.itemsize == 2) {
735 // i16
736 bulkLoadElementType = signless
737 ? mlirIntegerTypeGet(context, 16)
738 : mlirIntegerTypeUnsignedGet(context, 16);
739 }
740 }
741 if (!bulkLoadElementType) {
742 throw std::invalid_argument(
743 std::string("unimplemented array format conversion from format: ") +
744 std::string(format));
745 }
746 }
747
748 MlirType shapedType;
749 if (mlirTypeIsAShaped(*bulkLoadElementType)) {
750 if (explicitShape) {
751 throw std::invalid_argument("Shape can only be specified explicitly "
752 "when the type is not a shaped type.");
753 }
754 shapedType = *bulkLoadElementType;
755 } else {
756 shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
757 *bulkLoadElementType, encodingAttr);
758 }
759 size_t rawBufferSize = view.len;
760 MlirAttribute attr =
761 mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
762 if (mlirAttributeIsNull(attr)) {
763 throw std::invalid_argument(
764 "DenseElementsAttr could not be constructed from the given buffer. "
765 "This may mean that the Python buffer layout does not match that "
766 "MLIR expected layout and is a bug.");
767 }
768 return PyDenseElementsAttribute(contextWrapper->getRef(), attr);
769 }
770
771 static PyDenseElementsAttribute getSplat(const PyType &shapedType,
772 PyAttribute &elementAttr) {
773 auto contextWrapper =
774 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
775 if (!mlirAttributeIsAInteger(elementAttr) &&
776 !mlirAttributeIsAFloat(elementAttr)) {
777 std::string message = "Illegal element type for DenseElementsAttr: ";
778 message.append(py::repr(py::cast(elementAttr)));
779 throw py::value_error(message);
780 }
781 if (!mlirTypeIsAShaped(shapedType) ||
782 !mlirShapedTypeHasStaticShape(shapedType)) {
783 std::string message =
784 "Expected a static ShapedType for the shaped_type parameter: ";
785 message.append(py::repr(py::cast(shapedType)));
786 throw py::value_error(message);
787 }
788 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
789 MlirType attrType = mlirAttributeGetType(elementAttr);
790 if (!mlirTypeEqual(shapedElementType, attrType)) {
791 std::string message =
792 "Shaped element type and attribute type must be equal: shaped=";
793 message.append(py::repr(py::cast(shapedType)));
794 message.append(s: ", element=");
795 message.append(py::repr(py::cast(elementAttr)));
796 throw py::value_error(message);
797 }
798
799 MlirAttribute elements =
800 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
801 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
802 }
803
804 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
805
806 py::buffer_info accessBuffer() {
807 MlirType shapedType = mlirAttributeGetType(*this);
808 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
809 std::string format;
810
811 if (mlirTypeIsAF32(elementType)) {
812 // f32
813 return bufferInfo<float>(shapedType);
814 }
815 if (mlirTypeIsAF64(elementType)) {
816 // f64
817 return bufferInfo<double>(shapedType);
818 }
819 if (mlirTypeIsAF16(elementType)) {
820 // f16
821 return bufferInfo<uint16_t>(shapedType, "e");
822 }
823 if (mlirTypeIsAIndex(elementType)) {
824 // Same as IndexType::kInternalStorageBitWidth
825 return bufferInfo<int64_t>(shapedType);
826 }
827 if (mlirTypeIsAInteger(elementType) &&
828 mlirIntegerTypeGetWidth(elementType) == 32) {
829 if (mlirIntegerTypeIsSignless(elementType) ||
830 mlirIntegerTypeIsSigned(elementType)) {
831 // i32
832 return bufferInfo<int32_t>(shapedType);
833 }
834 if (mlirIntegerTypeIsUnsigned(elementType)) {
835 // unsigned i32
836 return bufferInfo<uint32_t>(shapedType);
837 }
838 } else if (mlirTypeIsAInteger(elementType) &&
839 mlirIntegerTypeGetWidth(elementType) == 64) {
840 if (mlirIntegerTypeIsSignless(elementType) ||
841 mlirIntegerTypeIsSigned(elementType)) {
842 // i64
843 return bufferInfo<int64_t>(shapedType);
844 }
845 if (mlirIntegerTypeIsUnsigned(elementType)) {
846 // unsigned i64
847 return bufferInfo<uint64_t>(shapedType);
848 }
849 } else if (mlirTypeIsAInteger(elementType) &&
850 mlirIntegerTypeGetWidth(elementType) == 8) {
851 if (mlirIntegerTypeIsSignless(elementType) ||
852 mlirIntegerTypeIsSigned(elementType)) {
853 // i8
854 return bufferInfo<int8_t>(shapedType);
855 }
856 if (mlirIntegerTypeIsUnsigned(elementType)) {
857 // unsigned i8
858 return bufferInfo<uint8_t>(shapedType);
859 }
860 } else if (mlirTypeIsAInteger(elementType) &&
861 mlirIntegerTypeGetWidth(elementType) == 16) {
862 if (mlirIntegerTypeIsSignless(elementType) ||
863 mlirIntegerTypeIsSigned(elementType)) {
864 // i16
865 return bufferInfo<int16_t>(shapedType);
866 }
867 if (mlirIntegerTypeIsUnsigned(elementType)) {
868 // unsigned i16
869 return bufferInfo<uint16_t>(shapedType);
870 }
871 }
872
873 // TODO: Currently crashes the program.
874 // Reported as https://github.com/pybind/pybind11/issues/3336
875 throw std::invalid_argument(
876 "unsupported data type for conversion to Python buffer");
877 }
878
879 static void bindDerived(ClassTy &c) {
880 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
881 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
882 py::arg("array"), py::arg("signless") = true,
883 py::arg("type") = py::none(), py::arg("shape") = py::none(),
884 py::arg("context") = py::none(),
885 kDenseElementsAttrGetDocstring)
886 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
887 py::arg("shaped_type"), py::arg("element_attr"),
888 "Gets a DenseElementsAttr where all values are the same")
889 .def_property_readonly("is_splat",
890 [](PyDenseElementsAttribute &self) -> bool {
891 return mlirDenseElementsAttrIsSplat(self);
892 })
893 .def("get_splat_value",
894 [](PyDenseElementsAttribute &self) {
895 if (!mlirDenseElementsAttrIsSplat(self))
896 throw py::value_error(
897 "get_splat_value called on a non-splat attribute");
898 return mlirDenseElementsAttrGetSplatValue(self);
899 })
900 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
901 }
902
903private:
904 static bool isUnsignedIntegerFormat(std::string_view format) {
905 if (format.empty())
906 return false;
907 char code = format[0];
908 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
909 code == 'Q';
910 }
911
912 static bool isSignedIntegerFormat(std::string_view format) {
913 if (format.empty())
914 return false;
915 char code = format[0];
916 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
917 code == 'q';
918 }
919
920 template <typename Type>
921 py::buffer_info bufferInfo(MlirType shapedType,
922 const char *explicitFormat = nullptr) {
923 intptr_t rank = mlirShapedTypeGetRank(shapedType);
924 // Prepare the data for the buffer_info.
925 // Buffer is configured for read-only access below.
926 Type *data = static_cast<Type *>(
927 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
928 // Prepare the shape for the buffer_info.
929 SmallVector<intptr_t, 4> shape;
930 for (intptr_t i = 0; i < rank; ++i)
931 shape.push_back(Elt: mlirShapedTypeGetDimSize(shapedType, i));
932 // Prepare the strides for the buffer_info.
933 SmallVector<intptr_t, 4> strides;
934 if (mlirDenseElementsAttrIsSplat(*this)) {
935 // Splats are special, only the single value is stored.
936 strides.assign(NumElts: rank, Elt: 0);
937 } else {
938 for (intptr_t i = 1; i < rank; ++i) {
939 intptr_t strideFactor = 1;
940 for (intptr_t j = i; j < rank; ++j)
941 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
942 strides.push_back(Elt: sizeof(Type) * strideFactor);
943 }
944 strides.push_back(Elt: sizeof(Type));
945 }
946 std::string format;
947 if (explicitFormat) {
948 format = explicitFormat;
949 } else {
950 format = py::format_descriptor<Type>::format();
951 }
952 return py::buffer_info(data, sizeof(Type), format, rank, shape, strides,
953 /*readonly=*/true);
954 }
955}; // namespace
956
957/// Refinement of the PyDenseElementsAttribute for attributes containing integer
958/// (and boolean) values. Supports element access.
959class PyDenseIntElementsAttribute
960 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
961 PyDenseElementsAttribute> {
962public:
963 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
964 static constexpr const char *pyClassName = "DenseIntElementsAttr";
965 using PyConcreteAttribute::PyConcreteAttribute;
966
967 /// Returns the element at the given linear position. Asserts if the index is
968 /// out of range.
969 py::int_ dunderGetItem(intptr_t pos) {
970 if (pos < 0 || pos >= dunderLen()) {
971 throw py::index_error("attempt to access out of bounds element");
972 }
973
974 MlirType type = mlirAttributeGetType(*this);
975 type = mlirShapedTypeGetElementType(type);
976 assert(mlirTypeIsAInteger(type) &&
977 "expected integer element type in dense int elements attribute");
978 // Dispatch element extraction to an appropriate C function based on the
979 // elemental type of the attribute. py::int_ is implicitly constructible
980 // from any C++ integral type and handles bitwidth correctly.
981 // TODO: consider caching the type properties in the constructor to avoid
982 // querying them on each element access.
983 unsigned width = mlirIntegerTypeGetWidth(type);
984 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
985 if (isUnsigned) {
986 if (width == 1) {
987 return mlirDenseElementsAttrGetBoolValue(*this, pos);
988 }
989 if (width == 8) {
990 return mlirDenseElementsAttrGetUInt8Value(*this, pos);
991 }
992 if (width == 16) {
993 return mlirDenseElementsAttrGetUInt16Value(*this, pos);
994 }
995 if (width == 32) {
996 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
997 }
998 if (width == 64) {
999 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1000 }
1001 } else {
1002 if (width == 1) {
1003 return mlirDenseElementsAttrGetBoolValue(*this, pos);
1004 }
1005 if (width == 8) {
1006 return mlirDenseElementsAttrGetInt8Value(*this, pos);
1007 }
1008 if (width == 16) {
1009 return mlirDenseElementsAttrGetInt16Value(*this, pos);
1010 }
1011 if (width == 32) {
1012 return mlirDenseElementsAttrGetInt32Value(*this, pos);
1013 }
1014 if (width == 64) {
1015 return mlirDenseElementsAttrGetInt64Value(*this, pos);
1016 }
1017 }
1018 throw py::type_error("Unsupported integer type");
1019 }
1020
1021 static void bindDerived(ClassTy &c) {
1022 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1023 }
1024};
1025
1026class PyDenseResourceElementsAttribute
1027 : public PyConcreteAttribute<PyDenseResourceElementsAttribute> {
1028public:
1029 static constexpr IsAFunctionTy isaFunction =
1030 mlirAttributeIsADenseResourceElements;
1031 static constexpr const char *pyClassName = "DenseResourceElementsAttr";
1032 using PyConcreteAttribute::PyConcreteAttribute;
1033
1034 static PyDenseResourceElementsAttribute
1035 getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type,
1036 std::optional<size_t> alignment, bool isMutable,
1037 DefaultingPyMlirContext contextWrapper) {
1038 if (!mlirTypeIsAShaped(type)) {
1039 throw std::invalid_argument(
1040 "Constructing a DenseResourceElementsAttr requires a ShapedType.");
1041 }
1042
1043 // Do not request any conversions as we must ensure to use caller
1044 // managed memory.
1045 int flags = PyBUF_STRIDES;
1046 std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>();
1047 if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
1048 throw py::error_already_set();
1049 }
1050
1051 // This scope releaser will only release if we haven't yet transferred
1052 // ownership.
1053 auto freeBuffer = llvm::make_scope_exit(F: [&]() {
1054 if (view)
1055 PyBuffer_Release(view.get());
1056 });
1057
1058 if (!PyBuffer_IsContiguous(view.get(), 'A')) {
1059 throw std::invalid_argument("Contiguous buffer is required.");
1060 }
1061
1062 // Infer alignment to be the stride of one element if not explicit.
1063 size_t inferredAlignment;
1064 if (alignment)
1065 inferredAlignment = *alignment;
1066 else
1067 inferredAlignment = view->strides[view->ndim - 1];
1068
1069 // The userData is a Py_buffer* that the deleter owns.
1070 auto deleter = [](void *userData, const void *data, size_t size,
1071 size_t align) {
1072 Py_buffer *ownedView = static_cast<Py_buffer *>(userData);
1073 PyBuffer_Release(ownedView);
1074 delete ownedView;
1075 };
1076
1077 size_t rawBufferSize = view->len;
1078 MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet(
1079 type, toMlirStringRef(name), view->buf, rawBufferSize,
1080 inferredAlignment, isMutable, deleter, static_cast<void *>(view.get()));
1081 if (mlirAttributeIsNull(attr)) {
1082 throw std::invalid_argument(
1083 "DenseResourceElementsAttr could not be constructed from the given "
1084 "buffer. "
1085 "This may mean that the Python buffer layout does not match that "
1086 "MLIR expected layout and is a bug.");
1087 }
1088 view.release();
1089 return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr);
1090 }
1091
1092 static void bindDerived(ClassTy &c) {
1093 c.def_static("get_from_buffer",
1094 PyDenseResourceElementsAttribute::getFromBuffer,
1095 py::arg("array"), py::arg("name"), py::arg("type"),
1096 py::arg("alignment") = py::none(),
1097 py::arg("is_mutable") = false, py::arg("context") = py::none(),
1098 kDenseResourceElementsAttrGetFromBufferDocstring);
1099 }
1100};
1101
1102class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
1103public:
1104 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary;
1105 static constexpr const char *pyClassName = "DictAttr";
1106 using PyConcreteAttribute::PyConcreteAttribute;
1107 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1108 mlirDictionaryAttrGetTypeID;
1109
1110 intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); }
1111
1112 bool dunderContains(const std::string &name) {
1113 return !mlirAttributeIsNull(
1114 mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name)));
1115 }
1116
1117 static void bindDerived(ClassTy &c) {
1118 c.def("__contains__", &PyDictAttribute::dunderContains);
1119 c.def("__len__", &PyDictAttribute::dunderLen);
1120 c.def_static(
1121 "get",
1122 [](py::dict attributes, DefaultingPyMlirContext context) {
1123 SmallVector<MlirNamedAttribute> mlirNamedAttributes;
1124 mlirNamedAttributes.reserve(attributes.size());
1125 for (auto &it : attributes) {
1126 auto &mlirAttr = it.second.cast<PyAttribute &>();
1127 auto name = it.first.cast<std::string>();
1128 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1129 mlirIdentifierGet(mlirAttributeGetContext(mlirAttr),
1130 toMlirStringRef(name)),
1131 mlirAttr));
1132 }
1133 MlirAttribute attr =
1134 mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(),
1135 mlirNamedAttributes.data());
1136 return PyDictAttribute(context->getRef(), attr);
1137 },
1138 py::arg("value") = py::dict(), py::arg("context") = py::none(),
1139 "Gets an uniqued dict attribute");
1140 c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1141 MlirAttribute attr =
1142 mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1143 if (mlirAttributeIsNull(attr))
1144 throw py::key_error("attempt to access a non-existent attribute");
1145 return attr;
1146 });
1147 c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
1148 if (index < 0 || index >= self.dunderLen()) {
1149 throw py::index_error("attempt to access out of bounds attribute");
1150 }
1151 MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index);
1152 return PyNamedAttribute(
1153 namedAttr.attribute,
1154 std::string(mlirIdentifierStr(namedAttr.name).data));
1155 });
1156 }
1157};
1158
1159/// Refinement of PyDenseElementsAttribute for attributes containing
1160/// floating-point values. Supports element access.
1161class PyDenseFPElementsAttribute
1162 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1163 PyDenseElementsAttribute> {
1164public:
1165 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1166 static constexpr const char *pyClassName = "DenseFPElementsAttr";
1167 using PyConcreteAttribute::PyConcreteAttribute;
1168
1169 py::float_ dunderGetItem(intptr_t pos) {
1170 if (pos < 0 || pos >= dunderLen()) {
1171 throw py::index_error("attempt to access out of bounds element");
1172 }
1173
1174 MlirType type = mlirAttributeGetType(*this);
1175 type = mlirShapedTypeGetElementType(type);
1176 // Dispatch element extraction to an appropriate C function based on the
1177 // elemental type of the attribute. py::float_ is implicitly constructible
1178 // from float and double.
1179 // TODO: consider caching the type properties in the constructor to avoid
1180 // querying them on each element access.
1181 if (mlirTypeIsAF32(type)) {
1182 return mlirDenseElementsAttrGetFloatValue(*this, pos);
1183 }
1184 if (mlirTypeIsAF64(type)) {
1185 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1186 }
1187 throw py::type_error("Unsupported floating-point type");
1188 }
1189
1190 static void bindDerived(ClassTy &c) {
1191 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1192 }
1193};
1194
1195class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1196public:
1197 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1198 static constexpr const char *pyClassName = "TypeAttr";
1199 using PyConcreteAttribute::PyConcreteAttribute;
1200 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1201 mlirTypeAttrGetTypeID;
1202
1203 static void bindDerived(ClassTy &c) {
1204 c.def_static(
1205 "get",
1206 [](PyType value, DefaultingPyMlirContext context) {
1207 MlirAttribute attr = mlirTypeAttrGet(value.get());
1208 return PyTypeAttribute(context->getRef(), attr);
1209 },
1210 py::arg("value"), py::arg("context") = py::none(),
1211 "Gets a uniqued Type attribute");
1212 c.def_property_readonly("value", [](PyTypeAttribute &self) {
1213 return mlirTypeAttrGetValue(self.get());
1214 });
1215 }
1216};
1217
1218/// Unit Attribute subclass. Unit attributes don't have values.
1219class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1220public:
1221 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1222 static constexpr const char *pyClassName = "UnitAttr";
1223 using PyConcreteAttribute::PyConcreteAttribute;
1224 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1225 mlirUnitAttrGetTypeID;
1226
1227 static void bindDerived(ClassTy &c) {
1228 c.def_static(
1229 "get",
1230 [](DefaultingPyMlirContext context) {
1231 return PyUnitAttribute(context->getRef(),
1232 mlirUnitAttrGet(context->get()));
1233 },
1234 py::arg("context") = py::none(), "Create a Unit attribute.");
1235 }
1236};
1237
1238/// Strided layout attribute subclass.
1239class PyStridedLayoutAttribute
1240 : public PyConcreteAttribute<PyStridedLayoutAttribute> {
1241public:
1242 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout;
1243 static constexpr const char *pyClassName = "StridedLayoutAttr";
1244 using PyConcreteAttribute::PyConcreteAttribute;
1245 static constexpr GetTypeIDFunctionTy getTypeIdFunction =
1246 mlirStridedLayoutAttrGetTypeID;
1247
1248 static void bindDerived(ClassTy &c) {
1249 c.def_static(
1250 "get",
1251 [](int64_t offset, const std::vector<int64_t> strides,
1252 DefaultingPyMlirContext ctx) {
1253 MlirAttribute attr = mlirStridedLayoutAttrGet(
1254 ctx->get(), offset, strides.size(), strides.data());
1255 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1256 },
1257 py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(),
1258 "Gets a strided layout attribute.");
1259 c.def_static(
1260 "get_fully_dynamic",
1261 [](int64_t rank, DefaultingPyMlirContext ctx) {
1262 auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset();
1263 std::vector<int64_t> strides(rank);
1264 std::fill(strides.begin(), strides.end(), dynamic);
1265 MlirAttribute attr = mlirStridedLayoutAttrGet(
1266 ctx->get(), dynamic, strides.size(), strides.data());
1267 return PyStridedLayoutAttribute(ctx->getRef(), attr);
1268 },
1269 py::arg("rank"), py::arg("context") = py::none(),
1270 "Gets a strided layout attribute with dynamic offset and strides of a "
1271 "given rank.");
1272 c.def_property_readonly(
1273 "offset",
1274 [](PyStridedLayoutAttribute &self) {
1275 return mlirStridedLayoutAttrGetOffset(self);
1276 },
1277 "Returns the value of the float point attribute");
1278 c.def_property_readonly(
1279 "strides",
1280 [](PyStridedLayoutAttribute &self) {
1281 intptr_t size = mlirStridedLayoutAttrGetNumStrides(self);
1282 std::vector<int64_t> strides(size);
1283 for (intptr_t i = 0; i < size; i++) {
1284 strides[i] = mlirStridedLayoutAttrGetStride(self, i);
1285 }
1286 return strides;
1287 },
1288 "Returns the value of the float point attribute");
1289 }
1290};
1291
1292py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) {
1293 if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute))
1294 return py::cast(PyDenseBoolArrayAttribute(pyAttribute));
1295 if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute))
1296 return py::cast(PyDenseI8ArrayAttribute(pyAttribute));
1297 if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute))
1298 return py::cast(PyDenseI16ArrayAttribute(pyAttribute));
1299 if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute))
1300 return py::cast(PyDenseI32ArrayAttribute(pyAttribute));
1301 if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute))
1302 return py::cast(PyDenseI64ArrayAttribute(pyAttribute));
1303 if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute))
1304 return py::cast(PyDenseF32ArrayAttribute(pyAttribute));
1305 if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute))
1306 return py::cast(PyDenseF64ArrayAttribute(pyAttribute));
1307 std::string msg =
1308 std::string("Can't cast unknown element type DenseArrayAttr (") +
1309 std::string(py::repr(py::cast(pyAttribute))) + ")";
1310 throw py::cast_error(msg);
1311}
1312
1313py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) {
1314 if (PyDenseFPElementsAttribute::isaFunction(pyAttribute))
1315 return py::cast(PyDenseFPElementsAttribute(pyAttribute));
1316 if (PyDenseIntElementsAttribute::isaFunction(pyAttribute))
1317 return py::cast(PyDenseIntElementsAttribute(pyAttribute));
1318 std::string msg =
1319 std::string(
1320 "Can't cast unknown element type DenseIntOrFPElementsAttr (") +
1321 std::string(py::repr(py::cast(pyAttribute))) + ")";
1322 throw py::cast_error(msg);
1323}
1324
1325py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) {
1326 if (PyBoolAttribute::isaFunction(pyAttribute))
1327 return py::cast(PyBoolAttribute(pyAttribute));
1328 if (PyIntegerAttribute::isaFunction(pyAttribute))
1329 return py::cast(PyIntegerAttribute(pyAttribute));
1330 std::string msg =
1331 std::string("Can't cast unknown element type DenseArrayAttr (") +
1332 std::string(py::repr(py::cast(pyAttribute))) + ")";
1333 throw py::cast_error(msg);
1334}
1335
1336py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
1337 if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute))
1338 return py::cast(PyFlatSymbolRefAttribute(pyAttribute));
1339 if (PySymbolRefAttribute::isaFunction(pyAttribute))
1340 return py::cast(PySymbolRefAttribute(pyAttribute));
1341 std::string msg = std::string("Can't cast unknown SymbolRef attribute (") +
1342 std::string(py::repr(py::cast(pyAttribute))) + ")";
1343 throw py::cast_error(msg);
1344}
1345
1346} // namespace
1347
1348void mlir::python::populateIRAttributes(py::module &m) {
1349 PyAffineMapAttribute::bind(m);
1350
1351 PyDenseBoolArrayAttribute::bind(m);
1352 PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m);
1353 PyDenseI8ArrayAttribute::bind(m);
1354 PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m);
1355 PyDenseI16ArrayAttribute::bind(m);
1356 PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m);
1357 PyDenseI32ArrayAttribute::bind(m);
1358 PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m);
1359 PyDenseI64ArrayAttribute::bind(m);
1360 PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m);
1361 PyDenseF32ArrayAttribute::bind(m);
1362 PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m);
1363 PyDenseF64ArrayAttribute::bind(m);
1364 PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m);
1365 PyGlobals::get().registerTypeCaster(
1366 mlirDenseArrayAttrGetTypeID(),
1367 pybind11::cpp_function(denseArrayAttributeCaster));
1368
1369 PyArrayAttribute::bind(m);
1370 PyArrayAttribute::PyArrayAttributeIterator::bind(m);
1371 PyBoolAttribute::bind(m);
1372 PyDenseElementsAttribute::bind(m);
1373 PyDenseFPElementsAttribute::bind(m);
1374 PyDenseIntElementsAttribute::bind(m);
1375 PyGlobals::get().registerTypeCaster(
1376 mlirDenseIntOrFPElementsAttrGetTypeID(),
1377 pybind11::cpp_function(denseIntOrFPElementsAttributeCaster));
1378 PyDenseResourceElementsAttribute::bind(m);
1379
1380 PyDictAttribute::bind(m);
1381 PySymbolRefAttribute::bind(m);
1382 PyGlobals::get().registerTypeCaster(
1383 mlirSymbolRefAttrGetTypeID(),
1384 pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster));
1385
1386 PyFlatSymbolRefAttribute::bind(m);
1387 PyOpaqueAttribute::bind(m);
1388 PyFloatAttribute::bind(m);
1389 PyIntegerAttribute::bind(m);
1390 PyStringAttribute::bind(m);
1391 PyTypeAttribute::bind(m);
1392 PyGlobals::get().registerTypeCaster(
1393 mlirIntegerAttrGetTypeID(),
1394 pybind11::cpp_function(integerOrBoolAttributeCaster));
1395 PyUnitAttribute::bind(m);
1396
1397 PyStridedLayoutAttribute::bind(m);
1398}
1399

source code of mlir/lib/Bindings/Python/IRAttributes.cpp