1//===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/BuiltinAttributes.h"
10#include "mlir-c/Support.h"
11#include "mlir/CAPI/AffineMap.h"
12#include "mlir/CAPI/IR.h"
13#include "mlir/CAPI/IntegerSet.h"
14#include "mlir/CAPI/Support.h"
15#include "mlir/IR/AsmState.h"
16#include "mlir/IR/Attributes.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/BuiltinTypes.h"
19
20using namespace mlir;
21
22MlirAttribute mlirAttributeGetNull() { return {.ptr: nullptr}; }
23
24//===----------------------------------------------------------------------===//
25// Location attribute.
26//===----------------------------------------------------------------------===//
27
28bool mlirAttributeIsALocation(MlirAttribute attr) {
29 return llvm::isa<LocationAttr>(Val: unwrap(c: attr));
30}
31
32//===----------------------------------------------------------------------===//
33// Affine map attribute.
34//===----------------------------------------------------------------------===//
35
36bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
37 return llvm::isa<AffineMapAttr>(unwrap(attr));
38}
39
40MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
41 return wrap(AffineMapAttr::get(unwrap(map)));
42}
43
44MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
45 return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
46}
47
48MlirTypeID mlirAffineMapAttrGetTypeID(void) {
49 return wrap(AffineMapAttr::getTypeID());
50}
51
52//===----------------------------------------------------------------------===//
53// Array attribute.
54//===----------------------------------------------------------------------===//
55
56bool mlirAttributeIsAArray(MlirAttribute attr) {
57 return llvm::isa<ArrayAttr>(unwrap(attr));
58}
59
60MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
61 MlirAttribute const *elements) {
62 SmallVector<Attribute, 8> attrs;
63 return wrap(
64 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
65 elements, attrs)));
66}
67
68intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
69 return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size());
70}
71
72MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
73 return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
74}
75
76MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
77
78//===----------------------------------------------------------------------===//
79// Dictionary attribute.
80//===----------------------------------------------------------------------===//
81
82bool mlirAttributeIsADictionary(MlirAttribute attr) {
83 return llvm::isa<DictionaryAttr>(Val: unwrap(c: attr));
84}
85
86MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
87 MlirNamedAttribute const *elements) {
88 SmallVector<NamedAttribute, 8> attributes;
89 attributes.reserve(N: numElements);
90 for (intptr_t i = 0; i < numElements; ++i)
91 attributes.emplace_back(unwrap(elements[i].name),
92 unwrap(c: elements[i].attribute));
93 return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
94}
95
96intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
97 return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(c: attr)).size());
98}
99
100MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
101 intptr_t pos) {
102 NamedAttribute attribute =
103 llvm::cast<DictionaryAttr>(unwrap(c: attr)).getValue()[pos];
104 return {wrap(attribute.getName()), wrap(cpp: attribute.getValue())};
105}
106
107MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
108 MlirStringRef name) {
109 return wrap(llvm::cast<DictionaryAttr>(unwrap(c: attr)).get(unwrap(ref: name)));
110}
111
112MlirTypeID mlirDictionaryAttrGetTypeID(void) {
113 return wrap(DictionaryAttr::getTypeID());
114}
115
116//===----------------------------------------------------------------------===//
117// Floating point attribute.
118//===----------------------------------------------------------------------===//
119
120bool mlirAttributeIsAFloat(MlirAttribute attr) {
121 return llvm::isa<FloatAttr>(unwrap(attr));
122}
123
124MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
125 double value) {
126 return wrap(FloatAttr::get(unwrap(type), value));
127}
128
129MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
130 double value) {
131 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
132}
133
134double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
135 return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
136}
137
138MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
139
140//===----------------------------------------------------------------------===//
141// Integer attribute.
142//===----------------------------------------------------------------------===//
143
144bool mlirAttributeIsAInteger(MlirAttribute attr) {
145 return llvm::isa<IntegerAttr>(unwrap(attr));
146}
147
148MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
149 return wrap(IntegerAttr::get(unwrap(type), value));
150}
151
152int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
153 return llvm::cast<IntegerAttr>(unwrap(attr)).getInt();
154}
155
156int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
157 return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt();
158}
159
160uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
161 return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
162}
163
164MlirTypeID mlirIntegerAttrGetTypeID(void) {
165 return wrap(IntegerAttr::getTypeID());
166}
167
168//===----------------------------------------------------------------------===//
169// Bool attribute.
170//===----------------------------------------------------------------------===//
171
172bool mlirAttributeIsABool(MlirAttribute attr) {
173 return llvm::isa<BoolAttr>(Val: unwrap(c: attr));
174}
175
176MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
177 return wrap(cpp: BoolAttr::get(context: unwrap(c: ctx), value));
178}
179
180bool mlirBoolAttrGetValue(MlirAttribute attr) {
181 return llvm::cast<BoolAttr>(Val: unwrap(c: attr)).getValue();
182}
183
184//===----------------------------------------------------------------------===//
185// Integer set attribute.
186//===----------------------------------------------------------------------===//
187
188bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
189 return llvm::isa<IntegerSetAttr>(unwrap(attr));
190}
191
192MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
193 return wrap(IntegerSetAttr::getTypeID());
194}
195
196MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) {
197 return wrap(IntegerSetAttr::get(unwrap(set)));
198}
199
200MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) {
201 return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue());
202}
203
204//===----------------------------------------------------------------------===//
205// Opaque attribute.
206//===----------------------------------------------------------------------===//
207
208bool mlirAttributeIsAOpaque(MlirAttribute attr) {
209 return llvm::isa<OpaqueAttr>(unwrap(attr));
210}
211
212MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
213 intptr_t dataLength, const char *data,
214 MlirType type) {
215 return wrap(
216 OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
217 StringRef(data, dataLength), unwrap(type)));
218}
219
220MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
221 return wrap(
222 llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref());
223}
224
225MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
226 return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
227}
228
229MlirTypeID mlirOpaqueAttrGetTypeID(void) {
230 return wrap(OpaqueAttr::getTypeID());
231}
232
233//===----------------------------------------------------------------------===//
234// String attribute.
235//===----------------------------------------------------------------------===//
236
237bool mlirAttributeIsAString(MlirAttribute attr) {
238 return llvm::isa<StringAttr>(Val: unwrap(c: attr));
239}
240
241MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
242 return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
243}
244
245MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
246 return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
247}
248
249MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
250 return wrap(llvm::cast<StringAttr>(unwrap(c: attr)).getValue());
251}
252
253MlirTypeID mlirStringAttrGetTypeID(void) {
254 return wrap(StringAttr::getTypeID());
255}
256
257//===----------------------------------------------------------------------===//
258// SymbolRef attribute.
259//===----------------------------------------------------------------------===//
260
261bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
262 return llvm::isa<SymbolRefAttr>(unwrap(attr));
263}
264
265MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
266 intptr_t numReferences,
267 MlirAttribute const *references) {
268 SmallVector<FlatSymbolRefAttr, 4> refs;
269 refs.reserve(N: numReferences);
270 for (intptr_t i = 0; i < numReferences; ++i)
271 refs.push_back(Elt: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: references[i])));
272 auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
273 return wrap(SymbolRefAttr::get(symbolAttr, refs));
274}
275
276MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
277 return wrap(
278 llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue());
279}
280
281MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
282 return wrap(
283 llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue());
284}
285
286intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
287 return static_cast<intptr_t>(
288 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size());
289}
290
291MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
292 intptr_t pos) {
293 return wrap(
294 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
295}
296
297MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
298 return wrap(SymbolRefAttr::getTypeID());
299}
300
301MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) {
302 return wrap(mlir::DistinctAttr::create(referencedAttr: unwrap(c: referencedAttr)));
303}
304
305//===----------------------------------------------------------------------===//
306// Flat SymbolRef attribute.
307//===----------------------------------------------------------------------===//
308
309bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
310 return llvm::isa<FlatSymbolRefAttr>(Val: unwrap(c: attr));
311}
312
313MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
314 return wrap(FlatSymbolRefAttr::get(ctx: unwrap(c: ctx), value: unwrap(ref: symbol)));
315}
316
317MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
318 return wrap(ref: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: attr)).getValue());
319}
320
321//===----------------------------------------------------------------------===//
322// Type attribute.
323//===----------------------------------------------------------------------===//
324
325bool mlirAttributeIsAType(MlirAttribute attr) {
326 return llvm::isa<TypeAttr>(unwrap(attr));
327}
328
329MlirAttribute mlirTypeAttrGet(MlirType type) {
330 return wrap(TypeAttr::get(unwrap(type)));
331}
332
333MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
334 return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
335}
336
337MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
338
339//===----------------------------------------------------------------------===//
340// Unit attribute.
341//===----------------------------------------------------------------------===//
342
343bool mlirAttributeIsAUnit(MlirAttribute attr) {
344 return llvm::isa<UnitAttr>(unwrap(attr));
345}
346
347MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
348 return wrap(UnitAttr::get(unwrap(ctx)));
349}
350
351MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
352
353//===----------------------------------------------------------------------===//
354// Elements attributes.
355//===----------------------------------------------------------------------===//
356
357bool mlirAttributeIsAElements(MlirAttribute attr) {
358 return llvm::isa<ElementsAttr>(Val: unwrap(c: attr));
359}
360
361MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
362 uint64_t *idxs) {
363 return wrap(llvm::cast<ElementsAttr>(unwrap(c: attr))
364 .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]);
365}
366
367bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
368 uint64_t *idxs) {
369 return llvm::cast<ElementsAttr>(unwrap(c: attr))
370 .isValidIndex(llvm::ArrayRef(idxs, rank));
371}
372
373int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
374 return llvm::cast<ElementsAttr>(unwrap(c: attr)).getNumElements();
375}
376
377//===----------------------------------------------------------------------===//
378// Dense array attribute.
379//===----------------------------------------------------------------------===//
380
381MlirTypeID mlirDenseArrayAttrGetTypeID() {
382 return wrap(DenseArrayAttr::getTypeID());
383}
384
385//===----------------------------------------------------------------------===//
386// IsA support.
387//===----------------------------------------------------------------------===//
388
389bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
390 return llvm::isa<DenseBoolArrayAttr>(Val: unwrap(c: attr));
391}
392bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
393 return llvm::isa<DenseI8ArrayAttr>(Val: unwrap(c: attr));
394}
395bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
396 return llvm::isa<DenseI16ArrayAttr>(Val: unwrap(c: attr));
397}
398bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
399 return llvm::isa<DenseI32ArrayAttr>(Val: unwrap(c: attr));
400}
401bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
402 return llvm::isa<DenseI64ArrayAttr>(Val: unwrap(c: attr));
403}
404bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
405 return llvm::isa<DenseF32ArrayAttr>(Val: unwrap(c: attr));
406}
407bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
408 return llvm::isa<DenseF64ArrayAttr>(Val: unwrap(c: attr));
409}
410
411//===----------------------------------------------------------------------===//
412// Constructors.
413//===----------------------------------------------------------------------===//
414
415MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
416 int const *values) {
417 SmallVector<bool, 4> elements(values, values + size);
418 return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements));
419}
420MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size,
421 int8_t const *values) {
422 return wrap(
423 DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size)));
424}
425MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size,
426 int16_t const *values) {
427 return wrap(
428 DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size)));
429}
430MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size,
431 int32_t const *values) {
432 return wrap(
433 DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size)));
434}
435MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size,
436 int64_t const *values) {
437 return wrap(
438 DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size)));
439}
440MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size,
441 float const *values) {
442 return wrap(
443 DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size)));
444}
445MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
446 double const *values) {
447 return wrap(
448 DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size)));
449}
450
451//===----------------------------------------------------------------------===//
452// Accessors.
453//===----------------------------------------------------------------------===//
454
455intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
456 return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
457}
458
459//===----------------------------------------------------------------------===//
460// Indexed accessors.
461//===----------------------------------------------------------------------===//
462
463bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
464 return llvm::cast<DenseBoolArrayAttr>(unwrap(c: attr))[pos];
465}
466int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
467 return llvm::cast<DenseI8ArrayAttr>(unwrap(c: attr))[pos];
468}
469int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
470 return llvm::cast<DenseI16ArrayAttr>(unwrap(c: attr))[pos];
471}
472int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
473 return llvm::cast<DenseI32ArrayAttr>(unwrap(c: attr))[pos];
474}
475int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
476 return llvm::cast<DenseI64ArrayAttr>(unwrap(c: attr))[pos];
477}
478float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
479 return llvm::cast<DenseF32ArrayAttr>(unwrap(c: attr))[pos];
480}
481double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
482 return llvm::cast<DenseF64ArrayAttr>(unwrap(c: attr))[pos];
483}
484
485//===----------------------------------------------------------------------===//
486// Dense elements attribute.
487//===----------------------------------------------------------------------===//
488
489//===----------------------------------------------------------------------===//
490// IsA support.
491//===----------------------------------------------------------------------===//
492
493bool mlirAttributeIsADenseElements(MlirAttribute attr) {
494 return llvm::isa<DenseElementsAttr>(Val: unwrap(c: attr));
495}
496
497bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
498 return llvm::isa<DenseIntElementsAttr>(Val: unwrap(c: attr));
499}
500
501bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
502 return llvm::isa<DenseFPElementsAttr>(Val: unwrap(c: attr));
503}
504
505MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
506 return wrap(DenseIntOrFPElementsAttr::getTypeID());
507}
508
509//===----------------------------------------------------------------------===//
510// Constructors.
511//===----------------------------------------------------------------------===//
512
513MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
514 intptr_t numElements,
515 MlirAttribute const *elements) {
516 SmallVector<Attribute, 8> attributes;
517 return wrap(
518 DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
519 unwrapList(numElements, elements, attributes)));
520}
521
522MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
523 size_t rawBufferSize,
524 const void *rawBuffer) {
525 auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
526 ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
527 rawBufferSize);
528 bool isSplat = false;
529 if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
530 isSplat))
531 return mlirAttributeGetNull();
532 return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
533}
534
535MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
536 MlirAttribute element) {
537 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
538 unwrap(element)));
539}
540MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
541 bool element) {
542 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
543 element));
544}
545MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
546 uint8_t element) {
547 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
548 element));
549}
550MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
551 int8_t element) {
552 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
553 element));
554}
555MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
556 uint32_t element) {
557 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
558 element));
559}
560MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
561 int32_t element) {
562 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
563 element));
564}
565MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
566 uint64_t element) {
567 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
568 element));
569}
570MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
571 int64_t element) {
572 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
573 element));
574}
575MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
576 float element) {
577 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
578 element));
579}
580MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
581 double element) {
582 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
583 element));
584}
585
586MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
587 intptr_t numElements,
588 const int *elements) {
589 SmallVector<bool, 8> values(elements, elements + numElements);
590 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
591 values));
592}
593
594/// Creates a dense attribute with elements of the type deduced by templates.
595template <typename T>
596static MlirAttribute getDenseAttribute(MlirType shapedType,
597 intptr_t numElements,
598 const T *elements) {
599 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
600 llvm::ArrayRef(elements, numElements)));
601}
602
603MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
604 intptr_t numElements,
605 const uint8_t *elements) {
606 return getDenseAttribute(shapedType, numElements, elements);
607}
608MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
609 intptr_t numElements,
610 const int8_t *elements) {
611 return getDenseAttribute(shapedType, numElements, elements);
612}
613MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
614 intptr_t numElements,
615 const uint16_t *elements) {
616 return getDenseAttribute(shapedType, numElements, elements);
617}
618MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
619 intptr_t numElements,
620 const int16_t *elements) {
621 return getDenseAttribute(shapedType, numElements, elements);
622}
623MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
624 intptr_t numElements,
625 const uint32_t *elements) {
626 return getDenseAttribute(shapedType, numElements, elements);
627}
628MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
629 intptr_t numElements,
630 const int32_t *elements) {
631 return getDenseAttribute(shapedType, numElements, elements);
632}
633MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
634 intptr_t numElements,
635 const uint64_t *elements) {
636 return getDenseAttribute(shapedType, numElements, elements);
637}
638MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
639 intptr_t numElements,
640 const int64_t *elements) {
641 return getDenseAttribute(shapedType, numElements, elements);
642}
643MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
644 intptr_t numElements,
645 const float *elements) {
646 return getDenseAttribute(shapedType, numElements, elements);
647}
648MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
649 intptr_t numElements,
650 const double *elements) {
651 return getDenseAttribute(shapedType, numElements, elements);
652}
653MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
654 intptr_t numElements,
655 const uint16_t *elements) {
656 size_t bufferSize = numElements * 2;
657 const void *buffer = static_cast<const void *>(elements);
658 return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer);
659}
660MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
661 intptr_t numElements,
662 const uint16_t *elements) {
663 size_t bufferSize = numElements * 2;
664 const void *buffer = static_cast<const void *>(elements);
665 return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer);
666}
667
668MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
669 intptr_t numElements,
670 MlirStringRef *strs) {
671 SmallVector<StringRef, 8> values;
672 values.reserve(N: numElements);
673 for (intptr_t i = 0; i < numElements; ++i)
674 values.push_back(Elt: unwrap(ref: strs[i]));
675
676 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
677 values));
678}
679
680MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
681 MlirType shapedType) {
682 return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
683 .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
684}
685
686//===----------------------------------------------------------------------===//
687// Splat accessors.
688//===----------------------------------------------------------------------===//
689
690bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
691 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).isSplat();
692}
693
694MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
695 return wrap(
696 cpp: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<Attribute>());
697}
698int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
699 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<bool>();
700}
701int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
702 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int8_t>();
703}
704uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
705 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint8_t>();
706}
707int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
708 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int32_t>();
709}
710uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
711 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint32_t>();
712}
713int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
714 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int64_t>();
715}
716uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
717 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint64_t>();
718}
719float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
720 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<float>();
721}
722double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
723 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<double>();
724}
725MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
726 return wrap(
727 ref: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<StringRef>());
728}
729
730//===----------------------------------------------------------------------===//
731// Indexed accessors.
732//===----------------------------------------------------------------------===//
733
734bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
735 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<bool>()[pos];
736}
737int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
738 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int8_t>()[pos];
739}
740uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
741 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint8_t>()[pos];
742}
743int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
744 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int16_t>()[pos];
745}
746uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
747 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint16_t>()[pos];
748}
749int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
750 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int32_t>()[pos];
751}
752uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
753 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint32_t>()[pos];
754}
755int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
756 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int64_t>()[pos];
757}
758uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
759 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint64_t>()[pos];
760}
761uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) {
762 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint64_t>()[pos];
763}
764float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
765 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<float>()[pos];
766}
767double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
768 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<double>()[pos];
769}
770MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
771 intptr_t pos) {
772 return wrap(
773 llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<StringRef>()[pos]);
774}
775
776//===----------------------------------------------------------------------===//
777// Raw data accessors.
778//===----------------------------------------------------------------------===//
779
780const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
781 return static_cast<const void *>(
782 llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getRawData().data());
783}
784
785//===----------------------------------------------------------------------===//
786// Resource blob attributes.
787//===----------------------------------------------------------------------===//
788
789bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
790 return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
791}
792
793MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
794 MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
795 size_t dataAlignment, bool dataIsMutable,
796 void (*deleter)(void *userData, const void *data, size_t size,
797 size_t align),
798 void *userData) {
799 AsmResourceBlob::DeleterFn cppDeleter = {};
800 if (deleter) {
801 cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
802 deleter(userData, data, size, align);
803 };
804 }
805 AsmResourceBlob blob(
806 llvm::ArrayRef(static_cast<const char *>(data), dataLength),
807 dataAlignment, std::move(cppDeleter), dataIsMutable);
808 return wrap(
809 DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
810 unwrap(name), std::move(blob)));
811}
812
813template <typename U, typename T>
814static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
815 intptr_t numElements, const T *elements) {
816 return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
817 UnmanagedAsmResourceBlob::allocateInferAlign(
818 llvm::ArrayRef(elements, numElements))));
819}
820
821MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
822 MlirType shapedType, MlirStringRef name, intptr_t numElements,
823 const int *elements) {
824 return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
825 numElements, elements);
826}
827MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
828 MlirType shapedType, MlirStringRef name, intptr_t numElements,
829 const uint8_t *elements) {
830 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
831 numElements, elements);
832}
833MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
834 MlirType shapedType, MlirStringRef name, intptr_t numElements,
835 const uint16_t *elements) {
836 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
837 numElements, elements);
838}
839MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
840 MlirType shapedType, MlirStringRef name, intptr_t numElements,
841 const uint32_t *elements) {
842 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
843 numElements, elements);
844}
845MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
846 MlirType shapedType, MlirStringRef name, intptr_t numElements,
847 const uint64_t *elements) {
848 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
849 numElements, elements);
850}
851MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
852 MlirType shapedType, MlirStringRef name, intptr_t numElements,
853 const int8_t *elements) {
854 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
855 numElements, elements);
856}
857MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
858 MlirType shapedType, MlirStringRef name, intptr_t numElements,
859 const int16_t *elements) {
860 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
861 numElements, elements);
862}
863MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
864 MlirType shapedType, MlirStringRef name, intptr_t numElements,
865 const int32_t *elements) {
866 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
867 numElements, elements);
868}
869MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
870 MlirType shapedType, MlirStringRef name, intptr_t numElements,
871 const int64_t *elements) {
872 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
873 numElements, elements);
874}
875MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
876 MlirType shapedType, MlirStringRef name, intptr_t numElements,
877 const float *elements) {
878 return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
879 numElements, elements);
880}
881MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
882 MlirType shapedType, MlirStringRef name, intptr_t numElements,
883 const double *elements) {
884 return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
885 numElements, elements);
886}
887template <typename U, typename T>
888static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
889 return (*llvm::cast<U>(unwrap(c: attr)).tryGetAsArrayRef())[pos];
890}
891
892bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
893 intptr_t pos) {
894 return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
895}
896uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
897 intptr_t pos) {
898 return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
899}
900uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
901 intptr_t pos) {
902 return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
903 pos);
904}
905uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
906 intptr_t pos) {
907 return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
908 pos);
909}
910uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
911 intptr_t pos) {
912 return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
913 pos);
914}
915int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
916 intptr_t pos) {
917 return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
918}
919int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
920 intptr_t pos) {
921 return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
922}
923int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
924 intptr_t pos) {
925 return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
926}
927int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
928 intptr_t pos) {
929 return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
930}
931float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
932 intptr_t pos) {
933 return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
934}
935double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
936 intptr_t pos) {
937 return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
938}
939
940//===----------------------------------------------------------------------===//
941// Sparse elements attribute.
942//===----------------------------------------------------------------------===//
943
944bool mlirAttributeIsASparseElements(MlirAttribute attr) {
945 return llvm::isa<SparseElementsAttr>(unwrap(attr));
946}
947
948MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
949 MlirAttribute denseIndices,
950 MlirAttribute denseValues) {
951 return wrap(SparseElementsAttr::get(
952 llvm::cast<ShapedType>(unwrap(shapedType)),
953 llvm::cast<DenseElementsAttr>(unwrap(denseIndices)),
954 llvm::cast<DenseElementsAttr>(unwrap(denseValues))));
955}
956
957MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
958 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices());
959}
960
961MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
962 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
963}
964
965MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
966 return wrap(SparseElementsAttr::getTypeID());
967}
968
969//===----------------------------------------------------------------------===//
970// Strided layout attribute.
971//===----------------------------------------------------------------------===//
972
973bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
974 return llvm::isa<StridedLayoutAttr>(unwrap(attr));
975}
976
977MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
978 intptr_t numStrides,
979 const int64_t *strides) {
980 return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
981 ArrayRef<int64_t>(strides, numStrides)));
982}
983
984int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) {
985 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset();
986}
987
988intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
989 return static_cast<intptr_t>(
990 llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size());
991}
992
993int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
994 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
995}
996
997MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
998 return wrap(StridedLayoutAttr::getTypeID());
999}
1000

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/CAPI/IR/BuiltinAttributes.cpp