1 | //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <cstdint> |
10 | #include <optional> |
11 | #include <string> |
12 | #include <string_view> |
13 | #include <utility> |
14 | |
15 | #include "IRModule.h" |
16 | #include "NanobindUtils.h" |
17 | #include "mlir-c/BuiltinAttributes.h" |
18 | #include "mlir-c/BuiltinTypes.h" |
19 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
20 | #include "mlir/Bindings/Python/Nanobind.h" |
21 | #include "llvm/ADT/ScopeExit.h" |
22 | #include "llvm/Support/raw_ostream.h" |
23 | |
24 | namespace nb = nanobind; |
25 | using namespace nanobind::literals; |
26 | using namespace mlir; |
27 | using namespace mlir::python; |
28 | |
29 | using llvm::SmallVector; |
30 | |
31 | //------------------------------------------------------------------------------ |
32 | // Docstrings (trivial, non-duplicated docstrings are included inline). |
33 | //------------------------------------------------------------------------------ |
34 | |
35 | static const char kDenseElementsAttrGetDocstring[] = |
36 | R"(Gets a DenseElementsAttr from a Python buffer or array. |
37 | |
38 | When `type` is not provided, then some limited type inferencing is done based |
39 | on the buffer format. Support presently exists for 8/16/32/64 signed and |
40 | unsigned integers and float16/float32/float64. DenseElementsAttrs of these |
41 | types can also be converted back to a corresponding buffer. |
42 | |
43 | For conversions outside of these types, a `type=` must be explicitly provided |
44 | and the buffer contents must be bit-castable to the MLIR internal |
45 | representation: |
46 | |
47 | * Integer types (except for i1): the buffer must be byte aligned to the |
48 | next byte boundary. |
49 | * Floating point types: Must be bit-castable to the given floating point |
50 | size. |
51 | * i1 (bool): Bit packed into 8bit words where the bit pattern matches a |
52 | row major ordering. An arbitrary Numpy `bool_` array can be bit packed to |
53 | this specification with: `np.packbits(ary, axis=None, bitorder='little')`. |
54 | |
55 | If a single element buffer is passed (or for i1, a single byte with value 0 |
56 | or 255), then a splat will be created. |
57 | |
58 | Args: |
59 | array: The array or buffer to convert. |
60 | signless: If inferring an appropriate MLIR type, use signless types for |
61 | integers (defaults True). |
62 | type: Skips inference of the MLIR element type and uses this instead. The |
63 | storage size must be consistent with the actual contents of the buffer. |
64 | shape: Overrides the shape of the buffer when constructing the MLIR |
65 | shaped type. This is needed when the physical and logical shape differ (as |
66 | for i1). |
67 | context: Explicit context, if not from context manager. |
68 | |
69 | Returns: |
70 | DenseElementsAttr on success. |
71 | |
72 | Raises: |
73 | ValueError: If the type of the buffer or array cannot be matched to an MLIR |
74 | type or if the buffer does not meet expectations. |
75 | )"; |
76 | |
77 | static const char kDenseElementsAttrGetFromListDocstring[] = |
78 | R"(Gets a DenseElementsAttr from a Python list of attributes. |
79 | |
80 | Note that it can be expensive to construct attributes individually. |
81 | For a large number of elements, consider using a Python buffer or array instead. |
82 | |
83 | Args: |
84 | attrs: A list of attributes. |
85 | type: The desired shape and type of the resulting DenseElementsAttr. |
86 | If not provided, the element type is determined based on the type |
87 | of the 0th attribute and the shape is `[len(attrs)]`. |
88 | context: Explicit context, if not from context manager. |
89 | |
90 | Returns: |
91 | DenseElementsAttr on success. |
92 | |
93 | Raises: |
94 | ValueError: If the type of the attributes does not match the type |
95 | specified by `shaped_type`. |
96 | )"; |
97 | |
98 | static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = |
99 | R"(Gets a DenseResourceElementsAttr from a Python buffer or array. |
100 | |
101 | This function does minimal validation or massaging of the data, and it is |
102 | up to the caller to ensure that the buffer meets the characteristics |
103 | implied by the shape. |
104 | |
105 | The backing buffer and any user objects will be retained for the lifetime |
106 | of the resource blob. This is typically bounded to the context but the |
107 | resource can have a shorter lifespan depending on how it is used in |
108 | subsequent processing. |
109 | |
110 | Args: |
111 | buffer: The array or buffer to convert. |
112 | name: Name to provide to the resource (may be changed upon collision). |
113 | type: The explicit ShapedType to construct the attribute with. |
114 | context: Explicit context, if not from context manager. |
115 | |
116 | Returns: |
117 | DenseResourceElementsAttr on success. |
118 | |
119 | Raises: |
120 | ValueError: If the type of the buffer or array cannot be matched to an MLIR |
121 | type or if the buffer does not meet expectations. |
122 | )"; |
123 | |
124 | namespace { |
125 | |
126 | struct nb_buffer_info { |
127 | void *ptr = nullptr; |
128 | ssize_t itemsize = 0; |
129 | ssize_t size = 0; |
130 | const char *format = nullptr; |
131 | ssize_t ndim = 0; |
132 | SmallVector<ssize_t, 4> shape; |
133 | SmallVector<ssize_t, 4> strides; |
134 | bool readonly = false; |
135 | |
136 | nb_buffer_info( |
137 | void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, |
138 | SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in, |
139 | bool readonly = false, |
140 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in = |
141 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr)) |
142 | : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), |
143 | shape(std::move(shape_in)), strides(std::move(strides_in)), |
144 | readonly(readonly), owned_view(std::move(owned_view_in)) { |
145 | size = 1; |
146 | for (ssize_t i = 0; i < ndim; ++i) { |
147 | size *= shape[i]; |
148 | } |
149 | } |
150 | |
151 | explicit nb_buffer_info(Py_buffer *view) |
152 | : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, |
153 | {view->shape, view->shape + view->ndim}, |
154 | // TODO(phawkins): check for null strides |
155 | {view->strides, view->strides + view->ndim}, |
156 | view->readonly != 0, |
157 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>( |
158 | view, PyBuffer_Release)) {} |
159 | |
160 | nb_buffer_info(const nb_buffer_info &) = delete; |
161 | nb_buffer_info(nb_buffer_info &&) = default; |
162 | nb_buffer_info &operator=(const nb_buffer_info &) = delete; |
163 | nb_buffer_info &operator=(nb_buffer_info &&) = default; |
164 | |
165 | private: |
166 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view; |
167 | }; |
168 | |
169 | class nb_buffer : public nb::object { |
170 | NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); |
171 | |
172 | nb_buffer_info request() const { |
173 | int flags = PyBUF_STRIDES | PyBUF_FORMAT; |
174 | auto *view = new Py_buffer(); |
175 | if (PyObject_GetBuffer(ptr(), view, flags) != 0) { |
176 | delete view; |
177 | throw nb::python_error(); |
178 | } |
179 | return nb_buffer_info(view); |
180 | } |
181 | }; |
182 | |
183 | template <typename T> |
184 | struct nb_format_descriptor {}; |
185 | |
186 | template <> |
187 | struct nb_format_descriptor<bool> { |
188 | static const char *format() { return "?"; } |
189 | }; |
190 | template <> |
191 | struct nb_format_descriptor<int8_t> { |
192 | static const char *format() { return "b"; } |
193 | }; |
194 | template <> |
195 | struct nb_format_descriptor<uint8_t> { |
196 | static const char *format() { return "B"; } |
197 | }; |
198 | template <> |
199 | struct nb_format_descriptor<int16_t> { |
200 | static const char *format() { return "h"; } |
201 | }; |
202 | template <> |
203 | struct nb_format_descriptor<uint16_t> { |
204 | static const char *format() { return "H"; } |
205 | }; |
206 | template <> |
207 | struct nb_format_descriptor<int32_t> { |
208 | static const char *format() { return "i"; } |
209 | }; |
210 | template <> |
211 | struct nb_format_descriptor<uint32_t> { |
212 | static const char *format() { return "I"; } |
213 | }; |
214 | template <> |
215 | struct nb_format_descriptor<int64_t> { |
216 | static const char *format() { return "q"; } |
217 | }; |
218 | template <> |
219 | struct nb_format_descriptor<uint64_t> { |
220 | static const char *format() { return "Q"; } |
221 | }; |
222 | template <> |
223 | struct nb_format_descriptor<float> { |
224 | static const char *format() { return "f"; } |
225 | }; |
226 | template <> |
227 | struct nb_format_descriptor<double> { |
228 | static const char *format() { return "d"; } |
229 | }; |
230 | |
231 | static MlirStringRef toMlirStringRef(const std::string &s) { |
232 | return mlirStringRefCreate(s.data(), s.size()); |
233 | } |
234 | |
235 | static MlirStringRef toMlirStringRef(const nb::bytes &s) { |
236 | return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); |
237 | } |
238 | |
239 | class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { |
240 | public: |
241 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; |
242 | static constexpr const char *pyClassName = "AffineMapAttr"; |
243 | using PyConcreteAttribute::PyConcreteAttribute; |
244 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
245 | mlirAffineMapAttrGetTypeID; |
246 | |
247 | static void bindDerived(ClassTy &c) { |
248 | c.def_static( |
249 | "get", |
250 | [](PyAffineMap &affineMap) { |
251 | MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); |
252 | return PyAffineMapAttribute(affineMap.getContext(), attr); |
253 | }, |
254 | nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); |
255 | c.def_prop_ro("value", mlirAffineMapAttrGetValue, |
256 | "Returns the value of the AffineMap attribute"); |
257 | } |
258 | }; |
259 | |
260 | class PyIntegerSetAttribute |
261 | : public PyConcreteAttribute<PyIntegerSetAttribute> { |
262 | public: |
263 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; |
264 | static constexpr const char *pyClassName = "IntegerSetAttr"; |
265 | using PyConcreteAttribute::PyConcreteAttribute; |
266 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
267 | mlirIntegerSetAttrGetTypeID; |
268 | |
269 | static void bindDerived(ClassTy &c) { |
270 | c.def_static( |
271 | "get", |
272 | [](PyIntegerSet &integerSet) { |
273 | MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); |
274 | return PyIntegerSetAttribute(integerSet.getContext(), attr); |
275 | }, |
276 | nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); |
277 | } |
278 | }; |
279 | |
280 | template <typename T> |
281 | static T pyTryCast(nb::handle object) { |
282 | try { |
283 | return nb::cast<T>(object); |
284 | } catch (nb::cast_error &err) { |
285 | std::string msg = std::string("Invalid attribute when attempting to " |
286 | "create an ArrayAttribute (") + |
287 | err.what() + ")"; |
288 | throw std::runtime_error(msg.c_str()); |
289 | } catch (std::runtime_error &err) { |
290 | std::string msg = std::string("Invalid attribute (None?) when attempting " |
291 | "to create an ArrayAttribute (") + |
292 | err.what() + ")"; |
293 | throw std::runtime_error(msg.c_str()); |
294 | } |
295 | } |
296 | |
297 | /// A python-wrapped dense array attribute with an element type and a derived |
298 | /// implementation class. |
299 | template <typename EltTy, typename DerivedT> |
300 | class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { |
301 | public: |
302 | using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; |
303 | |
304 | /// Iterator over the integer elements of a dense array. |
305 | class PyDenseArrayIterator { |
306 | public: |
307 | PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} |
308 | |
309 | /// Return a copy of the iterator. |
310 | PyDenseArrayIterator dunderIter() { return *this; } |
311 | |
312 | /// Return the next element. |
313 | EltTy dunderNext() { |
314 | // Throw if the index has reached the end. |
315 | if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) |
316 | throw nb::stop_iteration(); |
317 | return DerivedT::getElement(attr.get(), nextIndex++); |
318 | } |
319 | |
320 | /// Bind the iterator class. |
321 | static void bind(nb::module_ &m) { |
322 | nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName) |
323 | .def("__iter__", &PyDenseArrayIterator::dunderIter) |
324 | .def("__next__", &PyDenseArrayIterator::dunderNext); |
325 | } |
326 | |
327 | private: |
328 | /// The referenced dense array attribute. |
329 | PyAttribute attr; |
330 | /// The next index to read. |
331 | int nextIndex = 0; |
332 | }; |
333 | |
334 | /// Get the element at the given index. |
335 | EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } |
336 | |
337 | /// Bind the attribute class. |
338 | static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { |
339 | // Bind the constructor. |
340 | if constexpr (std::is_same_v<EltTy, bool>) { |
341 | c.def_static( |
342 | "get", |
343 | [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { |
344 | std::vector<bool> values; |
345 | for (nb::handle py_value : py_values) { |
346 | int is_true = PyObject_IsTrue(py_value.ptr()); |
347 | if (is_true < 0) { |
348 | throw nb::python_error(); |
349 | } |
350 | values.push_back(is_true); |
351 | } |
352 | return getAttribute(values, ctx->getRef()); |
353 | }, |
354 | nb::arg("values"), nb::arg( "context").none() = nb::none(), |
355 | "Gets a uniqued dense array attribute"); |
356 | } else { |
357 | c.def_static( |
358 | "get", |
359 | [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { |
360 | return getAttribute(values, ctx->getRef()); |
361 | }, |
362 | nb::arg("values"), nb::arg( "context").none() = nb::none(), |
363 | "Gets a uniqued dense array attribute"); |
364 | } |
365 | // Bind the array methods. |
366 | c.def("__getitem__", [](DerivedT &arr, intptr_t i) { |
367 | if (i >= mlirDenseArrayGetNumElements(arr)) |
368 | throw nb::index_error("DenseArray index out of range"); |
369 | return arr.getItem(i); |
370 | }); |
371 | c.def("__len__", [](const DerivedT &arr) { |
372 | return mlirDenseArrayGetNumElements(arr); |
373 | }); |
374 | c.def("__iter__", |
375 | [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); |
376 | c.def("__add__", [](DerivedT &arr, const nb::list &extras) { |
377 | std::vector<EltTy> values; |
378 | intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); |
379 | values.reserve(numOldElements + nb::len(extras)); |
380 | for (intptr_t i = 0; i < numOldElements; ++i) |
381 | values.push_back(arr.getItem(i)); |
382 | for (nb::handle attr : extras) |
383 | values.push_back(pyTryCast<EltTy>(attr)); |
384 | return getAttribute(values, ctx: arr.getContext()); |
385 | }); |
386 | } |
387 | |
388 | private: |
389 | static DerivedT getAttribute(const std::vector<EltTy> &values, |
390 | PyMlirContextRef ctx) { |
391 | if constexpr (std::is_same_v<EltTy, bool>) { |
392 | std::vector<int> intValues(values.begin(), values.end()); |
393 | MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), |
394 | intValues.data()); |
395 | return DerivedT(ctx, attr); |
396 | } else { |
397 | MlirAttribute attr = |
398 | DerivedT::getAttribute(ctx->get(), values.size(), values.data()); |
399 | return DerivedT(ctx, attr); |
400 | } |
401 | } |
402 | }; |
403 | |
404 | /// Instantiate the python dense array classes. |
405 | struct PyDenseBoolArrayAttribute |
406 | : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { |
407 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; |
408 | static constexpr auto getAttribute = mlirDenseBoolArrayGet; |
409 | static constexpr auto getElement = mlirDenseBoolArrayGetElement; |
410 | static constexpr const char *pyClassName = "DenseBoolArrayAttr"; |
411 | static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; |
412 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
413 | }; |
414 | struct PyDenseI8ArrayAttribute |
415 | : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { |
416 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; |
417 | static constexpr auto getAttribute = mlirDenseI8ArrayGet; |
418 | static constexpr auto getElement = mlirDenseI8ArrayGetElement; |
419 | static constexpr const char *pyClassName = "DenseI8ArrayAttr"; |
420 | static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; |
421 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
422 | }; |
423 | struct PyDenseI16ArrayAttribute |
424 | : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { |
425 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; |
426 | static constexpr auto getAttribute = mlirDenseI16ArrayGet; |
427 | static constexpr auto getElement = mlirDenseI16ArrayGetElement; |
428 | static constexpr const char *pyClassName = "DenseI16ArrayAttr"; |
429 | static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; |
430 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
431 | }; |
432 | struct PyDenseI32ArrayAttribute |
433 | : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { |
434 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; |
435 | static constexpr auto getAttribute = mlirDenseI32ArrayGet; |
436 | static constexpr auto getElement = mlirDenseI32ArrayGetElement; |
437 | static constexpr const char *pyClassName = "DenseI32ArrayAttr"; |
438 | static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; |
439 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
440 | }; |
441 | struct PyDenseI64ArrayAttribute |
442 | : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { |
443 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; |
444 | static constexpr auto getAttribute = mlirDenseI64ArrayGet; |
445 | static constexpr auto getElement = mlirDenseI64ArrayGetElement; |
446 | static constexpr const char *pyClassName = "DenseI64ArrayAttr"; |
447 | static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; |
448 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
449 | }; |
450 | struct PyDenseF32ArrayAttribute |
451 | : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { |
452 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; |
453 | static constexpr auto getAttribute = mlirDenseF32ArrayGet; |
454 | static constexpr auto getElement = mlirDenseF32ArrayGetElement; |
455 | static constexpr const char *pyClassName = "DenseF32ArrayAttr"; |
456 | static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; |
457 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
458 | }; |
459 | struct PyDenseF64ArrayAttribute |
460 | : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { |
461 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; |
462 | static constexpr auto getAttribute = mlirDenseF64ArrayGet; |
463 | static constexpr auto getElement = mlirDenseF64ArrayGetElement; |
464 | static constexpr const char *pyClassName = "DenseF64ArrayAttr"; |
465 | static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; |
466 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
467 | }; |
468 | |
469 | class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { |
470 | public: |
471 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; |
472 | static constexpr const char *pyClassName = "ArrayAttr"; |
473 | using PyConcreteAttribute::PyConcreteAttribute; |
474 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
475 | mlirArrayAttrGetTypeID; |
476 | |
477 | class PyArrayAttributeIterator { |
478 | public: |
479 | PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} |
480 | |
481 | PyArrayAttributeIterator &dunderIter() { return *this; } |
482 | |
483 | MlirAttribute dunderNext() { |
484 | // TODO: Throw is an inefficient way to stop iteration. |
485 | if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) |
486 | throw nb::stop_iteration(); |
487 | return mlirArrayAttrGetElement(attr.get(), nextIndex++); |
488 | } |
489 | |
490 | static void bind(nb::module_ &m) { |
491 | nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator") |
492 | .def("__iter__", &PyArrayAttributeIterator::dunderIter) |
493 | .def("__next__", &PyArrayAttributeIterator::dunderNext); |
494 | } |
495 | |
496 | private: |
497 | PyAttribute attr; |
498 | int nextIndex = 0; |
499 | }; |
500 | |
501 | MlirAttribute getItem(intptr_t i) { |
502 | return mlirArrayAttrGetElement(*this, i); |
503 | } |
504 | |
505 | static void bindDerived(ClassTy &c) { |
506 | c.def_static( |
507 | "get", |
508 | [](nb::list attributes, DefaultingPyMlirContext context) { |
509 | SmallVector<MlirAttribute> mlirAttributes; |
510 | mlirAttributes.reserve(nb::len(attributes)); |
511 | for (auto attribute : attributes) { |
512 | mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); |
513 | } |
514 | MlirAttribute attr = mlirArrayAttrGet( |
515 | context->get(), mlirAttributes.size(), mlirAttributes.data()); |
516 | return PyArrayAttribute(context->getRef(), attr); |
517 | }, |
518 | nb::arg("attributes"), nb::arg( "context").none() = nb::none(), |
519 | "Gets a uniqued Array attribute"); |
520 | c.def("__getitem__", |
521 | [](PyArrayAttribute &arr, intptr_t i) { |
522 | if (i >= mlirArrayAttrGetNumElements(arr)) |
523 | throw nb::index_error("ArrayAttribute index out of range"); |
524 | return arr.getItem(i); |
525 | }) |
526 | .def("__len__", |
527 | [](const PyArrayAttribute &arr) { |
528 | return mlirArrayAttrGetNumElements(arr); |
529 | }) |
530 | .def("__iter__", [](const PyArrayAttribute &arr) { |
531 | return PyArrayAttributeIterator(arr); |
532 | }); |
533 | c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { |
534 | std::vector<MlirAttribute> attributes; |
535 | intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); |
536 | attributes.reserve(numOldElements + nb::len(extras)); |
537 | for (intptr_t i = 0; i < numOldElements; ++i) |
538 | attributes.push_back(arr.getItem(i)); |
539 | for (nb::handle attr : extras) |
540 | attributes.push_back(pyTryCast<PyAttribute>(attr)); |
541 | MlirAttribute arrayAttr = mlirArrayAttrGet( |
542 | arr.getContext()->get(), attributes.size(), attributes.data()); |
543 | return PyArrayAttribute(arr.getContext(), arrayAttr); |
544 | }); |
545 | } |
546 | }; |
547 | |
548 | /// Float Point Attribute subclass - FloatAttr. |
549 | class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { |
550 | public: |
551 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; |
552 | static constexpr const char *pyClassName = "FloatAttr"; |
553 | using PyConcreteAttribute::PyConcreteAttribute; |
554 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
555 | mlirFloatAttrGetTypeID; |
556 | |
557 | static void bindDerived(ClassTy &c) { |
558 | c.def_static( |
559 | "get", |
560 | [](PyType &type, double value, DefaultingPyLocation loc) { |
561 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
562 | MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); |
563 | if (mlirAttributeIsNull(attr)) |
564 | throw MLIRError("Invalid attribute", errors.take()); |
565 | return PyFloatAttribute(type.getContext(), attr); |
566 | }, |
567 | nb::arg("type"), nb::arg( "value"), nb::arg( "loc").none() = nb::none(), |
568 | "Gets an uniqued float point attribute associated to a type"); |
569 | c.def_static( |
570 | "get_f32", |
571 | [](double value, DefaultingPyMlirContext context) { |
572 | MlirAttribute attr = mlirFloatAttrDoubleGet( |
573 | context->get(), mlirF32TypeGet(context->get()), value); |
574 | return PyFloatAttribute(context->getRef(), attr); |
575 | }, |
576 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
577 | "Gets an uniqued float point attribute associated to a f32 type"); |
578 | c.def_static( |
579 | "get_f64", |
580 | [](double value, DefaultingPyMlirContext context) { |
581 | MlirAttribute attr = mlirFloatAttrDoubleGet( |
582 | context->get(), mlirF64TypeGet(context->get()), value); |
583 | return PyFloatAttribute(context->getRef(), attr); |
584 | }, |
585 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
586 | "Gets an uniqued float point attribute associated to a f64 type"); |
587 | c.def_prop_ro("value", mlirFloatAttrGetValueDouble, |
588 | "Returns the value of the float attribute"); |
589 | c.def("__float__", mlirFloatAttrGetValueDouble, |
590 | "Converts the value of the float attribute to a Python float"); |
591 | } |
592 | }; |
593 | |
594 | /// Integer Attribute subclass - IntegerAttr. |
595 | class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { |
596 | public: |
597 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; |
598 | static constexpr const char *pyClassName = "IntegerAttr"; |
599 | using PyConcreteAttribute::PyConcreteAttribute; |
600 | |
601 | static void bindDerived(ClassTy &c) { |
602 | c.def_static( |
603 | "get", |
604 | [](PyType &type, int64_t value) { |
605 | MlirAttribute attr = mlirIntegerAttrGet(type, value); |
606 | return PyIntegerAttribute(type.getContext(), attr); |
607 | }, |
608 | nb::arg("type"), nb::arg( "value"), |
609 | "Gets an uniqued integer attribute associated to a type"); |
610 | c.def_prop_ro("value", toPyInt, |
611 | "Returns the value of the integer attribute"); |
612 | c.def("__int__", toPyInt, |
613 | "Converts the value of the integer attribute to a Python int"); |
614 | c.def_prop_ro_static("static_typeid", |
615 | [](nb::object & /*class*/) -> MlirTypeID { |
616 | return mlirIntegerAttrGetTypeID(); |
617 | }); |
618 | } |
619 | |
620 | private: |
621 | static int64_t toPyInt(PyIntegerAttribute &self) { |
622 | MlirType type = mlirAttributeGetType(self); |
623 | if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) |
624 | return mlirIntegerAttrGetValueInt(self); |
625 | if (mlirIntegerTypeIsSigned(type)) |
626 | return mlirIntegerAttrGetValueSInt(self); |
627 | return mlirIntegerAttrGetValueUInt(self); |
628 | } |
629 | }; |
630 | |
631 | /// Bool Attribute subclass - BoolAttr. |
632 | class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { |
633 | public: |
634 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; |
635 | static constexpr const char *pyClassName = "BoolAttr"; |
636 | using PyConcreteAttribute::PyConcreteAttribute; |
637 | |
638 | static void bindDerived(ClassTy &c) { |
639 | c.def_static( |
640 | "get", |
641 | [](bool value, DefaultingPyMlirContext context) { |
642 | MlirAttribute attr = mlirBoolAttrGet(context->get(), value); |
643 | return PyBoolAttribute(context->getRef(), attr); |
644 | }, |
645 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
646 | "Gets an uniqued bool attribute"); |
647 | c.def_prop_ro("value", mlirBoolAttrGetValue, |
648 | "Returns the value of the bool attribute"); |
649 | c.def("__bool__", mlirBoolAttrGetValue, |
650 | "Converts the value of the bool attribute to a Python bool"); |
651 | } |
652 | }; |
653 | |
654 | class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { |
655 | public: |
656 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; |
657 | static constexpr const char *pyClassName = "SymbolRefAttr"; |
658 | using PyConcreteAttribute::PyConcreteAttribute; |
659 | |
660 | static MlirAttribute fromList(const std::vector<std::string> &symbols, |
661 | PyMlirContext &context) { |
662 | if (symbols.empty()) |
663 | throw std::runtime_error("SymbolRefAttr must be composed of at least " |
664 | "one symbol."); |
665 | MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); |
666 | SmallVector<MlirAttribute, 3> referenceAttrs; |
667 | for (size_t i = 1; i < symbols.size(); ++i) { |
668 | referenceAttrs.push_back( |
669 | mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); |
670 | } |
671 | return mlirSymbolRefAttrGet(context.get(), rootSymbol, |
672 | referenceAttrs.size(), referenceAttrs.data()); |
673 | } |
674 | |
675 | static void bindDerived(ClassTy &c) { |
676 | c.def_static( |
677 | "get", |
678 | [](const std::vector<std::string> &symbols, |
679 | DefaultingPyMlirContext context) { |
680 | return PySymbolRefAttribute::fromList(symbols, context.resolve()); |
681 | }, |
682 | nb::arg("symbols"), nb::arg( "context").none() = nb::none(), |
683 | "Gets a uniqued SymbolRef attribute from a list of symbol names"); |
684 | c.def_prop_ro( |
685 | "value", |
686 | [](PySymbolRefAttribute &self) { |
687 | std::vector<std::string> symbols = { |
688 | unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; |
689 | for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); |
690 | ++i) |
691 | symbols.push_back( |
692 | unwrap(mlirSymbolRefAttrGetRootReference( |
693 | mlirSymbolRefAttrGetNestedReference(self, i))) |
694 | .str()); |
695 | return symbols; |
696 | }, |
697 | "Returns the value of the SymbolRef attribute as a list[str]"); |
698 | } |
699 | }; |
700 | |
701 | class PyFlatSymbolRefAttribute |
702 | : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { |
703 | public: |
704 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; |
705 | static constexpr const char *pyClassName = "FlatSymbolRefAttr"; |
706 | using PyConcreteAttribute::PyConcreteAttribute; |
707 | |
708 | static void bindDerived(ClassTy &c) { |
709 | c.def_static( |
710 | "get", |
711 | [](std::string value, DefaultingPyMlirContext context) { |
712 | MlirAttribute attr = |
713 | mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); |
714 | return PyFlatSymbolRefAttribute(context->getRef(), attr); |
715 | }, |
716 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
717 | "Gets a uniqued FlatSymbolRef attribute"); |
718 | c.def_prop_ro( |
719 | "value", |
720 | [](PyFlatSymbolRefAttribute &self) { |
721 | MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); |
722 | return nb::str(stringRef.data, stringRef.length); |
723 | }, |
724 | "Returns the value of the FlatSymbolRef attribute as a string"); |
725 | } |
726 | }; |
727 | |
728 | class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { |
729 | public: |
730 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; |
731 | static constexpr const char *pyClassName = "OpaqueAttr"; |
732 | using PyConcreteAttribute::PyConcreteAttribute; |
733 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
734 | mlirOpaqueAttrGetTypeID; |
735 | |
736 | static void bindDerived(ClassTy &c) { |
737 | c.def_static( |
738 | "get", |
739 | [](std::string dialectNamespace, nb_buffer buffer, PyType &type, |
740 | DefaultingPyMlirContext context) { |
741 | const nb_buffer_info bufferInfo = buffer.request(); |
742 | intptr_t bufferSize = bufferInfo.size; |
743 | MlirAttribute attr = mlirOpaqueAttrGet( |
744 | context->get(), toMlirStringRef(dialectNamespace), bufferSize, |
745 | static_cast<char *>(bufferInfo.ptr), type); |
746 | return PyOpaqueAttribute(context->getRef(), attr); |
747 | }, |
748 | nb::arg("dialect_namespace"), nb::arg( "buffer"), nb::arg( "type"), |
749 | nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); |
750 | c.def_prop_ro( |
751 | "dialect_namespace", |
752 | [](PyOpaqueAttribute &self) { |
753 | MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); |
754 | return nb::str(stringRef.data, stringRef.length); |
755 | }, |
756 | "Returns the dialect namespace for the Opaque attribute as a string"); |
757 | c.def_prop_ro( |
758 | "data", |
759 | [](PyOpaqueAttribute &self) { |
760 | MlirStringRef stringRef = mlirOpaqueAttrGetData(self); |
761 | return nb::bytes(stringRef.data, stringRef.length); |
762 | }, |
763 | "Returns the data for the Opaqued attributes as `bytes`"); |
764 | } |
765 | }; |
766 | |
767 | class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { |
768 | public: |
769 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; |
770 | static constexpr const char *pyClassName = "StringAttr"; |
771 | using PyConcreteAttribute::PyConcreteAttribute; |
772 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
773 | mlirStringAttrGetTypeID; |
774 | |
775 | static void bindDerived(ClassTy &c) { |
776 | c.def_static( |
777 | "get", |
778 | [](std::string value, DefaultingPyMlirContext context) { |
779 | MlirAttribute attr = |
780 | mlirStringAttrGet(context->get(), toMlirStringRef(value)); |
781 | return PyStringAttribute(context->getRef(), attr); |
782 | }, |
783 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
784 | "Gets a uniqued string attribute"); |
785 | c.def_static( |
786 | "get", |
787 | [](nb::bytes value, DefaultingPyMlirContext context) { |
788 | MlirAttribute attr = |
789 | mlirStringAttrGet(context->get(), toMlirStringRef(value)); |
790 | return PyStringAttribute(context->getRef(), attr); |
791 | }, |
792 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
793 | "Gets a uniqued string attribute"); |
794 | c.def_static( |
795 | "get_typed", |
796 | [](PyType &type, std::string value) { |
797 | MlirAttribute attr = |
798 | mlirStringAttrTypedGet(type, toMlirStringRef(value)); |
799 | return PyStringAttribute(type.getContext(), attr); |
800 | }, |
801 | nb::arg("type"), nb::arg( "value"), |
802 | "Gets a uniqued string attribute associated to a type"); |
803 | c.def_prop_ro( |
804 | "value", |
805 | [](PyStringAttribute &self) { |
806 | MlirStringRef stringRef = mlirStringAttrGetValue(self); |
807 | return nb::str(stringRef.data, stringRef.length); |
808 | }, |
809 | "Returns the value of the string attribute"); |
810 | c.def_prop_ro( |
811 | "value_bytes", |
812 | [](PyStringAttribute &self) { |
813 | MlirStringRef stringRef = mlirStringAttrGetValue(self); |
814 | return nb::bytes(stringRef.data, stringRef.length); |
815 | }, |
816 | "Returns the value of the string attribute as `bytes`"); |
817 | } |
818 | }; |
819 | |
820 | // TODO: Support construction of string elements. |
821 | class PyDenseElementsAttribute |
822 | : public PyConcreteAttribute<PyDenseElementsAttribute> { |
823 | public: |
824 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; |
825 | static constexpr const char *pyClassName = "DenseElementsAttr"; |
826 | using PyConcreteAttribute::PyConcreteAttribute; |
827 | |
828 | static PyDenseElementsAttribute |
829 | getFromList(nb::list attributes, std::optional<PyType> explicitType, |
830 | DefaultingPyMlirContext contextWrapper) { |
831 | const size_t numAttributes = nb::len(attributes); |
832 | if (numAttributes == 0) |
833 | throw nb::value_error("Attributes list must be non-empty."); |
834 | |
835 | MlirType shapedType; |
836 | if (explicitType) { |
837 | if ((!mlirTypeIsAShaped(*explicitType) || |
838 | !mlirShapedTypeHasStaticShape(*explicitType))) { |
839 | |
840 | std::string message; |
841 | llvm::raw_string_ostream os(message); |
842 | os << "Expected a static ShapedType for the shaped_type parameter: " |
843 | << nb::cast<std::string>(nb::repr(nb::cast(*explicitType))); |
844 | throw nb::value_error(message.c_str()); |
845 | } |
846 | shapedType = *explicitType; |
847 | } else { |
848 | SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)}; |
849 | shapedType = mlirRankedTensorTypeGet( |
850 | shape.size(), shape.data(), |
851 | mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), |
852 | mlirAttributeGetNull()); |
853 | } |
854 | |
855 | SmallVector<MlirAttribute> mlirAttributes; |
856 | mlirAttributes.reserve(numAttributes); |
857 | for (const nb::handle &attribute : attributes) { |
858 | MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); |
859 | MlirType attrType = mlirAttributeGetType(mlirAttribute); |
860 | mlirAttributes.push_back(mlirAttribute); |
861 | |
862 | if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { |
863 | std::string message; |
864 | llvm::raw_string_ostream os(message); |
865 | os << "All attributes must be of the same type and match " |
866 | << "the type parameter: expected=" |
867 | << nb::cast<std::string>(nb::repr(nb::cast(shapedType))) |
868 | << ", but got=" |
869 | << nb::cast<std::string>(nb::repr(nb::cast(attrType))); |
870 | throw nb::value_error(message.c_str()); |
871 | } |
872 | } |
873 | |
874 | MlirAttribute elements = mlirDenseElementsAttrGet( |
875 | shapedType, mlirAttributes.size(), mlirAttributes.data()); |
876 | |
877 | return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
878 | } |
879 | |
880 | static PyDenseElementsAttribute |
881 | getFromBuffer(nb_buffer array, bool signless, |
882 | std::optional<PyType> explicitType, |
883 | std::optional<std::vector<int64_t>> explicitShape, |
884 | DefaultingPyMlirContext contextWrapper) { |
885 | // Request a contiguous view. In exotic cases, this will cause a copy. |
886 | int flags = PyBUF_ND; |
887 | if (!explicitType) { |
888 | flags |= PyBUF_FORMAT; |
889 | } |
890 | Py_buffer view; |
891 | if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { |
892 | throw nb::python_error(); |
893 | } |
894 | auto freeBuffer = llvm::make_scope_exit(F: [&]() { PyBuffer_Release(&view); }); |
895 | |
896 | MlirContext context = contextWrapper->get(); |
897 | MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, |
898 | explicitShape, context); |
899 | if (mlirAttributeIsNull(attr)) { |
900 | throw std::invalid_argument( |
901 | "DenseElementsAttr could not be constructed from the given buffer. " |
902 | "This may mean that the Python buffer layout does not match that " |
903 | "MLIR expected layout and is a bug."); |
904 | } |
905 | return PyDenseElementsAttribute(contextWrapper->getRef(), attr); |
906 | } |
907 | |
908 | static PyDenseElementsAttribute getSplat(const PyType &shapedType, |
909 | PyAttribute &elementAttr) { |
910 | auto contextWrapper = |
911 | PyMlirContext::forContext(mlirTypeGetContext(shapedType)); |
912 | if (!mlirAttributeIsAInteger(elementAttr) && |
913 | !mlirAttributeIsAFloat(elementAttr)) { |
914 | std::string message = "Illegal element type for DenseElementsAttr: "; |
915 | message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); |
916 | throw nb::value_error(message.c_str()); |
917 | } |
918 | if (!mlirTypeIsAShaped(shapedType) || |
919 | !mlirShapedTypeHasStaticShape(shapedType)) { |
920 | std::string message = |
921 | "Expected a static ShapedType for the shaped_type parameter: "; |
922 | message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); |
923 | throw nb::value_error(message.c_str()); |
924 | } |
925 | MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); |
926 | MlirType attrType = mlirAttributeGetType(elementAttr); |
927 | if (!mlirTypeEqual(shapedElementType, attrType)) { |
928 | std::string message = |
929 | "Shaped element type and attribute type must be equal: shaped="; |
930 | message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); |
931 | message.append(s: ", element="); |
932 | message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); |
933 | throw nb::value_error(message.c_str()); |
934 | } |
935 | |
936 | MlirAttribute elements = |
937 | mlirDenseElementsAttrSplatGet(shapedType, elementAttr); |
938 | return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
939 | } |
940 | |
941 | intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } |
942 | |
943 | std::unique_ptr<nb_buffer_info> accessBuffer() { |
944 | MlirType shapedType = mlirAttributeGetType(*this); |
945 | MlirType elementType = mlirShapedTypeGetElementType(shapedType); |
946 | std::string format; |
947 | |
948 | if (mlirTypeIsAF32(elementType)) { |
949 | // f32 |
950 | return bufferInfo<float>(shapedType); |
951 | } |
952 | if (mlirTypeIsAF64(elementType)) { |
953 | // f64 |
954 | return bufferInfo<double>(shapedType); |
955 | } |
956 | if (mlirTypeIsAF16(elementType)) { |
957 | // f16 |
958 | return bufferInfo<uint16_t>(shapedType, "e"); |
959 | } |
960 | if (mlirTypeIsAIndex(elementType)) { |
961 | // Same as IndexType::kInternalStorageBitWidth |
962 | return bufferInfo<int64_t>(shapedType); |
963 | } |
964 | if (mlirTypeIsAInteger(elementType) && |
965 | mlirIntegerTypeGetWidth(elementType) == 32) { |
966 | if (mlirIntegerTypeIsSignless(elementType) || |
967 | mlirIntegerTypeIsSigned(elementType)) { |
968 | // i32 |
969 | return bufferInfo<int32_t>(shapedType); |
970 | } |
971 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
972 | // unsigned i32 |
973 | return bufferInfo<uint32_t>(shapedType); |
974 | } |
975 | } else if (mlirTypeIsAInteger(elementType) && |
976 | mlirIntegerTypeGetWidth(elementType) == 64) { |
977 | if (mlirIntegerTypeIsSignless(elementType) || |
978 | mlirIntegerTypeIsSigned(elementType)) { |
979 | // i64 |
980 | return bufferInfo<int64_t>(shapedType); |
981 | } |
982 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
983 | // unsigned i64 |
984 | return bufferInfo<uint64_t>(shapedType); |
985 | } |
986 | } else if (mlirTypeIsAInteger(elementType) && |
987 | mlirIntegerTypeGetWidth(elementType) == 8) { |
988 | if (mlirIntegerTypeIsSignless(elementType) || |
989 | mlirIntegerTypeIsSigned(elementType)) { |
990 | // i8 |
991 | return bufferInfo<int8_t>(shapedType); |
992 | } |
993 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
994 | // unsigned i8 |
995 | return bufferInfo<uint8_t>(shapedType); |
996 | } |
997 | } else if (mlirTypeIsAInteger(elementType) && |
998 | mlirIntegerTypeGetWidth(elementType) == 16) { |
999 | if (mlirIntegerTypeIsSignless(elementType) || |
1000 | mlirIntegerTypeIsSigned(elementType)) { |
1001 | // i16 |
1002 | return bufferInfo<int16_t>(shapedType); |
1003 | } |
1004 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
1005 | // unsigned i16 |
1006 | return bufferInfo<uint16_t>(shapedType); |
1007 | } |
1008 | } else if (mlirTypeIsAInteger(elementType) && |
1009 | mlirIntegerTypeGetWidth(elementType) == 1) { |
1010 | // i1 / bool |
1011 | // We can not send the buffer directly back to Python, because the i1 |
1012 | // values are bitpacked within MLIR. We call numpy's unpackbits function |
1013 | // to convert the bytes. |
1014 | return getBooleanBufferFromBitpackedAttribute(); |
1015 | } |
1016 | |
1017 | // TODO: Currently crashes the program. |
1018 | // Reported as https://github.com/pybind/pybind11/issues/3336 |
1019 | throw std::invalid_argument( |
1020 | "unsupported data type for conversion to Python buffer"); |
1021 | } |
1022 | |
1023 | static void bindDerived(ClassTy &c) { |
1024 | #if PY_VERSION_HEX < 0x03090000 |
1025 | PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr()); |
1026 | tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; |
1027 | tp->tp_as_buffer->bf_releasebuffer = |
1028 | PyDenseElementsAttribute::bf_releasebuffer; |
1029 | #endif |
1030 | c.def("__len__", &PyDenseElementsAttribute::dunderLen) |
1031 | .def_static("get", PyDenseElementsAttribute::getFromBuffer, |
1032 | nb::arg("array"), nb::arg( "signless") = true, |
1033 | nb::arg("type").none() = nb::none(), |
1034 | nb::arg("shape").none() = nb::none(), |
1035 | nb::arg("context").none() = nb::none(), |
1036 | kDenseElementsAttrGetDocstring) |
1037 | .def_static("get", PyDenseElementsAttribute::getFromList, |
1038 | nb::arg("attrs"), nb::arg( "type").none() = nb::none(), |
1039 | nb::arg("context").none() = nb::none(), |
1040 | kDenseElementsAttrGetFromListDocstring) |
1041 | .def_static("get_splat", PyDenseElementsAttribute::getSplat, |
1042 | nb::arg("shaped_type"), nb::arg( "element_attr"), |
1043 | "Gets a DenseElementsAttr where all values are the same") |
1044 | .def_prop_ro("is_splat", |
1045 | [](PyDenseElementsAttribute &self) -> bool { |
1046 | return mlirDenseElementsAttrIsSplat(self); |
1047 | }) |
1048 | .def("get_splat_value", [](PyDenseElementsAttribute &self) { |
1049 | if (!mlirDenseElementsAttrIsSplat(self)) |
1050 | throw nb::value_error( |
1051 | "get_splat_value called on a non-splat attribute"); |
1052 | return mlirDenseElementsAttrGetSplatValue(self); |
1053 | }); |
1054 | } |
1055 | |
1056 | static PyType_Slot slots[]; |
1057 | |
1058 | private: |
1059 | static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); |
1060 | static void bf_releasebuffer(PyObject *, Py_buffer *buffer); |
1061 | |
1062 | static bool isUnsignedIntegerFormat(std::string_view format) { |
1063 | if (format.empty()) |
1064 | return false; |
1065 | char code = format[0]; |
1066 | return code == 'I' || code == 'B' || code == 'H' || code == 'L' || |
1067 | code == 'Q'; |
1068 | } |
1069 | |
1070 | static bool isSignedIntegerFormat(std::string_view format) { |
1071 | if (format.empty()) |
1072 | return false; |
1073 | char code = format[0]; |
1074 | return code == 'i' || code == 'b' || code == 'h' || code == 'l' || |
1075 | code == 'q'; |
1076 | } |
1077 | |
1078 | static MlirType |
1079 | getShapedType(std::optional<MlirType> bulkLoadElementType, |
1080 | std::optional<std::vector<int64_t>> explicitShape, |
1081 | Py_buffer &view) { |
1082 | SmallVector<int64_t> shape; |
1083 | if (explicitShape) { |
1084 | shape.append(in_start: explicitShape->begin(), in_end: explicitShape->end()); |
1085 | } else { |
1086 | shape.append(view.shape, view.shape + view.ndim); |
1087 | } |
1088 | |
1089 | if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
1090 | if (explicitShape) { |
1091 | throw std::invalid_argument("Shape can only be specified explicitly " |
1092 | "when the type is not a shaped type."); |
1093 | } |
1094 | return *bulkLoadElementType; |
1095 | } else { |
1096 | MlirAttribute encodingAttr = mlirAttributeGetNull(); |
1097 | return mlirRankedTensorTypeGet(shape.size(), shape.data(), |
1098 | *bulkLoadElementType, encodingAttr); |
1099 | } |
1100 | } |
1101 | |
1102 | static MlirAttribute getAttributeFromBuffer( |
1103 | Py_buffer &view, bool signless, std::optional<PyType> explicitType, |
1104 | std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { |
1105 | // Detect format codes that are suitable for bulk loading. This includes |
1106 | // all byte aligned integer and floating point types up to 8 bytes. |
1107 | // Notably, this excludes exotics types which do not have a direct |
1108 | // representation in the buffer protocol (i.e. complex, etc). |
1109 | std::optional<MlirType> bulkLoadElementType; |
1110 | if (explicitType) { |
1111 | bulkLoadElementType = *explicitType; |
1112 | } else { |
1113 | std::string_view format(view.format); |
1114 | if (format == "f") { |
1115 | // f32 |
1116 | assert(view.itemsize == 4 && "mismatched array itemsize"); |
1117 | bulkLoadElementType = mlirF32TypeGet(context); |
1118 | } else if (format == "d") { |
1119 | // f64 |
1120 | assert(view.itemsize == 8 && "mismatched array itemsize"); |
1121 | bulkLoadElementType = mlirF64TypeGet(context); |
1122 | } else if (format == "e") { |
1123 | // f16 |
1124 | assert(view.itemsize == 2 && "mismatched array itemsize"); |
1125 | bulkLoadElementType = mlirF16TypeGet(context); |
1126 | } else if (format == "?") { |
1127 | // i1 |
1128 | // The i1 type needs to be bit-packed, so we will handle it seperately |
1129 | return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, |
1130 | context); |
1131 | } else if (isSignedIntegerFormat(format)) { |
1132 | if (view.itemsize == 4) { |
1133 | // i32 |
1134 | bulkLoadElementType = signless |
1135 | ? mlirIntegerTypeGet(context, 32) |
1136 | : mlirIntegerTypeSignedGet(context, 32); |
1137 | } else if (view.itemsize == 8) { |
1138 | // i64 |
1139 | bulkLoadElementType = signless |
1140 | ? mlirIntegerTypeGet(context, 64) |
1141 | : mlirIntegerTypeSignedGet(context, 64); |
1142 | } else if (view.itemsize == 1) { |
1143 | // i8 |
1144 | bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
1145 | : mlirIntegerTypeSignedGet(context, 8); |
1146 | } else if (view.itemsize == 2) { |
1147 | // i16 |
1148 | bulkLoadElementType = signless |
1149 | ? mlirIntegerTypeGet(context, 16) |
1150 | : mlirIntegerTypeSignedGet(context, 16); |
1151 | } |
1152 | } else if (isUnsignedIntegerFormat(format)) { |
1153 | if (view.itemsize == 4) { |
1154 | // unsigned i32 |
1155 | bulkLoadElementType = signless |
1156 | ? mlirIntegerTypeGet(context, 32) |
1157 | : mlirIntegerTypeUnsignedGet(context, 32); |
1158 | } else if (view.itemsize == 8) { |
1159 | // unsigned i64 |
1160 | bulkLoadElementType = signless |
1161 | ? mlirIntegerTypeGet(context, 64) |
1162 | : mlirIntegerTypeUnsignedGet(context, 64); |
1163 | } else if (view.itemsize == 1) { |
1164 | // i8 |
1165 | bulkLoadElementType = signless |
1166 | ? mlirIntegerTypeGet(context, 8) |
1167 | : mlirIntegerTypeUnsignedGet(context, 8); |
1168 | } else if (view.itemsize == 2) { |
1169 | // i16 |
1170 | bulkLoadElementType = signless |
1171 | ? mlirIntegerTypeGet(context, 16) |
1172 | : mlirIntegerTypeUnsignedGet(context, 16); |
1173 | } |
1174 | } |
1175 | if (!bulkLoadElementType) { |
1176 | throw std::invalid_argument( |
1177 | std::string("unimplemented array format conversion from format: ") + |
1178 | std::string(format)); |
1179 | } |
1180 | } |
1181 | |
1182 | MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); |
1183 | return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); |
1184 | } |
1185 | |
1186 | // There is a complication for boolean numpy arrays, as numpy represents |
1187 | // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 |
1188 | // booleans per byte. |
1189 | static MlirAttribute getBitpackedAttributeFromBooleanBuffer( |
1190 | Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, |
1191 | MlirContext &context) { |
1192 | if (llvm::endianness::native != llvm::endianness::little) { |
1193 | // Given we have no good way of testing the behavior on big-endian |
1194 | // systems we will throw |
1195 | throw nb::type_error("Constructing a bit-packed MLIR attribute is " |
1196 | "unsupported on big-endian systems"); |
1197 | } |
1198 | nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray( |
1199 | /*data=*/static_cast<uint8_t *>(view.buf), |
1200 | /*shape=*/{static_cast<size_t>(view.len)}); |
1201 | |
1202 | nb::module_ numpy = nb::module_::import_("numpy"); |
1203 | nb::object packbitsFunc = numpy.attr("packbits"); |
1204 | nb::object packedBooleans = |
1205 | packbitsFunc(nb::cast(unpackedArray), "bitorder"_a= "little"); |
1206 | nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request(); |
1207 | |
1208 | MlirType bitpackedType = |
1209 | getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); |
1210 | assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); |
1211 | // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of |
1212 | // packedBooleans, hence the MlirAttribute will remain valid even when |
1213 | // packedBooleans get reclaimed by the end of the function. |
1214 | return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, |
1215 | pythonBuffer.ptr); |
1216 | } |
1217 | |
1218 | // This does the opposite transformation of |
1219 | // `getBitpackedAttributeFromBooleanBuffer` |
1220 | std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() { |
1221 | if (llvm::endianness::native != llvm::endianness::little) { |
1222 | // Given we have no good way of testing the behavior on big-endian |
1223 | // systems we will throw |
1224 | throw nb::type_error("Constructing a numpy array from a MLIR attribute " |
1225 | "is unsupported on big-endian systems"); |
1226 | } |
1227 | |
1228 | int64_t numBooleans = mlirElementsAttrGetNumElements(*this); |
1229 | int64_t numBitpackedBytes = llvm::divideCeil(Numerator: numBooleans, Denominator: 8); |
1230 | uint8_t *bitpackedData = static_cast<uint8_t *>( |
1231 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
1232 | nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray( |
1233 | /*data=*/bitpackedData, |
1234 | /*shape=*/{static_cast<size_t>(numBitpackedBytes)}); |
1235 | |
1236 | nb::module_ numpy = nb::module_::import_("numpy"); |
1237 | nb::object unpackbitsFunc = numpy.attr("unpackbits"); |
1238 | nb::object equalFunc = numpy.attr("equal"); |
1239 | nb::object reshapeFunc = numpy.attr("reshape"); |
1240 | nb::object unpackedBooleans = |
1241 | unpackbitsFunc(nb::cast(packedArray), "bitorder"_a= "little"); |
1242 | |
1243 | // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. |
1244 | // We need to: |
1245 | // 1. Slice away the padded bits |
1246 | // 2. Make the boolean array have the correct shape |
1247 | // 3. Convert the array to a boolean array |
1248 | unpackedBooleans = unpackedBooleans[nb::slice( |
1249 | nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; |
1250 | unpackedBooleans = equalFunc(unpackedBooleans, 1); |
1251 | |
1252 | MlirType shapedType = mlirAttributeGetType(*this); |
1253 | intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1254 | std::vector<intptr_t> shape(rank); |
1255 | for (intptr_t i = 0; i < rank; ++i) { |
1256 | shape[i] = mlirShapedTypeGetDimSize(shapedType, i); |
1257 | } |
1258 | unpackedBooleans = reshapeFunc(unpackedBooleans, shape); |
1259 | |
1260 | // Make sure the returned nb::buffer_view claims ownership of the data in |
1261 | // `pythonBuffer` so it remains valid when Python reads it |
1262 | nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans); |
1263 | return std::make_unique<nb_buffer_info>(args: pythonBuffer.request()); |
1264 | } |
1265 | |
1266 | template <typename Type> |
1267 | std::unique_ptr<nb_buffer_info> |
1268 | bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { |
1269 | intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1270 | // Prepare the data for the buffer_info. |
1271 | // Buffer is configured for read-only access below. |
1272 | Type *data = static_cast<Type *>( |
1273 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
1274 | // Prepare the shape for the buffer_info. |
1275 | SmallVector<intptr_t, 4> shape; |
1276 | for (intptr_t i = 0; i < rank; ++i) |
1277 | shape.push_back(Elt: mlirShapedTypeGetDimSize(shapedType, i)); |
1278 | // Prepare the strides for the buffer_info. |
1279 | SmallVector<intptr_t, 4> strides; |
1280 | if (mlirDenseElementsAttrIsSplat(*this)) { |
1281 | // Splats are special, only the single value is stored. |
1282 | strides.assign(NumElts: rank, Elt: 0); |
1283 | } else { |
1284 | for (intptr_t i = 1; i < rank; ++i) { |
1285 | intptr_t strideFactor = 1; |
1286 | for (intptr_t j = i; j < rank; ++j) |
1287 | strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); |
1288 | strides.push_back(Elt: sizeof(Type) * strideFactor); |
1289 | } |
1290 | strides.push_back(Elt: sizeof(Type)); |
1291 | } |
1292 | const char *format; |
1293 | if (explicitFormat) { |
1294 | format = explicitFormat; |
1295 | } else { |
1296 | format = nb_format_descriptor<Type>::format(); |
1297 | } |
1298 | return std::make_unique<nb_buffer_info>( |
1299 | data, sizeof(Type), format, rank, std::move(shape), std::move(strides), |
1300 | /*readonly=*/true); |
1301 | } |
1302 | }; // namespace |
1303 | |
1304 | PyType_Slot PyDenseElementsAttribute::slots[] = { |
1305 | // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. |
1306 | #if PY_VERSION_HEX >= 0x03090000 |
1307 | {Py_bf_getbuffer, |
1308 | reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)}, |
1309 | {Py_bf_releasebuffer, |
1310 | reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)}, |
1311 | #endif |
1312 | {0, nullptr}, |
1313 | }; |
1314 | |
1315 | /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, |
1316 | Py_buffer *view, |
1317 | int flags) { |
1318 | view->obj = nullptr; |
1319 | std::unique_ptr<nb_buffer_info> info; |
1320 | try { |
1321 | auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj)); |
1322 | info = attr->accessBuffer(); |
1323 | } catch (nb::python_error &e) { |
1324 | e.restore(); |
1325 | nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); |
1326 | return -1; |
1327 | } |
1328 | view->obj = obj; |
1329 | view->ndim = 1; |
1330 | view->buf = info->ptr; |
1331 | view->itemsize = info->itemsize; |
1332 | view->len = info->itemsize; |
1333 | for (auto s : info->shape) { |
1334 | view->len *= s; |
1335 | } |
1336 | view->readonly = info->readonly; |
1337 | if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { |
1338 | view->format = const_cast<char *>(info->format); |
1339 | } |
1340 | if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { |
1341 | view->ndim = static_cast<int>(info->ndim); |
1342 | view->strides = info->strides.data(); |
1343 | view->shape = info->shape.data(); |
1344 | } |
1345 | view->suboffsets = nullptr; |
1346 | view->internal = info.release(); |
1347 | Py_INCREF(obj); |
1348 | return 0; |
1349 | } |
1350 | |
1351 | /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, |
1352 | Py_buffer *view) { |
1353 | delete reinterpret_cast<nb_buffer_info *>(view->internal); |
1354 | } |
1355 | |
1356 | /// Refinement of the PyDenseElementsAttribute for attributes containing |
1357 | /// integer (and boolean) values. Supports element access. |
1358 | class PyDenseIntElementsAttribute |
1359 | : public PyConcreteAttribute<PyDenseIntElementsAttribute, |
1360 | PyDenseElementsAttribute> { |
1361 | public: |
1362 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; |
1363 | static constexpr const char *pyClassName = "DenseIntElementsAttr"; |
1364 | using PyConcreteAttribute::PyConcreteAttribute; |
1365 | |
1366 | /// Returns the element at the given linear position. Asserts if the index |
1367 | /// is out of range. |
1368 | nb::object dunderGetItem(intptr_t pos) { |
1369 | if (pos < 0 || pos >= dunderLen()) { |
1370 | throw nb::index_error("attempt to access out of bounds element"); |
1371 | } |
1372 | |
1373 | MlirType type = mlirAttributeGetType(*this); |
1374 | type = mlirShapedTypeGetElementType(type); |
1375 | // Index type can also appear as a DenseIntElementsAttr and therefore can be |
1376 | // casted to integer. |
1377 | assert(mlirTypeIsAInteger(type) || |
1378 | mlirTypeIsAIndex(type) && "expected integer/index element type in " |
1379 | "dense int elements attribute"); |
1380 | // Dispatch element extraction to an appropriate C function based on the |
1381 | // elemental type of the attribute. nb::int_ is implicitly constructible |
1382 | // from any C++ integral type and handles bitwidth correctly. |
1383 | // TODO: consider caching the type properties in the constructor to avoid |
1384 | // querying them on each element access. |
1385 | if (mlirTypeIsAIndex(type)) { |
1386 | return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); |
1387 | } |
1388 | unsigned width = mlirIntegerTypeGetWidth(type); |
1389 | bool isUnsigned = mlirIntegerTypeIsUnsigned(type); |
1390 | if (isUnsigned) { |
1391 | if (width == 1) { |
1392 | return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); |
1393 | } |
1394 | if (width == 8) { |
1395 | return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); |
1396 | } |
1397 | if (width == 16) { |
1398 | return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); |
1399 | } |
1400 | if (width == 32) { |
1401 | return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); |
1402 | } |
1403 | if (width == 64) { |
1404 | return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); |
1405 | } |
1406 | } else { |
1407 | if (width == 1) { |
1408 | return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); |
1409 | } |
1410 | if (width == 8) { |
1411 | return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); |
1412 | } |
1413 | if (width == 16) { |
1414 | return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); |
1415 | } |
1416 | if (width == 32) { |
1417 | return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); |
1418 | } |
1419 | if (width == 64) { |
1420 | return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); |
1421 | } |
1422 | } |
1423 | throw nb::type_error("Unsupported integer type"); |
1424 | } |
1425 | |
1426 | static void bindDerived(ClassTy &c) { |
1427 | c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); |
1428 | } |
1429 | }; |
1430 | |
1431 | class PyDenseResourceElementsAttribute |
1432 | : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { |
1433 | public: |
1434 | static constexpr IsAFunctionTy isaFunction = |
1435 | mlirAttributeIsADenseResourceElements; |
1436 | static constexpr const char *pyClassName = "DenseResourceElementsAttr"; |
1437 | using PyConcreteAttribute::PyConcreteAttribute; |
1438 | |
1439 | static PyDenseResourceElementsAttribute |
1440 | getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, |
1441 | std::optional<size_t> alignment, bool isMutable, |
1442 | DefaultingPyMlirContext contextWrapper) { |
1443 | if (!mlirTypeIsAShaped(type)) { |
1444 | throw std::invalid_argument( |
1445 | "Constructing a DenseResourceElementsAttr requires a ShapedType."); |
1446 | } |
1447 | |
1448 | // Do not request any conversions as we must ensure to use caller |
1449 | // managed memory. |
1450 | int flags = PyBUF_STRIDES; |
1451 | std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); |
1452 | if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { |
1453 | throw nb::python_error(); |
1454 | } |
1455 | |
1456 | // This scope releaser will only release if we haven't yet transferred |
1457 | // ownership. |
1458 | auto freeBuffer = llvm::make_scope_exit(F: [&]() { |
1459 | if (view) |
1460 | PyBuffer_Release(view.get()); |
1461 | }); |
1462 | |
1463 | if (!PyBuffer_IsContiguous(view.get(), 'A')) { |
1464 | throw std::invalid_argument("Contiguous buffer is required."); |
1465 | } |
1466 | |
1467 | // Infer alignment to be the stride of one element if not explicit. |
1468 | size_t inferredAlignment; |
1469 | if (alignment) |
1470 | inferredAlignment = *alignment; |
1471 | else |
1472 | inferredAlignment = view->strides[view->ndim - 1]; |
1473 | |
1474 | // The userData is a Py_buffer* that the deleter owns. |
1475 | auto deleter = [](void *userData, const void *data, size_t size, |
1476 | size_t align) { |
1477 | if (!Py_IsInitialized()) |
1478 | Py_Initialize(); |
1479 | Py_buffer *ownedView = static_cast<Py_buffer *>(userData); |
1480 | nb::gil_scoped_acquire gil; |
1481 | PyBuffer_Release(ownedView); |
1482 | delete ownedView; |
1483 | }; |
1484 | |
1485 | size_t rawBufferSize = view->len; |
1486 | MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( |
1487 | type, toMlirStringRef(name), view->buf, rawBufferSize, |
1488 | inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); |
1489 | if (mlirAttributeIsNull(attr)) { |
1490 | throw std::invalid_argument( |
1491 | "DenseResourceElementsAttr could not be constructed from the given " |
1492 | "buffer. " |
1493 | "This may mean that the Python buffer layout does not match that " |
1494 | "MLIR expected layout and is a bug."); |
1495 | } |
1496 | view.release(); |
1497 | return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); |
1498 | } |
1499 | |
1500 | static void bindDerived(ClassTy &c) { |
1501 | c.def_static( |
1502 | "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, |
1503 | nb::arg("array"), nb::arg( "name"), nb::arg( "type"), |
1504 | nb::arg("alignment").none() = nb::none(), nb::arg( "is_mutable") = false, |
1505 | nb::arg("context").none() = nb::none(), |
1506 | kDenseResourceElementsAttrGetFromBufferDocstring); |
1507 | } |
1508 | }; |
1509 | |
1510 | class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { |
1511 | public: |
1512 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; |
1513 | static constexpr const char *pyClassName = "DictAttr"; |
1514 | using PyConcreteAttribute::PyConcreteAttribute; |
1515 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1516 | mlirDictionaryAttrGetTypeID; |
1517 | |
1518 | intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } |
1519 | |
1520 | bool dunderContains(const std::string &name) { |
1521 | return !mlirAttributeIsNull( |
1522 | mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); |
1523 | } |
1524 | |
1525 | static void bindDerived(ClassTy &c) { |
1526 | c.def("__contains__", &PyDictAttribute::dunderContains); |
1527 | c.def("__len__", &PyDictAttribute::dunderLen); |
1528 | c.def_static( |
1529 | "get", |
1530 | [](nb::dict attributes, DefaultingPyMlirContext context) { |
1531 | SmallVector<MlirNamedAttribute> mlirNamedAttributes; |
1532 | mlirNamedAttributes.reserve(attributes.size()); |
1533 | for (std::pair<nb::handle, nb::handle> it : attributes) { |
1534 | auto &mlirAttr = nb::cast<PyAttribute &>(it.second); |
1535 | auto name = nb::cast<std::string>(it.first); |
1536 | mlirNamedAttributes.push_back(mlirNamedAttributeGet( |
1537 | mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), |
1538 | toMlirStringRef(name)), |
1539 | mlirAttr)); |
1540 | } |
1541 | MlirAttribute attr = |
1542 | mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), |
1543 | mlirNamedAttributes.data()); |
1544 | return PyDictAttribute(context->getRef(), attr); |
1545 | }, |
1546 | nb::arg("value") = nb::dict(), nb::arg( "context").none() = nb::none(), |
1547 | "Gets an uniqued dict attribute"); |
1548 | c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { |
1549 | MlirAttribute attr = |
1550 | mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); |
1551 | if (mlirAttributeIsNull(attr)) |
1552 | throw nb::key_error("attempt to access a non-existent attribute"); |
1553 | return attr; |
1554 | }); |
1555 | c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { |
1556 | if (index < 0 || index >= self.dunderLen()) { |
1557 | throw nb::index_error("attempt to access out of bounds attribute"); |
1558 | } |
1559 | MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); |
1560 | return PyNamedAttribute( |
1561 | namedAttr.attribute, |
1562 | std::string(mlirIdentifierStr(namedAttr.name).data)); |
1563 | }); |
1564 | } |
1565 | }; |
1566 | |
1567 | /// Refinement of PyDenseElementsAttribute for attributes containing |
1568 | /// floating-point values. Supports element access. |
1569 | class PyDenseFPElementsAttribute |
1570 | : public PyConcreteAttribute<PyDenseFPElementsAttribute, |
1571 | PyDenseElementsAttribute> { |
1572 | public: |
1573 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; |
1574 | static constexpr const char *pyClassName = "DenseFPElementsAttr"; |
1575 | using PyConcreteAttribute::PyConcreteAttribute; |
1576 | |
1577 | nb::float_ dunderGetItem(intptr_t pos) { |
1578 | if (pos < 0 || pos >= dunderLen()) { |
1579 | throw nb::index_error("attempt to access out of bounds element"); |
1580 | } |
1581 | |
1582 | MlirType type = mlirAttributeGetType(*this); |
1583 | type = mlirShapedTypeGetElementType(type); |
1584 | // Dispatch element extraction to an appropriate C function based on the |
1585 | // elemental type of the attribute. nb::float_ is implicitly constructible |
1586 | // from float and double. |
1587 | // TODO: consider caching the type properties in the constructor to avoid |
1588 | // querying them on each element access. |
1589 | if (mlirTypeIsAF32(type)) { |
1590 | return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); |
1591 | } |
1592 | if (mlirTypeIsAF64(type)) { |
1593 | return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); |
1594 | } |
1595 | throw nb::type_error("Unsupported floating-point type"); |
1596 | } |
1597 | |
1598 | static void bindDerived(ClassTy &c) { |
1599 | c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); |
1600 | } |
1601 | }; |
1602 | |
1603 | class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { |
1604 | public: |
1605 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; |
1606 | static constexpr const char *pyClassName = "TypeAttr"; |
1607 | using PyConcreteAttribute::PyConcreteAttribute; |
1608 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1609 | mlirTypeAttrGetTypeID; |
1610 | |
1611 | static void bindDerived(ClassTy &c) { |
1612 | c.def_static( |
1613 | "get", |
1614 | [](PyType value, DefaultingPyMlirContext context) { |
1615 | MlirAttribute attr = mlirTypeAttrGet(value.get()); |
1616 | return PyTypeAttribute(context->getRef(), attr); |
1617 | }, |
1618 | nb::arg("value"), nb::arg( "context").none() = nb::none(), |
1619 | "Gets a uniqued Type attribute"); |
1620 | c.def_prop_ro("value", [](PyTypeAttribute &self) { |
1621 | return mlirTypeAttrGetValue(self.get()); |
1622 | }); |
1623 | } |
1624 | }; |
1625 | |
1626 | /// Unit Attribute subclass. Unit attributes don't have values. |
1627 | class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { |
1628 | public: |
1629 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; |
1630 | static constexpr const char *pyClassName = "UnitAttr"; |
1631 | using PyConcreteAttribute::PyConcreteAttribute; |
1632 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1633 | mlirUnitAttrGetTypeID; |
1634 | |
1635 | static void bindDerived(ClassTy &c) { |
1636 | c.def_static( |
1637 | "get", |
1638 | [](DefaultingPyMlirContext context) { |
1639 | return PyUnitAttribute(context->getRef(), |
1640 | mlirUnitAttrGet(context->get())); |
1641 | }, |
1642 | nb::arg("context").none() = nb::none(), "Create a Unit attribute."); |
1643 | } |
1644 | }; |
1645 | |
1646 | /// Strided layout attribute subclass. |
1647 | class PyStridedLayoutAttribute |
1648 | : public PyConcreteAttribute<PyStridedLayoutAttribute> { |
1649 | public: |
1650 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; |
1651 | static constexpr const char *pyClassName = "StridedLayoutAttr"; |
1652 | using PyConcreteAttribute::PyConcreteAttribute; |
1653 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1654 | mlirStridedLayoutAttrGetTypeID; |
1655 | |
1656 | static void bindDerived(ClassTy &c) { |
1657 | c.def_static( |
1658 | "get", |
1659 | [](int64_t offset, const std::vector<int64_t> strides, |
1660 | DefaultingPyMlirContext ctx) { |
1661 | MlirAttribute attr = mlirStridedLayoutAttrGet( |
1662 | ctx->get(), offset, strides.size(), strides.data()); |
1663 | return PyStridedLayoutAttribute(ctx->getRef(), attr); |
1664 | }, |
1665 | nb::arg("offset"), nb::arg( "strides"), |
1666 | nb::arg("context").none() = nb::none(), |
1667 | "Gets a strided layout attribute."); |
1668 | c.def_static( |
1669 | "get_fully_dynamic", |
1670 | [](int64_t rank, DefaultingPyMlirContext ctx) { |
1671 | auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); |
1672 | std::vector<int64_t> strides(rank); |
1673 | std::fill(strides.begin(), strides.end(), dynamic); |
1674 | MlirAttribute attr = mlirStridedLayoutAttrGet( |
1675 | ctx->get(), dynamic, strides.size(), strides.data()); |
1676 | return PyStridedLayoutAttribute(ctx->getRef(), attr); |
1677 | }, |
1678 | nb::arg("rank"), nb::arg( "context").none() = nb::none(), |
1679 | "Gets a strided layout attribute with dynamic offset and strides of " |
1680 | "a " |
1681 | "given rank."); |
1682 | c.def_prop_ro( |
1683 | "offset", |
1684 | [](PyStridedLayoutAttribute &self) { |
1685 | return mlirStridedLayoutAttrGetOffset(self); |
1686 | }, |
1687 | "Returns the value of the float point attribute"); |
1688 | c.def_prop_ro( |
1689 | "strides", |
1690 | [](PyStridedLayoutAttribute &self) { |
1691 | intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); |
1692 | std::vector<int64_t> strides(size); |
1693 | for (intptr_t i = 0; i < size; i++) { |
1694 | strides[i] = mlirStridedLayoutAttrGetStride(self, i); |
1695 | } |
1696 | return strides; |
1697 | }, |
1698 | "Returns the value of the float point attribute"); |
1699 | } |
1700 | }; |
1701 | |
1702 | nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { |
1703 | if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) |
1704 | return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); |
1705 | if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) |
1706 | return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); |
1707 | if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) |
1708 | return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); |
1709 | if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) |
1710 | return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); |
1711 | if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) |
1712 | return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); |
1713 | if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) |
1714 | return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); |
1715 | if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) |
1716 | return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); |
1717 | std::string msg = |
1718 | std::string("Can't cast unknown element type DenseArrayAttr (") + |
1719 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; |
1720 | throw nb::type_error(msg.c_str()); |
1721 | } |
1722 | |
1723 | nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { |
1724 | if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) |
1725 | return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); |
1726 | if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) |
1727 | return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); |
1728 | std::string msg = |
1729 | std::string( |
1730 | "Can't cast unknown element type DenseIntOrFPElementsAttr (") + |
1731 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; |
1732 | throw nb::type_error(msg.c_str()); |
1733 | } |
1734 | |
1735 | nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { |
1736 | if (PyBoolAttribute::isaFunction(pyAttribute)) |
1737 | return nb::cast(PyBoolAttribute(pyAttribute)); |
1738 | if (PyIntegerAttribute::isaFunction(pyAttribute)) |
1739 | return nb::cast(PyIntegerAttribute(pyAttribute)); |
1740 | std::string msg = |
1741 | std::string("Can't cast unknown element type DenseArrayAttr (") + |
1742 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")"; |
1743 | throw nb::type_error(msg.c_str()); |
1744 | } |
1745 | |
1746 | nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { |
1747 | if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) |
1748 | return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); |
1749 | if (PySymbolRefAttribute::isaFunction(pyAttribute)) |
1750 | return nb::cast(PySymbolRefAttribute(pyAttribute)); |
1751 | std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + |
1752 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + |
1753 | ")"; |
1754 | throw nb::type_error(msg.c_str()); |
1755 | } |
1756 | |
1757 | } // namespace |
1758 | |
1759 | void mlir::python::populateIRAttributes(nb::module_ &m) { |
1760 | PyAffineMapAttribute::bind(m); |
1761 | PyDenseBoolArrayAttribute::bind(m); |
1762 | PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); |
1763 | PyDenseI8ArrayAttribute::bind(m); |
1764 | PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); |
1765 | PyDenseI16ArrayAttribute::bind(m); |
1766 | PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); |
1767 | PyDenseI32ArrayAttribute::bind(m); |
1768 | PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); |
1769 | PyDenseI64ArrayAttribute::bind(m); |
1770 | PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); |
1771 | PyDenseF32ArrayAttribute::bind(m); |
1772 | PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); |
1773 | PyDenseF64ArrayAttribute::bind(m); |
1774 | PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); |
1775 | PyGlobals::get().registerTypeCaster( |
1776 | mlirDenseArrayAttrGetTypeID(), |
1777 | nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster))); |
1778 | |
1779 | PyArrayAttribute::bind(m); |
1780 | PyArrayAttribute::PyArrayAttributeIterator::bind(m); |
1781 | PyBoolAttribute::bind(m); |
1782 | PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); |
1783 | PyDenseFPElementsAttribute::bind(m); |
1784 | PyDenseIntElementsAttribute::bind(m); |
1785 | PyGlobals::get().registerTypeCaster( |
1786 | mlirDenseIntOrFPElementsAttrGetTypeID(), |
1787 | nb::cast<nb::callable>( |
1788 | nb::cpp_function(denseIntOrFPElementsAttributeCaster))); |
1789 | PyDenseResourceElementsAttribute::bind(m); |
1790 | |
1791 | PyDictAttribute::bind(m); |
1792 | PySymbolRefAttribute::bind(m); |
1793 | PyGlobals::get().registerTypeCaster( |
1794 | mlirSymbolRefAttrGetTypeID(), |
1795 | nb::cast<nb::callable>( |
1796 | nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); |
1797 | |
1798 | PyFlatSymbolRefAttribute::bind(m); |
1799 | PyOpaqueAttribute::bind(m); |
1800 | PyFloatAttribute::bind(m); |
1801 | PyIntegerAttribute::bind(m); |
1802 | PyIntegerSetAttribute::bind(m); |
1803 | PyStringAttribute::bind(m); |
1804 | PyTypeAttribute::bind(m); |
1805 | PyGlobals::get().registerTypeCaster( |
1806 | mlirIntegerAttrGetTypeID(), |
1807 | nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster))); |
1808 | PyUnitAttribute::bind(m); |
1809 | |
1810 | PyStridedLayoutAttribute::bind(m); |
1811 | } |
1812 |
Definitions
- kDenseElementsAttrGetDocstring
- kDenseElementsAttrGetFromListDocstring
- kDenseResourceElementsAttrGetFromBufferDocstring
- nb_buffer_info
- nb_buffer_info
- nb_buffer_info
- nb_buffer_info
- nb_buffer_info
- operator=
- operator=
- nb_buffer
- request
- nb_format_descriptor
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- nb_format_descriptor
- format
- toMlirStringRef
- toMlirStringRef
- PyAffineMapAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- PyIntegerSetAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- pyTryCast
- PyDenseArrayAttribute
- PyDenseArrayIterator
- PyDenseArrayIterator
- dunderIter
- dunderNext
- bind
- getItem
- bindDerived
- getAttribute
- PyDenseBoolArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyDenseI8ArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyDenseI16ArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyDenseI32ArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyDenseI64ArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyDenseF32ArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyDenseF64ArrayAttribute
- isaFunction
- getAttribute
- getElement
- pyClassName
- pyIteratorName
- PyArrayAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- PyArrayAttributeIterator
- PyArrayAttributeIterator
- dunderIter
- dunderNext
- bind
- getItem
- bindDerived
- PyFloatAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- PyIntegerAttribute
- isaFunction
- pyClassName
- bindDerived
- toPyInt
- PyBoolAttribute
- isaFunction
- pyClassName
- bindDerived
- PySymbolRefAttribute
- isaFunction
- pyClassName
- fromList
- bindDerived
- PyFlatSymbolRefAttribute
- isaFunction
- pyClassName
- bindDerived
- PyOpaqueAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- PyStringAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- PyDenseElementsAttribute
- isaFunction
- pyClassName
- getFromList
- getFromBuffer
- getSplat
- dunderLen
- accessBuffer
- bindDerived
- isUnsignedIntegerFormat
- isSignedIntegerFormat
- getShapedType
- getAttributeFromBuffer
- getBitpackedAttributeFromBooleanBuffer
- getBooleanBufferFromBitpackedAttribute
- bufferInfo
- slots
- bf_getbuffer
- bf_releasebuffer
- PyDenseIntElementsAttribute
- isaFunction
- pyClassName
- dunderGetItem
- bindDerived
- PyDenseResourceElementsAttribute
- isaFunction
- pyClassName
- getFromBuffer
- bindDerived
- PyDictAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- dunderLen
- dunderContains
- bindDerived
- PyDenseFPElementsAttribute
- isaFunction
- pyClassName
- dunderGetItem
- bindDerived
- PyTypeAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- PyUnitAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- PyStridedLayoutAttribute
- isaFunction
- pyClassName
- getTypeIdFunction
- bindDerived
- denseArrayAttributeCaster
- denseIntOrFPElementsAttributeCaster
- integerOrBoolAttributeCaster
- symbolRefOrFlatSymbolRefAttributeCaster
Learn to use CMake with our Intro Training
Find out more