| 1 | //===- BuiltinAttributes.cpp - C Interface to MLIR Builtin Attributes -----===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir-c/BuiltinAttributes.h" |
| 10 | #include "mlir-c/Support.h" |
| 11 | #include "mlir/CAPI/AffineMap.h" |
| 12 | #include "mlir/CAPI/IR.h" |
| 13 | #include "mlir/CAPI/IntegerSet.h" |
| 14 | #include "mlir/CAPI/Support.h" |
| 15 | #include "mlir/IR/AsmState.h" |
| 16 | #include "mlir/IR/Attributes.h" |
| 17 | #include "mlir/IR/BuiltinAttributes.h" |
| 18 | #include "mlir/IR/BuiltinTypes.h" |
| 19 | |
| 20 | using namespace mlir; |
| 21 | |
| 22 | MlirAttribute mlirAttributeGetNull() { return {.ptr: nullptr}; } |
| 23 | |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | // Location attribute. |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | |
| 28 | bool mlirAttributeIsALocation(MlirAttribute attr) { |
| 29 | return llvm::isa<LocationAttr>(Val: unwrap(c: attr)); |
| 30 | } |
| 31 | |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | // Affine map attribute. |
| 34 | //===----------------------------------------------------------------------===// |
| 35 | |
| 36 | bool mlirAttributeIsAAffineMap(MlirAttribute attr) { |
| 37 | return llvm::isa<AffineMapAttr>(unwrap(attr)); |
| 38 | } |
| 39 | |
| 40 | MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { |
| 41 | return wrap(AffineMapAttr::get(unwrap(map))); |
| 42 | } |
| 43 | |
| 44 | MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { |
| 45 | return wrap(llvm::cast<AffineMapAttr>(unwrap(attr)).getValue()); |
| 46 | } |
| 47 | |
| 48 | MlirTypeID mlirAffineMapAttrGetTypeID(void) { |
| 49 | return wrap(AffineMapAttr::getTypeID()); |
| 50 | } |
| 51 | |
| 52 | //===----------------------------------------------------------------------===// |
| 53 | // Array attribute. |
| 54 | //===----------------------------------------------------------------------===// |
| 55 | |
| 56 | bool mlirAttributeIsAArray(MlirAttribute attr) { |
| 57 | return llvm::isa<ArrayAttr>(unwrap(attr)); |
| 58 | } |
| 59 | |
| 60 | MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, |
| 61 | MlirAttribute const *elements) { |
| 62 | SmallVector<Attribute, 8> attrs; |
| 63 | return wrap( |
| 64 | ArrayAttr::get(unwrap(ctx), unwrapList(static_cast<size_t>(numElements), |
| 65 | elements, attrs))); |
| 66 | } |
| 67 | |
| 68 | intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { |
| 69 | return static_cast<intptr_t>(llvm::cast<ArrayAttr>(unwrap(attr)).size()); |
| 70 | } |
| 71 | |
| 72 | MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { |
| 73 | return wrap(llvm::cast<ArrayAttr>(unwrap(attr)).getValue()[pos]); |
| 74 | } |
| 75 | |
| 76 | MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } |
| 77 | |
| 78 | //===----------------------------------------------------------------------===// |
| 79 | // Dictionary attribute. |
| 80 | //===----------------------------------------------------------------------===// |
| 81 | |
| 82 | bool mlirAttributeIsADictionary(MlirAttribute attr) { |
| 83 | return llvm::isa<DictionaryAttr>(Val: unwrap(c: attr)); |
| 84 | } |
| 85 | |
| 86 | MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, |
| 87 | MlirNamedAttribute const *elements) { |
| 88 | SmallVector<NamedAttribute, 8> attributes; |
| 89 | attributes.reserve(N: numElements); |
| 90 | for (intptr_t i = 0; i < numElements; ++i) |
| 91 | attributes.emplace_back(unwrap(elements[i].name), |
| 92 | unwrap(c: elements[i].attribute)); |
| 93 | return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); |
| 94 | } |
| 95 | |
| 96 | intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { |
| 97 | return static_cast<intptr_t>(llvm::cast<DictionaryAttr>(unwrap(c: attr)).size()); |
| 98 | } |
| 99 | |
| 100 | MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, |
| 101 | intptr_t pos) { |
| 102 | NamedAttribute attribute = |
| 103 | llvm::cast<DictionaryAttr>(unwrap(c: attr)).getValue()[pos]; |
| 104 | return {wrap(attribute.getName()), wrap(cpp: attribute.getValue())}; |
| 105 | } |
| 106 | |
| 107 | MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, |
| 108 | MlirStringRef name) { |
| 109 | return wrap(llvm::cast<DictionaryAttr>(unwrap(c: attr)).get(unwrap(ref: name))); |
| 110 | } |
| 111 | |
| 112 | MlirTypeID mlirDictionaryAttrGetTypeID(void) { |
| 113 | return wrap(DictionaryAttr::getTypeID()); |
| 114 | } |
| 115 | |
| 116 | //===----------------------------------------------------------------------===// |
| 117 | // Floating point attribute. |
| 118 | //===----------------------------------------------------------------------===// |
| 119 | |
| 120 | bool mlirAttributeIsAFloat(MlirAttribute attr) { |
| 121 | return llvm::isa<FloatAttr>(unwrap(attr)); |
| 122 | } |
| 123 | |
| 124 | MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, |
| 125 | double value) { |
| 126 | return wrap(FloatAttr::get(unwrap(type), value)); |
| 127 | } |
| 128 | |
| 129 | MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, |
| 130 | double value) { |
| 131 | return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); |
| 132 | } |
| 133 | |
| 134 | double mlirFloatAttrGetValueDouble(MlirAttribute attr) { |
| 135 | return llvm::cast<FloatAttr>(unwrap(attr)).getValueAsDouble(); |
| 136 | } |
| 137 | |
| 138 | MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } |
| 139 | |
| 140 | //===----------------------------------------------------------------------===// |
| 141 | // Integer attribute. |
| 142 | //===----------------------------------------------------------------------===// |
| 143 | |
| 144 | bool mlirAttributeIsAInteger(MlirAttribute attr) { |
| 145 | return llvm::isa<IntegerAttr>(unwrap(attr)); |
| 146 | } |
| 147 | |
| 148 | MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { |
| 149 | return wrap(IntegerAttr::get(unwrap(type), value)); |
| 150 | } |
| 151 | |
| 152 | int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { |
| 153 | return llvm::cast<IntegerAttr>(unwrap(attr)).getInt(); |
| 154 | } |
| 155 | |
| 156 | int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { |
| 157 | return llvm::cast<IntegerAttr>(unwrap(attr)).getSInt(); |
| 158 | } |
| 159 | |
| 160 | uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { |
| 161 | return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt(); |
| 162 | } |
| 163 | |
| 164 | MlirTypeID mlirIntegerAttrGetTypeID(void) { |
| 165 | return wrap(IntegerAttr::getTypeID()); |
| 166 | } |
| 167 | |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | // Bool attribute. |
| 170 | //===----------------------------------------------------------------------===// |
| 171 | |
| 172 | bool mlirAttributeIsABool(MlirAttribute attr) { |
| 173 | return llvm::isa<BoolAttr>(Val: unwrap(c: attr)); |
| 174 | } |
| 175 | |
| 176 | MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { |
| 177 | return wrap(cpp: BoolAttr::get(context: unwrap(c: ctx), value)); |
| 178 | } |
| 179 | |
| 180 | bool mlirBoolAttrGetValue(MlirAttribute attr) { |
| 181 | return llvm::cast<BoolAttr>(Val: unwrap(c: attr)).getValue(); |
| 182 | } |
| 183 | |
| 184 | //===----------------------------------------------------------------------===// |
| 185 | // Integer set attribute. |
| 186 | //===----------------------------------------------------------------------===// |
| 187 | |
| 188 | bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { |
| 189 | return llvm::isa<IntegerSetAttr>(unwrap(attr)); |
| 190 | } |
| 191 | |
| 192 | MlirTypeID mlirIntegerSetAttrGetTypeID(void) { |
| 193 | return wrap(IntegerSetAttr::getTypeID()); |
| 194 | } |
| 195 | |
| 196 | MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) { |
| 197 | return wrap(IntegerSetAttr::get(unwrap(set))); |
| 198 | } |
| 199 | |
| 200 | MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) { |
| 201 | return wrap(llvm::cast<IntegerSetAttr>(unwrap(attr)).getValue()); |
| 202 | } |
| 203 | |
| 204 | //===----------------------------------------------------------------------===// |
| 205 | // Opaque attribute. |
| 206 | //===----------------------------------------------------------------------===// |
| 207 | |
| 208 | bool mlirAttributeIsAOpaque(MlirAttribute attr) { |
| 209 | return llvm::isa<OpaqueAttr>(unwrap(attr)); |
| 210 | } |
| 211 | |
| 212 | MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, |
| 213 | intptr_t dataLength, const char *data, |
| 214 | MlirType type) { |
| 215 | return wrap( |
| 216 | OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), |
| 217 | StringRef(data, dataLength), unwrap(type))); |
| 218 | } |
| 219 | |
| 220 | MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { |
| 221 | return wrap( |
| 222 | llvm::cast<OpaqueAttr>(unwrap(attr)).getDialectNamespace().strref()); |
| 223 | } |
| 224 | |
| 225 | MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { |
| 226 | return wrap(llvm::cast<OpaqueAttr>(unwrap(attr)).getAttrData()); |
| 227 | } |
| 228 | |
| 229 | MlirTypeID mlirOpaqueAttrGetTypeID(void) { |
| 230 | return wrap(OpaqueAttr::getTypeID()); |
| 231 | } |
| 232 | |
| 233 | //===----------------------------------------------------------------------===// |
| 234 | // String attribute. |
| 235 | //===----------------------------------------------------------------------===// |
| 236 | |
| 237 | bool mlirAttributeIsAString(MlirAttribute attr) { |
| 238 | return llvm::isa<StringAttr>(Val: unwrap(c: attr)); |
| 239 | } |
| 240 | |
| 241 | MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { |
| 242 | return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); |
| 243 | } |
| 244 | |
| 245 | MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { |
| 246 | return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); |
| 247 | } |
| 248 | |
| 249 | MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { |
| 250 | return wrap(llvm::cast<StringAttr>(unwrap(c: attr)).getValue()); |
| 251 | } |
| 252 | |
| 253 | MlirTypeID mlirStringAttrGetTypeID(void) { |
| 254 | return wrap(StringAttr::getTypeID()); |
| 255 | } |
| 256 | |
| 257 | //===----------------------------------------------------------------------===// |
| 258 | // SymbolRef attribute. |
| 259 | //===----------------------------------------------------------------------===// |
| 260 | |
| 261 | bool mlirAttributeIsASymbolRef(MlirAttribute attr) { |
| 262 | return llvm::isa<SymbolRefAttr>(unwrap(attr)); |
| 263 | } |
| 264 | |
| 265 | MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, |
| 266 | intptr_t numReferences, |
| 267 | MlirAttribute const *references) { |
| 268 | SmallVector<FlatSymbolRefAttr, 4> refs; |
| 269 | refs.reserve(N: numReferences); |
| 270 | for (intptr_t i = 0; i < numReferences; ++i) |
| 271 | refs.push_back(Elt: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: references[i]))); |
| 272 | auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); |
| 273 | return wrap(SymbolRefAttr::get(symbolAttr, refs)); |
| 274 | } |
| 275 | |
| 276 | MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { |
| 277 | return wrap( |
| 278 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getRootReference().getValue()); |
| 279 | } |
| 280 | |
| 281 | MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { |
| 282 | return wrap( |
| 283 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getLeafReference().getValue()); |
| 284 | } |
| 285 | |
| 286 | intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { |
| 287 | return static_cast<intptr_t>( |
| 288 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences().size()); |
| 289 | } |
| 290 | |
| 291 | MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, |
| 292 | intptr_t pos) { |
| 293 | return wrap( |
| 294 | llvm::cast<SymbolRefAttr>(unwrap(attr)).getNestedReferences()[pos]); |
| 295 | } |
| 296 | |
| 297 | MlirTypeID mlirSymbolRefAttrGetTypeID(void) { |
| 298 | return wrap(SymbolRefAttr::getTypeID()); |
| 299 | } |
| 300 | |
| 301 | MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) { |
| 302 | return wrap(mlir::DistinctAttr::create(referencedAttr: unwrap(c: referencedAttr))); |
| 303 | } |
| 304 | |
| 305 | //===----------------------------------------------------------------------===// |
| 306 | // Flat SymbolRef attribute. |
| 307 | //===----------------------------------------------------------------------===// |
| 308 | |
| 309 | bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { |
| 310 | return llvm::isa<FlatSymbolRefAttr>(Val: unwrap(c: attr)); |
| 311 | } |
| 312 | |
| 313 | MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { |
| 314 | return wrap(FlatSymbolRefAttr::get(ctx: unwrap(c: ctx), value: unwrap(ref: symbol))); |
| 315 | } |
| 316 | |
| 317 | MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { |
| 318 | return wrap(ref: llvm::cast<FlatSymbolRefAttr>(Val: unwrap(c: attr)).getValue()); |
| 319 | } |
| 320 | |
| 321 | //===----------------------------------------------------------------------===// |
| 322 | // Type attribute. |
| 323 | //===----------------------------------------------------------------------===// |
| 324 | |
| 325 | bool mlirAttributeIsAType(MlirAttribute attr) { |
| 326 | return llvm::isa<TypeAttr>(unwrap(attr)); |
| 327 | } |
| 328 | |
| 329 | MlirAttribute mlirTypeAttrGet(MlirType type) { |
| 330 | return wrap(TypeAttr::get(unwrap(type))); |
| 331 | } |
| 332 | |
| 333 | MlirType mlirTypeAttrGetValue(MlirAttribute attr) { |
| 334 | return wrap(llvm::cast<TypeAttr>(unwrap(attr)).getValue()); |
| 335 | } |
| 336 | |
| 337 | MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } |
| 338 | |
| 339 | //===----------------------------------------------------------------------===// |
| 340 | // Unit attribute. |
| 341 | //===----------------------------------------------------------------------===// |
| 342 | |
| 343 | bool mlirAttributeIsAUnit(MlirAttribute attr) { |
| 344 | return llvm::isa<UnitAttr>(unwrap(attr)); |
| 345 | } |
| 346 | |
| 347 | MlirAttribute mlirUnitAttrGet(MlirContext ctx) { |
| 348 | return wrap(UnitAttr::get(unwrap(ctx))); |
| 349 | } |
| 350 | |
| 351 | MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } |
| 352 | |
| 353 | //===----------------------------------------------------------------------===// |
| 354 | // Elements attributes. |
| 355 | //===----------------------------------------------------------------------===// |
| 356 | |
| 357 | bool mlirAttributeIsAElements(MlirAttribute attr) { |
| 358 | return llvm::isa<ElementsAttr>(Val: unwrap(c: attr)); |
| 359 | } |
| 360 | |
| 361 | MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, |
| 362 | uint64_t *idxs) { |
| 363 | return wrap(llvm::cast<ElementsAttr>(unwrap(c: attr)) |
| 364 | .getValues<Attribute>()[llvm::ArrayRef(idxs, rank)]); |
| 365 | } |
| 366 | |
| 367 | bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, |
| 368 | uint64_t *idxs) { |
| 369 | return llvm::cast<ElementsAttr>(unwrap(c: attr)) |
| 370 | .isValidIndex(llvm::ArrayRef(idxs, rank)); |
| 371 | } |
| 372 | |
| 373 | int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { |
| 374 | return llvm::cast<ElementsAttr>(unwrap(c: attr)).getNumElements(); |
| 375 | } |
| 376 | |
| 377 | //===----------------------------------------------------------------------===// |
| 378 | // Dense array attribute. |
| 379 | //===----------------------------------------------------------------------===// |
| 380 | |
| 381 | MlirTypeID mlirDenseArrayAttrGetTypeID() { |
| 382 | return wrap(DenseArrayAttr::getTypeID()); |
| 383 | } |
| 384 | |
| 385 | //===----------------------------------------------------------------------===// |
| 386 | // IsA support. |
| 387 | //===----------------------------------------------------------------------===// |
| 388 | |
| 389 | bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { |
| 390 | return llvm::isa<DenseBoolArrayAttr>(Val: unwrap(c: attr)); |
| 391 | } |
| 392 | bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { |
| 393 | return llvm::isa<DenseI8ArrayAttr>(Val: unwrap(c: attr)); |
| 394 | } |
| 395 | bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { |
| 396 | return llvm::isa<DenseI16ArrayAttr>(Val: unwrap(c: attr)); |
| 397 | } |
| 398 | bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { |
| 399 | return llvm::isa<DenseI32ArrayAttr>(Val: unwrap(c: attr)); |
| 400 | } |
| 401 | bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { |
| 402 | return llvm::isa<DenseI64ArrayAttr>(Val: unwrap(c: attr)); |
| 403 | } |
| 404 | bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { |
| 405 | return llvm::isa<DenseF32ArrayAttr>(Val: unwrap(c: attr)); |
| 406 | } |
| 407 | bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { |
| 408 | return llvm::isa<DenseF64ArrayAttr>(Val: unwrap(c: attr)); |
| 409 | } |
| 410 | |
| 411 | //===----------------------------------------------------------------------===// |
| 412 | // Constructors. |
| 413 | //===----------------------------------------------------------------------===// |
| 414 | |
| 415 | MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, |
| 416 | int const *values) { |
| 417 | SmallVector<bool, 4> elements(values, values + size); |
| 418 | return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); |
| 419 | } |
| 420 | MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, |
| 421 | int8_t const *values) { |
| 422 | return wrap( |
| 423 | DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef<int8_t>(values, size))); |
| 424 | } |
| 425 | MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, |
| 426 | int16_t const *values) { |
| 427 | return wrap( |
| 428 | DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef<int16_t>(values, size))); |
| 429 | } |
| 430 | MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, |
| 431 | int32_t const *values) { |
| 432 | return wrap( |
| 433 | DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef<int32_t>(values, size))); |
| 434 | } |
| 435 | MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, |
| 436 | int64_t const *values) { |
| 437 | return wrap( |
| 438 | DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef<int64_t>(values, size))); |
| 439 | } |
| 440 | MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, |
| 441 | float const *values) { |
| 442 | return wrap( |
| 443 | DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef<float>(values, size))); |
| 444 | } |
| 445 | MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, |
| 446 | double const *values) { |
| 447 | return wrap( |
| 448 | DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef<double>(values, size))); |
| 449 | } |
| 450 | |
| 451 | //===----------------------------------------------------------------------===// |
| 452 | // Accessors. |
| 453 | //===----------------------------------------------------------------------===// |
| 454 | |
| 455 | intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { |
| 456 | return llvm::cast<DenseArrayAttr>(unwrap(attr)).size(); |
| 457 | } |
| 458 | |
| 459 | //===----------------------------------------------------------------------===// |
| 460 | // Indexed accessors. |
| 461 | //===----------------------------------------------------------------------===// |
| 462 | |
| 463 | bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 464 | return llvm::cast<DenseBoolArrayAttr>(unwrap(c: attr))[pos]; |
| 465 | } |
| 466 | int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 467 | return llvm::cast<DenseI8ArrayAttr>(unwrap(c: attr))[pos]; |
| 468 | } |
| 469 | int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 470 | return llvm::cast<DenseI16ArrayAttr>(unwrap(c: attr))[pos]; |
| 471 | } |
| 472 | int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 473 | return llvm::cast<DenseI32ArrayAttr>(unwrap(c: attr))[pos]; |
| 474 | } |
| 475 | int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 476 | return llvm::cast<DenseI64ArrayAttr>(unwrap(c: attr))[pos]; |
| 477 | } |
| 478 | float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 479 | return llvm::cast<DenseF32ArrayAttr>(unwrap(c: attr))[pos]; |
| 480 | } |
| 481 | double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { |
| 482 | return llvm::cast<DenseF64ArrayAttr>(unwrap(c: attr))[pos]; |
| 483 | } |
| 484 | |
| 485 | //===----------------------------------------------------------------------===// |
| 486 | // Dense elements attribute. |
| 487 | //===----------------------------------------------------------------------===// |
| 488 | |
| 489 | //===----------------------------------------------------------------------===// |
| 490 | // IsA support. |
| 491 | //===----------------------------------------------------------------------===// |
| 492 | |
| 493 | bool mlirAttributeIsADenseElements(MlirAttribute attr) { |
| 494 | return llvm::isa<DenseElementsAttr>(Val: unwrap(c: attr)); |
| 495 | } |
| 496 | |
| 497 | bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { |
| 498 | return llvm::isa<DenseIntElementsAttr>(Val: unwrap(c: attr)); |
| 499 | } |
| 500 | |
| 501 | bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { |
| 502 | return llvm::isa<DenseFPElementsAttr>(Val: unwrap(c: attr)); |
| 503 | } |
| 504 | |
| 505 | MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { |
| 506 | return wrap(DenseIntOrFPElementsAttr::getTypeID()); |
| 507 | } |
| 508 | |
| 509 | //===----------------------------------------------------------------------===// |
| 510 | // Constructors. |
| 511 | //===----------------------------------------------------------------------===// |
| 512 | |
| 513 | MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, |
| 514 | intptr_t numElements, |
| 515 | MlirAttribute const *elements) { |
| 516 | SmallVector<Attribute, 8> attributes; |
| 517 | return wrap( |
| 518 | DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 519 | unwrapList(numElements, elements, attributes))); |
| 520 | } |
| 521 | |
| 522 | MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, |
| 523 | size_t rawBufferSize, |
| 524 | const void *rawBuffer) { |
| 525 | auto shapedTypeCpp = llvm::cast<ShapedType>(unwrap(shapedType)); |
| 526 | ArrayRef<char> rawBufferCpp(static_cast<const char *>(rawBuffer), |
| 527 | rawBufferSize); |
| 528 | bool isSplat = false; |
| 529 | if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, |
| 530 | isSplat)) |
| 531 | return mlirAttributeGetNull(); |
| 532 | return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); |
| 533 | } |
| 534 | |
| 535 | MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, |
| 536 | MlirAttribute element) { |
| 537 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 538 | unwrap(element))); |
| 539 | } |
| 540 | MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, |
| 541 | bool element) { |
| 542 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 543 | element)); |
| 544 | } |
| 545 | MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, |
| 546 | uint8_t element) { |
| 547 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 548 | element)); |
| 549 | } |
| 550 | MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, |
| 551 | int8_t element) { |
| 552 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 553 | element)); |
| 554 | } |
| 555 | MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, |
| 556 | uint32_t element) { |
| 557 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 558 | element)); |
| 559 | } |
| 560 | MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, |
| 561 | int32_t element) { |
| 562 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 563 | element)); |
| 564 | } |
| 565 | MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, |
| 566 | uint64_t element) { |
| 567 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 568 | element)); |
| 569 | } |
| 570 | MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, |
| 571 | int64_t element) { |
| 572 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 573 | element)); |
| 574 | } |
| 575 | MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, |
| 576 | float element) { |
| 577 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 578 | element)); |
| 579 | } |
| 580 | MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, |
| 581 | double element) { |
| 582 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 583 | element)); |
| 584 | } |
| 585 | |
| 586 | MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, |
| 587 | intptr_t numElements, |
| 588 | const int *elements) { |
| 589 | SmallVector<bool, 8> values(elements, elements + numElements); |
| 590 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 591 | values)); |
| 592 | } |
| 593 | |
| 594 | /// Creates a dense attribute with elements of the type deduced by templates. |
| 595 | template <typename T> |
| 596 | static MlirAttribute getDenseAttribute(MlirType shapedType, |
| 597 | intptr_t numElements, |
| 598 | const T *elements) { |
| 599 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 600 | llvm::ArrayRef(elements, numElements))); |
| 601 | } |
| 602 | |
| 603 | MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, |
| 604 | intptr_t numElements, |
| 605 | const uint8_t *elements) { |
| 606 | return getDenseAttribute(shapedType, numElements, elements); |
| 607 | } |
| 608 | MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, |
| 609 | intptr_t numElements, |
| 610 | const int8_t *elements) { |
| 611 | return getDenseAttribute(shapedType, numElements, elements); |
| 612 | } |
| 613 | MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, |
| 614 | intptr_t numElements, |
| 615 | const uint16_t *elements) { |
| 616 | return getDenseAttribute(shapedType, numElements, elements); |
| 617 | } |
| 618 | MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, |
| 619 | intptr_t numElements, |
| 620 | const int16_t *elements) { |
| 621 | return getDenseAttribute(shapedType, numElements, elements); |
| 622 | } |
| 623 | MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, |
| 624 | intptr_t numElements, |
| 625 | const uint32_t *elements) { |
| 626 | return getDenseAttribute(shapedType, numElements, elements); |
| 627 | } |
| 628 | MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, |
| 629 | intptr_t numElements, |
| 630 | const int32_t *elements) { |
| 631 | return getDenseAttribute(shapedType, numElements, elements); |
| 632 | } |
| 633 | MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, |
| 634 | intptr_t numElements, |
| 635 | const uint64_t *elements) { |
| 636 | return getDenseAttribute(shapedType, numElements, elements); |
| 637 | } |
| 638 | MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, |
| 639 | intptr_t numElements, |
| 640 | const int64_t *elements) { |
| 641 | return getDenseAttribute(shapedType, numElements, elements); |
| 642 | } |
| 643 | MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, |
| 644 | intptr_t numElements, |
| 645 | const float *elements) { |
| 646 | return getDenseAttribute(shapedType, numElements, elements); |
| 647 | } |
| 648 | MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, |
| 649 | intptr_t numElements, |
| 650 | const double *elements) { |
| 651 | return getDenseAttribute(shapedType, numElements, elements); |
| 652 | } |
| 653 | MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, |
| 654 | intptr_t numElements, |
| 655 | const uint16_t *elements) { |
| 656 | size_t bufferSize = numElements * 2; |
| 657 | const void *buffer = static_cast<const void *>(elements); |
| 658 | return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer); |
| 659 | } |
| 660 | MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, |
| 661 | intptr_t numElements, |
| 662 | const uint16_t *elements) { |
| 663 | size_t bufferSize = numElements * 2; |
| 664 | const void *buffer = static_cast<const void *>(elements); |
| 665 | return mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize: bufferSize, rawBuffer: buffer); |
| 666 | } |
| 667 | |
| 668 | MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, |
| 669 | intptr_t numElements, |
| 670 | MlirStringRef *strs) { |
| 671 | SmallVector<StringRef, 8> values; |
| 672 | values.reserve(N: numElements); |
| 673 | for (intptr_t i = 0; i < numElements; ++i) |
| 674 | values.push_back(Elt: unwrap(ref: strs[i])); |
| 675 | |
| 676 | return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 677 | values)); |
| 678 | } |
| 679 | |
| 680 | MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, |
| 681 | MlirType shapedType) { |
| 682 | return wrap(llvm::cast<DenseElementsAttr>(unwrap(attr)) |
| 683 | .reshape(llvm::cast<ShapedType>(unwrap(shapedType)))); |
| 684 | } |
| 685 | |
| 686 | //===----------------------------------------------------------------------===// |
| 687 | // Splat accessors. |
| 688 | //===----------------------------------------------------------------------===// |
| 689 | |
| 690 | bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { |
| 691 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).isSplat(); |
| 692 | } |
| 693 | |
| 694 | MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { |
| 695 | return wrap( |
| 696 | cpp: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<Attribute>()); |
| 697 | } |
| 698 | int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { |
| 699 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<bool>(); |
| 700 | } |
| 701 | int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { |
| 702 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int8_t>(); |
| 703 | } |
| 704 | uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { |
| 705 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint8_t>(); |
| 706 | } |
| 707 | int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { |
| 708 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int32_t>(); |
| 709 | } |
| 710 | uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { |
| 711 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint32_t>(); |
| 712 | } |
| 713 | int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { |
| 714 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<int64_t>(); |
| 715 | } |
| 716 | uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { |
| 717 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<uint64_t>(); |
| 718 | } |
| 719 | float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { |
| 720 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<float>(); |
| 721 | } |
| 722 | double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { |
| 723 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<double>(); |
| 724 | } |
| 725 | MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { |
| 726 | return wrap( |
| 727 | ref: llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getSplatValue<StringRef>()); |
| 728 | } |
| 729 | |
| 730 | //===----------------------------------------------------------------------===// |
| 731 | // Indexed accessors. |
| 732 | //===----------------------------------------------------------------------===// |
| 733 | |
| 734 | bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { |
| 735 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<bool>()[pos]; |
| 736 | } |
| 737 | int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { |
| 738 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int8_t>()[pos]; |
| 739 | } |
| 740 | uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { |
| 741 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint8_t>()[pos]; |
| 742 | } |
| 743 | int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { |
| 744 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int16_t>()[pos]; |
| 745 | } |
| 746 | uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { |
| 747 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint16_t>()[pos]; |
| 748 | } |
| 749 | int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { |
| 750 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int32_t>()[pos]; |
| 751 | } |
| 752 | uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { |
| 753 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint32_t>()[pos]; |
| 754 | } |
| 755 | int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { |
| 756 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<int64_t>()[pos]; |
| 757 | } |
| 758 | uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { |
| 759 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint64_t>()[pos]; |
| 760 | } |
| 761 | uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { |
| 762 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<uint64_t>()[pos]; |
| 763 | } |
| 764 | float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { |
| 765 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<float>()[pos]; |
| 766 | } |
| 767 | double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { |
| 768 | return llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<double>()[pos]; |
| 769 | } |
| 770 | MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, |
| 771 | intptr_t pos) { |
| 772 | return wrap( |
| 773 | llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getValues<StringRef>()[pos]); |
| 774 | } |
| 775 | |
| 776 | //===----------------------------------------------------------------------===// |
| 777 | // Raw data accessors. |
| 778 | //===----------------------------------------------------------------------===// |
| 779 | |
| 780 | const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { |
| 781 | return static_cast<const void *>( |
| 782 | llvm::cast<DenseElementsAttr>(Val: unwrap(c: attr)).getRawData().data()); |
| 783 | } |
| 784 | |
| 785 | //===----------------------------------------------------------------------===// |
| 786 | // Resource blob attributes. |
| 787 | //===----------------------------------------------------------------------===// |
| 788 | |
| 789 | bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { |
| 790 | return llvm::isa<DenseResourceElementsAttr>(unwrap(attr)); |
| 791 | } |
| 792 | |
| 793 | MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( |
| 794 | MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, |
| 795 | size_t dataAlignment, bool dataIsMutable, |
| 796 | void (*deleter)(void *userData, const void *data, size_t size, |
| 797 | size_t align), |
| 798 | void *userData) { |
| 799 | AsmResourceBlob::DeleterFn cppDeleter = {}; |
| 800 | if (deleter) { |
| 801 | cppDeleter = [deleter, userData](void *data, size_t size, size_t align) { |
| 802 | deleter(userData, data, size, align); |
| 803 | }; |
| 804 | } |
| 805 | AsmResourceBlob blob( |
| 806 | llvm::ArrayRef(static_cast<const char *>(data), dataLength), |
| 807 | dataAlignment, std::move(cppDeleter), dataIsMutable); |
| 808 | return wrap( |
| 809 | DenseResourceElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)), |
| 810 | unwrap(name), std::move(blob))); |
| 811 | } |
| 812 | |
| 813 | template <typename U, typename T> |
| 814 | static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, |
| 815 | intptr_t numElements, const T *elements) { |
| 816 | return wrap(U::get(llvm::cast<ShapedType>(unwrap(shapedType)), unwrap(name), |
| 817 | UnmanagedAsmResourceBlob::allocateInferAlign( |
| 818 | llvm::ArrayRef(elements, numElements)))); |
| 819 | } |
| 820 | |
| 821 | MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( |
| 822 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 823 | const int *elements) { |
| 824 | return getDenseResource<DenseBoolResourceElementsAttr>(shapedType, name, |
| 825 | numElements, elements); |
| 826 | } |
| 827 | MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( |
| 828 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 829 | const uint8_t *elements) { |
| 830 | return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name, |
| 831 | numElements, elements); |
| 832 | } |
| 833 | MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet( |
| 834 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 835 | const uint16_t *elements) { |
| 836 | return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name, |
| 837 | numElements, elements); |
| 838 | } |
| 839 | MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet( |
| 840 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 841 | const uint32_t *elements) { |
| 842 | return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name, |
| 843 | numElements, elements); |
| 844 | } |
| 845 | MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet( |
| 846 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 847 | const uint64_t *elements) { |
| 848 | return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name, |
| 849 | numElements, elements); |
| 850 | } |
| 851 | MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( |
| 852 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 853 | const int8_t *elements) { |
| 854 | return getDenseResource<DenseUI8ResourceElementsAttr>(shapedType, name, |
| 855 | numElements, elements); |
| 856 | } |
| 857 | MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( |
| 858 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 859 | const int16_t *elements) { |
| 860 | return getDenseResource<DenseUI16ResourceElementsAttr>(shapedType, name, |
| 861 | numElements, elements); |
| 862 | } |
| 863 | MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( |
| 864 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 865 | const int32_t *elements) { |
| 866 | return getDenseResource<DenseUI32ResourceElementsAttr>(shapedType, name, |
| 867 | numElements, elements); |
| 868 | } |
| 869 | MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( |
| 870 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 871 | const int64_t *elements) { |
| 872 | return getDenseResource<DenseUI64ResourceElementsAttr>(shapedType, name, |
| 873 | numElements, elements); |
| 874 | } |
| 875 | MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( |
| 876 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 877 | const float *elements) { |
| 878 | return getDenseResource<DenseF32ResourceElementsAttr>(shapedType, name, |
| 879 | numElements, elements); |
| 880 | } |
| 881 | MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet( |
| 882 | MlirType shapedType, MlirStringRef name, intptr_t numElements, |
| 883 | const double *elements) { |
| 884 | return getDenseResource<DenseF64ResourceElementsAttr>(shapedType, name, |
| 885 | numElements, elements); |
| 886 | } |
| 887 | template <typename U, typename T> |
| 888 | static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { |
| 889 | return (*llvm::cast<U>(unwrap(c: attr)).tryGetAsArrayRef())[pos]; |
| 890 | } |
| 891 | |
| 892 | bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, |
| 893 | intptr_t pos) { |
| 894 | return getDenseResourceVal<DenseBoolResourceElementsAttr, uint8_t>(attr, pos); |
| 895 | } |
| 896 | uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, |
| 897 | intptr_t pos) { |
| 898 | return getDenseResourceVal<DenseUI8ResourceElementsAttr, uint8_t>(attr, pos); |
| 899 | } |
| 900 | uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, |
| 901 | intptr_t pos) { |
| 902 | return getDenseResourceVal<DenseUI16ResourceElementsAttr, uint16_t>(attr, |
| 903 | pos); |
| 904 | } |
| 905 | uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, |
| 906 | intptr_t pos) { |
| 907 | return getDenseResourceVal<DenseUI32ResourceElementsAttr, uint32_t>(attr, |
| 908 | pos); |
| 909 | } |
| 910 | uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, |
| 911 | intptr_t pos) { |
| 912 | return getDenseResourceVal<DenseUI64ResourceElementsAttr, uint64_t>(attr, |
| 913 | pos); |
| 914 | } |
| 915 | int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, |
| 916 | intptr_t pos) { |
| 917 | return getDenseResourceVal<DenseUI8ResourceElementsAttr, int8_t>(attr, pos); |
| 918 | } |
| 919 | int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, |
| 920 | intptr_t pos) { |
| 921 | return getDenseResourceVal<DenseUI16ResourceElementsAttr, int16_t>(attr, pos); |
| 922 | } |
| 923 | int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, |
| 924 | intptr_t pos) { |
| 925 | return getDenseResourceVal<DenseUI32ResourceElementsAttr, int32_t>(attr, pos); |
| 926 | } |
| 927 | int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, |
| 928 | intptr_t pos) { |
| 929 | return getDenseResourceVal<DenseUI64ResourceElementsAttr, int64_t>(attr, pos); |
| 930 | } |
| 931 | float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, |
| 932 | intptr_t pos) { |
| 933 | return getDenseResourceVal<DenseF32ResourceElementsAttr, float>(attr, pos); |
| 934 | } |
| 935 | double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, |
| 936 | intptr_t pos) { |
| 937 | return getDenseResourceVal<DenseF64ResourceElementsAttr, double>(attr, pos); |
| 938 | } |
| 939 | |
| 940 | //===----------------------------------------------------------------------===// |
| 941 | // Sparse elements attribute. |
| 942 | //===----------------------------------------------------------------------===// |
| 943 | |
| 944 | bool mlirAttributeIsASparseElements(MlirAttribute attr) { |
| 945 | return llvm::isa<SparseElementsAttr>(unwrap(attr)); |
| 946 | } |
| 947 | |
| 948 | MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, |
| 949 | MlirAttribute denseIndices, |
| 950 | MlirAttribute denseValues) { |
| 951 | return wrap(SparseElementsAttr::get( |
| 952 | llvm::cast<ShapedType>(unwrap(shapedType)), |
| 953 | llvm::cast<DenseElementsAttr>(unwrap(denseIndices)), |
| 954 | llvm::cast<DenseElementsAttr>(unwrap(denseValues)))); |
| 955 | } |
| 956 | |
| 957 | MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { |
| 958 | return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getIndices()); |
| 959 | } |
| 960 | |
| 961 | MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { |
| 962 | return wrap(llvm::cast<SparseElementsAttr>(unwrap(attr)).getValues()); |
| 963 | } |
| 964 | |
| 965 | MlirTypeID mlirSparseElementsAttrGetTypeID(void) { |
| 966 | return wrap(SparseElementsAttr::getTypeID()); |
| 967 | } |
| 968 | |
| 969 | //===----------------------------------------------------------------------===// |
| 970 | // Strided layout attribute. |
| 971 | //===----------------------------------------------------------------------===// |
| 972 | |
| 973 | bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { |
| 974 | return llvm::isa<StridedLayoutAttr>(unwrap(attr)); |
| 975 | } |
| 976 | |
| 977 | MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, |
| 978 | intptr_t numStrides, |
| 979 | const int64_t *strides) { |
| 980 | return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, |
| 981 | ArrayRef<int64_t>(strides, numStrides))); |
| 982 | } |
| 983 | |
| 984 | int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { |
| 985 | return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getOffset(); |
| 986 | } |
| 987 | |
| 988 | intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { |
| 989 | return static_cast<intptr_t>( |
| 990 | llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides().size()); |
| 991 | } |
| 992 | |
| 993 | int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { |
| 994 | return llvm::cast<StridedLayoutAttr>(unwrap(attr)).getStrides()[pos]; |
| 995 | } |
| 996 | |
| 997 | MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { |
| 998 | return wrap(StridedLayoutAttr::getTypeID()); |
| 999 | } |
| 1000 | |