| 1 | //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include <cstdint> |
| 10 | #include <optional> |
| 11 | #include <string> |
| 12 | #include <string_view> |
| 13 | #include <utility> |
| 14 | |
| 15 | #include "IRModule.h" |
| 16 | #include "NanobindUtils.h" |
| 17 | #include "mlir-c/BuiltinAttributes.h" |
| 18 | #include "mlir-c/BuiltinTypes.h" |
| 19 | #include "mlir/Bindings/Python/NanobindAdaptors.h" |
| 20 | #include "mlir/Bindings/Python/Nanobind.h" |
| 21 | #include "llvm/ADT/ScopeExit.h" |
| 22 | #include "llvm/Support/raw_ostream.h" |
| 23 | |
| 24 | namespace nb = nanobind; |
| 25 | using namespace nanobind::literals; |
| 26 | using namespace mlir; |
| 27 | using namespace mlir::python; |
| 28 | |
| 29 | using llvm::SmallVector; |
| 30 | |
| 31 | //------------------------------------------------------------------------------ |
| 32 | // Docstrings (trivial, non-duplicated docstrings are included inline). |
| 33 | //------------------------------------------------------------------------------ |
| 34 | |
| 35 | static const char kDenseElementsAttrGetDocstring[] = |
| 36 | R"(Gets a DenseElementsAttr from a Python buffer or array. |
| 37 | |
| 38 | When `type` is not provided, then some limited type inferencing is done based |
| 39 | on the buffer format. Support presently exists for 8/16/32/64 signed and |
| 40 | unsigned integers and float16/float32/float64. DenseElementsAttrs of these |
| 41 | types can also be converted back to a corresponding buffer. |
| 42 | |
| 43 | For conversions outside of these types, a `type=` must be explicitly provided |
| 44 | and the buffer contents must be bit-castable to the MLIR internal |
| 45 | representation: |
| 46 | |
| 47 | * Integer types (except for i1): the buffer must be byte aligned to the |
| 48 | next byte boundary. |
| 49 | * Floating point types: Must be bit-castable to the given floating point |
| 50 | size. |
| 51 | * i1 (bool): Bit packed into 8bit words where the bit pattern matches a |
| 52 | row major ordering. An arbitrary Numpy `bool_` array can be bit packed to |
| 53 | this specification with: `np.packbits(ary, axis=None, bitorder='little')`. |
| 54 | |
| 55 | If a single element buffer is passed (or for i1, a single byte with value 0 |
| 56 | or 255), then a splat will be created. |
| 57 | |
| 58 | Args: |
| 59 | array: The array or buffer to convert. |
| 60 | signless: If inferring an appropriate MLIR type, use signless types for |
| 61 | integers (defaults True). |
| 62 | type: Skips inference of the MLIR element type and uses this instead. The |
| 63 | storage size must be consistent with the actual contents of the buffer. |
| 64 | shape: Overrides the shape of the buffer when constructing the MLIR |
| 65 | shaped type. This is needed when the physical and logical shape differ (as |
| 66 | for i1). |
| 67 | context: Explicit context, if not from context manager. |
| 68 | |
| 69 | Returns: |
| 70 | DenseElementsAttr on success. |
| 71 | |
| 72 | Raises: |
| 73 | ValueError: If the type of the buffer or array cannot be matched to an MLIR |
| 74 | type or if the buffer does not meet expectations. |
| 75 | )" ; |
| 76 | |
| 77 | static const char kDenseElementsAttrGetFromListDocstring[] = |
| 78 | R"(Gets a DenseElementsAttr from a Python list of attributes. |
| 79 | |
| 80 | Note that it can be expensive to construct attributes individually. |
| 81 | For a large number of elements, consider using a Python buffer or array instead. |
| 82 | |
| 83 | Args: |
| 84 | attrs: A list of attributes. |
| 85 | type: The desired shape and type of the resulting DenseElementsAttr. |
| 86 | If not provided, the element type is determined based on the type |
| 87 | of the 0th attribute and the shape is `[len(attrs)]`. |
| 88 | context: Explicit context, if not from context manager. |
| 89 | |
| 90 | Returns: |
| 91 | DenseElementsAttr on success. |
| 92 | |
| 93 | Raises: |
| 94 | ValueError: If the type of the attributes does not match the type |
| 95 | specified by `shaped_type`. |
| 96 | )" ; |
| 97 | |
| 98 | static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = |
| 99 | R"(Gets a DenseResourceElementsAttr from a Python buffer or array. |
| 100 | |
| 101 | This function does minimal validation or massaging of the data, and it is |
| 102 | up to the caller to ensure that the buffer meets the characteristics |
| 103 | implied by the shape. |
| 104 | |
| 105 | The backing buffer and any user objects will be retained for the lifetime |
| 106 | of the resource blob. This is typically bounded to the context but the |
| 107 | resource can have a shorter lifespan depending on how it is used in |
| 108 | subsequent processing. |
| 109 | |
| 110 | Args: |
| 111 | buffer: The array or buffer to convert. |
| 112 | name: Name to provide to the resource (may be changed upon collision). |
| 113 | type: The explicit ShapedType to construct the attribute with. |
| 114 | context: Explicit context, if not from context manager. |
| 115 | |
| 116 | Returns: |
| 117 | DenseResourceElementsAttr on success. |
| 118 | |
| 119 | Raises: |
| 120 | ValueError: If the type of the buffer or array cannot be matched to an MLIR |
| 121 | type or if the buffer does not meet expectations. |
| 122 | )" ; |
| 123 | |
| 124 | namespace { |
| 125 | |
| 126 | struct nb_buffer_info { |
| 127 | void *ptr = nullptr; |
| 128 | ssize_t itemsize = 0; |
| 129 | ssize_t size = 0; |
| 130 | const char *format = nullptr; |
| 131 | ssize_t ndim = 0; |
| 132 | SmallVector<ssize_t, 4> shape; |
| 133 | SmallVector<ssize_t, 4> strides; |
| 134 | bool readonly = false; |
| 135 | |
| 136 | nb_buffer_info( |
| 137 | void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, |
| 138 | SmallVector<ssize_t, 4> shape_in, SmallVector<ssize_t, 4> strides_in, |
| 139 | bool readonly = false, |
| 140 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view_in = |
| 141 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>(nullptr, nullptr)) |
| 142 | : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), |
| 143 | shape(std::move(shape_in)), strides(std::move(strides_in)), |
| 144 | readonly(readonly), owned_view(std::move(owned_view_in)) { |
| 145 | size = 1; |
| 146 | for (ssize_t i = 0; i < ndim; ++i) { |
| 147 | size *= shape[i]; |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | explicit nb_buffer_info(Py_buffer *view) |
| 152 | : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, |
| 153 | {view->shape, view->shape + view->ndim}, |
| 154 | // TODO(phawkins): check for null strides |
| 155 | {view->strides, view->strides + view->ndim}, |
| 156 | view->readonly != 0, |
| 157 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)>( |
| 158 | view, PyBuffer_Release)) {} |
| 159 | |
| 160 | nb_buffer_info(const nb_buffer_info &) = delete; |
| 161 | nb_buffer_info(nb_buffer_info &&) = default; |
| 162 | nb_buffer_info &operator=(const nb_buffer_info &) = delete; |
| 163 | nb_buffer_info &operator=(nb_buffer_info &&) = default; |
| 164 | |
| 165 | private: |
| 166 | std::unique_ptr<Py_buffer, void (*)(Py_buffer *)> owned_view; |
| 167 | }; |
| 168 | |
| 169 | class nb_buffer : public nb::object { |
| 170 | NB_OBJECT_DEFAULT(nb_buffer, object, "buffer" , PyObject_CheckBuffer); |
| 171 | |
| 172 | nb_buffer_info request() const { |
| 173 | int flags = PyBUF_STRIDES | PyBUF_FORMAT; |
| 174 | auto *view = new Py_buffer(); |
| 175 | if (PyObject_GetBuffer(ptr(), view, flags) != 0) { |
| 176 | delete view; |
| 177 | throw nb::python_error(); |
| 178 | } |
| 179 | return nb_buffer_info(view); |
| 180 | } |
| 181 | }; |
| 182 | |
| 183 | template <typename T> |
| 184 | struct nb_format_descriptor {}; |
| 185 | |
| 186 | template <> |
| 187 | struct nb_format_descriptor<bool> { |
| 188 | static const char *format() { return "?" ; } |
| 189 | }; |
| 190 | template <> |
| 191 | struct nb_format_descriptor<int8_t> { |
| 192 | static const char *format() { return "b" ; } |
| 193 | }; |
| 194 | template <> |
| 195 | struct nb_format_descriptor<uint8_t> { |
| 196 | static const char *format() { return "B" ; } |
| 197 | }; |
| 198 | template <> |
| 199 | struct nb_format_descriptor<int16_t> { |
| 200 | static const char *format() { return "h" ; } |
| 201 | }; |
| 202 | template <> |
| 203 | struct nb_format_descriptor<uint16_t> { |
| 204 | static const char *format() { return "H" ; } |
| 205 | }; |
| 206 | template <> |
| 207 | struct nb_format_descriptor<int32_t> { |
| 208 | static const char *format() { return "i" ; } |
| 209 | }; |
| 210 | template <> |
| 211 | struct nb_format_descriptor<uint32_t> { |
| 212 | static const char *format() { return "I" ; } |
| 213 | }; |
| 214 | template <> |
| 215 | struct nb_format_descriptor<int64_t> { |
| 216 | static const char *format() { return "q" ; } |
| 217 | }; |
| 218 | template <> |
| 219 | struct nb_format_descriptor<uint64_t> { |
| 220 | static const char *format() { return "Q" ; } |
| 221 | }; |
| 222 | template <> |
| 223 | struct nb_format_descriptor<float> { |
| 224 | static const char *format() { return "f" ; } |
| 225 | }; |
| 226 | template <> |
| 227 | struct nb_format_descriptor<double> { |
| 228 | static const char *format() { return "d" ; } |
| 229 | }; |
| 230 | |
| 231 | static MlirStringRef toMlirStringRef(const std::string &s) { |
| 232 | return mlirStringRefCreate(s.data(), s.size()); |
| 233 | } |
| 234 | |
| 235 | static MlirStringRef toMlirStringRef(const nb::bytes &s) { |
| 236 | return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size()); |
| 237 | } |
| 238 | |
| 239 | class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { |
| 240 | public: |
| 241 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; |
| 242 | static constexpr const char *pyClassName = "AffineMapAttr" ; |
| 243 | using PyConcreteAttribute::PyConcreteAttribute; |
| 244 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 245 | mlirAffineMapAttrGetTypeID; |
| 246 | |
| 247 | static void bindDerived(ClassTy &c) { |
| 248 | c.def_static( |
| 249 | "get" , |
| 250 | [](PyAffineMap &affineMap) { |
| 251 | MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); |
| 252 | return PyAffineMapAttribute(affineMap.getContext(), attr); |
| 253 | }, |
| 254 | nb::arg("affine_map" ), "Gets an attribute wrapping an AffineMap." ); |
| 255 | c.def_prop_ro("value" , mlirAffineMapAttrGetValue, |
| 256 | "Returns the value of the AffineMap attribute" ); |
| 257 | } |
| 258 | }; |
| 259 | |
| 260 | class PyIntegerSetAttribute |
| 261 | : public PyConcreteAttribute<PyIntegerSetAttribute> { |
| 262 | public: |
| 263 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; |
| 264 | static constexpr const char *pyClassName = "IntegerSetAttr" ; |
| 265 | using PyConcreteAttribute::PyConcreteAttribute; |
| 266 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 267 | mlirIntegerSetAttrGetTypeID; |
| 268 | |
| 269 | static void bindDerived(ClassTy &c) { |
| 270 | c.def_static( |
| 271 | "get" , |
| 272 | [](PyIntegerSet &integerSet) { |
| 273 | MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); |
| 274 | return PyIntegerSetAttribute(integerSet.getContext(), attr); |
| 275 | }, |
| 276 | nb::arg("integer_set" ), "Gets an attribute wrapping an IntegerSet." ); |
| 277 | } |
| 278 | }; |
| 279 | |
| 280 | template <typename T> |
| 281 | static T pyTryCast(nb::handle object) { |
| 282 | try { |
| 283 | return nb::cast<T>(object); |
| 284 | } catch (nb::cast_error &err) { |
| 285 | std::string msg = std::string("Invalid attribute when attempting to " |
| 286 | "create an ArrayAttribute (" ) + |
| 287 | err.what() + ")" ; |
| 288 | throw std::runtime_error(msg.c_str()); |
| 289 | } catch (std::runtime_error &err) { |
| 290 | std::string msg = std::string("Invalid attribute (None?) when attempting " |
| 291 | "to create an ArrayAttribute (" ) + |
| 292 | err.what() + ")" ; |
| 293 | throw std::runtime_error(msg.c_str()); |
| 294 | } |
| 295 | } |
| 296 | |
| 297 | /// A python-wrapped dense array attribute with an element type and a derived |
| 298 | /// implementation class. |
| 299 | template <typename EltTy, typename DerivedT> |
| 300 | class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { |
| 301 | public: |
| 302 | using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; |
| 303 | |
| 304 | /// Iterator over the integer elements of a dense array. |
| 305 | class PyDenseArrayIterator { |
| 306 | public: |
| 307 | PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} |
| 308 | |
| 309 | /// Return a copy of the iterator. |
| 310 | PyDenseArrayIterator dunderIter() { return *this; } |
| 311 | |
| 312 | /// Return the next element. |
| 313 | EltTy dunderNext() { |
| 314 | // Throw if the index has reached the end. |
| 315 | if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) |
| 316 | throw nb::stop_iteration(); |
| 317 | return DerivedT::getElement(attr.get(), nextIndex++); |
| 318 | } |
| 319 | |
| 320 | /// Bind the iterator class. |
| 321 | static void bind(nb::module_ &m) { |
| 322 | nb::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName) |
| 323 | .def("__iter__" , &PyDenseArrayIterator::dunderIter) |
| 324 | .def("__next__" , &PyDenseArrayIterator::dunderNext); |
| 325 | } |
| 326 | |
| 327 | private: |
| 328 | /// The referenced dense array attribute. |
| 329 | PyAttribute attr; |
| 330 | /// The next index to read. |
| 331 | int nextIndex = 0; |
| 332 | }; |
| 333 | |
| 334 | /// Get the element at the given index. |
| 335 | EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } |
| 336 | |
| 337 | /// Bind the attribute class. |
| 338 | static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { |
| 339 | // Bind the constructor. |
| 340 | if constexpr (std::is_same_v<EltTy, bool>) { |
| 341 | c.def_static( |
| 342 | "get" , |
| 343 | [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { |
| 344 | std::vector<bool> values; |
| 345 | for (nb::handle py_value : py_values) { |
| 346 | int is_true = PyObject_IsTrue(py_value.ptr()); |
| 347 | if (is_true < 0) { |
| 348 | throw nb::python_error(); |
| 349 | } |
| 350 | values.push_back(is_true); |
| 351 | } |
| 352 | return getAttribute(values, ctx->getRef()); |
| 353 | }, |
| 354 | nb::arg("values" ), nb::arg("context" ).none() = nb::none(), |
| 355 | "Gets a uniqued dense array attribute" ); |
| 356 | } else { |
| 357 | c.def_static( |
| 358 | "get" , |
| 359 | [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { |
| 360 | return getAttribute(values, ctx->getRef()); |
| 361 | }, |
| 362 | nb::arg("values" ), nb::arg("context" ).none() = nb::none(), |
| 363 | "Gets a uniqued dense array attribute" ); |
| 364 | } |
| 365 | // Bind the array methods. |
| 366 | c.def("__getitem__" , [](DerivedT &arr, intptr_t i) { |
| 367 | if (i >= mlirDenseArrayGetNumElements(arr)) |
| 368 | throw nb::index_error("DenseArray index out of range" ); |
| 369 | return arr.getItem(i); |
| 370 | }); |
| 371 | c.def("__len__" , [](const DerivedT &arr) { |
| 372 | return mlirDenseArrayGetNumElements(arr); |
| 373 | }); |
| 374 | c.def("__iter__" , |
| 375 | [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); |
| 376 | c.def("__add__" , [](DerivedT &arr, const nb::list &) { |
| 377 | std::vector<EltTy> values; |
| 378 | intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); |
| 379 | values.reserve(numOldElements + nb::len(extras)); |
| 380 | for (intptr_t i = 0; i < numOldElements; ++i) |
| 381 | values.push_back(arr.getItem(i)); |
| 382 | for (nb::handle attr : extras) |
| 383 | values.push_back(pyTryCast<EltTy>(attr)); |
| 384 | return getAttribute(values, ctx: arr.getContext()); |
| 385 | }); |
| 386 | } |
| 387 | |
| 388 | private: |
| 389 | static DerivedT getAttribute(const std::vector<EltTy> &values, |
| 390 | PyMlirContextRef ctx) { |
| 391 | if constexpr (std::is_same_v<EltTy, bool>) { |
| 392 | std::vector<int> intValues(values.begin(), values.end()); |
| 393 | MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), |
| 394 | intValues.data()); |
| 395 | return DerivedT(ctx, attr); |
| 396 | } else { |
| 397 | MlirAttribute attr = |
| 398 | DerivedT::getAttribute(ctx->get(), values.size(), values.data()); |
| 399 | return DerivedT(ctx, attr); |
| 400 | } |
| 401 | } |
| 402 | }; |
| 403 | |
| 404 | /// Instantiate the python dense array classes. |
| 405 | struct PyDenseBoolArrayAttribute |
| 406 | : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { |
| 407 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; |
| 408 | static constexpr auto getAttribute = mlirDenseBoolArrayGet; |
| 409 | static constexpr auto getElement = mlirDenseBoolArrayGetElement; |
| 410 | static constexpr const char *pyClassName = "DenseBoolArrayAttr" ; |
| 411 | static constexpr const char *pyIteratorName = "DenseBoolArrayIterator" ; |
| 412 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 413 | }; |
| 414 | struct PyDenseI8ArrayAttribute |
| 415 | : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { |
| 416 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; |
| 417 | static constexpr auto getAttribute = mlirDenseI8ArrayGet; |
| 418 | static constexpr auto getElement = mlirDenseI8ArrayGetElement; |
| 419 | static constexpr const char *pyClassName = "DenseI8ArrayAttr" ; |
| 420 | static constexpr const char *pyIteratorName = "DenseI8ArrayIterator" ; |
| 421 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 422 | }; |
| 423 | struct PyDenseI16ArrayAttribute |
| 424 | : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { |
| 425 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; |
| 426 | static constexpr auto getAttribute = mlirDenseI16ArrayGet; |
| 427 | static constexpr auto getElement = mlirDenseI16ArrayGetElement; |
| 428 | static constexpr const char *pyClassName = "DenseI16ArrayAttr" ; |
| 429 | static constexpr const char *pyIteratorName = "DenseI16ArrayIterator" ; |
| 430 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 431 | }; |
| 432 | struct PyDenseI32ArrayAttribute |
| 433 | : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { |
| 434 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; |
| 435 | static constexpr auto getAttribute = mlirDenseI32ArrayGet; |
| 436 | static constexpr auto getElement = mlirDenseI32ArrayGetElement; |
| 437 | static constexpr const char *pyClassName = "DenseI32ArrayAttr" ; |
| 438 | static constexpr const char *pyIteratorName = "DenseI32ArrayIterator" ; |
| 439 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 440 | }; |
| 441 | struct PyDenseI64ArrayAttribute |
| 442 | : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { |
| 443 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; |
| 444 | static constexpr auto getAttribute = mlirDenseI64ArrayGet; |
| 445 | static constexpr auto getElement = mlirDenseI64ArrayGetElement; |
| 446 | static constexpr const char *pyClassName = "DenseI64ArrayAttr" ; |
| 447 | static constexpr const char *pyIteratorName = "DenseI64ArrayIterator" ; |
| 448 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 449 | }; |
| 450 | struct PyDenseF32ArrayAttribute |
| 451 | : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { |
| 452 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; |
| 453 | static constexpr auto getAttribute = mlirDenseF32ArrayGet; |
| 454 | static constexpr auto getElement = mlirDenseF32ArrayGetElement; |
| 455 | static constexpr const char *pyClassName = "DenseF32ArrayAttr" ; |
| 456 | static constexpr const char *pyIteratorName = "DenseF32ArrayIterator" ; |
| 457 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 458 | }; |
| 459 | struct PyDenseF64ArrayAttribute |
| 460 | : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { |
| 461 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; |
| 462 | static constexpr auto getAttribute = mlirDenseF64ArrayGet; |
| 463 | static constexpr auto getElement = mlirDenseF64ArrayGetElement; |
| 464 | static constexpr const char *pyClassName = "DenseF64ArrayAttr" ; |
| 465 | static constexpr const char *pyIteratorName = "DenseF64ArrayIterator" ; |
| 466 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
| 467 | }; |
| 468 | |
| 469 | class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { |
| 470 | public: |
| 471 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; |
| 472 | static constexpr const char *pyClassName = "ArrayAttr" ; |
| 473 | using PyConcreteAttribute::PyConcreteAttribute; |
| 474 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 475 | mlirArrayAttrGetTypeID; |
| 476 | |
| 477 | class PyArrayAttributeIterator { |
| 478 | public: |
| 479 | PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} |
| 480 | |
| 481 | PyArrayAttributeIterator &dunderIter() { return *this; } |
| 482 | |
| 483 | MlirAttribute dunderNext() { |
| 484 | // TODO: Throw is an inefficient way to stop iteration. |
| 485 | if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) |
| 486 | throw nb::stop_iteration(); |
| 487 | return mlirArrayAttrGetElement(attr.get(), nextIndex++); |
| 488 | } |
| 489 | |
| 490 | static void bind(nb::module_ &m) { |
| 491 | nb::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator" ) |
| 492 | .def("__iter__" , &PyArrayAttributeIterator::dunderIter) |
| 493 | .def("__next__" , &PyArrayAttributeIterator::dunderNext); |
| 494 | } |
| 495 | |
| 496 | private: |
| 497 | PyAttribute attr; |
| 498 | int nextIndex = 0; |
| 499 | }; |
| 500 | |
| 501 | MlirAttribute getItem(intptr_t i) { |
| 502 | return mlirArrayAttrGetElement(*this, i); |
| 503 | } |
| 504 | |
| 505 | static void bindDerived(ClassTy &c) { |
| 506 | c.def_static( |
| 507 | "get" , |
| 508 | [](nb::list attributes, DefaultingPyMlirContext context) { |
| 509 | SmallVector<MlirAttribute> mlirAttributes; |
| 510 | mlirAttributes.reserve(nb::len(attributes)); |
| 511 | for (auto attribute : attributes) { |
| 512 | mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); |
| 513 | } |
| 514 | MlirAttribute attr = mlirArrayAttrGet( |
| 515 | context->get(), mlirAttributes.size(), mlirAttributes.data()); |
| 516 | return PyArrayAttribute(context->getRef(), attr); |
| 517 | }, |
| 518 | nb::arg("attributes" ), nb::arg("context" ).none() = nb::none(), |
| 519 | "Gets a uniqued Array attribute" ); |
| 520 | c.def("__getitem__" , |
| 521 | [](PyArrayAttribute &arr, intptr_t i) { |
| 522 | if (i >= mlirArrayAttrGetNumElements(arr)) |
| 523 | throw nb::index_error("ArrayAttribute index out of range" ); |
| 524 | return arr.getItem(i); |
| 525 | }) |
| 526 | .def("__len__" , |
| 527 | [](const PyArrayAttribute &arr) { |
| 528 | return mlirArrayAttrGetNumElements(arr); |
| 529 | }) |
| 530 | .def("__iter__" , [](const PyArrayAttribute &arr) { |
| 531 | return PyArrayAttributeIterator(arr); |
| 532 | }); |
| 533 | c.def("__add__" , [](PyArrayAttribute arr, nb::list ) { |
| 534 | std::vector<MlirAttribute> attributes; |
| 535 | intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); |
| 536 | attributes.reserve(numOldElements + nb::len(extras)); |
| 537 | for (intptr_t i = 0; i < numOldElements; ++i) |
| 538 | attributes.push_back(arr.getItem(i)); |
| 539 | for (nb::handle attr : extras) |
| 540 | attributes.push_back(pyTryCast<PyAttribute>(attr)); |
| 541 | MlirAttribute arrayAttr = mlirArrayAttrGet( |
| 542 | arr.getContext()->get(), attributes.size(), attributes.data()); |
| 543 | return PyArrayAttribute(arr.getContext(), arrayAttr); |
| 544 | }); |
| 545 | } |
| 546 | }; |
| 547 | |
| 548 | /// Float Point Attribute subclass - FloatAttr. |
| 549 | class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { |
| 550 | public: |
| 551 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; |
| 552 | static constexpr const char *pyClassName = "FloatAttr" ; |
| 553 | using PyConcreteAttribute::PyConcreteAttribute; |
| 554 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 555 | mlirFloatAttrGetTypeID; |
| 556 | |
| 557 | static void bindDerived(ClassTy &c) { |
| 558 | c.def_static( |
| 559 | "get" , |
| 560 | [](PyType &type, double value, DefaultingPyLocation loc) { |
| 561 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
| 562 | MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); |
| 563 | if (mlirAttributeIsNull(attr)) |
| 564 | throw MLIRError("Invalid attribute" , errors.take()); |
| 565 | return PyFloatAttribute(type.getContext(), attr); |
| 566 | }, |
| 567 | nb::arg("type" ), nb::arg("value" ), nb::arg("loc" ).none() = nb::none(), |
| 568 | "Gets an uniqued float point attribute associated to a type" ); |
| 569 | c.def_static( |
| 570 | "get_f32" , |
| 571 | [](double value, DefaultingPyMlirContext context) { |
| 572 | MlirAttribute attr = mlirFloatAttrDoubleGet( |
| 573 | context->get(), mlirF32TypeGet(context->get()), value); |
| 574 | return PyFloatAttribute(context->getRef(), attr); |
| 575 | }, |
| 576 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 577 | "Gets an uniqued float point attribute associated to a f32 type" ); |
| 578 | c.def_static( |
| 579 | "get_f64" , |
| 580 | [](double value, DefaultingPyMlirContext context) { |
| 581 | MlirAttribute attr = mlirFloatAttrDoubleGet( |
| 582 | context->get(), mlirF64TypeGet(context->get()), value); |
| 583 | return PyFloatAttribute(context->getRef(), attr); |
| 584 | }, |
| 585 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 586 | "Gets an uniqued float point attribute associated to a f64 type" ); |
| 587 | c.def_prop_ro("value" , mlirFloatAttrGetValueDouble, |
| 588 | "Returns the value of the float attribute" ); |
| 589 | c.def("__float__" , mlirFloatAttrGetValueDouble, |
| 590 | "Converts the value of the float attribute to a Python float" ); |
| 591 | } |
| 592 | }; |
| 593 | |
| 594 | /// Integer Attribute subclass - IntegerAttr. |
| 595 | class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { |
| 596 | public: |
| 597 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; |
| 598 | static constexpr const char *pyClassName = "IntegerAttr" ; |
| 599 | using PyConcreteAttribute::PyConcreteAttribute; |
| 600 | |
| 601 | static void bindDerived(ClassTy &c) { |
| 602 | c.def_static( |
| 603 | "get" , |
| 604 | [](PyType &type, int64_t value) { |
| 605 | MlirAttribute attr = mlirIntegerAttrGet(type, value); |
| 606 | return PyIntegerAttribute(type.getContext(), attr); |
| 607 | }, |
| 608 | nb::arg("type" ), nb::arg("value" ), |
| 609 | "Gets an uniqued integer attribute associated to a type" ); |
| 610 | c.def_prop_ro("value" , toPyInt, |
| 611 | "Returns the value of the integer attribute" ); |
| 612 | c.def("__int__" , toPyInt, |
| 613 | "Converts the value of the integer attribute to a Python int" ); |
| 614 | c.def_prop_ro_static("static_typeid" , |
| 615 | [](nb::object & /*class*/) -> MlirTypeID { |
| 616 | return mlirIntegerAttrGetTypeID(); |
| 617 | }); |
| 618 | } |
| 619 | |
| 620 | private: |
| 621 | static int64_t toPyInt(PyIntegerAttribute &self) { |
| 622 | MlirType type = mlirAttributeGetType(self); |
| 623 | if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) |
| 624 | return mlirIntegerAttrGetValueInt(self); |
| 625 | if (mlirIntegerTypeIsSigned(type)) |
| 626 | return mlirIntegerAttrGetValueSInt(self); |
| 627 | return mlirIntegerAttrGetValueUInt(self); |
| 628 | } |
| 629 | }; |
| 630 | |
| 631 | /// Bool Attribute subclass - BoolAttr. |
| 632 | class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { |
| 633 | public: |
| 634 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; |
| 635 | static constexpr const char *pyClassName = "BoolAttr" ; |
| 636 | using PyConcreteAttribute::PyConcreteAttribute; |
| 637 | |
| 638 | static void bindDerived(ClassTy &c) { |
| 639 | c.def_static( |
| 640 | "get" , |
| 641 | [](bool value, DefaultingPyMlirContext context) { |
| 642 | MlirAttribute attr = mlirBoolAttrGet(context->get(), value); |
| 643 | return PyBoolAttribute(context->getRef(), attr); |
| 644 | }, |
| 645 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 646 | "Gets an uniqued bool attribute" ); |
| 647 | c.def_prop_ro("value" , mlirBoolAttrGetValue, |
| 648 | "Returns the value of the bool attribute" ); |
| 649 | c.def("__bool__" , mlirBoolAttrGetValue, |
| 650 | "Converts the value of the bool attribute to a Python bool" ); |
| 651 | } |
| 652 | }; |
| 653 | |
| 654 | class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { |
| 655 | public: |
| 656 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; |
| 657 | static constexpr const char *pyClassName = "SymbolRefAttr" ; |
| 658 | using PyConcreteAttribute::PyConcreteAttribute; |
| 659 | |
| 660 | static MlirAttribute fromList(const std::vector<std::string> &symbols, |
| 661 | PyMlirContext &context) { |
| 662 | if (symbols.empty()) |
| 663 | throw std::runtime_error("SymbolRefAttr must be composed of at least " |
| 664 | "one symbol." ); |
| 665 | MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); |
| 666 | SmallVector<MlirAttribute, 3> referenceAttrs; |
| 667 | for (size_t i = 1; i < symbols.size(); ++i) { |
| 668 | referenceAttrs.push_back( |
| 669 | mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); |
| 670 | } |
| 671 | return mlirSymbolRefAttrGet(context.get(), rootSymbol, |
| 672 | referenceAttrs.size(), referenceAttrs.data()); |
| 673 | } |
| 674 | |
| 675 | static void bindDerived(ClassTy &c) { |
| 676 | c.def_static( |
| 677 | "get" , |
| 678 | [](const std::vector<std::string> &symbols, |
| 679 | DefaultingPyMlirContext context) { |
| 680 | return PySymbolRefAttribute::fromList(symbols, context.resolve()); |
| 681 | }, |
| 682 | nb::arg("symbols" ), nb::arg("context" ).none() = nb::none(), |
| 683 | "Gets a uniqued SymbolRef attribute from a list of symbol names" ); |
| 684 | c.def_prop_ro( |
| 685 | "value" , |
| 686 | [](PySymbolRefAttribute &self) { |
| 687 | std::vector<std::string> symbols = { |
| 688 | unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; |
| 689 | for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); |
| 690 | ++i) |
| 691 | symbols.push_back( |
| 692 | unwrap(mlirSymbolRefAttrGetRootReference( |
| 693 | mlirSymbolRefAttrGetNestedReference(self, i))) |
| 694 | .str()); |
| 695 | return symbols; |
| 696 | }, |
| 697 | "Returns the value of the SymbolRef attribute as a list[str]" ); |
| 698 | } |
| 699 | }; |
| 700 | |
| 701 | class PyFlatSymbolRefAttribute |
| 702 | : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { |
| 703 | public: |
| 704 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; |
| 705 | static constexpr const char *pyClassName = "FlatSymbolRefAttr" ; |
| 706 | using PyConcreteAttribute::PyConcreteAttribute; |
| 707 | |
| 708 | static void bindDerived(ClassTy &c) { |
| 709 | c.def_static( |
| 710 | "get" , |
| 711 | [](std::string value, DefaultingPyMlirContext context) { |
| 712 | MlirAttribute attr = |
| 713 | mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); |
| 714 | return PyFlatSymbolRefAttribute(context->getRef(), attr); |
| 715 | }, |
| 716 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 717 | "Gets a uniqued FlatSymbolRef attribute" ); |
| 718 | c.def_prop_ro( |
| 719 | "value" , |
| 720 | [](PyFlatSymbolRefAttribute &self) { |
| 721 | MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); |
| 722 | return nb::str(stringRef.data, stringRef.length); |
| 723 | }, |
| 724 | "Returns the value of the FlatSymbolRef attribute as a string" ); |
| 725 | } |
| 726 | }; |
| 727 | |
| 728 | class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { |
| 729 | public: |
| 730 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; |
| 731 | static constexpr const char *pyClassName = "OpaqueAttr" ; |
| 732 | using PyConcreteAttribute::PyConcreteAttribute; |
| 733 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 734 | mlirOpaqueAttrGetTypeID; |
| 735 | |
| 736 | static void bindDerived(ClassTy &c) { |
| 737 | c.def_static( |
| 738 | "get" , |
| 739 | [](std::string dialectNamespace, nb_buffer buffer, PyType &type, |
| 740 | DefaultingPyMlirContext context) { |
| 741 | const nb_buffer_info bufferInfo = buffer.request(); |
| 742 | intptr_t bufferSize = bufferInfo.size; |
| 743 | MlirAttribute attr = mlirOpaqueAttrGet( |
| 744 | context->get(), toMlirStringRef(dialectNamespace), bufferSize, |
| 745 | static_cast<char *>(bufferInfo.ptr), type); |
| 746 | return PyOpaqueAttribute(context->getRef(), attr); |
| 747 | }, |
| 748 | nb::arg("dialect_namespace" ), nb::arg("buffer" ), nb::arg("type" ), |
| 749 | nb::arg("context" ).none() = nb::none(), "Gets an Opaque attribute." ); |
| 750 | c.def_prop_ro( |
| 751 | "dialect_namespace" , |
| 752 | [](PyOpaqueAttribute &self) { |
| 753 | MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); |
| 754 | return nb::str(stringRef.data, stringRef.length); |
| 755 | }, |
| 756 | "Returns the dialect namespace for the Opaque attribute as a string" ); |
| 757 | c.def_prop_ro( |
| 758 | "data" , |
| 759 | [](PyOpaqueAttribute &self) { |
| 760 | MlirStringRef stringRef = mlirOpaqueAttrGetData(self); |
| 761 | return nb::bytes(stringRef.data, stringRef.length); |
| 762 | }, |
| 763 | "Returns the data for the Opaqued attributes as `bytes`" ); |
| 764 | } |
| 765 | }; |
| 766 | |
| 767 | class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { |
| 768 | public: |
| 769 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; |
| 770 | static constexpr const char *pyClassName = "StringAttr" ; |
| 771 | using PyConcreteAttribute::PyConcreteAttribute; |
| 772 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 773 | mlirStringAttrGetTypeID; |
| 774 | |
| 775 | static void bindDerived(ClassTy &c) { |
| 776 | c.def_static( |
| 777 | "get" , |
| 778 | [](std::string value, DefaultingPyMlirContext context) { |
| 779 | MlirAttribute attr = |
| 780 | mlirStringAttrGet(context->get(), toMlirStringRef(value)); |
| 781 | return PyStringAttribute(context->getRef(), attr); |
| 782 | }, |
| 783 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 784 | "Gets a uniqued string attribute" ); |
| 785 | c.def_static( |
| 786 | "get" , |
| 787 | [](nb::bytes value, DefaultingPyMlirContext context) { |
| 788 | MlirAttribute attr = |
| 789 | mlirStringAttrGet(context->get(), toMlirStringRef(value)); |
| 790 | return PyStringAttribute(context->getRef(), attr); |
| 791 | }, |
| 792 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 793 | "Gets a uniqued string attribute" ); |
| 794 | c.def_static( |
| 795 | "get_typed" , |
| 796 | [](PyType &type, std::string value) { |
| 797 | MlirAttribute attr = |
| 798 | mlirStringAttrTypedGet(type, toMlirStringRef(value)); |
| 799 | return PyStringAttribute(type.getContext(), attr); |
| 800 | }, |
| 801 | nb::arg("type" ), nb::arg("value" ), |
| 802 | "Gets a uniqued string attribute associated to a type" ); |
| 803 | c.def_prop_ro( |
| 804 | "value" , |
| 805 | [](PyStringAttribute &self) { |
| 806 | MlirStringRef stringRef = mlirStringAttrGetValue(self); |
| 807 | return nb::str(stringRef.data, stringRef.length); |
| 808 | }, |
| 809 | "Returns the value of the string attribute" ); |
| 810 | c.def_prop_ro( |
| 811 | "value_bytes" , |
| 812 | [](PyStringAttribute &self) { |
| 813 | MlirStringRef stringRef = mlirStringAttrGetValue(self); |
| 814 | return nb::bytes(stringRef.data, stringRef.length); |
| 815 | }, |
| 816 | "Returns the value of the string attribute as `bytes`" ); |
| 817 | } |
| 818 | }; |
| 819 | |
| 820 | // TODO: Support construction of string elements. |
| 821 | class PyDenseElementsAttribute |
| 822 | : public PyConcreteAttribute<PyDenseElementsAttribute> { |
| 823 | public: |
| 824 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; |
| 825 | static constexpr const char *pyClassName = "DenseElementsAttr" ; |
| 826 | using PyConcreteAttribute::PyConcreteAttribute; |
| 827 | |
| 828 | static PyDenseElementsAttribute |
| 829 | getFromList(nb::list attributes, std::optional<PyType> explicitType, |
| 830 | DefaultingPyMlirContext contextWrapper) { |
| 831 | const size_t numAttributes = nb::len(attributes); |
| 832 | if (numAttributes == 0) |
| 833 | throw nb::value_error("Attributes list must be non-empty." ); |
| 834 | |
| 835 | MlirType shapedType; |
| 836 | if (explicitType) { |
| 837 | if ((!mlirTypeIsAShaped(*explicitType) || |
| 838 | !mlirShapedTypeHasStaticShape(*explicitType))) { |
| 839 | |
| 840 | std::string message; |
| 841 | llvm::raw_string_ostream os(message); |
| 842 | os << "Expected a static ShapedType for the shaped_type parameter: " |
| 843 | << nb::cast<std::string>(nb::repr(nb::cast(*explicitType))); |
| 844 | throw nb::value_error(message.c_str()); |
| 845 | } |
| 846 | shapedType = *explicitType; |
| 847 | } else { |
| 848 | SmallVector<int64_t> shape = {static_cast<int64_t>(numAttributes)}; |
| 849 | shapedType = mlirRankedTensorTypeGet( |
| 850 | shape.size(), shape.data(), |
| 851 | mlirAttributeGetType(pyTryCast<PyAttribute>(attributes[0])), |
| 852 | mlirAttributeGetNull()); |
| 853 | } |
| 854 | |
| 855 | SmallVector<MlirAttribute> mlirAttributes; |
| 856 | mlirAttributes.reserve(numAttributes); |
| 857 | for (const nb::handle &attribute : attributes) { |
| 858 | MlirAttribute mlirAttribute = pyTryCast<PyAttribute>(attribute); |
| 859 | MlirType attrType = mlirAttributeGetType(mlirAttribute); |
| 860 | mlirAttributes.push_back(mlirAttribute); |
| 861 | |
| 862 | if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { |
| 863 | std::string message; |
| 864 | llvm::raw_string_ostream os(message); |
| 865 | os << "All attributes must be of the same type and match " |
| 866 | << "the type parameter: expected=" |
| 867 | << nb::cast<std::string>(nb::repr(nb::cast(shapedType))) |
| 868 | << ", but got=" |
| 869 | << nb::cast<std::string>(nb::repr(nb::cast(attrType))); |
| 870 | throw nb::value_error(message.c_str()); |
| 871 | } |
| 872 | } |
| 873 | |
| 874 | MlirAttribute elements = mlirDenseElementsAttrGet( |
| 875 | shapedType, mlirAttributes.size(), mlirAttributes.data()); |
| 876 | |
| 877 | return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
| 878 | } |
| 879 | |
| 880 | static PyDenseElementsAttribute |
| 881 | getFromBuffer(nb_buffer array, bool signless, |
| 882 | std::optional<PyType> explicitType, |
| 883 | std::optional<std::vector<int64_t>> explicitShape, |
| 884 | DefaultingPyMlirContext contextWrapper) { |
| 885 | // Request a contiguous view. In exotic cases, this will cause a copy. |
| 886 | int flags = PyBUF_ND; |
| 887 | if (!explicitType) { |
| 888 | flags |= PyBUF_FORMAT; |
| 889 | } |
| 890 | Py_buffer view; |
| 891 | if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { |
| 892 | throw nb::python_error(); |
| 893 | } |
| 894 | auto freeBuffer = llvm::make_scope_exit(F: [&]() { PyBuffer_Release(&view); }); |
| 895 | |
| 896 | MlirContext context = contextWrapper->get(); |
| 897 | MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, |
| 898 | explicitShape, context); |
| 899 | if (mlirAttributeIsNull(attr)) { |
| 900 | throw std::invalid_argument( |
| 901 | "DenseElementsAttr could not be constructed from the given buffer. " |
| 902 | "This may mean that the Python buffer layout does not match that " |
| 903 | "MLIR expected layout and is a bug." ); |
| 904 | } |
| 905 | return PyDenseElementsAttribute(contextWrapper->getRef(), attr); |
| 906 | } |
| 907 | |
| 908 | static PyDenseElementsAttribute getSplat(const PyType &shapedType, |
| 909 | PyAttribute &elementAttr) { |
| 910 | auto contextWrapper = |
| 911 | PyMlirContext::forContext(mlirTypeGetContext(shapedType)); |
| 912 | if (!mlirAttributeIsAInteger(elementAttr) && |
| 913 | !mlirAttributeIsAFloat(elementAttr)) { |
| 914 | std::string message = "Illegal element type for DenseElementsAttr: " ; |
| 915 | message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); |
| 916 | throw nb::value_error(message.c_str()); |
| 917 | } |
| 918 | if (!mlirTypeIsAShaped(shapedType) || |
| 919 | !mlirShapedTypeHasStaticShape(shapedType)) { |
| 920 | std::string message = |
| 921 | "Expected a static ShapedType for the shaped_type parameter: " ; |
| 922 | message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); |
| 923 | throw nb::value_error(message.c_str()); |
| 924 | } |
| 925 | MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); |
| 926 | MlirType attrType = mlirAttributeGetType(elementAttr); |
| 927 | if (!mlirTypeEqual(shapedElementType, attrType)) { |
| 928 | std::string message = |
| 929 | "Shaped element type and attribute type must be equal: shaped=" ; |
| 930 | message.append(nb::cast<std::string>(nb::repr(nb::cast(shapedType)))); |
| 931 | message.append(s: ", element=" ); |
| 932 | message.append(nb::cast<std::string>(nb::repr(nb::cast(elementAttr)))); |
| 933 | throw nb::value_error(message.c_str()); |
| 934 | } |
| 935 | |
| 936 | MlirAttribute elements = |
| 937 | mlirDenseElementsAttrSplatGet(shapedType, elementAttr); |
| 938 | return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
| 939 | } |
| 940 | |
| 941 | intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } |
| 942 | |
| 943 | std::unique_ptr<nb_buffer_info> accessBuffer() { |
| 944 | MlirType shapedType = mlirAttributeGetType(*this); |
| 945 | MlirType elementType = mlirShapedTypeGetElementType(shapedType); |
| 946 | std::string format; |
| 947 | |
| 948 | if (mlirTypeIsAF32(elementType)) { |
| 949 | // f32 |
| 950 | return bufferInfo<float>(shapedType); |
| 951 | } |
| 952 | if (mlirTypeIsAF64(elementType)) { |
| 953 | // f64 |
| 954 | return bufferInfo<double>(shapedType); |
| 955 | } |
| 956 | if (mlirTypeIsAF16(elementType)) { |
| 957 | // f16 |
| 958 | return bufferInfo<uint16_t>(shapedType, "e" ); |
| 959 | } |
| 960 | if (mlirTypeIsAIndex(elementType)) { |
| 961 | // Same as IndexType::kInternalStorageBitWidth |
| 962 | return bufferInfo<int64_t>(shapedType); |
| 963 | } |
| 964 | if (mlirTypeIsAInteger(elementType) && |
| 965 | mlirIntegerTypeGetWidth(elementType) == 32) { |
| 966 | if (mlirIntegerTypeIsSignless(elementType) || |
| 967 | mlirIntegerTypeIsSigned(elementType)) { |
| 968 | // i32 |
| 969 | return bufferInfo<int32_t>(shapedType); |
| 970 | } |
| 971 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
| 972 | // unsigned i32 |
| 973 | return bufferInfo<uint32_t>(shapedType); |
| 974 | } |
| 975 | } else if (mlirTypeIsAInteger(elementType) && |
| 976 | mlirIntegerTypeGetWidth(elementType) == 64) { |
| 977 | if (mlirIntegerTypeIsSignless(elementType) || |
| 978 | mlirIntegerTypeIsSigned(elementType)) { |
| 979 | // i64 |
| 980 | return bufferInfo<int64_t>(shapedType); |
| 981 | } |
| 982 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
| 983 | // unsigned i64 |
| 984 | return bufferInfo<uint64_t>(shapedType); |
| 985 | } |
| 986 | } else if (mlirTypeIsAInteger(elementType) && |
| 987 | mlirIntegerTypeGetWidth(elementType) == 8) { |
| 988 | if (mlirIntegerTypeIsSignless(elementType) || |
| 989 | mlirIntegerTypeIsSigned(elementType)) { |
| 990 | // i8 |
| 991 | return bufferInfo<int8_t>(shapedType); |
| 992 | } |
| 993 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
| 994 | // unsigned i8 |
| 995 | return bufferInfo<uint8_t>(shapedType); |
| 996 | } |
| 997 | } else if (mlirTypeIsAInteger(elementType) && |
| 998 | mlirIntegerTypeGetWidth(elementType) == 16) { |
| 999 | if (mlirIntegerTypeIsSignless(elementType) || |
| 1000 | mlirIntegerTypeIsSigned(elementType)) { |
| 1001 | // i16 |
| 1002 | return bufferInfo<int16_t>(shapedType); |
| 1003 | } |
| 1004 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
| 1005 | // unsigned i16 |
| 1006 | return bufferInfo<uint16_t>(shapedType); |
| 1007 | } |
| 1008 | } else if (mlirTypeIsAInteger(elementType) && |
| 1009 | mlirIntegerTypeGetWidth(elementType) == 1) { |
| 1010 | // i1 / bool |
| 1011 | // We can not send the buffer directly back to Python, because the i1 |
| 1012 | // values are bitpacked within MLIR. We call numpy's unpackbits function |
| 1013 | // to convert the bytes. |
| 1014 | return getBooleanBufferFromBitpackedAttribute(); |
| 1015 | } |
| 1016 | |
| 1017 | // TODO: Currently crashes the program. |
| 1018 | // Reported as https://github.com/pybind/pybind11/issues/3336 |
| 1019 | throw std::invalid_argument( |
| 1020 | "unsupported data type for conversion to Python buffer" ); |
| 1021 | } |
| 1022 | |
| 1023 | static void bindDerived(ClassTy &c) { |
| 1024 | #if PY_VERSION_HEX < 0x03090000 |
| 1025 | PyTypeObject *tp = reinterpret_cast<PyTypeObject *>(c.ptr()); |
| 1026 | tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; |
| 1027 | tp->tp_as_buffer->bf_releasebuffer = |
| 1028 | PyDenseElementsAttribute::bf_releasebuffer; |
| 1029 | #endif |
| 1030 | c.def("__len__" , &PyDenseElementsAttribute::dunderLen) |
| 1031 | .def_static("get" , PyDenseElementsAttribute::getFromBuffer, |
| 1032 | nb::arg("array" ), nb::arg("signless" ) = true, |
| 1033 | nb::arg("type" ).none() = nb::none(), |
| 1034 | nb::arg("shape" ).none() = nb::none(), |
| 1035 | nb::arg("context" ).none() = nb::none(), |
| 1036 | kDenseElementsAttrGetDocstring) |
| 1037 | .def_static("get" , PyDenseElementsAttribute::getFromList, |
| 1038 | nb::arg("attrs" ), nb::arg("type" ).none() = nb::none(), |
| 1039 | nb::arg("context" ).none() = nb::none(), |
| 1040 | kDenseElementsAttrGetFromListDocstring) |
| 1041 | .def_static("get_splat" , PyDenseElementsAttribute::getSplat, |
| 1042 | nb::arg("shaped_type" ), nb::arg("element_attr" ), |
| 1043 | "Gets a DenseElementsAttr where all values are the same" ) |
| 1044 | .def_prop_ro("is_splat" , |
| 1045 | [](PyDenseElementsAttribute &self) -> bool { |
| 1046 | return mlirDenseElementsAttrIsSplat(self); |
| 1047 | }) |
| 1048 | .def("get_splat_value" , [](PyDenseElementsAttribute &self) { |
| 1049 | if (!mlirDenseElementsAttrIsSplat(self)) |
| 1050 | throw nb::value_error( |
| 1051 | "get_splat_value called on a non-splat attribute" ); |
| 1052 | return mlirDenseElementsAttrGetSplatValue(self); |
| 1053 | }); |
| 1054 | } |
| 1055 | |
| 1056 | static PyType_Slot slots[]; |
| 1057 | |
| 1058 | private: |
| 1059 | static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); |
| 1060 | static void bf_releasebuffer(PyObject *, Py_buffer *buffer); |
| 1061 | |
| 1062 | static bool isUnsignedIntegerFormat(std::string_view format) { |
| 1063 | if (format.empty()) |
| 1064 | return false; |
| 1065 | char code = format[0]; |
| 1066 | return code == 'I' || code == 'B' || code == 'H' || code == 'L' || |
| 1067 | code == 'Q'; |
| 1068 | } |
| 1069 | |
| 1070 | static bool isSignedIntegerFormat(std::string_view format) { |
| 1071 | if (format.empty()) |
| 1072 | return false; |
| 1073 | char code = format[0]; |
| 1074 | return code == 'i' || code == 'b' || code == 'h' || code == 'l' || |
| 1075 | code == 'q'; |
| 1076 | } |
| 1077 | |
| 1078 | static MlirType |
| 1079 | getShapedType(std::optional<MlirType> bulkLoadElementType, |
| 1080 | std::optional<std::vector<int64_t>> explicitShape, |
| 1081 | Py_buffer &view) { |
| 1082 | SmallVector<int64_t> shape; |
| 1083 | if (explicitShape) { |
| 1084 | shape.append(in_start: explicitShape->begin(), in_end: explicitShape->end()); |
| 1085 | } else { |
| 1086 | shape.append(view.shape, view.shape + view.ndim); |
| 1087 | } |
| 1088 | |
| 1089 | if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
| 1090 | if (explicitShape) { |
| 1091 | throw std::invalid_argument("Shape can only be specified explicitly " |
| 1092 | "when the type is not a shaped type." ); |
| 1093 | } |
| 1094 | return *bulkLoadElementType; |
| 1095 | } else { |
| 1096 | MlirAttribute encodingAttr = mlirAttributeGetNull(); |
| 1097 | return mlirRankedTensorTypeGet(shape.size(), shape.data(), |
| 1098 | *bulkLoadElementType, encodingAttr); |
| 1099 | } |
| 1100 | } |
| 1101 | |
| 1102 | static MlirAttribute getAttributeFromBuffer( |
| 1103 | Py_buffer &view, bool signless, std::optional<PyType> explicitType, |
| 1104 | std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { |
| 1105 | // Detect format codes that are suitable for bulk loading. This includes |
| 1106 | // all byte aligned integer and floating point types up to 8 bytes. |
| 1107 | // Notably, this excludes exotics types which do not have a direct |
| 1108 | // representation in the buffer protocol (i.e. complex, etc). |
| 1109 | std::optional<MlirType> bulkLoadElementType; |
| 1110 | if (explicitType) { |
| 1111 | bulkLoadElementType = *explicitType; |
| 1112 | } else { |
| 1113 | std::string_view format(view.format); |
| 1114 | if (format == "f" ) { |
| 1115 | // f32 |
| 1116 | assert(view.itemsize == 4 && "mismatched array itemsize" ); |
| 1117 | bulkLoadElementType = mlirF32TypeGet(context); |
| 1118 | } else if (format == "d" ) { |
| 1119 | // f64 |
| 1120 | assert(view.itemsize == 8 && "mismatched array itemsize" ); |
| 1121 | bulkLoadElementType = mlirF64TypeGet(context); |
| 1122 | } else if (format == "e" ) { |
| 1123 | // f16 |
| 1124 | assert(view.itemsize == 2 && "mismatched array itemsize" ); |
| 1125 | bulkLoadElementType = mlirF16TypeGet(context); |
| 1126 | } else if (format == "?" ) { |
| 1127 | // i1 |
| 1128 | // The i1 type needs to be bit-packed, so we will handle it seperately |
| 1129 | return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, |
| 1130 | context); |
| 1131 | } else if (isSignedIntegerFormat(format)) { |
| 1132 | if (view.itemsize == 4) { |
| 1133 | // i32 |
| 1134 | bulkLoadElementType = signless |
| 1135 | ? mlirIntegerTypeGet(context, 32) |
| 1136 | : mlirIntegerTypeSignedGet(context, 32); |
| 1137 | } else if (view.itemsize == 8) { |
| 1138 | // i64 |
| 1139 | bulkLoadElementType = signless |
| 1140 | ? mlirIntegerTypeGet(context, 64) |
| 1141 | : mlirIntegerTypeSignedGet(context, 64); |
| 1142 | } else if (view.itemsize == 1) { |
| 1143 | // i8 |
| 1144 | bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
| 1145 | : mlirIntegerTypeSignedGet(context, 8); |
| 1146 | } else if (view.itemsize == 2) { |
| 1147 | // i16 |
| 1148 | bulkLoadElementType = signless |
| 1149 | ? mlirIntegerTypeGet(context, 16) |
| 1150 | : mlirIntegerTypeSignedGet(context, 16); |
| 1151 | } |
| 1152 | } else if (isUnsignedIntegerFormat(format)) { |
| 1153 | if (view.itemsize == 4) { |
| 1154 | // unsigned i32 |
| 1155 | bulkLoadElementType = signless |
| 1156 | ? mlirIntegerTypeGet(context, 32) |
| 1157 | : mlirIntegerTypeUnsignedGet(context, 32); |
| 1158 | } else if (view.itemsize == 8) { |
| 1159 | // unsigned i64 |
| 1160 | bulkLoadElementType = signless |
| 1161 | ? mlirIntegerTypeGet(context, 64) |
| 1162 | : mlirIntegerTypeUnsignedGet(context, 64); |
| 1163 | } else if (view.itemsize == 1) { |
| 1164 | // i8 |
| 1165 | bulkLoadElementType = signless |
| 1166 | ? mlirIntegerTypeGet(context, 8) |
| 1167 | : mlirIntegerTypeUnsignedGet(context, 8); |
| 1168 | } else if (view.itemsize == 2) { |
| 1169 | // i16 |
| 1170 | bulkLoadElementType = signless |
| 1171 | ? mlirIntegerTypeGet(context, 16) |
| 1172 | : mlirIntegerTypeUnsignedGet(context, 16); |
| 1173 | } |
| 1174 | } |
| 1175 | if (!bulkLoadElementType) { |
| 1176 | throw std::invalid_argument( |
| 1177 | std::string("unimplemented array format conversion from format: " ) + |
| 1178 | std::string(format)); |
| 1179 | } |
| 1180 | } |
| 1181 | |
| 1182 | MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); |
| 1183 | return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); |
| 1184 | } |
| 1185 | |
| 1186 | // There is a complication for boolean numpy arrays, as numpy represents |
| 1187 | // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 |
| 1188 | // booleans per byte. |
| 1189 | static MlirAttribute getBitpackedAttributeFromBooleanBuffer( |
| 1190 | Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, |
| 1191 | MlirContext &context) { |
| 1192 | if (llvm::endianness::native != llvm::endianness::little) { |
| 1193 | // Given we have no good way of testing the behavior on big-endian |
| 1194 | // systems we will throw |
| 1195 | throw nb::type_error("Constructing a bit-packed MLIR attribute is " |
| 1196 | "unsupported on big-endian systems" ); |
| 1197 | } |
| 1198 | nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> unpackedArray( |
| 1199 | /*data=*/static_cast<uint8_t *>(view.buf), |
| 1200 | /*shape=*/{static_cast<size_t>(view.len)}); |
| 1201 | |
| 1202 | nb::module_ numpy = nb::module_::import_("numpy" ); |
| 1203 | nb::object packbitsFunc = numpy.attr("packbits" ); |
| 1204 | nb::object packedBooleans = |
| 1205 | packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little" ); |
| 1206 | nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request(); |
| 1207 | |
| 1208 | MlirType bitpackedType = |
| 1209 | getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); |
| 1210 | assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8" ); |
| 1211 | // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of |
| 1212 | // packedBooleans, hence the MlirAttribute will remain valid even when |
| 1213 | // packedBooleans get reclaimed by the end of the function. |
| 1214 | return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, |
| 1215 | pythonBuffer.ptr); |
| 1216 | } |
| 1217 | |
| 1218 | // This does the opposite transformation of |
| 1219 | // `getBitpackedAttributeFromBooleanBuffer` |
| 1220 | std::unique_ptr<nb_buffer_info> getBooleanBufferFromBitpackedAttribute() { |
| 1221 | if (llvm::endianness::native != llvm::endianness::little) { |
| 1222 | // Given we have no good way of testing the behavior on big-endian |
| 1223 | // systems we will throw |
| 1224 | throw nb::type_error("Constructing a numpy array from a MLIR attribute " |
| 1225 | "is unsupported on big-endian systems" ); |
| 1226 | } |
| 1227 | |
| 1228 | int64_t numBooleans = mlirElementsAttrGetNumElements(*this); |
| 1229 | int64_t numBitpackedBytes = llvm::divideCeil(Numerator: numBooleans, Denominator: 8); |
| 1230 | uint8_t *bitpackedData = static_cast<uint8_t *>( |
| 1231 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
| 1232 | nb::ndarray<uint8_t, nb::numpy, nb::ndim<1>, nb::c_contig> packedArray( |
| 1233 | /*data=*/bitpackedData, |
| 1234 | /*shape=*/{static_cast<size_t>(numBitpackedBytes)}); |
| 1235 | |
| 1236 | nb::module_ numpy = nb::module_::import_("numpy" ); |
| 1237 | nb::object unpackbitsFunc = numpy.attr("unpackbits" ); |
| 1238 | nb::object equalFunc = numpy.attr("equal" ); |
| 1239 | nb::object reshapeFunc = numpy.attr("reshape" ); |
| 1240 | nb::object unpackedBooleans = |
| 1241 | unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little" ); |
| 1242 | |
| 1243 | // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. |
| 1244 | // We need to: |
| 1245 | // 1. Slice away the padded bits |
| 1246 | // 2. Make the boolean array have the correct shape |
| 1247 | // 3. Convert the array to a boolean array |
| 1248 | unpackedBooleans = unpackedBooleans[nb::slice( |
| 1249 | nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; |
| 1250 | unpackedBooleans = equalFunc(unpackedBooleans, 1); |
| 1251 | |
| 1252 | MlirType shapedType = mlirAttributeGetType(*this); |
| 1253 | intptr_t rank = mlirShapedTypeGetRank(shapedType); |
| 1254 | std::vector<intptr_t> shape(rank); |
| 1255 | for (intptr_t i = 0; i < rank; ++i) { |
| 1256 | shape[i] = mlirShapedTypeGetDimSize(shapedType, i); |
| 1257 | } |
| 1258 | unpackedBooleans = reshapeFunc(unpackedBooleans, shape); |
| 1259 | |
| 1260 | // Make sure the returned nb::buffer_view claims ownership of the data in |
| 1261 | // `pythonBuffer` so it remains valid when Python reads it |
| 1262 | nb_buffer pythonBuffer = nb::cast<nb_buffer>(unpackedBooleans); |
| 1263 | return std::make_unique<nb_buffer_info>(args: pythonBuffer.request()); |
| 1264 | } |
| 1265 | |
| 1266 | template <typename Type> |
| 1267 | std::unique_ptr<nb_buffer_info> |
| 1268 | bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { |
| 1269 | intptr_t rank = mlirShapedTypeGetRank(shapedType); |
| 1270 | // Prepare the data for the buffer_info. |
| 1271 | // Buffer is configured for read-only access below. |
| 1272 | Type *data = static_cast<Type *>( |
| 1273 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
| 1274 | // Prepare the shape for the buffer_info. |
| 1275 | SmallVector<intptr_t, 4> shape; |
| 1276 | for (intptr_t i = 0; i < rank; ++i) |
| 1277 | shape.push_back(Elt: mlirShapedTypeGetDimSize(shapedType, i)); |
| 1278 | // Prepare the strides for the buffer_info. |
| 1279 | SmallVector<intptr_t, 4> strides; |
| 1280 | if (mlirDenseElementsAttrIsSplat(*this)) { |
| 1281 | // Splats are special, only the single value is stored. |
| 1282 | strides.assign(NumElts: rank, Elt: 0); |
| 1283 | } else { |
| 1284 | for (intptr_t i = 1; i < rank; ++i) { |
| 1285 | intptr_t strideFactor = 1; |
| 1286 | for (intptr_t j = i; j < rank; ++j) |
| 1287 | strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); |
| 1288 | strides.push_back(Elt: sizeof(Type) * strideFactor); |
| 1289 | } |
| 1290 | strides.push_back(Elt: sizeof(Type)); |
| 1291 | } |
| 1292 | const char *format; |
| 1293 | if (explicitFormat) { |
| 1294 | format = explicitFormat; |
| 1295 | } else { |
| 1296 | format = nb_format_descriptor<Type>::format(); |
| 1297 | } |
| 1298 | return std::make_unique<nb_buffer_info>( |
| 1299 | data, sizeof(Type), format, rank, std::move(shape), std::move(strides), |
| 1300 | /*readonly=*/true); |
| 1301 | } |
| 1302 | }; // namespace |
| 1303 | |
| 1304 | PyType_Slot PyDenseElementsAttribute::slots[] = { |
| 1305 | // Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. |
| 1306 | #if PY_VERSION_HEX >= 0x03090000 |
| 1307 | {Py_bf_getbuffer, |
| 1308 | reinterpret_cast<void *>(PyDenseElementsAttribute::bf_getbuffer)}, |
| 1309 | {Py_bf_releasebuffer, |
| 1310 | reinterpret_cast<void *>(PyDenseElementsAttribute::bf_releasebuffer)}, |
| 1311 | #endif |
| 1312 | {0, nullptr}, |
| 1313 | }; |
| 1314 | |
| 1315 | /*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, |
| 1316 | Py_buffer *view, |
| 1317 | int flags) { |
| 1318 | view->obj = nullptr; |
| 1319 | std::unique_ptr<nb_buffer_info> info; |
| 1320 | try { |
| 1321 | auto *attr = nb::cast<PyDenseElementsAttribute *>(nb::handle(obj)); |
| 1322 | info = attr->accessBuffer(); |
| 1323 | } catch (nb::python_error &e) { |
| 1324 | e.restore(); |
| 1325 | nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer" ); |
| 1326 | return -1; |
| 1327 | } |
| 1328 | view->obj = obj; |
| 1329 | view->ndim = 1; |
| 1330 | view->buf = info->ptr; |
| 1331 | view->itemsize = info->itemsize; |
| 1332 | view->len = info->itemsize; |
| 1333 | for (auto s : info->shape) { |
| 1334 | view->len *= s; |
| 1335 | } |
| 1336 | view->readonly = info->readonly; |
| 1337 | if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { |
| 1338 | view->format = const_cast<char *>(info->format); |
| 1339 | } |
| 1340 | if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { |
| 1341 | view->ndim = static_cast<int>(info->ndim); |
| 1342 | view->strides = info->strides.data(); |
| 1343 | view->shape = info->shape.data(); |
| 1344 | } |
| 1345 | view->suboffsets = nullptr; |
| 1346 | view->internal = info.release(); |
| 1347 | Py_INCREF(obj); |
| 1348 | return 0; |
| 1349 | } |
| 1350 | |
| 1351 | /*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, |
| 1352 | Py_buffer *view) { |
| 1353 | delete reinterpret_cast<nb_buffer_info *>(view->internal); |
| 1354 | } |
| 1355 | |
| 1356 | /// Refinement of the PyDenseElementsAttribute for attributes containing |
| 1357 | /// integer (and boolean) values. Supports element access. |
| 1358 | class PyDenseIntElementsAttribute |
| 1359 | : public PyConcreteAttribute<PyDenseIntElementsAttribute, |
| 1360 | PyDenseElementsAttribute> { |
| 1361 | public: |
| 1362 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; |
| 1363 | static constexpr const char *pyClassName = "DenseIntElementsAttr" ; |
| 1364 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1365 | |
| 1366 | /// Returns the element at the given linear position. Asserts if the index |
| 1367 | /// is out of range. |
| 1368 | nb::object dunderGetItem(intptr_t pos) { |
| 1369 | if (pos < 0 || pos >= dunderLen()) { |
| 1370 | throw nb::index_error("attempt to access out of bounds element" ); |
| 1371 | } |
| 1372 | |
| 1373 | MlirType type = mlirAttributeGetType(*this); |
| 1374 | type = mlirShapedTypeGetElementType(type); |
| 1375 | // Index type can also appear as a DenseIntElementsAttr and therefore can be |
| 1376 | // casted to integer. |
| 1377 | assert(mlirTypeIsAInteger(type) || |
| 1378 | mlirTypeIsAIndex(type) && "expected integer/index element type in " |
| 1379 | "dense int elements attribute" ); |
| 1380 | // Dispatch element extraction to an appropriate C function based on the |
| 1381 | // elemental type of the attribute. nb::int_ is implicitly constructible |
| 1382 | // from any C++ integral type and handles bitwidth correctly. |
| 1383 | // TODO: consider caching the type properties in the constructor to avoid |
| 1384 | // querying them on each element access. |
| 1385 | if (mlirTypeIsAIndex(type)) { |
| 1386 | return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); |
| 1387 | } |
| 1388 | unsigned width = mlirIntegerTypeGetWidth(type); |
| 1389 | bool isUnsigned = mlirIntegerTypeIsUnsigned(type); |
| 1390 | if (isUnsigned) { |
| 1391 | if (width == 1) { |
| 1392 | return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); |
| 1393 | } |
| 1394 | if (width == 8) { |
| 1395 | return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); |
| 1396 | } |
| 1397 | if (width == 16) { |
| 1398 | return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); |
| 1399 | } |
| 1400 | if (width == 32) { |
| 1401 | return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); |
| 1402 | } |
| 1403 | if (width == 64) { |
| 1404 | return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); |
| 1405 | } |
| 1406 | } else { |
| 1407 | if (width == 1) { |
| 1408 | return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); |
| 1409 | } |
| 1410 | if (width == 8) { |
| 1411 | return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); |
| 1412 | } |
| 1413 | if (width == 16) { |
| 1414 | return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); |
| 1415 | } |
| 1416 | if (width == 32) { |
| 1417 | return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); |
| 1418 | } |
| 1419 | if (width == 64) { |
| 1420 | return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); |
| 1421 | } |
| 1422 | } |
| 1423 | throw nb::type_error("Unsupported integer type" ); |
| 1424 | } |
| 1425 | |
| 1426 | static void bindDerived(ClassTy &c) { |
| 1427 | c.def("__getitem__" , &PyDenseIntElementsAttribute::dunderGetItem); |
| 1428 | } |
| 1429 | }; |
| 1430 | |
| 1431 | class PyDenseResourceElementsAttribute |
| 1432 | : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { |
| 1433 | public: |
| 1434 | static constexpr IsAFunctionTy isaFunction = |
| 1435 | mlirAttributeIsADenseResourceElements; |
| 1436 | static constexpr const char *pyClassName = "DenseResourceElementsAttr" ; |
| 1437 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1438 | |
| 1439 | static PyDenseResourceElementsAttribute |
| 1440 | getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, |
| 1441 | std::optional<size_t> alignment, bool isMutable, |
| 1442 | DefaultingPyMlirContext contextWrapper) { |
| 1443 | if (!mlirTypeIsAShaped(type)) { |
| 1444 | throw std::invalid_argument( |
| 1445 | "Constructing a DenseResourceElementsAttr requires a ShapedType." ); |
| 1446 | } |
| 1447 | |
| 1448 | // Do not request any conversions as we must ensure to use caller |
| 1449 | // managed memory. |
| 1450 | int flags = PyBUF_STRIDES; |
| 1451 | std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); |
| 1452 | if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { |
| 1453 | throw nb::python_error(); |
| 1454 | } |
| 1455 | |
| 1456 | // This scope releaser will only release if we haven't yet transferred |
| 1457 | // ownership. |
| 1458 | auto freeBuffer = llvm::make_scope_exit(F: [&]() { |
| 1459 | if (view) |
| 1460 | PyBuffer_Release(view.get()); |
| 1461 | }); |
| 1462 | |
| 1463 | if (!PyBuffer_IsContiguous(view.get(), 'A')) { |
| 1464 | throw std::invalid_argument("Contiguous buffer is required." ); |
| 1465 | } |
| 1466 | |
| 1467 | // Infer alignment to be the stride of one element if not explicit. |
| 1468 | size_t inferredAlignment; |
| 1469 | if (alignment) |
| 1470 | inferredAlignment = *alignment; |
| 1471 | else |
| 1472 | inferredAlignment = view->strides[view->ndim - 1]; |
| 1473 | |
| 1474 | // The userData is a Py_buffer* that the deleter owns. |
| 1475 | auto deleter = [](void *userData, const void *data, size_t size, |
| 1476 | size_t align) { |
| 1477 | if (!Py_IsInitialized()) |
| 1478 | Py_Initialize(); |
| 1479 | Py_buffer *ownedView = static_cast<Py_buffer *>(userData); |
| 1480 | nb::gil_scoped_acquire gil; |
| 1481 | PyBuffer_Release(ownedView); |
| 1482 | delete ownedView; |
| 1483 | }; |
| 1484 | |
| 1485 | size_t rawBufferSize = view->len; |
| 1486 | MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( |
| 1487 | type, toMlirStringRef(name), view->buf, rawBufferSize, |
| 1488 | inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); |
| 1489 | if (mlirAttributeIsNull(attr)) { |
| 1490 | throw std::invalid_argument( |
| 1491 | "DenseResourceElementsAttr could not be constructed from the given " |
| 1492 | "buffer. " |
| 1493 | "This may mean that the Python buffer layout does not match that " |
| 1494 | "MLIR expected layout and is a bug." ); |
| 1495 | } |
| 1496 | view.release(); |
| 1497 | return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); |
| 1498 | } |
| 1499 | |
| 1500 | static void bindDerived(ClassTy &c) { |
| 1501 | c.def_static( |
| 1502 | "get_from_buffer" , PyDenseResourceElementsAttribute::getFromBuffer, |
| 1503 | nb::arg("array" ), nb::arg("name" ), nb::arg("type" ), |
| 1504 | nb::arg("alignment" ).none() = nb::none(), nb::arg("is_mutable" ) = false, |
| 1505 | nb::arg("context" ).none() = nb::none(), |
| 1506 | kDenseResourceElementsAttrGetFromBufferDocstring); |
| 1507 | } |
| 1508 | }; |
| 1509 | |
| 1510 | class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { |
| 1511 | public: |
| 1512 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; |
| 1513 | static constexpr const char *pyClassName = "DictAttr" ; |
| 1514 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1515 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 1516 | mlirDictionaryAttrGetTypeID; |
| 1517 | |
| 1518 | intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } |
| 1519 | |
| 1520 | bool dunderContains(const std::string &name) { |
| 1521 | return !mlirAttributeIsNull( |
| 1522 | mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); |
| 1523 | } |
| 1524 | |
| 1525 | static void bindDerived(ClassTy &c) { |
| 1526 | c.def("__contains__" , &PyDictAttribute::dunderContains); |
| 1527 | c.def("__len__" , &PyDictAttribute::dunderLen); |
| 1528 | c.def_static( |
| 1529 | "get" , |
| 1530 | [](nb::dict attributes, DefaultingPyMlirContext context) { |
| 1531 | SmallVector<MlirNamedAttribute> mlirNamedAttributes; |
| 1532 | mlirNamedAttributes.reserve(attributes.size()); |
| 1533 | for (std::pair<nb::handle, nb::handle> it : attributes) { |
| 1534 | auto &mlirAttr = nb::cast<PyAttribute &>(it.second); |
| 1535 | auto name = nb::cast<std::string>(it.first); |
| 1536 | mlirNamedAttributes.push_back(mlirNamedAttributeGet( |
| 1537 | mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), |
| 1538 | toMlirStringRef(name)), |
| 1539 | mlirAttr)); |
| 1540 | } |
| 1541 | MlirAttribute attr = |
| 1542 | mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), |
| 1543 | mlirNamedAttributes.data()); |
| 1544 | return PyDictAttribute(context->getRef(), attr); |
| 1545 | }, |
| 1546 | nb::arg("value" ) = nb::dict(), nb::arg("context" ).none() = nb::none(), |
| 1547 | "Gets an uniqued dict attribute" ); |
| 1548 | c.def("__getitem__" , [](PyDictAttribute &self, const std::string &name) { |
| 1549 | MlirAttribute attr = |
| 1550 | mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); |
| 1551 | if (mlirAttributeIsNull(attr)) |
| 1552 | throw nb::key_error("attempt to access a non-existent attribute" ); |
| 1553 | return attr; |
| 1554 | }); |
| 1555 | c.def("__getitem__" , [](PyDictAttribute &self, intptr_t index) { |
| 1556 | if (index < 0 || index >= self.dunderLen()) { |
| 1557 | throw nb::index_error("attempt to access out of bounds attribute" ); |
| 1558 | } |
| 1559 | MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); |
| 1560 | return PyNamedAttribute( |
| 1561 | namedAttr.attribute, |
| 1562 | std::string(mlirIdentifierStr(namedAttr.name).data)); |
| 1563 | }); |
| 1564 | } |
| 1565 | }; |
| 1566 | |
| 1567 | /// Refinement of PyDenseElementsAttribute for attributes containing |
| 1568 | /// floating-point values. Supports element access. |
| 1569 | class PyDenseFPElementsAttribute |
| 1570 | : public PyConcreteAttribute<PyDenseFPElementsAttribute, |
| 1571 | PyDenseElementsAttribute> { |
| 1572 | public: |
| 1573 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; |
| 1574 | static constexpr const char *pyClassName = "DenseFPElementsAttr" ; |
| 1575 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1576 | |
| 1577 | nb::float_ dunderGetItem(intptr_t pos) { |
| 1578 | if (pos < 0 || pos >= dunderLen()) { |
| 1579 | throw nb::index_error("attempt to access out of bounds element" ); |
| 1580 | } |
| 1581 | |
| 1582 | MlirType type = mlirAttributeGetType(*this); |
| 1583 | type = mlirShapedTypeGetElementType(type); |
| 1584 | // Dispatch element extraction to an appropriate C function based on the |
| 1585 | // elemental type of the attribute. nb::float_ is implicitly constructible |
| 1586 | // from float and double. |
| 1587 | // TODO: consider caching the type properties in the constructor to avoid |
| 1588 | // querying them on each element access. |
| 1589 | if (mlirTypeIsAF32(type)) { |
| 1590 | return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); |
| 1591 | } |
| 1592 | if (mlirTypeIsAF64(type)) { |
| 1593 | return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); |
| 1594 | } |
| 1595 | throw nb::type_error("Unsupported floating-point type" ); |
| 1596 | } |
| 1597 | |
| 1598 | static void bindDerived(ClassTy &c) { |
| 1599 | c.def("__getitem__" , &PyDenseFPElementsAttribute::dunderGetItem); |
| 1600 | } |
| 1601 | }; |
| 1602 | |
| 1603 | class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { |
| 1604 | public: |
| 1605 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; |
| 1606 | static constexpr const char *pyClassName = "TypeAttr" ; |
| 1607 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1608 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 1609 | mlirTypeAttrGetTypeID; |
| 1610 | |
| 1611 | static void bindDerived(ClassTy &c) { |
| 1612 | c.def_static( |
| 1613 | "get" , |
| 1614 | [](PyType value, DefaultingPyMlirContext context) { |
| 1615 | MlirAttribute attr = mlirTypeAttrGet(value.get()); |
| 1616 | return PyTypeAttribute(context->getRef(), attr); |
| 1617 | }, |
| 1618 | nb::arg("value" ), nb::arg("context" ).none() = nb::none(), |
| 1619 | "Gets a uniqued Type attribute" ); |
| 1620 | c.def_prop_ro("value" , [](PyTypeAttribute &self) { |
| 1621 | return mlirTypeAttrGetValue(self.get()); |
| 1622 | }); |
| 1623 | } |
| 1624 | }; |
| 1625 | |
| 1626 | /// Unit Attribute subclass. Unit attributes don't have values. |
| 1627 | class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { |
| 1628 | public: |
| 1629 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; |
| 1630 | static constexpr const char *pyClassName = "UnitAttr" ; |
| 1631 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1632 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 1633 | mlirUnitAttrGetTypeID; |
| 1634 | |
| 1635 | static void bindDerived(ClassTy &c) { |
| 1636 | c.def_static( |
| 1637 | "get" , |
| 1638 | [](DefaultingPyMlirContext context) { |
| 1639 | return PyUnitAttribute(context->getRef(), |
| 1640 | mlirUnitAttrGet(context->get())); |
| 1641 | }, |
| 1642 | nb::arg("context" ).none() = nb::none(), "Create a Unit attribute." ); |
| 1643 | } |
| 1644 | }; |
| 1645 | |
| 1646 | /// Strided layout attribute subclass. |
| 1647 | class PyStridedLayoutAttribute |
| 1648 | : public PyConcreteAttribute<PyStridedLayoutAttribute> { |
| 1649 | public: |
| 1650 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; |
| 1651 | static constexpr const char *pyClassName = "StridedLayoutAttr" ; |
| 1652 | using PyConcreteAttribute::PyConcreteAttribute; |
| 1653 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
| 1654 | mlirStridedLayoutAttrGetTypeID; |
| 1655 | |
| 1656 | static void bindDerived(ClassTy &c) { |
| 1657 | c.def_static( |
| 1658 | "get" , |
| 1659 | [](int64_t offset, const std::vector<int64_t> strides, |
| 1660 | DefaultingPyMlirContext ctx) { |
| 1661 | MlirAttribute attr = mlirStridedLayoutAttrGet( |
| 1662 | ctx->get(), offset, strides.size(), strides.data()); |
| 1663 | return PyStridedLayoutAttribute(ctx->getRef(), attr); |
| 1664 | }, |
| 1665 | nb::arg("offset" ), nb::arg("strides" ), |
| 1666 | nb::arg("context" ).none() = nb::none(), |
| 1667 | "Gets a strided layout attribute." ); |
| 1668 | c.def_static( |
| 1669 | "get_fully_dynamic" , |
| 1670 | [](int64_t rank, DefaultingPyMlirContext ctx) { |
| 1671 | auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); |
| 1672 | std::vector<int64_t> strides(rank); |
| 1673 | std::fill(strides.begin(), strides.end(), dynamic); |
| 1674 | MlirAttribute attr = mlirStridedLayoutAttrGet( |
| 1675 | ctx->get(), dynamic, strides.size(), strides.data()); |
| 1676 | return PyStridedLayoutAttribute(ctx->getRef(), attr); |
| 1677 | }, |
| 1678 | nb::arg("rank" ), nb::arg("context" ).none() = nb::none(), |
| 1679 | "Gets a strided layout attribute with dynamic offset and strides of " |
| 1680 | "a " |
| 1681 | "given rank." ); |
| 1682 | c.def_prop_ro( |
| 1683 | "offset" , |
| 1684 | [](PyStridedLayoutAttribute &self) { |
| 1685 | return mlirStridedLayoutAttrGetOffset(self); |
| 1686 | }, |
| 1687 | "Returns the value of the float point attribute" ); |
| 1688 | c.def_prop_ro( |
| 1689 | "strides" , |
| 1690 | [](PyStridedLayoutAttribute &self) { |
| 1691 | intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); |
| 1692 | std::vector<int64_t> strides(size); |
| 1693 | for (intptr_t i = 0; i < size; i++) { |
| 1694 | strides[i] = mlirStridedLayoutAttrGetStride(self, i); |
| 1695 | } |
| 1696 | return strides; |
| 1697 | }, |
| 1698 | "Returns the value of the float point attribute" ); |
| 1699 | } |
| 1700 | }; |
| 1701 | |
| 1702 | nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { |
| 1703 | if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) |
| 1704 | return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); |
| 1705 | if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) |
| 1706 | return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); |
| 1707 | if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) |
| 1708 | return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); |
| 1709 | if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) |
| 1710 | return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); |
| 1711 | if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) |
| 1712 | return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); |
| 1713 | if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) |
| 1714 | return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); |
| 1715 | if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) |
| 1716 | return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); |
| 1717 | std::string msg = |
| 1718 | std::string("Can't cast unknown element type DenseArrayAttr (" ) + |
| 1719 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")" ; |
| 1720 | throw nb::type_error(msg.c_str()); |
| 1721 | } |
| 1722 | |
| 1723 | nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { |
| 1724 | if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) |
| 1725 | return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); |
| 1726 | if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) |
| 1727 | return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); |
| 1728 | std::string msg = |
| 1729 | std::string( |
| 1730 | "Can't cast unknown element type DenseIntOrFPElementsAttr (" ) + |
| 1731 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")" ; |
| 1732 | throw nb::type_error(msg.c_str()); |
| 1733 | } |
| 1734 | |
| 1735 | nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { |
| 1736 | if (PyBoolAttribute::isaFunction(pyAttribute)) |
| 1737 | return nb::cast(PyBoolAttribute(pyAttribute)); |
| 1738 | if (PyIntegerAttribute::isaFunction(pyAttribute)) |
| 1739 | return nb::cast(PyIntegerAttribute(pyAttribute)); |
| 1740 | std::string msg = |
| 1741 | std::string("Can't cast unknown element type DenseArrayAttr (" ) + |
| 1742 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + ")" ; |
| 1743 | throw nb::type_error(msg.c_str()); |
| 1744 | } |
| 1745 | |
| 1746 | nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { |
| 1747 | if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) |
| 1748 | return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); |
| 1749 | if (PySymbolRefAttribute::isaFunction(pyAttribute)) |
| 1750 | return nb::cast(PySymbolRefAttribute(pyAttribute)); |
| 1751 | std::string msg = std::string("Can't cast unknown SymbolRef attribute (" ) + |
| 1752 | nb::cast<std::string>(nb::repr(nb::cast(pyAttribute))) + |
| 1753 | ")" ; |
| 1754 | throw nb::type_error(msg.c_str()); |
| 1755 | } |
| 1756 | |
| 1757 | } // namespace |
| 1758 | |
| 1759 | void mlir::python::populateIRAttributes(nb::module_ &m) { |
| 1760 | PyAffineMapAttribute::bind(m); |
| 1761 | PyDenseBoolArrayAttribute::bind(m); |
| 1762 | PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1763 | PyDenseI8ArrayAttribute::bind(m); |
| 1764 | PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1765 | PyDenseI16ArrayAttribute::bind(m); |
| 1766 | PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1767 | PyDenseI32ArrayAttribute::bind(m); |
| 1768 | PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1769 | PyDenseI64ArrayAttribute::bind(m); |
| 1770 | PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1771 | PyDenseF32ArrayAttribute::bind(m); |
| 1772 | PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1773 | PyDenseF64ArrayAttribute::bind(m); |
| 1774 | PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); |
| 1775 | PyGlobals::get().registerTypeCaster( |
| 1776 | mlirDenseArrayAttrGetTypeID(), |
| 1777 | nb::cast<nb::callable>(nb::cpp_function(denseArrayAttributeCaster))); |
| 1778 | |
| 1779 | PyArrayAttribute::bind(m); |
| 1780 | PyArrayAttribute::PyArrayAttributeIterator::bind(m); |
| 1781 | PyBoolAttribute::bind(m); |
| 1782 | PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); |
| 1783 | PyDenseFPElementsAttribute::bind(m); |
| 1784 | PyDenseIntElementsAttribute::bind(m); |
| 1785 | PyGlobals::get().registerTypeCaster( |
| 1786 | mlirDenseIntOrFPElementsAttrGetTypeID(), |
| 1787 | nb::cast<nb::callable>( |
| 1788 | nb::cpp_function(denseIntOrFPElementsAttributeCaster))); |
| 1789 | PyDenseResourceElementsAttribute::bind(m); |
| 1790 | |
| 1791 | PyDictAttribute::bind(m); |
| 1792 | PySymbolRefAttribute::bind(m); |
| 1793 | PyGlobals::get().registerTypeCaster( |
| 1794 | mlirSymbolRefAttrGetTypeID(), |
| 1795 | nb::cast<nb::callable>( |
| 1796 | nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); |
| 1797 | |
| 1798 | PyFlatSymbolRefAttribute::bind(m); |
| 1799 | PyOpaqueAttribute::bind(m); |
| 1800 | PyFloatAttribute::bind(m); |
| 1801 | PyIntegerAttribute::bind(m); |
| 1802 | PyIntegerSetAttribute::bind(m); |
| 1803 | PyStringAttribute::bind(m); |
| 1804 | PyTypeAttribute::bind(m); |
| 1805 | PyGlobals::get().registerTypeCaster( |
| 1806 | mlirIntegerAttrGetTypeID(), |
| 1807 | nb::cast<nb::callable>(nb::cpp_function(integerOrBoolAttributeCaster))); |
| 1808 | PyUnitAttribute::bind(m); |
| 1809 | |
| 1810 | PyStridedLayoutAttribute::bind(m); |
| 1811 | } |
| 1812 | |