1 | //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/BuiltinAttributes.h" |
10 | #include "mlir-c/Support.h" |
11 | #include "mlir/CAPI/AffineMap.h" |
12 | #include "mlir/CAPI/IR.h" |
13 | #include "mlir/CAPI/Support.h" |
14 | #include "mlir/IR/AsmState.h" |
15 | #include "mlir/IR/Attributes.h" |
16 | #include "mlir/IR/BuiltinAttributes.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | |
19 | using namespace mlir; |
20 | |
21 | MlirAttribute mlirAttributeGetNull() { return {.ptr: nullptr}; } |
22 | |
23 | //===----------------------------------------------------------------------===// |
24 | // Location attribute. |
25 | //===----------------------------------------------------------------------===// |
26 | |
27 | bool mlirAttributeIsALocation(MlirAttribute attr) { |
28 | return llvm::isa<LocationAttr>(Val: unwrap(c: attr)); |
29 | } |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // Affine map attribute. |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | bool mlirAttributeIsAAffineMap(MlirAttribute attr) { |
36 | return llvm::isa<AffineMapAttr>(unwrap(attr)); |
37 | } |
38 | |
39 | MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { |
40 | return wrap(AffineMapAttr::get(unwrap(map))); |
41 | } |
42 | |
43 | MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { |
44 | return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue()); |
45 | } |
46 | |
47 | MlirTypeID mlirAffineMapAttrGetTypeID(void) { |
48 | return wrap(AffineMapAttr::getTypeID()); |
49 | } |
50 | |
51 | //===----------------------------------------------------------------------===// |
52 | // Array attribute. |
53 | //===----------------------------------------------------------------------===// |
54 | |
55 | bool mlirAttributeIsAArray(MlirAttribute attr) { |
56 | return llvm::isa<ArrayAttr>(unwrap(attr)); |
57 | } |
58 | |
59 | MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, |
60 | MlirAttribute const *elements) { |
61 | SmallVector<Attribute, 8> attrs; |
62 | return wrap( |
63 | ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements), |
64 | elements, attrs))); |
65 | } |
66 | |
67 | intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { |
68 | return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size()); |
69 | } |
70 | |
71 | MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { |
72 | return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]); |
73 | } |
74 | |
75 | MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } |
76 | |
77 | //===----------------------------------------------------------------------===// |
78 | // Dictionary attribute. |
79 | //===----------------------------------------------------------------------===// |
80 | |
81 | bool mlirAttributeIsADictionary(MlirAttribute attr) { |
82 | return llvm::isa<DictionaryAttr>(Val: unwrap(c: attr)); |
83 | } |
84 | |
85 | MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, |
86 | MlirNamedAttribute const *elements) { |
87 | SmallVector<NamedAttribute, 8> attributes; |
88 | attributes.reserve(N: numElements); |
89 | for (intptr_t i = 0; i < numElements; ++i) |
90 | attributes.emplace_back(unwrap(elements[i].name), |
91 | unwrap(c: elements[i].attribute)); |
92 | return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); |
93 | } |
94 | |
95 | intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { |
96 | return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(c: attr)).size()); |
97 | } |
98 | |
99 | MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, |
100 | intptr_t pos) { |
101 | NamedAttribute attribute = |
102 | llvm::cast<DictionaryAttr>(unwrap(c: attr)).getValue()[pos]; |
103 | return {wrap(attribute.getName()), wrap(cpp: attribute.getValue())}; |
104 | } |
105 | |
106 | MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, |
107 | MlirStringRef name) { |
108 | return wrap(llvm::cast<DictionaryAttr>(unwrap(c: attr)).get(unwrap(ref: name))); |
109 | } |
110 | |
111 | MlirTypeID mlirDictionaryAttrGetTypeID(void) { |
112 | return wrap(DictionaryAttr::getTypeID()); |
113 | } |
114 | |
115 | //===----------------------------------------------------------------------===// |
116 | // Floating point attribute. |
117 | //===----------------------------------------------------------------------===// |
118 | |
119 | bool mlirAttributeIsAFloat(MlirAttribute attr) { |
120 | return llvm::isa<FloatAttr>(unwrap(attr)); |
121 | } |
122 | |
123 | MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, |
124 | double value) { |
125 | return wrap(FloatAttr::get(unwrap(type), value)); |
126 | } |
127 | |
128 | MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, |
129 | double value) { |
130 | return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); |
131 | } |
132 | |
133 | double mlirFloatAttrGetValueDouble(MlirAttribute attr) { |
134 | return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble(); |
135 | } |
136 | |
137 | MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } |
138 | |
139 | //===----------------------------------------------------------------------===// |
140 | // Integer attribute. |
141 | //===----------------------------------------------------------------------===// |
142 | |
143 | bool mlirAttributeIsAInteger(MlirAttribute attr) { |
144 | return llvm::isa<IntegerAttr>(unwrap(attr)); |
145 | } |
146 | |
147 | MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { |
148 | return wrap(IntegerAttr::get(unwrap(type), value)); |
149 | } |
150 | |
151 | int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { |
152 | return llvm::cast<IntegerAttr>(unwrap(attr)).getInt(); |
153 | } |
154 | |
155 | int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { |
156 | return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt(); |
157 | } |
158 | |
159 | uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { |
160 | return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt(); |
161 | } |
162 | |
163 | MlirTypeID mlirIntegerAttrGetTypeID(void) { |
164 | return wrap(IntegerAttr::getTypeID()); |
165 | } |
166 | |
167 | //===----------------------------------------------------------------------===// |
168 | // Bool attribute. |
169 | //===----------------------------------------------------------------------===// |
170 | |
171 | bool mlirAttributeIsABool(MlirAttribute attr) { |
172 | return llvm::isa<BoolAttr>(Val: unwrap(c: attr)); |
173 | } |
174 | |
175 | MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { |
176 | return wrap(cpp: BoolAttr::get(context: unwrap(c: ctx), value)); |
177 | } |
178 | |
179 | bool mlirBoolAttrGetValue(MlirAttribute attr) { |
180 | return llvm::cast<BoolAttr>(Val: unwrap(c: attr)).getValue(); |
181 | } |
182 | |
183 | //===----------------------------------------------------------------------===// |
184 | // Integer set attribute. |
185 | //===----------------------------------------------------------------------===// |
186 | |
187 | bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { |
188 | return llvm::isa<IntegerSetAttr>(unwrap(attr)); |
189 | } |
190 | |
191 | MlirTypeID mlirIntegerSetAttrGetTypeID(void) { |
192 | return wrap(IntegerSetAttr::getTypeID()); |
193 | } |
194 | |
195 | //===----------------------------------------------------------------------===// |
196 | // Opaque attribute. |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | bool mlirAttributeIsAOpaque(MlirAttribute attr) { |
200 | return llvm::isa<OpaqueAttr>(unwrap(attr)); |
201 | } |
202 | |
203 | MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, |
204 | intptr_t dataLength, const char *data, |
205 | MlirType type) { |
206 | return wrap( |
207 | OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), |
208 | StringRef(data, dataLength), unwrap(type))); |
209 | } |
210 | |
211 | MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { |
212 | return wrap( |
213 | llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref()); |
214 | } |
215 | |
216 | MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { |
217 | return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData()); |
218 | } |
219 | |
220 | MlirTypeID mlirOpaqueAttrGetTypeID(void) { |
221 | return wrap(OpaqueAttr::getTypeID()); |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // String attribute. |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | bool mlirAttributeIsAString(MlirAttribute attr) { |
229 | return llvm::isa<StringAttr>(Val: unwrap(c: attr)); |
230 | } |
231 | |
232 | MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { |
233 | return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); |
234 | } |
235 | |
236 | MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { |
237 | return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); |
238 | } |
239 | |
240 | MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { |
241 | return wrap(llvm::cast<StringAttr>(unwrap(c: attr)).getValue()); |
242 | } |
243 | |
244 | MlirTypeID mlirStringAttrGetTypeID(void) { |
245 | return wrap(StringAttr::getTypeID()); |
246 | } |
247 | |
248 | //===----------------------------------------------------------------------===// |
249 | // SymbolRef attribute. |
250 | //===----------------------------------------------------------------------===// |
251 | |
252 | bool mlirAttributeIsASymbolRef(MlirAttribute attr) { |
253 | return llvm::isa<SymbolRefAttr>(unwrap(attr)); |
254 | } |
255 | |
256 | MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, |
257 | intptr_t numReferences, |
258 | MlirAttribute const *references) { |
259 | SmallVector<FlatSymbolRefAttr, 4> refs; |
260 | refs.reserve(N: numReferences); |
261 | for (intptr_t i = 0; i < numReferences; ++i) |
262 | refs.push_back(Elt: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: references[i]))); |
263 | auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); |
264 | return wrap(SymbolRefAttr::get(symbolAttr, refs)); |
265 | } |
266 | |
267 | MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { |
268 | return wrap( |
269 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue()); |
270 | } |
271 | |
272 | MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { |
273 | return wrap( |
274 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue()); |
275 | } |
276 | |
277 | intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { |
278 | return static_cast<intptr_t>( |
279 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size()); |
280 | } |
281 | |
282 | MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, |
283 | intptr_t pos) { |
284 | return wrap( |
285 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]); |
286 | } |
287 | |
288 | MlirTypeID mlirSymbolRefAttrGetTypeID(void) { |
289 | return wrap(SymbolRefAttr::getTypeID()); |
290 | } |
291 | |
292 | MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) { |
293 | return wrap(mlir::DistinctAttr::create(referencedAttr: unwrap(c: referencedAttr))); |
294 | } |
295 | |
296 | //===----------------------------------------------------------------------===// |
297 | // Flat SymbolRef attribute. |
298 | //===----------------------------------------------------------------------===// |
299 | |
300 | bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { |
301 | return llvm::isa<FlatSymbolRefAttr>(Val: unwrap(c: attr)); |
302 | } |
303 | |
304 | MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { |
305 | return wrap(FlatSymbolRefAttr::get(ctx: unwrap(c: ctx), value: unwrap(ref: symbol))); |
306 | } |
307 | |
308 | MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { |
309 | return wrap(ref: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: attr)).getValue()); |
310 | } |
311 | |
312 | //===----------------------------------------------------------------------===// |
313 | // Type attribute. |
314 | //===----------------------------------------------------------------------===// |
315 | |
316 | bool mlirAttributeIsAType(MlirAttribute attr) { |
317 | return llvm::isa<TypeAttr>(unwrap(attr)); |
318 | } |
319 | |
320 | MlirAttribute mlirTypeAttrGet(MlirType type) { |
321 | return wrap(TypeAttr::get(unwrap(type))); |
322 | } |
323 | |
324 | MlirType mlirTypeAttrGetValue(MlirAttribute attr) { |
325 | return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue()); |
326 | } |
327 | |
328 | MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } |
329 | |
330 | //===----------------------------------------------------------------------===// |
331 | // Unit attribute. |
332 | //===----------------------------------------------------------------------===// |
333 | |
334 | bool mlirAttributeIsAUnit(MlirAttribute attr) { |
335 | return llvm::isa<UnitAttr>(unwrap(attr)); |
336 | } |
337 | |
338 | MlirAttribute mlirUnitAttrGet(MlirContext ctx) { |
339 | return wrap(UnitAttr::get(unwrap(ctx))); |
340 | } |
341 | |
342 | MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } |
343 | |
344 | //===----------------------------------------------------------------------===// |
345 | // Elements attributes. |
346 | //===----------------------------------------------------------------------===// |
347 | |
348 | bool mlirAttributeIsAElements(MlirAttribute attr) { |
349 | return llvm::isa<ElementsAttr>(Val: unwrap(c: attr)); |
350 | } |
351 | |
352 | MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, |
353 | uint64_t *idxs) { |
354 | return wrap(llvm::cast<ElementsAttr>(unwrap(c: attr)) |
355 | .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]); |
356 | } |
357 | |
358 | bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, |
359 | uint64_t *idxs) { |
360 | return llvm::cast<ElementsAttr>(unwrap(c: attr)) |
361 | .isValidIndex(llvm::ArrayRef(idxs, rank)); |
362 | } |
363 | |
364 | int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { |
365 | return llvm::cast<ElementsAttr>(unwrap(c: attr)).getNumElements(); |
366 | } |
367 | |
368 | //===----------------------------------------------------------------------===// |
369 | // Dense array attribute. |
370 | //===----------------------------------------------------------------------===// |
371 | |
372 | MlirTypeID mlirDenseArrayAttrGetTypeID() { |
373 | return wrap(DenseArrayAttr::getTypeID()); |
374 | } |
375 | |
376 | //===----------------------------------------------------------------------===// |
377 | // IsA support. |
378 | //===----------------------------------------------------------------------===// |
379 | |
380 | bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { |
381 | return llvm::isa<DenseBoolArrayAttr>(Val: unwrap(c: attr)); |
382 | } |
383 | bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { |
384 | return llvm::isa<DenseI8ArrayAttr>(Val: unwrap(c: attr)); |
385 | } |
386 | bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { |
387 | return llvm::isa<DenseI16ArrayAttr>(Val: unwrap(c: attr)); |
388 | } |
389 | bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { |
390 | return llvm::isa<DenseI32ArrayAttr>(Val: unwrap(c: attr)); |
391 | } |
392 | bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { |
393 | return llvm::isa<DenseI64ArrayAttr>(Val: unwrap(c: attr)); |
394 | } |
395 | bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { |
396 | return llvm::isa<DenseF32ArrayAttr>(Val: unwrap(c: attr)); |
397 | } |
398 | bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { |
399 | return llvm::isa<DenseF64ArrayAttr>(Val: unwrap(c: attr)); |
400 | } |
401 | |
402 | //===----------------------------------------------------------------------===// |
403 | // Constructors. |
404 | //===----------------------------------------------------------------------===// |
405 | |
406 | MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, |
407 | int const *values) { |
408 | SmallVector<bool, 4> elements(values, values + size); |
409 | return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); |
410 | } |
411 | MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, |
412 | int8_t const *values) { |
413 | return wrap( |
414 | DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size))); |
415 | } |
416 | MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, |
417 | int16_t const *values) { |
418 | return wrap( |
419 | DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size))); |
420 | } |
421 | MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, |
422 | int32_t const *values) { |
423 | return wrap( |
424 | DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size))); |
425 | } |
426 | MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, |
427 | int64_t const *values) { |
428 | return wrap( |
429 | DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size))); |
430 | } |
431 | MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, |
432 | float const *values) { |
433 | return wrap( |
434 | DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size))); |
435 | } |
436 | MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, |
437 | double const *values) { |
438 | return wrap( |
439 | DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size))); |
440 | } |
441 | |
442 | //===----------------------------------------------------------------------===// |
443 | // Accessors. |
444 | //===----------------------------------------------------------------------===// |
445 | |
446 | intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { |
447 | return llvm::cast<DenseArrayAttr>(unwrap(attr)).size(); |
448 | } |
449 | |
450 | //===----------------------------------------------------------------------===// |
451 | // Indexed accessors. |
452 | //===----------------------------------------------------------------------===// |
453 | |
454 | bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { |
455 | return llvm::cast<DenseBoolArrayAttr>(unwrap(c: attr))[pos]; |
456 | } |
457 | int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
458 | return llvm::cast<DenseI8ArrayAttr>(unwrap(c: attr))[pos]; |
459 | } |
460 | int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
461 | return llvm::cast<DenseI16ArrayAttr>(unwrap(c: attr))[pos]; |
462 | } |
463 | int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
464 | return llvm::cast<DenseI32ArrayAttr>(unwrap(c: attr))[pos]; |
465 | } |
466 | int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
467 | return llvm::cast<DenseI64ArrayAttr>(unwrap(c: attr))[pos]; |
468 | } |
469 | float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
470 | return llvm::cast<DenseF32ArrayAttr>(unwrap(c: attr))[pos]; |
471 | } |
472 | double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
473 | return llvm::cast<DenseF64ArrayAttr>(unwrap(c: attr))[pos]; |
474 | } |
475 | |
476 | //===----------------------------------------------------------------------===// |
477 | // Dense elements attribute. |
478 | //===----------------------------------------------------------------------===// |
479 | |
480 | //===----------------------------------------------------------------------===// |
481 | // IsA support. |
482 | //===----------------------------------------------------------------------===// |
483 | |
484 | bool mlirAttributeIsADenseElements(MlirAttribute attr) { |
485 | return llvm::isa<DenseElementsAttr>(Val: unwrap(c: attr)); |
486 | } |
487 | |
488 | bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { |
489 | return llvm::isa<DenseIntElementsAttr>(Val: unwrap(c: attr)); |
490 | } |
491 | |
492 | bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { |
493 | return llvm::isa<DenseFPElementsAttr>(Val: unwrap(c: attr)); |
494 | } |
495 | |
496 | MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { |
497 | return wrap(DenseIntOrFPElementsAttr::getTypeID()); |
498 | } |
499 | |
500 | //===----------------------------------------------------------------------===// |
501 | // Constructors. |
502 | //===----------------------------------------------------------------------===// |
503 | |
504 | MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, |
505 | intptr_t numElements, |
506 | MlirAttribute const *elements) { |
507 | SmallVector<Attribute, 8> attributes; |
508 | return wrap( |
509 | DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
510 | unwrapList(numElements, elements, attributes))); |
511 | } |
512 | |
513 | MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, |
514 | size_t rawBufferSize, |
515 | const void *rawBuffer) { |
516 | auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType)); |
517 | ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer), |
518 | rawBufferSize); |
519 | bool isSplat = false; |
520 | if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, |
521 | isSplat)) |
522 | return mlirAttributeGetNull(); |
523 | return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); |
524 | } |
525 | |
526 | MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, |
527 | MlirAttribute element) { |
528 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
529 | unwrap(element))); |
530 | } |
531 | MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, |
532 | bool element) { |
533 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
534 | element)); |
535 | } |
536 | MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, |
537 | uint8_t element) { |
538 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
539 | element)); |
540 | } |
541 | MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, |
542 | int8_t element) { |
543 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
544 | element)); |
545 | } |
546 | MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, |
547 | uint32_t element) { |
548 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
549 | element)); |
550 | } |
551 | MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, |
552 | int32_t element) { |
553 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
554 | element)); |
555 | } |
556 | MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, |
557 | uint64_t element) { |
558 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
559 | element)); |
560 | } |
561 | MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, |
562 | int64_t element) { |
563 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
564 | element)); |
565 | } |
566 | MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, |
567 | float element) { |
568 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
569 | element)); |
570 | } |
571 | MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, |
572 | double element) { |
573 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
574 | element)); |
575 | } |
576 | |
577 | MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, |
578 | intptr_t numElements, |
579 | const int *elements) { |
580 | SmallVector<bool, 8> values(elements, elements + numElements); |
581 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
582 | values)); |
583 | } |
584 | |
585 | /// Creates a dense attribute with elements of the type deduced by templates. |
586 | template <typename T> |
587 | static MlirAttribute getDenseAttribute(MlirType shapedType, |
588 | intptr_t numElements, |
589 | const T *elements) { |
590 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
591 | llvm::ArrayRef(elements, numElements))); |
592 | } |
593 | |
594 | MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, |
595 | intptr_t numElements, |
596 | const uint8_t *elements) { |
597 | return getDenseAttribute(shapedType, numElements, elements); |
598 | } |
599 | MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, |
600 | intptr_t numElements, |
601 | const int8_t *elements) { |
602 | return getDenseAttribute(shapedType, numElements, elements); |
603 | } |
604 | MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, |
605 | intptr_t numElements, |
606 | const uint16_t *elements) { |
607 | return getDenseAttribute(shapedType, numElements, elements); |
608 | } |
609 | MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, |
610 | intptr_t numElements, |
611 | const int16_t *elements) { |
612 | return getDenseAttribute(shapedType, numElements, elements); |
613 | } |
614 | MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, |
615 | intptr_t numElements, |
616 | const uint32_t *elements) { |
617 | return getDenseAttribute(shapedType, numElements, elements); |
618 | } |
619 | MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, |
620 | intptr_t numElements, |
621 | const int32_t *elements) { |
622 | return getDenseAttribute(shapedType, numElements, elements); |
623 | } |
624 | MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, |
625 | intptr_t numElements, |
626 | const uint64_t *elements) { |
627 | return getDenseAttribute(shapedType, numElements, elements); |
628 | } |
629 | MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, |
630 | intptr_t numElements, |
631 | const int64_t *elements) { |
632 | return getDenseAttribute(shapedType, numElements, elements); |
633 | } |
634 | MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, |
635 | intptr_t numElements, |
636 | const float *elements) { |
637 | return getDenseAttribute(shapedType, numElements, elements); |
638 | } |
639 | MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, |
640 | intptr_t numElements, |
641 | const double *elements) { |
642 | return getDenseAttribute(shapedType, numElements, elements); |
643 | } |
644 | MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, |
645 | intptr_t numElements, |
646 | const uint16_t *elements) { |
647 | size_t bufferSize = numElements * 2; |
648 | const void *buffer = static_cast<const void *>(elements); |
649 | return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer); |
650 | } |
651 | MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, |
652 | intptr_t numElements, |
653 | const uint16_t *elements) { |
654 | size_t bufferSize = numElements * 2; |
655 | const void *buffer = static_cast<const void *>(elements); |
656 | return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer); |
657 | } |
658 | |
659 | MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, |
660 | intptr_t numElements, |
661 | MlirStringRef *strs) { |
662 | SmallVector<StringRef, 8> values; |
663 | values.reserve(N: numElements); |
664 | for (intptr_t i = 0; i < numElements; ++i) |
665 | values.push_back(Elt: unwrap(ref: strs[i])); |
666 | |
667 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
668 | values)); |
669 | } |
670 | |
671 | MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, |
672 | MlirType shapedType) { |
673 | return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr)) |
674 | .reshape(llvm::cast<ShapedType>(unwrap(shapedType)))); |
675 | } |
676 | |
677 | //===----------------------------------------------------------------------===// |
678 | // Splat accessors. |
679 | //===----------------------------------------------------------------------===// |
680 | |
681 | bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { |
682 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).isSplat(); |
683 | } |
684 | |
685 | MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { |
686 | return wrap( |
687 | cpp: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<Attribute>()); |
688 | } |
689 | int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { |
690 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<bool>(); |
691 | } |
692 | int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { |
693 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int8_t>(); |
694 | } |
695 | uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { |
696 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint8_t>(); |
697 | } |
698 | int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { |
699 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int32_t>(); |
700 | } |
701 | uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { |
702 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint32_t>(); |
703 | } |
704 | int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { |
705 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int64_t>(); |
706 | } |
707 | uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { |
708 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint64_t>(); |
709 | } |
710 | float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { |
711 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<float>(); |
712 | } |
713 | double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { |
714 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<double>(); |
715 | } |
716 | MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { |
717 | return wrap( |
718 | ref: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<StringRef>()); |
719 | } |
720 | |
721 | //===----------------------------------------------------------------------===// |
722 | // Indexed accessors. |
723 | //===----------------------------------------------------------------------===// |
724 | |
725 | bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { |
726 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<bool>()[pos]; |
727 | } |
728 | int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { |
729 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int8_t>()[pos]; |
730 | } |
731 | uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { |
732 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint8_t>()[pos]; |
733 | } |
734 | int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { |
735 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int16_t>()[pos]; |
736 | } |
737 | uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { |
738 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint16_t>()[pos]; |
739 | } |
740 | int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { |
741 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int32_t>()[pos]; |
742 | } |
743 | uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { |
744 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint32_t>()[pos]; |
745 | } |
746 | int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { |
747 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int64_t>()[pos]; |
748 | } |
749 | uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { |
750 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint64_t>()[pos]; |
751 | } |
752 | float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { |
753 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<float>()[pos]; |
754 | } |
755 | double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { |
756 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<double>()[pos]; |
757 | } |
758 | MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, |
759 | intptr_t pos) { |
760 | return wrap( |
761 | llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<StringRef>()[pos]); |
762 | } |
763 | |
764 | //===----------------------------------------------------------------------===// |
765 | // Raw data accessors. |
766 | //===----------------------------------------------------------------------===// |
767 | |
768 | const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { |
769 | return static_cast<const void *>( |
770 | llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getRawData().data()); |
771 | } |
772 | |
773 | //===----------------------------------------------------------------------===// |
774 | // Resource blob attributes. |
775 | //===----------------------------------------------------------------------===// |
776 | |
777 | bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { |
778 | return llvm::isa<DenseResourceElementsAttr>(unwrap(attr)); |
779 | } |
780 | |
781 | MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( |
782 | MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, |
783 | size_t dataAlignment, bool dataIsMutable, |
784 | void (*deleter)(void *userData, const void *data, size_t size, |
785 | size_t align), |
786 | void *userData) { |
787 | AsmResourceBlob::DeleterFn cppDeleter = {}; |
788 | if (deleter) { |
789 | cppDeleter = [deleter, userData](void *data, size_t size, size_t align) { |
790 | deleter(userData, data, size, align); |
791 | }; |
792 | } |
793 | AsmResourceBlob blob( |
794 | llvm::ArrayRef(static_cast<const char *>(data), dataLength), |
795 | dataAlignment, std::move(cppDeleter), dataIsMutable); |
796 | return wrap( |
797 | DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
798 | unwrap(name), std::move(blob))); |
799 | } |
800 | |
801 | template <typename U, typename T> |
802 | static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, |
803 | intptr_t numElements, const T *elements) { |
804 | return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name), |
805 | UnmanagedAsmResourceBlob::allocateInferAlign( |
806 | llvm::ArrayRef(elements, numElements)))); |
807 | } |
808 | |
809 | MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( |
810 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
811 | const int *elements) { |
812 | return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name, |
813 | numElements, elements); |
814 | } |
815 | MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( |
816 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
817 | const uint8_t *elements) { |
818 | return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name, |
819 | numElements, elements); |
820 | } |
821 | MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet( |
822 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
823 | const uint16_t *elements) { |
824 | return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name, |
825 | numElements, elements); |
826 | } |
827 | MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet( |
828 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
829 | const uint32_t *elements) { |
830 | return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name, |
831 | numElements, elements); |
832 | } |
833 | MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet( |
834 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
835 | const uint64_t *elements) { |
836 | return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name, |
837 | numElements, elements); |
838 | } |
839 | MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( |
840 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
841 | const int8_t *elements) { |
842 | return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name, |
843 | numElements, elements); |
844 | } |
845 | MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( |
846 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
847 | const int16_t *elements) { |
848 | return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name, |
849 | numElements, elements); |
850 | } |
851 | MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( |
852 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
853 | const int32_t *elements) { |
854 | return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name, |
855 | numElements, elements); |
856 | } |
857 | MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( |
858 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
859 | const int64_t *elements) { |
860 | return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name, |
861 | numElements, elements); |
862 | } |
863 | MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( |
864 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
865 | const float *elements) { |
866 | return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name, |
867 | numElements, elements); |
868 | } |
869 | MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet( |
870 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
871 | const double *elements) { |
872 | return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name, |
873 | numElements, elements); |
874 | } |
875 | template <typename U, typename T> |
876 | static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { |
877 | return (*llvm::cast<U>(unwrap(c: attr)).tryGetAsArrayRef())[pos]; |
878 | } |
879 | |
880 | bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, |
881 | intptr_t pos) { |
882 | return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos); |
883 | } |
884 | uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, |
885 | intptr_t pos) { |
886 | return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos); |
887 | } |
888 | uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, |
889 | intptr_t pos) { |
890 | return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr, |
891 | pos); |
892 | } |
893 | uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, |
894 | intptr_t pos) { |
895 | return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr, |
896 | pos); |
897 | } |
898 | uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, |
899 | intptr_t pos) { |
900 | return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr, |
901 | pos); |
902 | } |
903 | int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, |
904 | intptr_t pos) { |
905 | return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos); |
906 | } |
907 | int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, |
908 | intptr_t pos) { |
909 | return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos); |
910 | } |
911 | int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, |
912 | intptr_t pos) { |
913 | return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos); |
914 | } |
915 | int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, |
916 | intptr_t pos) { |
917 | return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos); |
918 | } |
919 | float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, |
920 | intptr_t pos) { |
921 | return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos); |
922 | } |
923 | double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, |
924 | intptr_t pos) { |
925 | return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos); |
926 | } |
927 | |
928 | //===----------------------------------------------------------------------===// |
929 | // Sparse elements attribute. |
930 | //===----------------------------------------------------------------------===// |
931 | |
932 | bool mlirAttributeIsASparseElements(MlirAttribute attr) { |
933 | return llvm::isa<SparseElementsAttr>(unwrap(attr)); |
934 | } |
935 | |
936 | MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, |
937 | MlirAttribute denseIndices, |
938 | MlirAttribute denseValues) { |
939 | return wrap(SparseElementsAttr::get( |
940 | llvm::cast<ShapedType>(unwrap(shapedType)), |
941 | llvm::cast<DenseElementsAttr>(unwrap(denseIndices)), |
942 | llvm::cast<DenseElementsAttr>(unwrap(denseValues)))); |
943 | } |
944 | |
945 | MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { |
946 | return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices()); |
947 | } |
948 | |
949 | MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { |
950 | return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues()); |
951 | } |
952 | |
953 | MlirTypeID mlirSparseElementsAttrGetTypeID(void) { |
954 | return wrap(SparseElementsAttr::getTypeID()); |
955 | } |
956 | |
957 | //===----------------------------------------------------------------------===// |
958 | // Strided layout attribute. |
959 | //===----------------------------------------------------------------------===// |
960 | |
961 | bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { |
962 | return llvm::isa<StridedLayoutAttr>(unwrap(attr)); |
963 | } |
964 | |
965 | MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, |
966 | intptr_t numStrides, |
967 | const int64_t *strides) { |
968 | return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, |
969 | ArrayRef<int64_t>(strides, numStrides))); |
970 | } |
971 | |
972 | int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { |
973 | return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset(); |
974 | } |
975 | |
976 | intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { |
977 | return static_cast<intptr_t>( |
978 | llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size()); |
979 | } |
980 | |
981 | int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { |
982 | return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos]; |
983 | } |
984 | |
985 | MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { |
986 | return wrap(StridedLayoutAttr::getTypeID()); |
987 | } |
988 | |