1 | //===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <optional> |
10 | #include <string_view> |
11 | #include <utility> |
12 | |
13 | #include "IRModule.h" |
14 | |
15 | #include "PybindUtils.h" |
16 | |
17 | #include "llvm/ADT/ScopeExit.h" |
18 | |
19 | #include "mlir-c/BuiltinAttributes.h" |
20 | #include "mlir-c/BuiltinTypes.h" |
21 | #include "mlir/Bindings/Python/PybindAdaptors.h" |
22 | |
23 | namespace py = pybind11; |
24 | using namespace mlir; |
25 | using namespace mlir::python; |
26 | |
27 | using llvm::SmallVector; |
28 | |
29 | //------------------------------------------------------------------------------ |
30 | // Docstrings (trivial, non-duplicated docstrings are included inline). |
31 | //------------------------------------------------------------------------------ |
32 | |
33 | static const char kDenseElementsAttrGetDocstring[] = |
34 | R"(Gets a DenseElementsAttr from a Python buffer or array. |
35 | |
36 | When `type` is not provided, then some limited type inferencing is done based |
37 | on the buffer format. Support presently exists for 8/16/32/64 signed and |
38 | unsigned integers and float16/float32/float64. DenseElementsAttrs of these |
39 | types can also be converted back to a corresponding buffer. |
40 | |
41 | For conversions outside of these types, a `type=` must be explicitly provided |
42 | and the buffer contents must be bit-castable to the MLIR internal |
43 | representation: |
44 | |
45 | * Integer types (except for i1): the buffer must be byte aligned to the |
46 | next byte boundary. |
47 | * Floating point types: Must be bit-castable to the given floating point |
48 | size. |
49 | * i1 (bool): Bit packed into 8bit words where the bit pattern matches a |
50 | row major ordering. An arbitrary Numpy `bool_` array can be bit packed to |
51 | this specification with: `np.packbits(ary, axis=None, bitorder='little')`. |
52 | |
53 | If a single element buffer is passed (or for i1, a single byte with value 0 |
54 | or 255), then a splat will be created. |
55 | |
56 | Args: |
57 | array: The array or buffer to convert. |
58 | signless: If inferring an appropriate MLIR type, use signless types for |
59 | integers (defaults True). |
60 | type: Skips inference of the MLIR element type and uses this instead. The |
61 | storage size must be consistent with the actual contents of the buffer. |
62 | shape: Overrides the shape of the buffer when constructing the MLIR |
63 | shaped type. This is needed when the physical and logical shape differ (as |
64 | for i1). |
65 | context: Explicit context, if not from context manager. |
66 | |
67 | Returns: |
68 | DenseElementsAttr on success. |
69 | |
70 | Raises: |
71 | ValueError: If the type of the buffer or array cannot be matched to an MLIR |
72 | type or if the buffer does not meet expectations. |
73 | )" ; |
74 | |
75 | static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = |
76 | R"(Gets a DenseResourceElementsAttr from a Python buffer or array. |
77 | |
78 | This function does minimal validation or massaging of the data, and it is |
79 | up to the caller to ensure that the buffer meets the characteristics |
80 | implied by the shape. |
81 | |
82 | The backing buffer and any user objects will be retained for the lifetime |
83 | of the resource blob. This is typically bounded to the context but the |
84 | resource can have a shorter lifespan depending on how it is used in |
85 | subsequent processing. |
86 | |
87 | Args: |
88 | buffer: The array or buffer to convert. |
89 | name: Name to provide to the resource (may be changed upon collision). |
90 | type: The explicit ShapedType to construct the attribute with. |
91 | context: Explicit context, if not from context manager. |
92 | |
93 | Returns: |
94 | DenseResourceElementsAttr on success. |
95 | |
96 | Raises: |
97 | ValueError: If the type of the buffer or array cannot be matched to an MLIR |
98 | type or if the buffer does not meet expectations. |
99 | )" ; |
100 | |
101 | namespace { |
102 | |
103 | static MlirStringRef toMlirStringRef(const std::string &s) { |
104 | return mlirStringRefCreate(s.data(), s.size()); |
105 | } |
106 | |
107 | class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> { |
108 | public: |
109 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; |
110 | static constexpr const char *pyClassName = "AffineMapAttr" ; |
111 | using PyConcreteAttribute::PyConcreteAttribute; |
112 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
113 | mlirAffineMapAttrGetTypeID; |
114 | |
115 | static void bindDerived(ClassTy &c) { |
116 | c.def_static( |
117 | "get" , |
118 | [](PyAffineMap &affineMap) { |
119 | MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); |
120 | return PyAffineMapAttribute(affineMap.getContext(), attr); |
121 | }, |
122 | py::arg("affine_map" ), "Gets an attribute wrapping an AffineMap." ); |
123 | } |
124 | }; |
125 | |
126 | template <typename T> |
127 | static T pyTryCast(py::handle object) { |
128 | try { |
129 | return object.cast<T>(); |
130 | } catch (py::cast_error &err) { |
131 | std::string msg = |
132 | std::string( |
133 | "Invalid attribute when attempting to create an ArrayAttribute (" ) + |
134 | err.what() + ")" ; |
135 | throw py::cast_error(msg); |
136 | } catch (py::reference_cast_error &err) { |
137 | std::string msg = std::string("Invalid attribute (None?) when attempting " |
138 | "to create an ArrayAttribute (" ) + |
139 | err.what() + ")" ; |
140 | throw py::cast_error(msg); |
141 | } |
142 | } |
143 | |
144 | /// A python-wrapped dense array attribute with an element type and a derived |
145 | /// implementation class. |
146 | template <typename EltTy, typename DerivedT> |
147 | class PyDenseArrayAttribute : public PyConcreteAttribute<DerivedT> { |
148 | public: |
149 | using PyConcreteAttribute<DerivedT>::PyConcreteAttribute; |
150 | |
151 | /// Iterator over the integer elements of a dense array. |
152 | class PyDenseArrayIterator { |
153 | public: |
154 | PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} |
155 | |
156 | /// Return a copy of the iterator. |
157 | PyDenseArrayIterator dunderIter() { return *this; } |
158 | |
159 | /// Return the next element. |
160 | EltTy dunderNext() { |
161 | // Throw if the index has reached the end. |
162 | if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) |
163 | throw py::stop_iteration(); |
164 | return DerivedT::getElement(attr.get(), nextIndex++); |
165 | } |
166 | |
167 | /// Bind the iterator class. |
168 | static void bind(py::module &m) { |
169 | py::class_<PyDenseArrayIterator>(m, DerivedT::pyIteratorName, |
170 | py::module_local()) |
171 | .def("__iter__" , &PyDenseArrayIterator::dunderIter) |
172 | .def("__next__" , &PyDenseArrayIterator::dunderNext); |
173 | } |
174 | |
175 | private: |
176 | /// The referenced dense array attribute. |
177 | PyAttribute attr; |
178 | /// The next index to read. |
179 | int nextIndex = 0; |
180 | }; |
181 | |
182 | /// Get the element at the given index. |
183 | EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } |
184 | |
185 | /// Bind the attribute class. |
186 | static void bindDerived(typename PyConcreteAttribute<DerivedT>::ClassTy &c) { |
187 | // Bind the constructor. |
188 | c.def_static( |
189 | "get" , |
190 | [](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) { |
191 | return getAttribute(values, ctx->getRef()); |
192 | }, |
193 | py::arg("values" ), py::arg("context" ) = py::none(), |
194 | "Gets a uniqued dense array attribute" ); |
195 | // Bind the array methods. |
196 | c.def("__getitem__" , [](DerivedT &arr, intptr_t i) { |
197 | if (i >= mlirDenseArrayGetNumElements(arr)) |
198 | throw py::index_error("DenseArray index out of range" ); |
199 | return arr.getItem(i); |
200 | }); |
201 | c.def("__len__" , [](const DerivedT &arr) { |
202 | return mlirDenseArrayGetNumElements(arr); |
203 | }); |
204 | c.def("__iter__" , |
205 | [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); |
206 | c.def("__add__" , [](DerivedT &arr, const py::list &) { |
207 | std::vector<EltTy> values; |
208 | intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); |
209 | values.reserve(numOldElements + py::len(extras)); |
210 | for (intptr_t i = 0; i < numOldElements; ++i) |
211 | values.push_back(arr.getItem(i)); |
212 | for (py::handle attr : extras) |
213 | values.push_back(pyTryCast<EltTy>(attr)); |
214 | return getAttribute(values, ctx: arr.getContext()); |
215 | }); |
216 | } |
217 | |
218 | private: |
219 | static DerivedT getAttribute(const std::vector<EltTy> &values, |
220 | PyMlirContextRef ctx) { |
221 | if constexpr (std::is_same_v<EltTy, bool>) { |
222 | std::vector<int> intValues(values.begin(), values.end()); |
223 | MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), |
224 | intValues.data()); |
225 | return DerivedT(ctx, attr); |
226 | } else { |
227 | MlirAttribute attr = |
228 | DerivedT::getAttribute(ctx->get(), values.size(), values.data()); |
229 | return DerivedT(ctx, attr); |
230 | } |
231 | } |
232 | }; |
233 | |
234 | /// Instantiate the python dense array classes. |
235 | struct PyDenseBoolArrayAttribute |
236 | : public PyDenseArrayAttribute<bool, PyDenseBoolArrayAttribute> { |
237 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; |
238 | static constexpr auto getAttribute = mlirDenseBoolArrayGet; |
239 | static constexpr auto getElement = mlirDenseBoolArrayGetElement; |
240 | static constexpr const char *pyClassName = "DenseBoolArrayAttr" ; |
241 | static constexpr const char *pyIteratorName = "DenseBoolArrayIterator" ; |
242 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
243 | }; |
244 | struct PyDenseI8ArrayAttribute |
245 | : public PyDenseArrayAttribute<int8_t, PyDenseI8ArrayAttribute> { |
246 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; |
247 | static constexpr auto getAttribute = mlirDenseI8ArrayGet; |
248 | static constexpr auto getElement = mlirDenseI8ArrayGetElement; |
249 | static constexpr const char *pyClassName = "DenseI8ArrayAttr" ; |
250 | static constexpr const char *pyIteratorName = "DenseI8ArrayIterator" ; |
251 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
252 | }; |
253 | struct PyDenseI16ArrayAttribute |
254 | : public PyDenseArrayAttribute<int16_t, PyDenseI16ArrayAttribute> { |
255 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; |
256 | static constexpr auto getAttribute = mlirDenseI16ArrayGet; |
257 | static constexpr auto getElement = mlirDenseI16ArrayGetElement; |
258 | static constexpr const char *pyClassName = "DenseI16ArrayAttr" ; |
259 | static constexpr const char *pyIteratorName = "DenseI16ArrayIterator" ; |
260 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
261 | }; |
262 | struct PyDenseI32ArrayAttribute |
263 | : public PyDenseArrayAttribute<int32_t, PyDenseI32ArrayAttribute> { |
264 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; |
265 | static constexpr auto getAttribute = mlirDenseI32ArrayGet; |
266 | static constexpr auto getElement = mlirDenseI32ArrayGetElement; |
267 | static constexpr const char *pyClassName = "DenseI32ArrayAttr" ; |
268 | static constexpr const char *pyIteratorName = "DenseI32ArrayIterator" ; |
269 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
270 | }; |
271 | struct PyDenseI64ArrayAttribute |
272 | : public PyDenseArrayAttribute<int64_t, PyDenseI64ArrayAttribute> { |
273 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; |
274 | static constexpr auto getAttribute = mlirDenseI64ArrayGet; |
275 | static constexpr auto getElement = mlirDenseI64ArrayGetElement; |
276 | static constexpr const char *pyClassName = "DenseI64ArrayAttr" ; |
277 | static constexpr const char *pyIteratorName = "DenseI64ArrayIterator" ; |
278 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
279 | }; |
280 | struct PyDenseF32ArrayAttribute |
281 | : public PyDenseArrayAttribute<float, PyDenseF32ArrayAttribute> { |
282 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; |
283 | static constexpr auto getAttribute = mlirDenseF32ArrayGet; |
284 | static constexpr auto getElement = mlirDenseF32ArrayGetElement; |
285 | static constexpr const char *pyClassName = "DenseF32ArrayAttr" ; |
286 | static constexpr const char *pyIteratorName = "DenseF32ArrayIterator" ; |
287 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
288 | }; |
289 | struct PyDenseF64ArrayAttribute |
290 | : public PyDenseArrayAttribute<double, PyDenseF64ArrayAttribute> { |
291 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; |
292 | static constexpr auto getAttribute = mlirDenseF64ArrayGet; |
293 | static constexpr auto getElement = mlirDenseF64ArrayGetElement; |
294 | static constexpr const char *pyClassName = "DenseF64ArrayAttr" ; |
295 | static constexpr const char *pyIteratorName = "DenseF64ArrayIterator" ; |
296 | using PyDenseArrayAttribute::PyDenseArrayAttribute; |
297 | }; |
298 | |
299 | class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> { |
300 | public: |
301 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; |
302 | static constexpr const char *pyClassName = "ArrayAttr" ; |
303 | using PyConcreteAttribute::PyConcreteAttribute; |
304 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
305 | mlirArrayAttrGetTypeID; |
306 | |
307 | class PyArrayAttributeIterator { |
308 | public: |
309 | PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} |
310 | |
311 | PyArrayAttributeIterator &dunderIter() { return *this; } |
312 | |
313 | MlirAttribute dunderNext() { |
314 | // TODO: Throw is an inefficient way to stop iteration. |
315 | if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) |
316 | throw py::stop_iteration(); |
317 | return mlirArrayAttrGetElement(attr.get(), nextIndex++); |
318 | } |
319 | |
320 | static void bind(py::module &m) { |
321 | py::class_<PyArrayAttributeIterator>(m, "ArrayAttributeIterator" , |
322 | py::module_local()) |
323 | .def("__iter__" , &PyArrayAttributeIterator::dunderIter) |
324 | .def("__next__" , &PyArrayAttributeIterator::dunderNext); |
325 | } |
326 | |
327 | private: |
328 | PyAttribute attr; |
329 | int nextIndex = 0; |
330 | }; |
331 | |
332 | MlirAttribute getItem(intptr_t i) { |
333 | return mlirArrayAttrGetElement(*this, i); |
334 | } |
335 | |
336 | static void bindDerived(ClassTy &c) { |
337 | c.def_static( |
338 | "get" , |
339 | [](py::list attributes, DefaultingPyMlirContext context) { |
340 | SmallVector<MlirAttribute> mlirAttributes; |
341 | mlirAttributes.reserve(py::len(attributes)); |
342 | for (auto attribute : attributes) { |
343 | mlirAttributes.push_back(pyTryCast<PyAttribute>(attribute)); |
344 | } |
345 | MlirAttribute attr = mlirArrayAttrGet( |
346 | context->get(), mlirAttributes.size(), mlirAttributes.data()); |
347 | return PyArrayAttribute(context->getRef(), attr); |
348 | }, |
349 | py::arg("attributes" ), py::arg("context" ) = py::none(), |
350 | "Gets a uniqued Array attribute" ); |
351 | c.def("__getitem__" , |
352 | [](PyArrayAttribute &arr, intptr_t i) { |
353 | if (i >= mlirArrayAttrGetNumElements(arr)) |
354 | throw py::index_error("ArrayAttribute index out of range" ); |
355 | return arr.getItem(i); |
356 | }) |
357 | .def("__len__" , |
358 | [](const PyArrayAttribute &arr) { |
359 | return mlirArrayAttrGetNumElements(arr); |
360 | }) |
361 | .def("__iter__" , [](const PyArrayAttribute &arr) { |
362 | return PyArrayAttributeIterator(arr); |
363 | }); |
364 | c.def("__add__" , [](PyArrayAttribute arr, py::list ) { |
365 | std::vector<MlirAttribute> attributes; |
366 | intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); |
367 | attributes.reserve(numOldElements + py::len(extras)); |
368 | for (intptr_t i = 0; i < numOldElements; ++i) |
369 | attributes.push_back(arr.getItem(i)); |
370 | for (py::handle attr : extras) |
371 | attributes.push_back(pyTryCast<PyAttribute>(attr)); |
372 | MlirAttribute arrayAttr = mlirArrayAttrGet( |
373 | arr.getContext()->get(), attributes.size(), attributes.data()); |
374 | return PyArrayAttribute(arr.getContext(), arrayAttr); |
375 | }); |
376 | } |
377 | }; |
378 | |
379 | /// Float Point Attribute subclass - FloatAttr. |
380 | class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { |
381 | public: |
382 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; |
383 | static constexpr const char *pyClassName = "FloatAttr" ; |
384 | using PyConcreteAttribute::PyConcreteAttribute; |
385 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
386 | mlirFloatAttrGetTypeID; |
387 | |
388 | static void bindDerived(ClassTy &c) { |
389 | c.def_static( |
390 | "get" , |
391 | [](PyType &type, double value, DefaultingPyLocation loc) { |
392 | PyMlirContext::ErrorCapture errors(loc->getContext()); |
393 | MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); |
394 | if (mlirAttributeIsNull(attr)) |
395 | throw MLIRError("Invalid attribute" , errors.take()); |
396 | return PyFloatAttribute(type.getContext(), attr); |
397 | }, |
398 | py::arg("type" ), py::arg("value" ), py::arg("loc" ) = py::none(), |
399 | "Gets an uniqued float point attribute associated to a type" ); |
400 | c.def_static( |
401 | "get_f32" , |
402 | [](double value, DefaultingPyMlirContext context) { |
403 | MlirAttribute attr = mlirFloatAttrDoubleGet( |
404 | context->get(), mlirF32TypeGet(context->get()), value); |
405 | return PyFloatAttribute(context->getRef(), attr); |
406 | }, |
407 | py::arg("value" ), py::arg("context" ) = py::none(), |
408 | "Gets an uniqued float point attribute associated to a f32 type" ); |
409 | c.def_static( |
410 | "get_f64" , |
411 | [](double value, DefaultingPyMlirContext context) { |
412 | MlirAttribute attr = mlirFloatAttrDoubleGet( |
413 | context->get(), mlirF64TypeGet(context->get()), value); |
414 | return PyFloatAttribute(context->getRef(), attr); |
415 | }, |
416 | py::arg("value" ), py::arg("context" ) = py::none(), |
417 | "Gets an uniqued float point attribute associated to a f64 type" ); |
418 | c.def_property_readonly("value" , mlirFloatAttrGetValueDouble, |
419 | "Returns the value of the float attribute" ); |
420 | c.def("__float__" , mlirFloatAttrGetValueDouble, |
421 | "Converts the value of the float attribute to a Python float" ); |
422 | } |
423 | }; |
424 | |
425 | /// Integer Attribute subclass - IntegerAttr. |
426 | class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { |
427 | public: |
428 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; |
429 | static constexpr const char *pyClassName = "IntegerAttr" ; |
430 | using PyConcreteAttribute::PyConcreteAttribute; |
431 | |
432 | static void bindDerived(ClassTy &c) { |
433 | c.def_static( |
434 | "get" , |
435 | [](PyType &type, int64_t value) { |
436 | MlirAttribute attr = mlirIntegerAttrGet(type, value); |
437 | return PyIntegerAttribute(type.getContext(), attr); |
438 | }, |
439 | py::arg("type" ), py::arg("value" ), |
440 | "Gets an uniqued integer attribute associated to a type" ); |
441 | c.def_property_readonly("value" , toPyInt, |
442 | "Returns the value of the integer attribute" ); |
443 | c.def("__int__" , toPyInt, |
444 | "Converts the value of the integer attribute to a Python int" ); |
445 | c.def_property_readonly_static("static_typeid" , |
446 | [](py::object & /*class*/) -> MlirTypeID { |
447 | return mlirIntegerAttrGetTypeID(); |
448 | }); |
449 | } |
450 | |
451 | private: |
452 | static py::int_ toPyInt(PyIntegerAttribute &self) { |
453 | MlirType type = mlirAttributeGetType(self); |
454 | if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) |
455 | return mlirIntegerAttrGetValueInt(self); |
456 | if (mlirIntegerTypeIsSigned(type)) |
457 | return mlirIntegerAttrGetValueSInt(self); |
458 | return mlirIntegerAttrGetValueUInt(self); |
459 | } |
460 | }; |
461 | |
462 | /// Bool Attribute subclass - BoolAttr. |
463 | class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { |
464 | public: |
465 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; |
466 | static constexpr const char *pyClassName = "BoolAttr" ; |
467 | using PyConcreteAttribute::PyConcreteAttribute; |
468 | |
469 | static void bindDerived(ClassTy &c) { |
470 | c.def_static( |
471 | "get" , |
472 | [](bool value, DefaultingPyMlirContext context) { |
473 | MlirAttribute attr = mlirBoolAttrGet(context->get(), value); |
474 | return PyBoolAttribute(context->getRef(), attr); |
475 | }, |
476 | py::arg("value" ), py::arg("context" ) = py::none(), |
477 | "Gets an uniqued bool attribute" ); |
478 | c.def_property_readonly("value" , mlirBoolAttrGetValue, |
479 | "Returns the value of the bool attribute" ); |
480 | c.def("__bool__" , mlirBoolAttrGetValue, |
481 | "Converts the value of the bool attribute to a Python bool" ); |
482 | } |
483 | }; |
484 | |
485 | class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> { |
486 | public: |
487 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; |
488 | static constexpr const char *pyClassName = "SymbolRefAttr" ; |
489 | using PyConcreteAttribute::PyConcreteAttribute; |
490 | |
491 | static MlirAttribute fromList(const std::vector<std::string> &symbols, |
492 | PyMlirContext &context) { |
493 | if (symbols.empty()) |
494 | throw std::runtime_error("SymbolRefAttr must be composed of at least " |
495 | "one symbol." ); |
496 | MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); |
497 | SmallVector<MlirAttribute, 3> referenceAttrs; |
498 | for (size_t i = 1; i < symbols.size(); ++i) { |
499 | referenceAttrs.push_back( |
500 | mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); |
501 | } |
502 | return mlirSymbolRefAttrGet(context.get(), rootSymbol, |
503 | referenceAttrs.size(), referenceAttrs.data()); |
504 | } |
505 | |
506 | static void bindDerived(ClassTy &c) { |
507 | c.def_static( |
508 | "get" , |
509 | [](const std::vector<std::string> &symbols, |
510 | DefaultingPyMlirContext context) { |
511 | return PySymbolRefAttribute::fromList(symbols, context.resolve()); |
512 | }, |
513 | py::arg("symbols" ), py::arg("context" ) = py::none(), |
514 | "Gets a uniqued SymbolRef attribute from a list of symbol names" ); |
515 | c.def_property_readonly( |
516 | "value" , |
517 | [](PySymbolRefAttribute &self) { |
518 | std::vector<std::string> symbols = { |
519 | unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; |
520 | for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); |
521 | ++i) |
522 | symbols.push_back( |
523 | unwrap(mlirSymbolRefAttrGetRootReference( |
524 | mlirSymbolRefAttrGetNestedReference(self, i))) |
525 | .str()); |
526 | return symbols; |
527 | }, |
528 | "Returns the value of the SymbolRef attribute as a list[str]" ); |
529 | } |
530 | }; |
531 | |
532 | class PyFlatSymbolRefAttribute |
533 | : public PyConcreteAttribute<PyFlatSymbolRefAttribute> { |
534 | public: |
535 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; |
536 | static constexpr const char *pyClassName = "FlatSymbolRefAttr" ; |
537 | using PyConcreteAttribute::PyConcreteAttribute; |
538 | |
539 | static void bindDerived(ClassTy &c) { |
540 | c.def_static( |
541 | "get" , |
542 | [](std::string value, DefaultingPyMlirContext context) { |
543 | MlirAttribute attr = |
544 | mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); |
545 | return PyFlatSymbolRefAttribute(context->getRef(), attr); |
546 | }, |
547 | py::arg("value" ), py::arg("context" ) = py::none(), |
548 | "Gets a uniqued FlatSymbolRef attribute" ); |
549 | c.def_property_readonly( |
550 | "value" , |
551 | [](PyFlatSymbolRefAttribute &self) { |
552 | MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); |
553 | return py::str(stringRef.data, stringRef.length); |
554 | }, |
555 | "Returns the value of the FlatSymbolRef attribute as a string" ); |
556 | } |
557 | }; |
558 | |
559 | class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> { |
560 | public: |
561 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; |
562 | static constexpr const char *pyClassName = "OpaqueAttr" ; |
563 | using PyConcreteAttribute::PyConcreteAttribute; |
564 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
565 | mlirOpaqueAttrGetTypeID; |
566 | |
567 | static void bindDerived(ClassTy &c) { |
568 | c.def_static( |
569 | "get" , |
570 | [](std::string dialectNamespace, py::buffer buffer, PyType &type, |
571 | DefaultingPyMlirContext context) { |
572 | const py::buffer_info bufferInfo = buffer.request(); |
573 | intptr_t bufferSize = bufferInfo.size; |
574 | MlirAttribute attr = mlirOpaqueAttrGet( |
575 | context->get(), toMlirStringRef(dialectNamespace), bufferSize, |
576 | static_cast<char *>(bufferInfo.ptr), type); |
577 | return PyOpaqueAttribute(context->getRef(), attr); |
578 | }, |
579 | py::arg("dialect_namespace" ), py::arg("buffer" ), py::arg("type" ), |
580 | py::arg("context" ) = py::none(), "Gets an Opaque attribute." ); |
581 | c.def_property_readonly( |
582 | "dialect_namespace" , |
583 | [](PyOpaqueAttribute &self) { |
584 | MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); |
585 | return py::str(stringRef.data, stringRef.length); |
586 | }, |
587 | "Returns the dialect namespace for the Opaque attribute as a string" ); |
588 | c.def_property_readonly( |
589 | "data" , |
590 | [](PyOpaqueAttribute &self) { |
591 | MlirStringRef stringRef = mlirOpaqueAttrGetData(self); |
592 | return py::bytes(stringRef.data, stringRef.length); |
593 | }, |
594 | "Returns the data for the Opaqued attributes as `bytes`" ); |
595 | } |
596 | }; |
597 | |
598 | class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { |
599 | public: |
600 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; |
601 | static constexpr const char *pyClassName = "StringAttr" ; |
602 | using PyConcreteAttribute::PyConcreteAttribute; |
603 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
604 | mlirStringAttrGetTypeID; |
605 | |
606 | static void bindDerived(ClassTy &c) { |
607 | c.def_static( |
608 | "get" , |
609 | [](std::string value, DefaultingPyMlirContext context) { |
610 | MlirAttribute attr = |
611 | mlirStringAttrGet(context->get(), toMlirStringRef(value)); |
612 | return PyStringAttribute(context->getRef(), attr); |
613 | }, |
614 | py::arg("value" ), py::arg("context" ) = py::none(), |
615 | "Gets a uniqued string attribute" ); |
616 | c.def_static( |
617 | "get_typed" , |
618 | [](PyType &type, std::string value) { |
619 | MlirAttribute attr = |
620 | mlirStringAttrTypedGet(type, toMlirStringRef(value)); |
621 | return PyStringAttribute(type.getContext(), attr); |
622 | }, |
623 | py::arg("type" ), py::arg("value" ), |
624 | "Gets a uniqued string attribute associated to a type" ); |
625 | c.def_property_readonly( |
626 | "value" , |
627 | [](PyStringAttribute &self) { |
628 | MlirStringRef stringRef = mlirStringAttrGetValue(self); |
629 | return py::str(stringRef.data, stringRef.length); |
630 | }, |
631 | "Returns the value of the string attribute" ); |
632 | c.def_property_readonly( |
633 | "value_bytes" , |
634 | [](PyStringAttribute &self) { |
635 | MlirStringRef stringRef = mlirStringAttrGetValue(self); |
636 | return py::bytes(stringRef.data, stringRef.length); |
637 | }, |
638 | "Returns the value of the string attribute as `bytes`" ); |
639 | } |
640 | }; |
641 | |
642 | // TODO: Support construction of string elements. |
643 | class PyDenseElementsAttribute |
644 | : public PyConcreteAttribute<PyDenseElementsAttribute> { |
645 | public: |
646 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; |
647 | static constexpr const char *pyClassName = "DenseElementsAttr" ; |
648 | using PyConcreteAttribute::PyConcreteAttribute; |
649 | |
650 | static PyDenseElementsAttribute |
651 | getFromBuffer(py::buffer array, bool signless, |
652 | std::optional<PyType> explicitType, |
653 | std::optional<std::vector<int64_t>> explicitShape, |
654 | DefaultingPyMlirContext contextWrapper) { |
655 | // Request a contiguous view. In exotic cases, this will cause a copy. |
656 | int flags = PyBUF_ND; |
657 | if (!explicitType) { |
658 | flags |= PyBUF_FORMAT; |
659 | } |
660 | Py_buffer view; |
661 | if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { |
662 | throw py::error_already_set(); |
663 | } |
664 | auto freeBuffer = llvm::make_scope_exit(F: [&]() { PyBuffer_Release(&view); }); |
665 | SmallVector<int64_t> shape; |
666 | if (explicitShape) { |
667 | shape.append(in_start: explicitShape->begin(), in_end: explicitShape->end()); |
668 | } else { |
669 | shape.append(view.shape, view.shape + view.ndim); |
670 | } |
671 | |
672 | MlirAttribute encodingAttr = mlirAttributeGetNull(); |
673 | MlirContext context = contextWrapper->get(); |
674 | |
675 | // Detect format codes that are suitable for bulk loading. This includes |
676 | // all byte aligned integer and floating point types up to 8 bytes. |
677 | // Notably, this excludes, bool (which needs to be bit-packed) and |
678 | // other exotics which do not have a direct representation in the buffer |
679 | // protocol (i.e. complex, etc). |
680 | std::optional<MlirType> bulkLoadElementType; |
681 | if (explicitType) { |
682 | bulkLoadElementType = *explicitType; |
683 | } else { |
684 | std::string_view format(view.format); |
685 | if (format == "f" ) { |
686 | // f32 |
687 | assert(view.itemsize == 4 && "mismatched array itemsize" ); |
688 | bulkLoadElementType = mlirF32TypeGet(context); |
689 | } else if (format == "d" ) { |
690 | // f64 |
691 | assert(view.itemsize == 8 && "mismatched array itemsize" ); |
692 | bulkLoadElementType = mlirF64TypeGet(context); |
693 | } else if (format == "e" ) { |
694 | // f16 |
695 | assert(view.itemsize == 2 && "mismatched array itemsize" ); |
696 | bulkLoadElementType = mlirF16TypeGet(context); |
697 | } else if (isSignedIntegerFormat(format)) { |
698 | if (view.itemsize == 4) { |
699 | // i32 |
700 | bulkLoadElementType = signless |
701 | ? mlirIntegerTypeGet(context, 32) |
702 | : mlirIntegerTypeSignedGet(context, 32); |
703 | } else if (view.itemsize == 8) { |
704 | // i64 |
705 | bulkLoadElementType = signless |
706 | ? mlirIntegerTypeGet(context, 64) |
707 | : mlirIntegerTypeSignedGet(context, 64); |
708 | } else if (view.itemsize == 1) { |
709 | // i8 |
710 | bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
711 | : mlirIntegerTypeSignedGet(context, 8); |
712 | } else if (view.itemsize == 2) { |
713 | // i16 |
714 | bulkLoadElementType = signless |
715 | ? mlirIntegerTypeGet(context, 16) |
716 | : mlirIntegerTypeSignedGet(context, 16); |
717 | } |
718 | } else if (isUnsignedIntegerFormat(format)) { |
719 | if (view.itemsize == 4) { |
720 | // unsigned i32 |
721 | bulkLoadElementType = signless |
722 | ? mlirIntegerTypeGet(context, 32) |
723 | : mlirIntegerTypeUnsignedGet(context, 32); |
724 | } else if (view.itemsize == 8) { |
725 | // unsigned i64 |
726 | bulkLoadElementType = signless |
727 | ? mlirIntegerTypeGet(context, 64) |
728 | : mlirIntegerTypeUnsignedGet(context, 64); |
729 | } else if (view.itemsize == 1) { |
730 | // i8 |
731 | bulkLoadElementType = signless |
732 | ? mlirIntegerTypeGet(context, 8) |
733 | : mlirIntegerTypeUnsignedGet(context, 8); |
734 | } else if (view.itemsize == 2) { |
735 | // i16 |
736 | bulkLoadElementType = signless |
737 | ? mlirIntegerTypeGet(context, 16) |
738 | : mlirIntegerTypeUnsignedGet(context, 16); |
739 | } |
740 | } |
741 | if (!bulkLoadElementType) { |
742 | throw std::invalid_argument( |
743 | std::string("unimplemented array format conversion from format: " ) + |
744 | std::string(format)); |
745 | } |
746 | } |
747 | |
748 | MlirType shapedType; |
749 | if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
750 | if (explicitShape) { |
751 | throw std::invalid_argument("Shape can only be specified explicitly " |
752 | "when the type is not a shaped type." ); |
753 | } |
754 | shapedType = *bulkLoadElementType; |
755 | } else { |
756 | shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), |
757 | *bulkLoadElementType, encodingAttr); |
758 | } |
759 | size_t rawBufferSize = view.len; |
760 | MlirAttribute attr = |
761 | mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); |
762 | if (mlirAttributeIsNull(attr)) { |
763 | throw std::invalid_argument( |
764 | "DenseElementsAttr could not be constructed from the given buffer. " |
765 | "This may mean that the Python buffer layout does not match that " |
766 | "MLIR expected layout and is a bug." ); |
767 | } |
768 | return PyDenseElementsAttribute(contextWrapper->getRef(), attr); |
769 | } |
770 | |
771 | static PyDenseElementsAttribute getSplat(const PyType &shapedType, |
772 | PyAttribute &elementAttr) { |
773 | auto contextWrapper = |
774 | PyMlirContext::forContext(mlirTypeGetContext(shapedType)); |
775 | if (!mlirAttributeIsAInteger(elementAttr) && |
776 | !mlirAttributeIsAFloat(elementAttr)) { |
777 | std::string message = "Illegal element type for DenseElementsAttr: " ; |
778 | message.append(py::repr(py::cast(elementAttr))); |
779 | throw py::value_error(message); |
780 | } |
781 | if (!mlirTypeIsAShaped(shapedType) || |
782 | !mlirShapedTypeHasStaticShape(shapedType)) { |
783 | std::string message = |
784 | "Expected a static ShapedType for the shaped_type parameter: " ; |
785 | message.append(py::repr(py::cast(shapedType))); |
786 | throw py::value_error(message); |
787 | } |
788 | MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); |
789 | MlirType attrType = mlirAttributeGetType(elementAttr); |
790 | if (!mlirTypeEqual(shapedElementType, attrType)) { |
791 | std::string message = |
792 | "Shaped element type and attribute type must be equal: shaped=" ; |
793 | message.append(py::repr(py::cast(shapedType))); |
794 | message.append(s: ", element=" ); |
795 | message.append(py::repr(py::cast(elementAttr))); |
796 | throw py::value_error(message); |
797 | } |
798 | |
799 | MlirAttribute elements = |
800 | mlirDenseElementsAttrSplatGet(shapedType, elementAttr); |
801 | return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
802 | } |
803 | |
804 | intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } |
805 | |
806 | py::buffer_info accessBuffer() { |
807 | MlirType shapedType = mlirAttributeGetType(*this); |
808 | MlirType elementType = mlirShapedTypeGetElementType(shapedType); |
809 | std::string format; |
810 | |
811 | if (mlirTypeIsAF32(elementType)) { |
812 | // f32 |
813 | return bufferInfo<float>(shapedType); |
814 | } |
815 | if (mlirTypeIsAF64(elementType)) { |
816 | // f64 |
817 | return bufferInfo<double>(shapedType); |
818 | } |
819 | if (mlirTypeIsAF16(elementType)) { |
820 | // f16 |
821 | return bufferInfo<uint16_t>(shapedType, "e" ); |
822 | } |
823 | if (mlirTypeIsAIndex(elementType)) { |
824 | // Same as IndexType::kInternalStorageBitWidth |
825 | return bufferInfo<int64_t>(shapedType); |
826 | } |
827 | if (mlirTypeIsAInteger(elementType) && |
828 | mlirIntegerTypeGetWidth(elementType) == 32) { |
829 | if (mlirIntegerTypeIsSignless(elementType) || |
830 | mlirIntegerTypeIsSigned(elementType)) { |
831 | // i32 |
832 | return bufferInfo<int32_t>(shapedType); |
833 | } |
834 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
835 | // unsigned i32 |
836 | return bufferInfo<uint32_t>(shapedType); |
837 | } |
838 | } else if (mlirTypeIsAInteger(elementType) && |
839 | mlirIntegerTypeGetWidth(elementType) == 64) { |
840 | if (mlirIntegerTypeIsSignless(elementType) || |
841 | mlirIntegerTypeIsSigned(elementType)) { |
842 | // i64 |
843 | return bufferInfo<int64_t>(shapedType); |
844 | } |
845 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
846 | // unsigned i64 |
847 | return bufferInfo<uint64_t>(shapedType); |
848 | } |
849 | } else if (mlirTypeIsAInteger(elementType) && |
850 | mlirIntegerTypeGetWidth(elementType) == 8) { |
851 | if (mlirIntegerTypeIsSignless(elementType) || |
852 | mlirIntegerTypeIsSigned(elementType)) { |
853 | // i8 |
854 | return bufferInfo<int8_t>(shapedType); |
855 | } |
856 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
857 | // unsigned i8 |
858 | return bufferInfo<uint8_t>(shapedType); |
859 | } |
860 | } else if (mlirTypeIsAInteger(elementType) && |
861 | mlirIntegerTypeGetWidth(elementType) == 16) { |
862 | if (mlirIntegerTypeIsSignless(elementType) || |
863 | mlirIntegerTypeIsSigned(elementType)) { |
864 | // i16 |
865 | return bufferInfo<int16_t>(shapedType); |
866 | } |
867 | if (mlirIntegerTypeIsUnsigned(elementType)) { |
868 | // unsigned i16 |
869 | return bufferInfo<uint16_t>(shapedType); |
870 | } |
871 | } |
872 | |
873 | // TODO: Currently crashes the program. |
874 | // Reported as https://github.com/pybind/pybind11/issues/3336 |
875 | throw std::invalid_argument( |
876 | "unsupported data type for conversion to Python buffer" ); |
877 | } |
878 | |
879 | static void bindDerived(ClassTy &c) { |
880 | c.def("__len__" , &PyDenseElementsAttribute::dunderLen) |
881 | .def_static("get" , PyDenseElementsAttribute::getFromBuffer, |
882 | py::arg("array" ), py::arg("signless" ) = true, |
883 | py::arg("type" ) = py::none(), py::arg("shape" ) = py::none(), |
884 | py::arg("context" ) = py::none(), |
885 | kDenseElementsAttrGetDocstring) |
886 | .def_static("get_splat" , PyDenseElementsAttribute::getSplat, |
887 | py::arg("shaped_type" ), py::arg("element_attr" ), |
888 | "Gets a DenseElementsAttr where all values are the same" ) |
889 | .def_property_readonly("is_splat" , |
890 | [](PyDenseElementsAttribute &self) -> bool { |
891 | return mlirDenseElementsAttrIsSplat(self); |
892 | }) |
893 | .def("get_splat_value" , |
894 | [](PyDenseElementsAttribute &self) { |
895 | if (!mlirDenseElementsAttrIsSplat(self)) |
896 | throw py::value_error( |
897 | "get_splat_value called on a non-splat attribute" ); |
898 | return mlirDenseElementsAttrGetSplatValue(self); |
899 | }) |
900 | .def_buffer(&PyDenseElementsAttribute::accessBuffer); |
901 | } |
902 | |
903 | private: |
904 | static bool isUnsignedIntegerFormat(std::string_view format) { |
905 | if (format.empty()) |
906 | return false; |
907 | char code = format[0]; |
908 | return code == 'I' || code == 'B' || code == 'H' || code == 'L' || |
909 | code == 'Q'; |
910 | } |
911 | |
912 | static bool isSignedIntegerFormat(std::string_view format) { |
913 | if (format.empty()) |
914 | return false; |
915 | char code = format[0]; |
916 | return code == 'i' || code == 'b' || code == 'h' || code == 'l' || |
917 | code == 'q'; |
918 | } |
919 | |
920 | template <typename Type> |
921 | py::buffer_info bufferInfo(MlirType shapedType, |
922 | const char *explicitFormat = nullptr) { |
923 | intptr_t rank = mlirShapedTypeGetRank(shapedType); |
924 | // Prepare the data for the buffer_info. |
925 | // Buffer is configured for read-only access below. |
926 | Type *data = static_cast<Type *>( |
927 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
928 | // Prepare the shape for the buffer_info. |
929 | SmallVector<intptr_t, 4> shape; |
930 | for (intptr_t i = 0; i < rank; ++i) |
931 | shape.push_back(Elt: mlirShapedTypeGetDimSize(shapedType, i)); |
932 | // Prepare the strides for the buffer_info. |
933 | SmallVector<intptr_t, 4> strides; |
934 | if (mlirDenseElementsAttrIsSplat(*this)) { |
935 | // Splats are special, only the single value is stored. |
936 | strides.assign(NumElts: rank, Elt: 0); |
937 | } else { |
938 | for (intptr_t i = 1; i < rank; ++i) { |
939 | intptr_t strideFactor = 1; |
940 | for (intptr_t j = i; j < rank; ++j) |
941 | strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); |
942 | strides.push_back(Elt: sizeof(Type) * strideFactor); |
943 | } |
944 | strides.push_back(Elt: sizeof(Type)); |
945 | } |
946 | std::string format; |
947 | if (explicitFormat) { |
948 | format = explicitFormat; |
949 | } else { |
950 | format = py::format_descriptor<Type>::format(); |
951 | } |
952 | return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, |
953 | /*readonly=*/true); |
954 | } |
955 | }; // namespace |
956 | |
957 | /// Refinement of the PyDenseElementsAttribute for attributes containing integer |
958 | /// (and boolean) values. Supports element access. |
959 | class PyDenseIntElementsAttribute |
960 | : public PyConcreteAttribute<PyDenseIntElementsAttribute, |
961 | PyDenseElementsAttribute> { |
962 | public: |
963 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; |
964 | static constexpr const char *pyClassName = "DenseIntElementsAttr" ; |
965 | using PyConcreteAttribute::PyConcreteAttribute; |
966 | |
967 | /// Returns the element at the given linear position. Asserts if the index is |
968 | /// out of range. |
969 | py::int_ dunderGetItem(intptr_t pos) { |
970 | if (pos < 0 || pos >= dunderLen()) { |
971 | throw py::index_error("attempt to access out of bounds element" ); |
972 | } |
973 | |
974 | MlirType type = mlirAttributeGetType(*this); |
975 | type = mlirShapedTypeGetElementType(type); |
976 | assert(mlirTypeIsAInteger(type) && |
977 | "expected integer element type in dense int elements attribute" ); |
978 | // Dispatch element extraction to an appropriate C function based on the |
979 | // elemental type of the attribute. py::int_ is implicitly constructible |
980 | // from any C++ integral type and handles bitwidth correctly. |
981 | // TODO: consider caching the type properties in the constructor to avoid |
982 | // querying them on each element access. |
983 | unsigned width = mlirIntegerTypeGetWidth(type); |
984 | bool isUnsigned = mlirIntegerTypeIsUnsigned(type); |
985 | if (isUnsigned) { |
986 | if (width == 1) { |
987 | return mlirDenseElementsAttrGetBoolValue(*this, pos); |
988 | } |
989 | if (width == 8) { |
990 | return mlirDenseElementsAttrGetUInt8Value(*this, pos); |
991 | } |
992 | if (width == 16) { |
993 | return mlirDenseElementsAttrGetUInt16Value(*this, pos); |
994 | } |
995 | if (width == 32) { |
996 | return mlirDenseElementsAttrGetUInt32Value(*this, pos); |
997 | } |
998 | if (width == 64) { |
999 | return mlirDenseElementsAttrGetUInt64Value(*this, pos); |
1000 | } |
1001 | } else { |
1002 | if (width == 1) { |
1003 | return mlirDenseElementsAttrGetBoolValue(*this, pos); |
1004 | } |
1005 | if (width == 8) { |
1006 | return mlirDenseElementsAttrGetInt8Value(*this, pos); |
1007 | } |
1008 | if (width == 16) { |
1009 | return mlirDenseElementsAttrGetInt16Value(*this, pos); |
1010 | } |
1011 | if (width == 32) { |
1012 | return mlirDenseElementsAttrGetInt32Value(*this, pos); |
1013 | } |
1014 | if (width == 64) { |
1015 | return mlirDenseElementsAttrGetInt64Value(*this, pos); |
1016 | } |
1017 | } |
1018 | throw py::type_error("Unsupported integer type" ); |
1019 | } |
1020 | |
1021 | static void bindDerived(ClassTy &c) { |
1022 | c.def("__getitem__" , &PyDenseIntElementsAttribute::dunderGetItem); |
1023 | } |
1024 | }; |
1025 | |
1026 | class PyDenseResourceElementsAttribute |
1027 | : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { |
1028 | public: |
1029 | static constexpr IsAFunctionTy isaFunction = |
1030 | mlirAttributeIsADenseResourceElements; |
1031 | static constexpr const char *pyClassName = "DenseResourceElementsAttr" ; |
1032 | using PyConcreteAttribute::PyConcreteAttribute; |
1033 | |
1034 | static PyDenseResourceElementsAttribute |
1035 | getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, |
1036 | std::optional<size_t> alignment, bool isMutable, |
1037 | DefaultingPyMlirContext contextWrapper) { |
1038 | if (!mlirTypeIsAShaped(type)) { |
1039 | throw std::invalid_argument( |
1040 | "Constructing a DenseResourceElementsAttr requires a ShapedType." ); |
1041 | } |
1042 | |
1043 | // Do not request any conversions as we must ensure to use caller |
1044 | // managed memory. |
1045 | int flags = PyBUF_STRIDES; |
1046 | std::unique_ptr<Py_buffer> view = std::make_unique<Py_buffer>(); |
1047 | if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { |
1048 | throw py::error_already_set(); |
1049 | } |
1050 | |
1051 | // This scope releaser will only release if we haven't yet transferred |
1052 | // ownership. |
1053 | auto freeBuffer = llvm::make_scope_exit(F: [&]() { |
1054 | if (view) |
1055 | PyBuffer_Release(view.get()); |
1056 | }); |
1057 | |
1058 | if (!PyBuffer_IsContiguous(view.get(), 'A')) { |
1059 | throw std::invalid_argument("Contiguous buffer is required." ); |
1060 | } |
1061 | |
1062 | // Infer alignment to be the stride of one element if not explicit. |
1063 | size_t inferredAlignment; |
1064 | if (alignment) |
1065 | inferredAlignment = *alignment; |
1066 | else |
1067 | inferredAlignment = view->strides[view->ndim - 1]; |
1068 | |
1069 | // The userData is a Py_buffer* that the deleter owns. |
1070 | auto deleter = [](void *userData, const void *data, size_t size, |
1071 | size_t align) { |
1072 | Py_buffer *ownedView = static_cast<Py_buffer *>(userData); |
1073 | PyBuffer_Release(ownedView); |
1074 | delete ownedView; |
1075 | }; |
1076 | |
1077 | size_t rawBufferSize = view->len; |
1078 | MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( |
1079 | type, toMlirStringRef(name), view->buf, rawBufferSize, |
1080 | inferredAlignment, isMutable, deleter, static_cast<void *>(view.get())); |
1081 | if (mlirAttributeIsNull(attr)) { |
1082 | throw std::invalid_argument( |
1083 | "DenseResourceElementsAttr could not be constructed from the given " |
1084 | "buffer. " |
1085 | "This may mean that the Python buffer layout does not match that " |
1086 | "MLIR expected layout and is a bug." ); |
1087 | } |
1088 | view.release(); |
1089 | return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); |
1090 | } |
1091 | |
1092 | static void bindDerived(ClassTy &c) { |
1093 | c.def_static("get_from_buffer" , |
1094 | PyDenseResourceElementsAttribute::getFromBuffer, |
1095 | py::arg("array" ), py::arg("name" ), py::arg("type" ), |
1096 | py::arg("alignment" ) = py::none(), |
1097 | py::arg("is_mutable" ) = false, py::arg("context" ) = py::none(), |
1098 | kDenseResourceElementsAttrGetFromBufferDocstring); |
1099 | } |
1100 | }; |
1101 | |
1102 | class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> { |
1103 | public: |
1104 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; |
1105 | static constexpr const char *pyClassName = "DictAttr" ; |
1106 | using PyConcreteAttribute::PyConcreteAttribute; |
1107 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1108 | mlirDictionaryAttrGetTypeID; |
1109 | |
1110 | intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } |
1111 | |
1112 | bool dunderContains(const std::string &name) { |
1113 | return !mlirAttributeIsNull( |
1114 | mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); |
1115 | } |
1116 | |
1117 | static void bindDerived(ClassTy &c) { |
1118 | c.def("__contains__" , &PyDictAttribute::dunderContains); |
1119 | c.def("__len__" , &PyDictAttribute::dunderLen); |
1120 | c.def_static( |
1121 | "get" , |
1122 | [](py::dict attributes, DefaultingPyMlirContext context) { |
1123 | SmallVector<MlirNamedAttribute> mlirNamedAttributes; |
1124 | mlirNamedAttributes.reserve(attributes.size()); |
1125 | for (auto &it : attributes) { |
1126 | auto &mlirAttr = it.second.cast<PyAttribute &>(); |
1127 | auto name = it.first.cast<std::string>(); |
1128 | mlirNamedAttributes.push_back(mlirNamedAttributeGet( |
1129 | mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), |
1130 | toMlirStringRef(name)), |
1131 | mlirAttr)); |
1132 | } |
1133 | MlirAttribute attr = |
1134 | mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), |
1135 | mlirNamedAttributes.data()); |
1136 | return PyDictAttribute(context->getRef(), attr); |
1137 | }, |
1138 | py::arg("value" ) = py::dict(), py::arg("context" ) = py::none(), |
1139 | "Gets an uniqued dict attribute" ); |
1140 | c.def("__getitem__" , [](PyDictAttribute &self, const std::string &name) { |
1141 | MlirAttribute attr = |
1142 | mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); |
1143 | if (mlirAttributeIsNull(attr)) |
1144 | throw py::key_error("attempt to access a non-existent attribute" ); |
1145 | return attr; |
1146 | }); |
1147 | c.def("__getitem__" , [](PyDictAttribute &self, intptr_t index) { |
1148 | if (index < 0 || index >= self.dunderLen()) { |
1149 | throw py::index_error("attempt to access out of bounds attribute" ); |
1150 | } |
1151 | MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); |
1152 | return PyNamedAttribute( |
1153 | namedAttr.attribute, |
1154 | std::string(mlirIdentifierStr(namedAttr.name).data)); |
1155 | }); |
1156 | } |
1157 | }; |
1158 | |
1159 | /// Refinement of PyDenseElementsAttribute for attributes containing |
1160 | /// floating-point values. Supports element access. |
1161 | class PyDenseFPElementsAttribute |
1162 | : public PyConcreteAttribute<PyDenseFPElementsAttribute, |
1163 | PyDenseElementsAttribute> { |
1164 | public: |
1165 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; |
1166 | static constexpr const char *pyClassName = "DenseFPElementsAttr" ; |
1167 | using PyConcreteAttribute::PyConcreteAttribute; |
1168 | |
1169 | py::float_ dunderGetItem(intptr_t pos) { |
1170 | if (pos < 0 || pos >= dunderLen()) { |
1171 | throw py::index_error("attempt to access out of bounds element" ); |
1172 | } |
1173 | |
1174 | MlirType type = mlirAttributeGetType(*this); |
1175 | type = mlirShapedTypeGetElementType(type); |
1176 | // Dispatch element extraction to an appropriate C function based on the |
1177 | // elemental type of the attribute. py::float_ is implicitly constructible |
1178 | // from float and double. |
1179 | // TODO: consider caching the type properties in the constructor to avoid |
1180 | // querying them on each element access. |
1181 | if (mlirTypeIsAF32(type)) { |
1182 | return mlirDenseElementsAttrGetFloatValue(*this, pos); |
1183 | } |
1184 | if (mlirTypeIsAF64(type)) { |
1185 | return mlirDenseElementsAttrGetDoubleValue(*this, pos); |
1186 | } |
1187 | throw py::type_error("Unsupported floating-point type" ); |
1188 | } |
1189 | |
1190 | static void bindDerived(ClassTy &c) { |
1191 | c.def("__getitem__" , &PyDenseFPElementsAttribute::dunderGetItem); |
1192 | } |
1193 | }; |
1194 | |
1195 | class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { |
1196 | public: |
1197 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; |
1198 | static constexpr const char *pyClassName = "TypeAttr" ; |
1199 | using PyConcreteAttribute::PyConcreteAttribute; |
1200 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1201 | mlirTypeAttrGetTypeID; |
1202 | |
1203 | static void bindDerived(ClassTy &c) { |
1204 | c.def_static( |
1205 | "get" , |
1206 | [](PyType value, DefaultingPyMlirContext context) { |
1207 | MlirAttribute attr = mlirTypeAttrGet(value.get()); |
1208 | return PyTypeAttribute(context->getRef(), attr); |
1209 | }, |
1210 | py::arg("value" ), py::arg("context" ) = py::none(), |
1211 | "Gets a uniqued Type attribute" ); |
1212 | c.def_property_readonly("value" , [](PyTypeAttribute &self) { |
1213 | return mlirTypeAttrGetValue(self.get()); |
1214 | }); |
1215 | } |
1216 | }; |
1217 | |
1218 | /// Unit Attribute subclass. Unit attributes don't have values. |
1219 | class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { |
1220 | public: |
1221 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; |
1222 | static constexpr const char *pyClassName = "UnitAttr" ; |
1223 | using PyConcreteAttribute::PyConcreteAttribute; |
1224 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1225 | mlirUnitAttrGetTypeID; |
1226 | |
1227 | static void bindDerived(ClassTy &c) { |
1228 | c.def_static( |
1229 | "get" , |
1230 | [](DefaultingPyMlirContext context) { |
1231 | return PyUnitAttribute(context->getRef(), |
1232 | mlirUnitAttrGet(context->get())); |
1233 | }, |
1234 | py::arg("context" ) = py::none(), "Create a Unit attribute." ); |
1235 | } |
1236 | }; |
1237 | |
1238 | /// Strided layout attribute subclass. |
1239 | class PyStridedLayoutAttribute |
1240 | : public PyConcreteAttribute<PyStridedLayoutAttribute> { |
1241 | public: |
1242 | static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; |
1243 | static constexpr const char *pyClassName = "StridedLayoutAttr" ; |
1244 | using PyConcreteAttribute::PyConcreteAttribute; |
1245 | static constexpr GetTypeIDFunctionTy getTypeIdFunction = |
1246 | mlirStridedLayoutAttrGetTypeID; |
1247 | |
1248 | static void bindDerived(ClassTy &c) { |
1249 | c.def_static( |
1250 | "get" , |
1251 | [](int64_t offset, const std::vector<int64_t> strides, |
1252 | DefaultingPyMlirContext ctx) { |
1253 | MlirAttribute attr = mlirStridedLayoutAttrGet( |
1254 | ctx->get(), offset, strides.size(), strides.data()); |
1255 | return PyStridedLayoutAttribute(ctx->getRef(), attr); |
1256 | }, |
1257 | py::arg("offset" ), py::arg("strides" ), py::arg("context" ) = py::none(), |
1258 | "Gets a strided layout attribute." ); |
1259 | c.def_static( |
1260 | "get_fully_dynamic" , |
1261 | [](int64_t rank, DefaultingPyMlirContext ctx) { |
1262 | auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); |
1263 | std::vector<int64_t> strides(rank); |
1264 | std::fill(strides.begin(), strides.end(), dynamic); |
1265 | MlirAttribute attr = mlirStridedLayoutAttrGet( |
1266 | ctx->get(), dynamic, strides.size(), strides.data()); |
1267 | return PyStridedLayoutAttribute(ctx->getRef(), attr); |
1268 | }, |
1269 | py::arg("rank" ), py::arg("context" ) = py::none(), |
1270 | "Gets a strided layout attribute with dynamic offset and strides of a " |
1271 | "given rank." ); |
1272 | c.def_property_readonly( |
1273 | "offset" , |
1274 | [](PyStridedLayoutAttribute &self) { |
1275 | return mlirStridedLayoutAttrGetOffset(self); |
1276 | }, |
1277 | "Returns the value of the float point attribute" ); |
1278 | c.def_property_readonly( |
1279 | "strides" , |
1280 | [](PyStridedLayoutAttribute &self) { |
1281 | intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); |
1282 | std::vector<int64_t> strides(size); |
1283 | for (intptr_t i = 0; i < size; i++) { |
1284 | strides[i] = mlirStridedLayoutAttrGetStride(self, i); |
1285 | } |
1286 | return strides; |
1287 | }, |
1288 | "Returns the value of the float point attribute" ); |
1289 | } |
1290 | }; |
1291 | |
1292 | py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { |
1293 | if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) |
1294 | return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); |
1295 | if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) |
1296 | return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); |
1297 | if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) |
1298 | return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); |
1299 | if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) |
1300 | return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); |
1301 | if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) |
1302 | return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); |
1303 | if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) |
1304 | return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); |
1305 | if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) |
1306 | return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); |
1307 | std::string msg = |
1308 | std::string("Can't cast unknown element type DenseArrayAttr (" ) + |
1309 | std::string(py::repr(py::cast(pyAttribute))) + ")" ; |
1310 | throw py::cast_error(msg); |
1311 | } |
1312 | |
1313 | py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { |
1314 | if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) |
1315 | return py::cast(PyDenseFPElementsAttribute(pyAttribute)); |
1316 | if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) |
1317 | return py::cast(PyDenseIntElementsAttribute(pyAttribute)); |
1318 | std::string msg = |
1319 | std::string( |
1320 | "Can't cast unknown element type DenseIntOrFPElementsAttr (" ) + |
1321 | std::string(py::repr(py::cast(pyAttribute))) + ")" ; |
1322 | throw py::cast_error(msg); |
1323 | } |
1324 | |
1325 | py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { |
1326 | if (PyBoolAttribute::isaFunction(pyAttribute)) |
1327 | return py::cast(PyBoolAttribute(pyAttribute)); |
1328 | if (PyIntegerAttribute::isaFunction(pyAttribute)) |
1329 | return py::cast(PyIntegerAttribute(pyAttribute)); |
1330 | std::string msg = |
1331 | std::string("Can't cast unknown element type DenseArrayAttr (" ) + |
1332 | std::string(py::repr(py::cast(pyAttribute))) + ")" ; |
1333 | throw py::cast_error(msg); |
1334 | } |
1335 | |
1336 | py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { |
1337 | if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) |
1338 | return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); |
1339 | if (PySymbolRefAttribute::isaFunction(pyAttribute)) |
1340 | return py::cast(PySymbolRefAttribute(pyAttribute)); |
1341 | std::string msg = std::string("Can't cast unknown SymbolRef attribute (" ) + |
1342 | std::string(py::repr(py::cast(pyAttribute))) + ")" ; |
1343 | throw py::cast_error(msg); |
1344 | } |
1345 | |
1346 | } // namespace |
1347 | |
1348 | void mlir::python::populateIRAttributes(py::module &m) { |
1349 | PyAffineMapAttribute::bind(m); |
1350 | |
1351 | PyDenseBoolArrayAttribute::bind(m); |
1352 | PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); |
1353 | PyDenseI8ArrayAttribute::bind(m); |
1354 | PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); |
1355 | PyDenseI16ArrayAttribute::bind(m); |
1356 | PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); |
1357 | PyDenseI32ArrayAttribute::bind(m); |
1358 | PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); |
1359 | PyDenseI64ArrayAttribute::bind(m); |
1360 | PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); |
1361 | PyDenseF32ArrayAttribute::bind(m); |
1362 | PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); |
1363 | PyDenseF64ArrayAttribute::bind(m); |
1364 | PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); |
1365 | PyGlobals::get().registerTypeCaster( |
1366 | mlirDenseArrayAttrGetTypeID(), |
1367 | pybind11::cpp_function(denseArrayAttributeCaster)); |
1368 | |
1369 | PyArrayAttribute::bind(m); |
1370 | PyArrayAttribute::PyArrayAttributeIterator::bind(m); |
1371 | PyBoolAttribute::bind(m); |
1372 | PyDenseElementsAttribute::bind(m); |
1373 | PyDenseFPElementsAttribute::bind(m); |
1374 | PyDenseIntElementsAttribute::bind(m); |
1375 | PyGlobals::get().registerTypeCaster( |
1376 | mlirDenseIntOrFPElementsAttrGetTypeID(), |
1377 | pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); |
1378 | PyDenseResourceElementsAttribute::bind(m); |
1379 | |
1380 | PyDictAttribute::bind(m); |
1381 | PySymbolRefAttribute::bind(m); |
1382 | PyGlobals::get().registerTypeCaster( |
1383 | mlirSymbolRefAttrGetTypeID(), |
1384 | pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); |
1385 | |
1386 | PyFlatSymbolRefAttribute::bind(m); |
1387 | PyOpaqueAttribute::bind(m); |
1388 | PyFloatAttribute::bind(m); |
1389 | PyIntegerAttribute::bind(m); |
1390 | PyStringAttribute::bind(m); |
1391 | PyTypeAttribute::bind(m); |
1392 | PyGlobals::get().registerTypeCaster( |
1393 | mlirIntegerAttrGetTypeID(), |
1394 | pybind11::cpp_function(integerOrBoolAttributeCaster)); |
1395 | PyUnitAttribute::bind(m); |
1396 | |
1397 | PyStridedLayoutAttribute::bind(m); |
1398 | } |
1399 | |