1//===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/BuiltinAttributes.h"
10#include "mlir-c/Support.h"
11#include "mlir/CAPI/AffineMap.h"
12#include "mlir/CAPI/IR.h"
13#include "mlir/CAPI/Support.h"
14#include "mlir/IR/AsmState.h"
15#include "mlir/IR/Attributes.h"
16#include "mlir/IR/BuiltinAttributes.h"
17#include "mlir/IR/BuiltinTypes.h"
18
19using namespace mlir;
20
21MlirAttribute mlirAttributeGetNull() { return {.ptr: nullptr}; }
22
23//===----------------------------------------------------------------------===//
24// Location attribute.
25//===----------------------------------------------------------------------===//
26
27bool mlirAttributeIsALocation(MlirAttribute attr) {
28 return llvm::isa<LocationAttr>(Val: unwrap(c: attr));
29}
30
31//===----------------------------------------------------------------------===//
32// Affine map attribute.
33//===----------------------------------------------------------------------===//
34
35bool mlirAttributeIsAAffineMap(MlirAttribute attr) {
36 return llvm::isa<AffineMapAttr>(unwrap(attr));
37}
38
39MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
40 return wrap(AffineMapAttr::get(unwrap(map)));
41}
42
43MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
44 return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue());
45}
46
47MlirTypeID mlirAffineMapAttrGetTypeID(void) {
48 return wrap(AffineMapAttr::getTypeID());
49}
50
51//===----------------------------------------------------------------------===//
52// Array attribute.
53//===----------------------------------------------------------------------===//
54
55bool mlirAttributeIsAArray(MlirAttribute attr) {
56 return llvm::isa<ArrayAttr>(unwrap(attr));
57}
58
59MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
60 MlirAttribute const *elements) {
61 SmallVector<Attribute, 8> attrs;
62 return wrap(
63 ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements),
64 elements, attrs)));
65}
66
67intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
68 return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size());
69}
70
71MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
72 return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]);
73}
74
75MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); }
76
77//===----------------------------------------------------------------------===//
78// Dictionary attribute.
79//===----------------------------------------------------------------------===//
80
81bool mlirAttributeIsADictionary(MlirAttribute attr) {
82 return llvm::isa<DictionaryAttr>(Val: unwrap(c: attr));
83}
84
85MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
86 MlirNamedAttribute const *elements) {
87 SmallVector<NamedAttribute, 8> attributes;
88 attributes.reserve(N: numElements);
89 for (intptr_t i = 0; i < numElements; ++i)
90 attributes.emplace_back(unwrap(elements[i].name),
91 unwrap(c: elements[i].attribute));
92 return wrap(DictionaryAttr::get(unwrap(ctx), attributes));
93}
94
95intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
96 return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(c: attr)).size());
97}
98
99MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
100 intptr_t pos) {
101 NamedAttribute attribute =
102 llvm::cast<DictionaryAttr>(unwrap(c: attr)).getValue()[pos];
103 return {wrap(attribute.getName()), wrap(cpp: attribute.getValue())};
104}
105
106MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
107 MlirStringRef name) {
108 return wrap(llvm::cast<DictionaryAttr>(unwrap(c: attr)).get(unwrap(ref: name)));
109}
110
111MlirTypeID mlirDictionaryAttrGetTypeID(void) {
112 return wrap(DictionaryAttr::getTypeID());
113}
114
115//===----------------------------------------------------------------------===//
116// Floating point attribute.
117//===----------------------------------------------------------------------===//
118
119bool mlirAttributeIsAFloat(MlirAttribute attr) {
120 return llvm::isa<FloatAttr>(unwrap(attr));
121}
122
123MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
124 double value) {
125 return wrap(FloatAttr::get(unwrap(type), value));
126}
127
128MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type,
129 double value) {
130 return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value));
131}
132
133double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
134 return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble();
135}
136
137MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); }
138
139//===----------------------------------------------------------------------===//
140// Integer attribute.
141//===----------------------------------------------------------------------===//
142
143bool mlirAttributeIsAInteger(MlirAttribute attr) {
144 return llvm::isa<IntegerAttr>(unwrap(attr));
145}
146
147MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
148 return wrap(IntegerAttr::get(unwrap(type), value));
149}
150
151int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
152 return llvm::cast<IntegerAttr>(unwrap(attr)).getInt();
153}
154
155int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
156 return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt();
157}
158
159uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
160 return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
161}
162
163MlirTypeID mlirIntegerAttrGetTypeID(void) {
164 return wrap(IntegerAttr::getTypeID());
165}
166
167//===----------------------------------------------------------------------===//
168// Bool attribute.
169//===----------------------------------------------------------------------===//
170
171bool mlirAttributeIsABool(MlirAttribute attr) {
172 return llvm::isa<BoolAttr>(Val: unwrap(c: attr));
173}
174
175MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
176 return wrap(cpp: BoolAttr::get(context: unwrap(c: ctx), value));
177}
178
179bool mlirBoolAttrGetValue(MlirAttribute attr) {
180 return llvm::cast<BoolAttr>(Val: unwrap(c: attr)).getValue();
181}
182
183//===----------------------------------------------------------------------===//
184// Integer set attribute.
185//===----------------------------------------------------------------------===//
186
187bool mlirAttributeIsAIntegerSet(MlirAttribute attr) {
188 return llvm::isa<IntegerSetAttr>(unwrap(attr));
189}
190
191MlirTypeID mlirIntegerSetAttrGetTypeID(void) {
192 return wrap(IntegerSetAttr::getTypeID());
193}
194
195//===----------------------------------------------------------------------===//
196// Opaque attribute.
197//===----------------------------------------------------------------------===//
198
199bool mlirAttributeIsAOpaque(MlirAttribute attr) {
200 return llvm::isa<OpaqueAttr>(unwrap(attr));
201}
202
203MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
204 intptr_t dataLength, const char *data,
205 MlirType type) {
206 return wrap(
207 OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
208 StringRef(data, dataLength), unwrap(type)));
209}
210
211MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
212 return wrap(
213 llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref());
214}
215
216MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) {
217 return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData());
218}
219
220MlirTypeID mlirOpaqueAttrGetTypeID(void) {
221 return wrap(OpaqueAttr::getTypeID());
222}
223
224//===----------------------------------------------------------------------===//
225// String attribute.
226//===----------------------------------------------------------------------===//
227
228bool mlirAttributeIsAString(MlirAttribute attr) {
229 return llvm::isa<StringAttr>(Val: unwrap(c: attr));
230}
231
232MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) {
233 return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str)));
234}
235
236MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) {
237 return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type)));
238}
239
240MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) {
241 return wrap(llvm::cast<StringAttr>(unwrap(c: attr)).getValue());
242}
243
244MlirTypeID mlirStringAttrGetTypeID(void) {
245 return wrap(StringAttr::getTypeID());
246}
247
248//===----------------------------------------------------------------------===//
249// SymbolRef attribute.
250//===----------------------------------------------------------------------===//
251
252bool mlirAttributeIsASymbolRef(MlirAttribute attr) {
253 return llvm::isa<SymbolRefAttr>(unwrap(attr));
254}
255
256MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
257 intptr_t numReferences,
258 MlirAttribute const *references) {
259 SmallVector<FlatSymbolRefAttr, 4> refs;
260 refs.reserve(N: numReferences);
261 for (intptr_t i = 0; i < numReferences; ++i)
262 refs.push_back(Elt: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: references[i])));
263 auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
264 return wrap(SymbolRefAttr::get(symbolAttr, refs));
265}
266
267MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
268 return wrap(
269 llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue());
270}
271
272MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
273 return wrap(
274 llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue());
275}
276
277intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
278 return static_cast<intptr_t>(
279 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size());
280}
281
282MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
283 intptr_t pos) {
284 return wrap(
285 llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]);
286}
287
288MlirTypeID mlirSymbolRefAttrGetTypeID(void) {
289 return wrap(SymbolRefAttr::getTypeID());
290}
291
292MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) {
293 return wrap(mlir::DistinctAttr::create(referencedAttr: unwrap(c: referencedAttr)));
294}
295
296//===----------------------------------------------------------------------===//
297// Flat SymbolRef attribute.
298//===----------------------------------------------------------------------===//
299
300bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
301 return llvm::isa<FlatSymbolRefAttr>(Val: unwrap(c: attr));
302}
303
304MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) {
305 return wrap(FlatSymbolRefAttr::get(ctx: unwrap(c: ctx), value: unwrap(ref: symbol)));
306}
307
308MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) {
309 return wrap(ref: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: attr)).getValue());
310}
311
312//===----------------------------------------------------------------------===//
313// Type attribute.
314//===----------------------------------------------------------------------===//
315
316bool mlirAttributeIsAType(MlirAttribute attr) {
317 return llvm::isa<TypeAttr>(unwrap(attr));
318}
319
320MlirAttribute mlirTypeAttrGet(MlirType type) {
321 return wrap(TypeAttr::get(unwrap(type)));
322}
323
324MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
325 return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue());
326}
327
328MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); }
329
330//===----------------------------------------------------------------------===//
331// Unit attribute.
332//===----------------------------------------------------------------------===//
333
334bool mlirAttributeIsAUnit(MlirAttribute attr) {
335 return llvm::isa<UnitAttr>(unwrap(attr));
336}
337
338MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
339 return wrap(UnitAttr::get(unwrap(ctx)));
340}
341
342MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); }
343
344//===----------------------------------------------------------------------===//
345// Elements attributes.
346//===----------------------------------------------------------------------===//
347
348bool mlirAttributeIsAElements(MlirAttribute attr) {
349 return llvm::isa<ElementsAttr>(Val: unwrap(c: attr));
350}
351
352MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
353 uint64_t *idxs) {
354 return wrap(llvm::cast<ElementsAttr>(unwrap(c: attr))
355 .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]);
356}
357
358bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
359 uint64_t *idxs) {
360 return llvm::cast<ElementsAttr>(unwrap(c: attr))
361 .isValidIndex(llvm::ArrayRef(idxs, rank));
362}
363
364int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
365 return llvm::cast<ElementsAttr>(unwrap(c: attr)).getNumElements();
366}
367
368//===----------------------------------------------------------------------===//
369// Dense array attribute.
370//===----------------------------------------------------------------------===//
371
372MlirTypeID mlirDenseArrayAttrGetTypeID() {
373 return wrap(DenseArrayAttr::getTypeID());
374}
375
376//===----------------------------------------------------------------------===//
377// IsA support.
378//===----------------------------------------------------------------------===//
379
380bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) {
381 return llvm::isa<DenseBoolArrayAttr>(Val: unwrap(c: attr));
382}
383bool mlirAttributeIsADenseI8Array(MlirAttribute attr) {
384 return llvm::isa<DenseI8ArrayAttr>(Val: unwrap(c: attr));
385}
386bool mlirAttributeIsADenseI16Array(MlirAttribute attr) {
387 return llvm::isa<DenseI16ArrayAttr>(Val: unwrap(c: attr));
388}
389bool mlirAttributeIsADenseI32Array(MlirAttribute attr) {
390 return llvm::isa<DenseI32ArrayAttr>(Val: unwrap(c: attr));
391}
392bool mlirAttributeIsADenseI64Array(MlirAttribute attr) {
393 return llvm::isa<DenseI64ArrayAttr>(Val: unwrap(c: attr));
394}
395bool mlirAttributeIsADenseF32Array(MlirAttribute attr) {
396 return llvm::isa<DenseF32ArrayAttr>(Val: unwrap(c: attr));
397}
398bool mlirAttributeIsADenseF64Array(MlirAttribute attr) {
399 return llvm::isa<DenseF64ArrayAttr>(Val: unwrap(c: attr));
400}
401
402//===----------------------------------------------------------------------===//
403// Constructors.
404//===----------------------------------------------------------------------===//
405
406MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size,
407 int const *values) {
408 SmallVector<bool, 4> elements(values, values + size);
409 return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements));
410}
411MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size,
412 int8_t const *values) {
413 return wrap(
414 DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size)));
415}
416MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size,
417 int16_t const *values) {
418 return wrap(
419 DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size)));
420}
421MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size,
422 int32_t const *values) {
423 return wrap(
424 DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size)));
425}
426MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size,
427 int64_t const *values) {
428 return wrap(
429 DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size)));
430}
431MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size,
432 float const *values) {
433 return wrap(
434 DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size)));
435}
436MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
437 double const *values) {
438 return wrap(
439 DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size)));
440}
441
442//===----------------------------------------------------------------------===//
443// Accessors.
444//===----------------------------------------------------------------------===//
445
446intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
447 return llvm::cast<DenseArrayAttr>(unwrap(attr)).size();
448}
449
450//===----------------------------------------------------------------------===//
451// Indexed accessors.
452//===----------------------------------------------------------------------===//
453
454bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) {
455 return llvm::cast<DenseBoolArrayAttr>(unwrap(c: attr))[pos];
456}
457int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) {
458 return llvm::cast<DenseI8ArrayAttr>(unwrap(c: attr))[pos];
459}
460int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) {
461 return llvm::cast<DenseI16ArrayAttr>(unwrap(c: attr))[pos];
462}
463int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
464 return llvm::cast<DenseI32ArrayAttr>(unwrap(c: attr))[pos];
465}
466int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
467 return llvm::cast<DenseI64ArrayAttr>(unwrap(c: attr))[pos];
468}
469float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) {
470 return llvm::cast<DenseF32ArrayAttr>(unwrap(c: attr))[pos];
471}
472double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) {
473 return llvm::cast<DenseF64ArrayAttr>(unwrap(c: attr))[pos];
474}
475
476//===----------------------------------------------------------------------===//
477// Dense elements attribute.
478//===----------------------------------------------------------------------===//
479
480//===----------------------------------------------------------------------===//
481// IsA support.
482//===----------------------------------------------------------------------===//
483
484bool mlirAttributeIsADenseElements(MlirAttribute attr) {
485 return llvm::isa<DenseElementsAttr>(Val: unwrap(c: attr));
486}
487
488bool mlirAttributeIsADenseIntElements(MlirAttribute attr) {
489 return llvm::isa<DenseIntElementsAttr>(Val: unwrap(c: attr));
490}
491
492bool mlirAttributeIsADenseFPElements(MlirAttribute attr) {
493 return llvm::isa<DenseFPElementsAttr>(Val: unwrap(c: attr));
494}
495
496MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) {
497 return wrap(DenseIntOrFPElementsAttr::getTypeID());
498}
499
500//===----------------------------------------------------------------------===//
501// Constructors.
502//===----------------------------------------------------------------------===//
503
504MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
505 intptr_t numElements,
506 MlirAttribute const *elements) {
507 SmallVector<Attribute, 8> attributes;
508 return wrap(
509 DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
510 unwrapList(numElements, elements, attributes)));
511}
512
513MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
514 size_t rawBufferSize,
515 const void *rawBuffer) {
516 auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType));
517 ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer),
518 rawBufferSize);
519 bool isSplat = false;
520 if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
521 isSplat))
522 return mlirAttributeGetNull();
523 return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
524}
525
526MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
527 MlirAttribute element) {
528 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
529 unwrap(element)));
530}
531MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
532 bool element) {
533 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
534 element));
535}
536MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
537 uint8_t element) {
538 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
539 element));
540}
541MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
542 int8_t element) {
543 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
544 element));
545}
546MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
547 uint32_t element) {
548 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
549 element));
550}
551MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
552 int32_t element) {
553 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
554 element));
555}
556MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
557 uint64_t element) {
558 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
559 element));
560}
561MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
562 int64_t element) {
563 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
564 element));
565}
566MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
567 float element) {
568 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
569 element));
570}
571MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
572 double element) {
573 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
574 element));
575}
576
577MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
578 intptr_t numElements,
579 const int *elements) {
580 SmallVector<bool, 8> values(elements, elements + numElements);
581 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
582 values));
583}
584
585/// Creates a dense attribute with elements of the type deduced by templates.
586template <typename T>
587static MlirAttribute getDenseAttribute(MlirType shapedType,
588 intptr_t numElements,
589 const T *elements) {
590 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
591 llvm::ArrayRef(elements, numElements)));
592}
593
594MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
595 intptr_t numElements,
596 const uint8_t *elements) {
597 return getDenseAttribute(shapedType, numElements, elements);
598}
599MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
600 intptr_t numElements,
601 const int8_t *elements) {
602 return getDenseAttribute(shapedType, numElements, elements);
603}
604MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
605 intptr_t numElements,
606 const uint16_t *elements) {
607 return getDenseAttribute(shapedType, numElements, elements);
608}
609MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
610 intptr_t numElements,
611 const int16_t *elements) {
612 return getDenseAttribute(shapedType, numElements, elements);
613}
614MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
615 intptr_t numElements,
616 const uint32_t *elements) {
617 return getDenseAttribute(shapedType, numElements, elements);
618}
619MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
620 intptr_t numElements,
621 const int32_t *elements) {
622 return getDenseAttribute(shapedType, numElements, elements);
623}
624MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
625 intptr_t numElements,
626 const uint64_t *elements) {
627 return getDenseAttribute(shapedType, numElements, elements);
628}
629MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
630 intptr_t numElements,
631 const int64_t *elements) {
632 return getDenseAttribute(shapedType, numElements, elements);
633}
634MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
635 intptr_t numElements,
636 const float *elements) {
637 return getDenseAttribute(shapedType, numElements, elements);
638}
639MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
640 intptr_t numElements,
641 const double *elements) {
642 return getDenseAttribute(shapedType, numElements, elements);
643}
644MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
645 intptr_t numElements,
646 const uint16_t *elements) {
647 size_t bufferSize = numElements * 2;
648 const void *buffer = static_cast<const void *>(elements);
649 return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer);
650}
651MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
652 intptr_t numElements,
653 const uint16_t *elements) {
654 size_t bufferSize = numElements * 2;
655 const void *buffer = static_cast<const void *>(elements);
656 return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer);
657}
658
659MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
660 intptr_t numElements,
661 MlirStringRef *strs) {
662 SmallVector<StringRef, 8> values;
663 values.reserve(N: numElements);
664 for (intptr_t i = 0; i < numElements; ++i)
665 values.push_back(Elt: unwrap(ref: strs[i]));
666
667 return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
668 values));
669}
670
671MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
672 MlirType shapedType) {
673 return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr))
674 .reshape(llvm::cast<ShapedType>(unwrap(shapedType))));
675}
676
677//===----------------------------------------------------------------------===//
678// Splat accessors.
679//===----------------------------------------------------------------------===//
680
681bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
682 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).isSplat();
683}
684
685MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
686 return wrap(
687 cpp: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<Attribute>());
688}
689int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
690 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<bool>();
691}
692int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
693 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int8_t>();
694}
695uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
696 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint8_t>();
697}
698int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
699 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int32_t>();
700}
701uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
702 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint32_t>();
703}
704int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
705 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int64_t>();
706}
707uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
708 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint64_t>();
709}
710float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
711 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<float>();
712}
713double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
714 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<double>();
715}
716MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) {
717 return wrap(
718 ref: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<StringRef>());
719}
720
721//===----------------------------------------------------------------------===//
722// Indexed accessors.
723//===----------------------------------------------------------------------===//
724
725bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
726 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<bool>()[pos];
727}
728int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
729 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int8_t>()[pos];
730}
731uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
732 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint8_t>()[pos];
733}
734int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
735 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int16_t>()[pos];
736}
737uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
738 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint16_t>()[pos];
739}
740int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
741 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int32_t>()[pos];
742}
743uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
744 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint32_t>()[pos];
745}
746int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
747 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int64_t>()[pos];
748}
749uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
750 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint64_t>()[pos];
751}
752float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
753 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<float>()[pos];
754}
755double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
756 return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<double>()[pos];
757}
758MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr,
759 intptr_t pos) {
760 return wrap(
761 llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<StringRef>()[pos]);
762}
763
764//===----------------------------------------------------------------------===//
765// Raw data accessors.
766//===----------------------------------------------------------------------===//
767
768const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) {
769 return static_cast<const void *>(
770 llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getRawData().data());
771}
772
773//===----------------------------------------------------------------------===//
774// Resource blob attributes.
775//===----------------------------------------------------------------------===//
776
777bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) {
778 return llvm::isa<DenseResourceElementsAttr>(unwrap(attr));
779}
780
781MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet(
782 MlirType shapedType, MlirStringRef name, void *data, size_t dataLength,
783 size_t dataAlignment, bool dataIsMutable,
784 void (*deleter)(void *userData, const void *data, size_t size,
785 size_t align),
786 void *userData) {
787 AsmResourceBlob::DeleterFn cppDeleter = {};
788 if (deleter) {
789 cppDeleter = [deleter, userData](void *data, size_t size, size_t align) {
790 deleter(userData, data, size, align);
791 };
792 }
793 AsmResourceBlob blob(
794 llvm::ArrayRef(static_cast<const char *>(data), dataLength),
795 dataAlignment, std::move(cppDeleter), dataIsMutable);
796 return wrap(
797 DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
798 unwrap(name), std::move(blob)));
799}
800
801template <typename U, typename T>
802static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name,
803 intptr_t numElements, const T *elements) {
804 return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name),
805 UnmanagedAsmResourceBlob::allocateInferAlign(
806 llvm::ArrayRef(elements, numElements))));
807}
808
809MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet(
810 MlirType shapedType, MlirStringRef name, intptr_t numElements,
811 const int *elements) {
812 return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name,
813 numElements, elements);
814}
815MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet(
816 MlirType shapedType, MlirStringRef name, intptr_t numElements,
817 const uint8_t *elements) {
818 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
819 numElements, elements);
820}
821MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet(
822 MlirType shapedType, MlirStringRef name, intptr_t numElements,
823 const uint16_t *elements) {
824 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
825 numElements, elements);
826}
827MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet(
828 MlirType shapedType, MlirStringRef name, intptr_t numElements,
829 const uint32_t *elements) {
830 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
831 numElements, elements);
832}
833MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet(
834 MlirType shapedType, MlirStringRef name, intptr_t numElements,
835 const uint64_t *elements) {
836 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
837 numElements, elements);
838}
839MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet(
840 MlirType shapedType, MlirStringRef name, intptr_t numElements,
841 const int8_t *elements) {
842 return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name,
843 numElements, elements);
844}
845MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet(
846 MlirType shapedType, MlirStringRef name, intptr_t numElements,
847 const int16_t *elements) {
848 return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name,
849 numElements, elements);
850}
851MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet(
852 MlirType shapedType, MlirStringRef name, intptr_t numElements,
853 const int32_t *elements) {
854 return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name,
855 numElements, elements);
856}
857MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet(
858 MlirType shapedType, MlirStringRef name, intptr_t numElements,
859 const int64_t *elements) {
860 return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name,
861 numElements, elements);
862}
863MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet(
864 MlirType shapedType, MlirStringRef name, intptr_t numElements,
865 const float *elements) {
866 return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name,
867 numElements, elements);
868}
869MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet(
870 MlirType shapedType, MlirStringRef name, intptr_t numElements,
871 const double *elements) {
872 return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name,
873 numElements, elements);
874}
875template <typename U, typename T>
876static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) {
877 return (*llvm::cast<U>(unwrap(c: attr)).tryGetAsArrayRef())[pos];
878}
879
880bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr,
881 intptr_t pos) {
882 return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos);
883}
884uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr,
885 intptr_t pos) {
886 return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos);
887}
888uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr,
889 intptr_t pos) {
890 return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr,
891 pos);
892}
893uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr,
894 intptr_t pos) {
895 return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr,
896 pos);
897}
898uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr,
899 intptr_t pos) {
900 return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr,
901 pos);
902}
903int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr,
904 intptr_t pos) {
905 return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos);
906}
907int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr,
908 intptr_t pos) {
909 return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos);
910}
911int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr,
912 intptr_t pos) {
913 return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos);
914}
915int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr,
916 intptr_t pos) {
917 return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos);
918}
919float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr,
920 intptr_t pos) {
921 return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos);
922}
923double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr,
924 intptr_t pos) {
925 return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos);
926}
927
928//===----------------------------------------------------------------------===//
929// Sparse elements attribute.
930//===----------------------------------------------------------------------===//
931
932bool mlirAttributeIsASparseElements(MlirAttribute attr) {
933 return llvm::isa<SparseElementsAttr>(unwrap(attr));
934}
935
936MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
937 MlirAttribute denseIndices,
938 MlirAttribute denseValues) {
939 return wrap(SparseElementsAttr::get(
940 llvm::cast<ShapedType>(unwrap(shapedType)),
941 llvm::cast<DenseElementsAttr>(unwrap(denseIndices)),
942 llvm::cast<DenseElementsAttr>(unwrap(denseValues))));
943}
944
945MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
946 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices());
947}
948
949MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
950 return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues());
951}
952
953MlirTypeID mlirSparseElementsAttrGetTypeID(void) {
954 return wrap(SparseElementsAttr::getTypeID());
955}
956
957//===----------------------------------------------------------------------===//
958// Strided layout attribute.
959//===----------------------------------------------------------------------===//
960
961bool mlirAttributeIsAStridedLayout(MlirAttribute attr) {
962 return llvm::isa<StridedLayoutAttr>(unwrap(attr));
963}
964
965MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset,
966 intptr_t numStrides,
967 const int64_t *strides) {
968 return wrap(StridedLayoutAttr::get(unwrap(ctx), offset,
969 ArrayRef<int64_t>(strides, numStrides)));
970}
971
972int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) {
973 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset();
974}
975
976intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) {
977 return static_cast<intptr_t>(
978 llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size());
979}
980
981int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) {
982 return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos];
983}
984
985MlirTypeID mlirStridedLayoutAttrGetTypeID(void) {
986 return wrap(StridedLayoutAttr::getTypeID());
987}
988

source code of mlir/lib/CAPI/IR/BuiltinAttributes.cpp