1//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributes.h"
10#include "AttributeDetail.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinDialect.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/DialectResourceBlobManager.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Types.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/Endian.h"
25#include <optional>
26
27#define DEBUG_TYPE "builtinattributes"
28
29using namespace mlir;
30using namespace mlir::detail;
31
32//===----------------------------------------------------------------------===//
33/// Tablegen Attribute Definitions
34//===----------------------------------------------------------------------===//
35
36#define GET_ATTRDEF_CLASSES
37#include "mlir/IR/BuiltinAttributes.cpp.inc"
38
39//===----------------------------------------------------------------------===//
40// BuiltinDialect
41//===----------------------------------------------------------------------===//
42
43void BuiltinDialect::registerAttributes() {
44 addAttributes<
45#define GET_ATTRDEF_LIST
46#include "mlir/IR/BuiltinAttributes.cpp.inc"
47 >();
48 addAttributes<DistinctAttr>();
49}
50
51//===----------------------------------------------------------------------===//
52// DictionaryAttr
53//===----------------------------------------------------------------------===//
54
55/// Helper function that does either an in place sort or sorts from source array
56/// into destination. If inPlace then storage is both the source and the
57/// destination, else value is the source and storage destination. Returns
58/// whether source was sorted.
59template <bool inPlace>
60static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
61 SmallVectorImpl<NamedAttribute> &storage) {
62 // Specialize for the common case.
63 switch (value.size()) {
64 case 0:
65 // Zero already sorted.
66 if (!inPlace)
67 storage.clear();
68 break;
69 case 1:
70 // One already sorted but may need to be copied.
71 if (!inPlace)
72 storage.assign(IL: {value[0]});
73 break;
74 case 2: {
75 bool isSorted = value[0] < value[1];
76 if (inPlace) {
77 if (!isSorted)
78 std::swap(storage[0], storage[1]);
79 } else if (isSorted) {
80 storage.assign(IL: {value[0], value[1]});
81 } else {
82 storage.assign(IL: {value[1], value[0]});
83 }
84 return !isSorted;
85 }
86 default:
87 if (!inPlace)
88 storage.assign(value.begin(), value.end());
89 // Check to see they are sorted already.
90 bool isSorted = llvm::is_sorted(value);
91 // If not, do a general sort.
92 if (!isSorted)
93 llvm::array_pod_sort(storage.begin(), storage.end());
94 return !isSorted;
95 }
96 return false;
97}
98
99/// Returns an entry with a duplicate name from the given sorted array of named
100/// attributes. Returns std::nullopt if all elements have unique names.
101static std::optional<NamedAttribute>
102findDuplicateElement(ArrayRef<NamedAttribute> value) {
103 const std::optional<NamedAttribute> none{std::nullopt};
104 if (value.size() < 2)
105 return none;
106
107 if (value.size() == 2)
108 return value[0].getName() == value[1].getName() ? value[0] : none;
109
110 const auto *it = std::adjacent_find(first: value.begin(), last: value.end(),
111 binary_pred: [](NamedAttribute l, NamedAttribute r) {
112 return l.getName() == r.getName();
113 });
114 return it != value.end() ? *it : none;
115}
116
117bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
118 SmallVectorImpl<NamedAttribute> &storage) {
119 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
120 assert(!findDuplicateElement(storage) &&
121 "DictionaryAttr element names must be unique");
122 return isSorted;
123}
124
125bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
126 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
127 assert(!findDuplicateElement(array) &&
128 "DictionaryAttr element names must be unique");
129 return isSorted;
130}
131
132std::optional<NamedAttribute>
133DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
134 bool isSorted) {
135 if (!isSorted)
136 dictionaryAttrSort</*inPlace=*/true>(array, array);
137 return findDuplicateElement(array);
138}
139
140DictionaryAttr DictionaryAttr::get(MLIRContext *context,
141 ArrayRef<NamedAttribute> value) {
142 if (value.empty())
143 return DictionaryAttr::getEmpty(context);
144
145 // We need to sort the element list to canonicalize it.
146 SmallVector<NamedAttribute, 8> storage;
147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
148 value = storage;
149 assert(!findDuplicateElement(value) &&
150 "DictionaryAttr element names must be unique");
151 return Base::get(context, value);
152}
153/// Construct a dictionary with an array of values that is known to already be
154/// sorted by name and uniqued.
155DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
156 ArrayRef<NamedAttribute> value) {
157 if (value.empty())
158 return DictionaryAttr::getEmpty(context);
159 // Ensure that the attribute elements are unique and sorted.
160 assert(llvm::is_sorted(
161 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
162 "expected attribute values to be sorted");
163 assert(!findDuplicateElement(value) &&
164 "DictionaryAttr element names must be unique");
165 return Base::get(context, value);
166}
167
168/// Return the specified attribute if present, null otherwise.
169Attribute DictionaryAttr::get(StringRef name) const {
170 auto it = impl::findAttrSorted(begin(), end(), name);
171 return it.second ? it.first->getValue() : Attribute();
172}
173Attribute DictionaryAttr::get(StringAttr name) const {
174 auto it = impl::findAttrSorted(begin(), end(), name);
175 return it.second ? it.first->getValue() : Attribute();
176}
177
178/// Return the specified named attribute if present, std::nullopt otherwise.
179std::optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
180 auto it = impl::findAttrSorted(begin(), end(), name);
181 return it.second ? *it.first : std::optional<NamedAttribute>();
182}
183std::optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
184 auto it = impl::findAttrSorted(begin(), end(), name);
185 return it.second ? *it.first : std::optional<NamedAttribute>();
186}
187
188/// Return whether the specified attribute is present.
189bool DictionaryAttr::contains(StringRef name) const {
190 return impl::findAttrSorted(begin(), end(), name).second;
191}
192bool DictionaryAttr::contains(StringAttr name) const {
193 return impl::findAttrSorted(begin(), end(), name).second;
194}
195
196DictionaryAttr::iterator DictionaryAttr::begin() const {
197 return getValue().begin();
198}
199DictionaryAttr::iterator DictionaryAttr::end() const {
200 return getValue().end();
201}
202size_t DictionaryAttr::size() const { return getValue().size(); }
203
204DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
205 return Base::get(context, ArrayRef<NamedAttribute>());
206}
207
208//===----------------------------------------------------------------------===//
209// StridedLayoutAttr
210//===----------------------------------------------------------------------===//
211
212/// Prints a strided layout attribute.
213void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
214 auto printIntOrQuestion = [&](int64_t value) {
215 if (ShapedType::isDynamic(value))
216 os << "?";
217 else
218 os << value;
219 };
220
221 os << "strided<[";
222 llvm::interleaveComma(getStrides(), os, printIntOrQuestion);
223 os << "]";
224
225 if (getOffset() != 0) {
226 os << ", offset: ";
227 printIntOrQuestion(getOffset());
228 }
229 os << ">";
230}
231
232/// Returns the strided layout as an affine map.
233AffineMap StridedLayoutAttr::getAffineMap() const {
234 return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
235}
236
237/// Checks that the type-agnostic strided layout invariants are satisfied.
238LogicalResult
239StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
240 int64_t offset, ArrayRef<int64_t> strides) {
241 if (llvm::is_contained(strides, 0))
242 return emitError() << "strides must not be zero";
243
244 return success();
245}
246
247/// Checks that the type-specific strided layout invariants are satisfied.
248LogicalResult StridedLayoutAttr::verifyLayout(
249 ArrayRef<int64_t> shape,
250 function_ref<InFlightDiagnostic()> emitError) const {
251 if (shape.size() != getStrides().size())
252 return emitError() << "expected the number of strides to match the rank";
253
254 return success();
255}
256
257//===----------------------------------------------------------------------===//
258// StringAttr
259//===----------------------------------------------------------------------===//
260
261StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
262 return Base::get(context, "", NoneType::get(context));
263}
264
265/// Twine support for StringAttr.
266StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
267 // Fast-path empty twine.
268 if (twine.isTriviallyEmpty())
269 return get(context);
270 SmallVector<char, 32> tempStr;
271 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
272}
273
274/// Twine support for StringAttr.
275StringAttr StringAttr::get(const Twine &twine, Type type) {
276 SmallVector<char, 32> tempStr;
277 return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
278}
279
280StringRef StringAttr::getValue() const { return getImpl()->value; }
281
282Type StringAttr::getType() const { return getImpl()->type; }
283
284Dialect *StringAttr::getReferencedDialect() const {
285 return getImpl()->referencedDialect;
286}
287
288//===----------------------------------------------------------------------===//
289// FloatAttr
290//===----------------------------------------------------------------------===//
291
292double FloatAttr::getValueAsDouble() const {
293 return getValueAsDouble(getValue());
294}
295double FloatAttr::getValueAsDouble(APFloat value) {
296 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
297 bool losesInfo = false;
298 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
299 &losesInfo);
300 }
301 return value.convertToDouble();
302}
303
304LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
305 Type type, APFloat value) {
306 // Verify that the type is correct.
307 if (!llvm::isa<FloatType>(type))
308 return emitError() << "expected floating point type";
309
310 // Verify that the type semantics match that of the value.
311 if (&llvm::cast<FloatType>(type).getFloatSemantics() !=
312 &value.getSemantics()) {
313 return emitError()
314 << "FloatAttr type doesn't match the type implied by its value";
315 }
316 return success();
317}
318
319//===----------------------------------------------------------------------===//
320// SymbolRefAttr
321//===----------------------------------------------------------------------===//
322
323SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
324 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
325 return get(StringAttr::get(ctx, value), nestedRefs);
326}
327
328FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
329 return llvm::cast<FlatSymbolRefAttr>(get(ctx, value, {}));
330}
331
332FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
333 return llvm::cast<FlatSymbolRefAttr>(get(value, {}));
334}
335
336FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
337 auto symName =
338 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
339 assert(symName && "value does not have a valid symbol name");
340 return SymbolRefAttr::get(symName);
341}
342
343StringAttr SymbolRefAttr::getLeafReference() const {
344 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
345 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
346}
347
348//===----------------------------------------------------------------------===//
349// IntegerAttr
350//===----------------------------------------------------------------------===//
351
352int64_t IntegerAttr::getInt() const {
353 assert((getType().isIndex() || getType().isSignlessInteger()) &&
354 "must be signless integer");
355 return getValue().getSExtValue();
356}
357
358int64_t IntegerAttr::getSInt() const {
359 assert(getType().isSignedInteger() && "must be signed integer");
360 return getValue().getSExtValue();
361}
362
363uint64_t IntegerAttr::getUInt() const {
364 assert(getType().isUnsignedInteger() && "must be unsigned integer");
365 return getValue().getZExtValue();
366}
367
368/// Return the value as an APSInt which carries the signed from the type of
369/// the attribute. This traps on signless integers types!
370APSInt IntegerAttr::getAPSInt() const {
371 assert(!getType().isSignlessInteger() &&
372 "Signless integers don't carry a sign for APSInt");
373 return APSInt(getValue(), getType().isUnsignedInteger());
374}
375
376LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
377 Type type, APInt value) {
378 if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
379 if (integerType.getWidth() != value.getBitWidth())
380 return emitError() << "integer type bit width (" << integerType.getWidth()
381 << ") doesn't match value bit width ("
382 << value.getBitWidth() << ")";
383 return success();
384 }
385 if (llvm::isa<IndexType>(type)) {
386 if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
387 return emitError()
388 << "value bit width (" << value.getBitWidth()
389 << ") doesn't match index type internal storage bit width ("
390 << IndexType::kInternalStorageBitWidth << ")";
391 return success();
392 }
393 return emitError() << "expected integer or index type";
394}
395
396BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
397 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
398 return llvm::cast<BoolAttr>(attr);
399}
400
401//===----------------------------------------------------------------------===//
402// BoolAttr
403//===----------------------------------------------------------------------===//
404
405bool BoolAttr::getValue() const {
406 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
407 return storage->value.getBoolValue();
408}
409
410bool BoolAttr::classof(Attribute attr) {
411 IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr);
412 return intAttr && intAttr.getType().isSignlessInteger(1);
413}
414
415//===----------------------------------------------------------------------===//
416// OpaqueAttr
417//===----------------------------------------------------------------------===//
418
419LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
420 StringAttr dialect, StringRef attrData,
421 Type type) {
422 if (!Dialect::isValidNamespace(dialect.strref()))
423 return emitError() << "invalid dialect namespace '" << dialect << "'";
424
425 // Check that the dialect is actually registered.
426 MLIRContext *context = dialect.getContext();
427 if (!context->allowsUnregisteredDialects() &&
428 !context->getLoadedDialect(dialect.strref())) {
429 return emitError()
430 << "#" << dialect << "<\"" << attrData << "\"> : " << type
431 << " attribute created with unregistered dialect. If this is "
432 "intended, please call allowUnregisteredDialects() on the "
433 "MLIRContext, or use -allow-unregistered-dialect with "
434 "the MLIR opt tool used";
435 }
436
437 return success();
438}
439
440//===----------------------------------------------------------------------===//
441// DenseElementsAttr Utilities
442//===----------------------------------------------------------------------===//
443
444const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0;
445const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
446
447/// Get the bitwidth of a dense element type within the buffer.
448/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
449static size_t getDenseElementStorageWidth(size_t origWidth) {
450 return origWidth == 1 ? origWidth : llvm::alignTo<8>(Value: origWidth);
451}
452static size_t getDenseElementStorageWidth(Type elementType) {
453 return getDenseElementStorageWidth(origWidth: getDenseElementBitWidth(eltType: elementType));
454}
455
456/// Set a bit to a specific value.
457static void setBit(char *rawData, size_t bitPos, bool value) {
458 if (value)
459 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
460 else
461 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
462}
463
464/// Return the value of the specified bit.
465static bool getBit(const char *rawData, size_t bitPos) {
466 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
467}
468
469/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
470/// BE format.
471static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
472 char *result) {
473 assert(llvm::endianness::native == llvm::endianness::big);
474 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
475
476 // Copy the words filled with data.
477 // For example, when `value` has 2 words, the first word is filled with data.
478 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
479 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
480 std::copy_n(first: reinterpret_cast<const char *>(value.getRawData()),
481 n: numFilledWords, result: result);
482 // Convert last word of APInt to LE format and store it in char
483 // array(`valueLE`).
484 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
485 size_t lastWordPos = numFilledWords;
486 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
487 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
488 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
489 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
490 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
491 // and store it in `result`.
492 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
493 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
494 valueLE.begin(), result + lastWordPos,
495 (numBytes - lastWordPos) * CHAR_BIT, 1);
496}
497
498/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
499/// format.
500static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
501 APInt &result) {
502 assert(llvm::endianness::native == llvm::endianness::big);
503 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
504
505 // Copy the data that fills the word of `result` from `inArray`.
506 // For example, when `result` has 2 words, the first word will be filled with
507 // data. So, the first 8 bytes are copied from `inArray` here.
508 // `inArray` (10 bytes, BE): |abcdefgh|ij|
509 // ==> `result` (2 words, BE): |abcdefgh|--------|
510 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
511 std::copy_n(
512 first: inArray, n: numFilledWords,
513 result: const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
514
515 // Convert array data which will be last word of `result` to LE format, and
516 // store it in char array(`inArrayLE`).
517 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
518 size_t lastWordPos = numFilledWords;
519 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
520 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
521 inArray + lastWordPos, inArrayLE.begin(),
522 (numBytes - lastWordPos) * CHAR_BIT, 1);
523
524 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
525 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
526 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
527 inArrayLE.begin(),
528 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
529 lastWordPos,
530 APInt::APINT_BITS_PER_WORD, 1);
531}
532
533/// Writes value to the bit position `bitPos` in array `rawData`.
534static void writeBits(char *rawData, size_t bitPos, APInt value) {
535 size_t bitWidth = value.getBitWidth();
536
537 // If the bitwidth is 1 we just toggle the specific bit.
538 if (bitWidth == 1)
539 return setBit(rawData, bitPos, value: value.isOne());
540
541 // Otherwise, the bit position is guaranteed to be byte aligned.
542 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
543 if (llvm::endianness::native == llvm::endianness::big) {
544 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
545 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
546 // work correctly in BE format.
547 // ex. `value` (2 words including 10 bytes)
548 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
549 copyAPIntToArrayForBEmachine(value, numBytes: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
550 result: rawData + (bitPos / CHAR_BIT));
551 } else {
552 std::copy_n(first: reinterpret_cast<const char *>(value.getRawData()),
553 n: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
554 result: rawData + (bitPos / CHAR_BIT));
555 }
556}
557
558/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
559/// `rawData`.
560static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
561 // Handle a boolean bit position.
562 if (bitWidth == 1)
563 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
564
565 // Otherwise, the bit position must be 8-bit aligned.
566 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
567 APInt result(bitWidth, 0);
568 if (llvm::endianness::native == llvm::endianness::big) {
569 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
570 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
571 // work correctly in BE format.
572 // ex. `result` (2 words including 10 bytes)
573 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
574 copyArrayToAPIntForBEmachine(inArray: rawData + (bitPos / CHAR_BIT),
575 numBytes: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT), result);
576 } else {
577 std::copy_n(first: rawData + (bitPos / CHAR_BIT),
578 n: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
579 result: const_cast<char *>(
580 reinterpret_cast<const char *>(result.getRawData())));
581 }
582 return result;
583}
584
585/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
586/// the same element count as 'type'.
587template <typename Values>
588static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
589 return (values.size() == 1) ||
590 (type.getNumElements() == static_cast<int64_t>(values.size()));
591}
592
593//===----------------------------------------------------------------------===//
594// DenseElementsAttr Iterators
595//===----------------------------------------------------------------------===//
596
597//===----------------------------------------------------------------------===//
598// AttributeElementIterator
599
600DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
601 DenseElementsAttr attr, size_t index)
602 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
603 Attribute, Attribute, Attribute>(
604 attr.getAsOpaquePointer(), index) {}
605
606Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
607 auto owner = llvm::cast<DenseElementsAttr>(Val: getFromOpaquePointer(ptr: base));
608 Type eltTy = owner.getElementType();
609 if (llvm::dyn_cast<IntegerType>(eltTy))
610 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
611 if (llvm::isa<IndexType>(eltTy))
612 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
613 if (auto floatEltTy = llvm::dyn_cast<FloatType>(Val&: eltTy)) {
614 IntElementIterator intIt(owner, index);
615 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
616 return FloatAttr::get(eltTy, *floatIt);
617 }
618 if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
619 auto complexEltTy = complexTy.getElementType();
620 ComplexIntElementIterator complexIntIt(owner, index);
621 if (llvm::isa<IntegerType>(complexEltTy)) {
622 auto value = *complexIntIt;
623 auto real = IntegerAttr::get(complexEltTy, value.real());
624 auto imag = IntegerAttr::get(complexEltTy, value.imag());
625 return ArrayAttr::get(complexTy.getContext(),
626 ArrayRef<Attribute>{real, imag});
627 }
628
629 ComplexFloatElementIterator complexFloatIt(
630 llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
631 auto value = *complexFloatIt;
632 auto real = FloatAttr::get(complexEltTy, value.real());
633 auto imag = FloatAttr::get(complexEltTy, value.imag());
634 return ArrayAttr::get(complexTy.getContext(),
635 ArrayRef<Attribute>{real, imag});
636 }
637 if (llvm::isa<DenseStringElementsAttr>(owner)) {
638 ArrayRef<StringRef> vals = owner.getRawStringData();
639 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
640 }
641 llvm_unreachable("unexpected element type");
642}
643
644//===----------------------------------------------------------------------===//
645// BoolElementIterator
646
647DenseElementsAttr::BoolElementIterator::BoolElementIterator(
648 DenseElementsAttr attr, size_t dataIndex)
649 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
650 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
651
652bool DenseElementsAttr::BoolElementIterator::operator*() const {
653 return getBit(rawData: getData(), bitPos: getDataIndex());
654}
655
656//===----------------------------------------------------------------------===//
657// IntElementIterator
658
659DenseElementsAttr::IntElementIterator::IntElementIterator(
660 DenseElementsAttr attr, size_t dataIndex)
661 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
662 attr.getRawData().data(), attr.isSplat(), dataIndex),
663 bitWidth(getDenseElementBitWidth(eltType: attr.getElementType())) {}
664
665APInt DenseElementsAttr::IntElementIterator::operator*() const {
666 return readBits(rawData: getData(),
667 bitPos: getDataIndex() * getDenseElementStorageWidth(origWidth: bitWidth),
668 bitWidth);
669}
670
671//===----------------------------------------------------------------------===//
672// ComplexIntElementIterator
673
674DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
675 DenseElementsAttr attr, size_t dataIndex)
676 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
677 std::complex<APInt>, std::complex<APInt>,
678 std::complex<APInt>>(
679 attr.getRawData().data(), attr.isSplat(), dataIndex) {
680 auto complexType = llvm::cast<ComplexType>(attr.getElementType());
681 bitWidth = getDenseElementBitWidth(complexType.getElementType());
682}
683
684std::complex<APInt>
685DenseElementsAttr::ComplexIntElementIterator::operator*() const {
686 size_t storageWidth = getDenseElementStorageWidth(origWidth: bitWidth);
687 size_t offset = getDataIndex() * storageWidth * 2;
688 return {readBits(rawData: getData(), bitPos: offset, bitWidth),
689 readBits(rawData: getData(), bitPos: offset + storageWidth, bitWidth)};
690}
691
692//===----------------------------------------------------------------------===//
693// DenseArrayAttr
694//===----------------------------------------------------------------------===//
695
696LogicalResult
697DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
698 Type elementType, int64_t size, ArrayRef<char> rawData) {
699 if (!elementType.isIntOrIndexOrFloat())
700 return emitError() << "expected integer or floating point element type";
701 int64_t dataSize = rawData.size();
702 int64_t elementSize =
703 llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT);
704 if (size * elementSize != dataSize) {
705 return emitError() << "expected data size (" << size << " elements, "
706 << elementSize
707 << " bytes each) does not match: " << dataSize
708 << " bytes";
709 }
710 return success();
711}
712
713namespace {
714/// Instantiations of this class provide utilities for interacting with native
715/// data types in the context of DenseArrayAttr.
716template <size_t width,
717 IntegerType::SignednessSemantics signedness = IntegerType::Signless>
718struct DenseArrayAttrIntUtil {
719 static bool checkElementType(Type eltType) {
720 auto type = llvm::dyn_cast<IntegerType>(eltType);
721 if (!type || type.getWidth() != width)
722 return false;
723 return type.getSignedness() == signedness;
724 }
725
726 static Type getElementType(MLIRContext *ctx) {
727 return IntegerType::get(ctx, width, signedness);
728 }
729
730 template <typename T>
731 static void printElement(raw_ostream &os, T value) {
732 os << value;
733 }
734
735 template <typename T>
736 static ParseResult parseElement(AsmParser &parser, T &value) {
737 return parser.parseInteger(value);
738 }
739};
740template <typename T>
741struct DenseArrayAttrUtil;
742
743/// Specialization for boolean elements to print 'true' and 'false' literals for
744/// elements.
745template <>
746struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
747 static void printElement(raw_ostream &os, bool value) {
748 os << (value ? "true" : "false");
749 }
750};
751
752/// Specialization for 8-bit integers to ensure values are printed as integers
753/// and not characters.
754template <>
755struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
756 static void printElement(raw_ostream &os, int8_t value) {
757 os << static_cast<int>(value);
758 }
759};
760template <>
761struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
762template <>
763struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
764template <>
765struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
766
767/// Specialization for 32-bit floats.
768template <>
769struct DenseArrayAttrUtil<float> {
770 static bool checkElementType(Type eltType) { return eltType.isF32(); }
771 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
772 static void printElement(raw_ostream &os, float value) { os << value; }
773
774 /// Parse a double and cast it to a float.
775 static ParseResult parseElement(AsmParser &parser, float &value) {
776 double doubleVal;
777 if (parser.parseFloat(result&: doubleVal))
778 return failure();
779 value = doubleVal;
780 return success();
781 }
782};
783
784/// Specialization for 64-bit floats.
785template <>
786struct DenseArrayAttrUtil<double> {
787 static bool checkElementType(Type eltType) { return eltType.isF64(); }
788 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
789 static void printElement(raw_ostream &os, float value) { os << value; }
790 static ParseResult parseElement(AsmParser &parser, double &value) {
791 return parser.parseFloat(result&: value);
792 }
793};
794} // namespace
795
796template <typename T>
797void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
798 print(os&: printer.getStream());
799}
800
801template <typename T>
802void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
803 llvm::interleaveComma(asArrayRef(), os, [&](T value) {
804 DenseArrayAttrUtil<T>::printElement(os, value);
805 });
806}
807
808template <typename T>
809void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
810 os << "[";
811 printWithoutBraces(os);
812 os << "]";
813}
814
815/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
816template <typename T>
817Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
818 Type odsType) {
819 SmallVector<T> data;
820 if (failed(parser.parseCommaSeparatedList([&]() {
821 T value;
822 if (DenseArrayAttrUtil<T>::parseElement(parser, value))
823 return failure();
824 data.push_back(value);
825 return success();
826 })))
827 return {};
828 return get(context: parser.getContext(), content: data);
829}
830
831/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
832template <typename T>
833Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
834 if (parser.parseLSquare())
835 return {};
836 // Handle empty list case.
837 if (succeeded(result: parser.parseOptionalRSquare()))
838 return get(context: parser.getContext(), content: {});
839 Attribute result = parseWithoutBraces(parser, odsType);
840 if (parser.parseRSquare())
841 return {};
842 return result;
843}
844
845/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
846template <typename T>
847DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
848 ArrayRef<char> raw = getRawData();
849 assert((raw.size() % sizeof(T)) == 0);
850 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
851 raw.size() / sizeof(T));
852}
853
854/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
855template <typename T>
856DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
857 ArrayRef<T> content) {
858 Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
859 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
860 content.size() * sizeof(T));
861 return llvm::cast<DenseArrayAttrImpl<T>>(
862 Base::get(context, elementType, content.size(), rawArray));
863}
864
865template <typename T>
866bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
867 if (auto denseArray = llvm::dyn_cast<DenseArrayAttr>(attr))
868 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
869 return false;
870}
871
872namespace mlir {
873namespace detail {
874// Explicit instantiation for all the supported DenseArrayAttr.
875template class DenseArrayAttrImpl<bool>;
876template class DenseArrayAttrImpl<int8_t>;
877template class DenseArrayAttrImpl<int16_t>;
878template class DenseArrayAttrImpl<int32_t>;
879template class DenseArrayAttrImpl<int64_t>;
880template class DenseArrayAttrImpl<float>;
881template class DenseArrayAttrImpl<double>;
882} // namespace detail
883} // namespace mlir
884
885//===----------------------------------------------------------------------===//
886// DenseElementsAttr
887//===----------------------------------------------------------------------===//
888
889/// Method for support type inquiry through isa, cast and dyn_cast.
890bool DenseElementsAttr::classof(Attribute attr) {
891 return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
892}
893
894DenseElementsAttr DenseElementsAttr::get(ShapedType type,
895 ArrayRef<Attribute> values) {
896 assert(hasSameElementsOrSplat(type, values));
897
898 Type eltType = type.getElementType();
899
900 // Take care complex type case first.
901 if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
902 if (complexType.getElementType().isIntOrIndex()) {
903 SmallVector<std::complex<APInt>> complexValues;
904 complexValues.reserve(N: values.size());
905 for (Attribute attr : values) {
906 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
907 auto arrayAttr = llvm::cast<ArrayAttr>(attr);
908 assert(arrayAttr.size() == 2 && "expected 2 element for complex");
909 auto attr0 = arrayAttr[0];
910 auto attr1 = arrayAttr[1];
911 complexValues.push_back(
912 std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
913 llvm::cast<IntegerAttr>(attr1).getValue()));
914 }
915 return DenseElementsAttr::get(type, complexValues);
916 }
917 // Must be float.
918 SmallVector<std::complex<APFloat>> complexValues;
919 complexValues.reserve(N: values.size());
920 for (Attribute attr : values) {
921 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
922 auto arrayAttr = llvm::cast<ArrayAttr>(attr);
923 assert(arrayAttr.size() == 2 && "expected 2 element for complex");
924 auto attr0 = arrayAttr[0];
925 auto attr1 = arrayAttr[1];
926 complexValues.push_back(
927 std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
928 llvm::cast<FloatAttr>(attr1).getValue()));
929 }
930 return DenseElementsAttr::get(type, complexValues);
931 }
932
933 // If the element type is not based on int/float/index, assume it is a string
934 // type.
935 if (!eltType.isIntOrIndexOrFloat()) {
936 SmallVector<StringRef, 8> stringValues;
937 stringValues.reserve(N: values.size());
938 for (Attribute attr : values) {
939 assert(llvm::isa<StringAttr>(attr) &&
940 "expected string value for non integer/index/float element");
941 stringValues.push_back(Elt: llvm::cast<StringAttr>(attr).getValue());
942 }
943 return get(type, stringValues);
944 }
945
946 // Otherwise, get the raw storage width to use for the allocation.
947 size_t bitWidth = getDenseElementBitWidth(eltType);
948 size_t storageBitWidth = getDenseElementStorageWidth(origWidth: bitWidth);
949
950 // Compress the attribute values into a character buffer.
951 SmallVector<char, 8> data(
952 llvm::divideCeil(Numerator: storageBitWidth * values.size(), CHAR_BIT));
953 APInt intVal;
954 for (unsigned i = 0, e = values.size(); i < e; ++i) {
955 if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
956 assert(floatAttr.getType() == eltType &&
957 "expected float attribute type to equal element type");
958 intVal = floatAttr.getValue().bitcastToAPInt();
959 } else {
960 auto intAttr = llvm::cast<IntegerAttr>(values[i]);
961 assert(intAttr.getType() == eltType &&
962 "expected integer attribute type to equal element type");
963 intVal = intAttr.getValue();
964 }
965
966 assert(intVal.getBitWidth() == bitWidth &&
967 "expected value to have same bitwidth as element type");
968 writeBits(rawData: data.data(), bitPos: i * storageBitWidth, value: intVal);
969 }
970
971 // Handle the special encoding of splat of bool.
972 if (values.size() == 1 && eltType.isInteger(width: 1))
973 data[0] = data[0] ? -1 : 0;
974
975 return DenseIntOrFPElementsAttr::getRaw(type, data);
976}
977
978DenseElementsAttr DenseElementsAttr::get(ShapedType type,
979 ArrayRef<bool> values) {
980 assert(hasSameElementsOrSplat(type, values));
981 assert(type.getElementType().isInteger(1));
982
983 std::vector<char> buff(llvm::divideCeil(Numerator: values.size(), CHAR_BIT));
984
985 if (!values.empty()) {
986 bool isSplat = true;
987 bool firstValue = values[0];
988 for (int i = 0, e = values.size(); i != e; ++i) {
989 isSplat &= values[i] == firstValue;
990 setBit(rawData: buff.data(), bitPos: i, value: values[i]);
991 }
992
993 // Splat of bool is encoded as a byte with all-ones in it.
994 if (isSplat) {
995 buff.resize(new_size: 1);
996 buff[0] = values[0] ? -1 : 0;
997 }
998 }
999
1000 return DenseIntOrFPElementsAttr::getRaw(type, buff);
1001}
1002
1003DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1004 ArrayRef<StringRef> values) {
1005 assert(!type.getElementType().isIntOrFloat());
1006 return DenseStringElementsAttr::get(type, values);
1007}
1008
1009/// Constructs a dense integer elements attribute from an array of APInt
1010/// values. Each APInt value is expected to have the same bitwidth as the
1011/// element type of 'type'.
1012DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1013 ArrayRef<APInt> values) {
1014 assert(type.getElementType().isIntOrIndex());
1015 assert(hasSameElementsOrSplat(type, values));
1016 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1017 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1018}
1019DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1020 ArrayRef<std::complex<APInt>> values) {
1021 ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
1022 assert(llvm::isa<IntegerType>(complex.getElementType()));
1023 assert(hasSameElementsOrSplat(type, values));
1024 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1025 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1026 values.size() * 2);
1027 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1028}
1029
1030// Constructs a dense float elements attribute from an array of APFloat
1031// values. Each APFloat value is expected to have the same bitwidth as the
1032// element type of 'type'.
1033DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1034 ArrayRef<APFloat> values) {
1035 assert(llvm::isa<FloatType>(type.getElementType()));
1036 assert(hasSameElementsOrSplat(type, values));
1037 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1038 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1039}
1040DenseElementsAttr
1041DenseElementsAttr::get(ShapedType type,
1042 ArrayRef<std::complex<APFloat>> values) {
1043 ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
1044 assert(llvm::isa<FloatType>(complex.getElementType()));
1045 assert(hasSameElementsOrSplat(type, values));
1046 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1047 values.size() * 2);
1048 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1049 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1050}
1051
1052/// Construct a dense elements attribute from a raw buffer representing the
1053/// data for this attribute. Users should generally not use this methods as
1054/// the expected buffer format may not be a form the user expects.
1055DenseElementsAttr
1056DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1057 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1058}
1059
1060/// Returns true if the given buffer is a valid raw buffer for the given type.
1061bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1062 ArrayRef<char> rawBuffer,
1063 bool &detectedSplat) {
1064 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1065 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1066 int64_t numElements = type.getNumElements();
1067
1068 // The initializer is always a splat if the result type has a single element.
1069 detectedSplat = numElements == 1;
1070
1071 // Storage width of 1 is special as it is packed by the bit.
1072 if (storageWidth == 1) {
1073 // Check for a splat, or a buffer equal to the number of elements which
1074 // consists of either all 0's or all 1's.
1075 if (rawBuffer.size() == 1) {
1076 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1077 if (rawByte == 0 || rawByte == 0xff) {
1078 detectedSplat = true;
1079 return true;
1080 }
1081 }
1082
1083 // This is a valid non-splat buffer if it has the right size.
1084 return rawBufferWidth == llvm::alignTo<8>(Value: numElements);
1085 }
1086
1087 // All other types are 8-bit aligned, so we can just check the buffer width
1088 // to know if only a single initializer element was passed in.
1089 if (rawBufferWidth == storageWidth) {
1090 detectedSplat = true;
1091 return true;
1092 }
1093
1094 // The raw buffer is valid if it has the right size.
1095 return rawBufferWidth == storageWidth * numElements;
1096}
1097
1098/// Check the information for a C++ data type, check if this type is valid for
1099/// the current attribute. This method is used to verify specific type
1100/// invariants that the templatized 'getValues' method cannot.
1101static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1102 bool isSigned) {
1103 // Make sure that the data element size is the same as the type element width.
1104 auto denseEltBitWidth = getDenseElementBitWidth(eltType: type);
1105 auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT);
1106 if (denseEltBitWidth != dataSize) {
1107 LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
1108 << denseEltBitWidth << " to match data size "
1109 << dataSize << " for type " << type << "\n");
1110 return false;
1111 }
1112
1113 // Check that the element type is either float or integer or index.
1114 if (!isInt) {
1115 bool valid = llvm::isa<FloatType>(Val: type);
1116 if (!valid)
1117 LLVM_DEBUG(llvm::dbgs()
1118 << "expected float type when isInt is false, but found "
1119 << type << "\n");
1120 return valid;
1121 }
1122 if (type.isIndex())
1123 return true;
1124
1125 auto intType = llvm::dyn_cast<IntegerType>(type);
1126 if (!intType) {
1127 LLVM_DEBUG(llvm::dbgs()
1128 << "expected integer type when isInt is true, but found " << type
1129 << "\n");
1130 return false;
1131 }
1132
1133 // Make sure signedness semantics is consistent.
1134 if (intType.isSignless())
1135 return true;
1136
1137 bool valid = intType.isSigned() == isSigned;
1138 if (!valid)
1139 LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
1140 << " to match type " << type << "\n");
1141 return valid;
1142}
1143
1144/// Defaults down the subclass implementation.
1145DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1146 ArrayRef<char> data,
1147 int64_t dataEltSize,
1148 bool isInt, bool isSigned) {
1149 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1150 isSigned);
1151}
1152DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1153 ArrayRef<char> data,
1154 int64_t dataEltSize,
1155 bool isInt,
1156 bool isSigned) {
1157 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1158 isInt, isSigned);
1159}
1160
1161bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1162 bool isSigned) const {
1163 return ::isValidIntOrFloat(type: getElementType(), dataEltSize, isInt, isSigned);
1164}
1165bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1166 bool isSigned) const {
1167 return ::isValidIntOrFloat(
1168 llvm::cast<ComplexType>(getElementType()).getElementType(),
1169 dataEltSize / 2, isInt, isSigned);
1170}
1171
1172/// Returns true if this attribute corresponds to a splat, i.e. if all element
1173/// values are the same.
1174bool DenseElementsAttr::isSplat() const {
1175 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1176}
1177
1178/// Return if the given complex type has an integer element type.
1179static bool isComplexOfIntType(Type type) {
1180 return llvm::isa<IntegerType>(llvm::cast<ComplexType>(type).getElementType());
1181}
1182
1183auto DenseElementsAttr::tryGetComplexIntValues() const
1184 -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
1185 if (!isComplexOfIntType(type: getElementType()))
1186 return failure();
1187 return iterator_range_impl<ComplexIntElementIterator>(
1188 getType(), ComplexIntElementIterator(*this, 0),
1189 ComplexIntElementIterator(*this, getNumElements()));
1190}
1191
1192auto DenseElementsAttr::tryGetFloatValues() const
1193 -> FailureOr<iterator_range_impl<FloatElementIterator>> {
1194 auto eltTy = llvm::dyn_cast<FloatType>(Val: getElementType());
1195 if (!eltTy)
1196 return failure();
1197 const auto &elementSemantics = eltTy.getFloatSemantics();
1198 return iterator_range_impl<FloatElementIterator>(
1199 getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1200 FloatElementIterator(elementSemantics, raw_int_end()));
1201}
1202
1203auto DenseElementsAttr::tryGetComplexFloatValues() const
1204 -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
1205 auto complexTy = llvm::dyn_cast<ComplexType>(getElementType());
1206 if (!complexTy)
1207 return failure();
1208 auto eltTy = llvm::dyn_cast<FloatType>(complexTy.getElementType());
1209 if (!eltTy)
1210 return failure();
1211 const auto &semantics = eltTy.getFloatSemantics();
1212 return iterator_range_impl<ComplexFloatElementIterator>(
1213 getType(), {semantics, {*this, 0}},
1214 {semantics, {*this, static_cast<size_t>(getNumElements())}});
1215}
1216
1217/// Return the raw storage data held by this attribute.
1218ArrayRef<char> DenseElementsAttr::getRawData() const {
1219 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1220}
1221
1222ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1223 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1224}
1225
1226/// Return a new DenseElementsAttr that has the same data as the current
1227/// attribute, but has been reshaped to 'newType'. The new type must have the
1228/// same total number of elements as well as element type.
1229DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1230 ShapedType curType = getType();
1231 if (curType == newType)
1232 return *this;
1233
1234 assert(newType.getElementType() == curType.getElementType() &&
1235 "expected the same element type");
1236 assert(newType.getNumElements() == curType.getNumElements() &&
1237 "expected the same number of elements");
1238 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1239}
1240
1241DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1242 assert(isSplat() && "expected a splat type");
1243
1244 ShapedType curType = getType();
1245 if (curType == newType)
1246 return *this;
1247
1248 assert(newType.getElementType() == curType.getElementType() &&
1249 "expected the same element type");
1250 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1251}
1252
1253/// Return a new DenseElementsAttr that has the same data as the current
1254/// attribute, but has bitcast elements such that it is now 'newType'. The new
1255/// type must have the same shape and element types of the same bitwidth as the
1256/// current type.
1257DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1258 ShapedType curType = getType();
1259 Type curElType = curType.getElementType();
1260 if (curElType == newElType)
1261 return *this;
1262
1263 assert(getDenseElementBitWidth(newElType) ==
1264 getDenseElementBitWidth(curElType) &&
1265 "expected element types with the same bitwidth");
1266 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1267 getRawData());
1268}
1269
1270DenseElementsAttr
1271DenseElementsAttr::mapValues(Type newElementType,
1272 function_ref<APInt(const APInt &)> mapping) const {
1273 return llvm::cast<DenseIntElementsAttr>(Val: *this).mapValues(newElementType,
1274 mapping);
1275}
1276
1277DenseElementsAttr DenseElementsAttr::mapValues(
1278 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1279 return llvm::cast<DenseFPElementsAttr>(Val: *this).mapValues(newElementType,
1280 mapping);
1281}
1282
1283ShapedType DenseElementsAttr::getType() const {
1284 return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1285}
1286
1287Type DenseElementsAttr::getElementType() const {
1288 return getType().getElementType();
1289}
1290
1291int64_t DenseElementsAttr::getNumElements() const {
1292 return getType().getNumElements();
1293}
1294
1295//===----------------------------------------------------------------------===//
1296// DenseIntOrFPElementsAttr
1297//===----------------------------------------------------------------------===//
1298
1299/// Utility method to write a range of APInt values to a buffer.
1300template <typename APRangeT>
1301static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1302 APRangeT &&values) {
1303 size_t numValues = llvm::size(values);
1304 data.resize(new_size: llvm::divideCeil(Numerator: storageWidth * numValues, CHAR_BIT));
1305 size_t offset = 0;
1306 for (auto it = values.begin(), e = values.end(); it != e;
1307 ++it, offset += storageWidth) {
1308 assert((*it).getBitWidth() <= storageWidth);
1309 writeBits(data.data(), offset, *it);
1310 }
1311
1312 // Handle the special encoding of splat of a boolean.
1313 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1314 data[0] = data[0] ? -1 : 0;
1315}
1316
1317/// Constructs a dense elements attribute from an array of raw APFloat values.
1318/// Each APFloat value is expected to have the same bitwidth as the element
1319/// type of 'type'. 'type' must be a vector or tensor with static shape.
1320DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1321 size_t storageWidth,
1322 ArrayRef<APFloat> values) {
1323 std::vector<char> data;
1324 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1325 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1326 return DenseIntOrFPElementsAttr::getRaw(type, data);
1327}
1328
1329/// Constructs a dense elements attribute from an array of raw APInt values.
1330/// Each APInt value is expected to have the same bitwidth as the element type
1331/// of 'type'.
1332DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1333 size_t storageWidth,
1334 ArrayRef<APInt> values) {
1335 std::vector<char> data;
1336 writeAPIntsToBuffer(storageWidth, data, values);
1337 return DenseIntOrFPElementsAttr::getRaw(type, data);
1338}
1339
1340DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1341 ArrayRef<char> data) {
1342 assert(type.hasStaticShape() && "type must have static shape");
1343 bool isSplat = false;
1344 bool isValid = isValidRawBuffer(type, data, isSplat);
1345 assert(isValid);
1346 (void)isValid;
1347 return Base::get(type.getContext(), type, data, isSplat);
1348}
1349
1350/// Overload of the raw 'get' method that asserts that the given type is of
1351/// complex type. This method is used to verify type invariants that the
1352/// templatized 'get' method cannot.
1353DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1354 ArrayRef<char> data,
1355 int64_t dataEltSize,
1356 bool isInt,
1357 bool isSigned) {
1358 assert(::isValidIntOrFloat(
1359 llvm::cast<ComplexType>(type.getElementType()).getElementType(),
1360 dataEltSize / 2, isInt, isSigned) &&
1361 "Try re-running with -debug-only=builtinattributes");
1362
1363 int64_t numElements = data.size() / dataEltSize;
1364 (void)numElements;
1365 assert(numElements == 1 || numElements == type.getNumElements());
1366 return getRaw(type, data);
1367}
1368
1369/// Overload of the 'getRaw' method that asserts that the given type is of
1370/// integer type. This method is used to verify type invariants that the
1371/// templatized 'get' method cannot.
1372DenseElementsAttr
1373DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1374 int64_t dataEltSize, bool isInt,
1375 bool isSigned) {
1376 assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt,
1377 isSigned) &&
1378 "Try re-running with -debug-only=builtinattributes");
1379
1380 int64_t numElements = data.size() / dataEltSize;
1381 assert(numElements == 1 || numElements == type.getNumElements());
1382 (void)numElements;
1383 return getRaw(type, data);
1384}
1385
1386void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1387 const char *inRawData, char *outRawData, size_t elementBitWidth,
1388 size_t numElements) {
1389 using llvm::support::ulittle16_t;
1390 using llvm::support::ulittle32_t;
1391 using llvm::support::ulittle64_t;
1392
1393 assert(llvm::endianness::native == llvm::endianness::big);
1394 // NOLINT to avoid warning message about replacing by static_assert()
1395
1396 // Following std::copy_n always converts endianness on BE machine.
1397 switch (elementBitWidth) {
1398 case 16: {
1399 const ulittle16_t *inRawDataPos =
1400 reinterpret_cast<const ulittle16_t *>(inRawData);
1401 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1402 std::copy_n(inRawDataPos, numElements, outDataPos);
1403 break;
1404 }
1405 case 32: {
1406 const ulittle32_t *inRawDataPos =
1407 reinterpret_cast<const ulittle32_t *>(inRawData);
1408 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1409 std::copy_n(inRawDataPos, numElements, outDataPos);
1410 break;
1411 }
1412 case 64: {
1413 const ulittle64_t *inRawDataPos =
1414 reinterpret_cast<const ulittle64_t *>(inRawData);
1415 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1416 std::copy_n(inRawDataPos, numElements, outDataPos);
1417 break;
1418 }
1419 default: {
1420 size_t nBytes = elementBitWidth / CHAR_BIT;
1421 for (size_t i = 0; i < nBytes; i++)
1422 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1423 break;
1424 }
1425 }
1426}
1427
1428void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1429 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1430 ShapedType type) {
1431 size_t numElements = type.getNumElements();
1432 Type elementType = type.getElementType();
1433 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
1434 elementType = complexTy.getElementType();
1435 numElements = numElements * 2;
1436 }
1437 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1438 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1439 inRawData.size() <= outRawData.size());
1440 if (elementBitWidth <= CHAR_BIT)
1441 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1442 else
1443 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1444 elementBitWidth, numElements);
1445}
1446
1447//===----------------------------------------------------------------------===//
1448// DenseFPElementsAttr
1449//===----------------------------------------------------------------------===//
1450
1451template <typename Fn, typename Attr>
1452static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1453 Type newElementType,
1454 llvm::SmallVectorImpl<char> &data) {
1455 size_t bitWidth = getDenseElementBitWidth(eltType: newElementType);
1456 size_t storageBitWidth = getDenseElementStorageWidth(origWidth: bitWidth);
1457
1458 ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType);
1459
1460 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1461 data.resize(N: llvm::divideCeil(Numerator: storageBitWidth * numRawElements, CHAR_BIT));
1462
1463 // Functor used to process a single element value of the attribute.
1464 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1465 auto newInt = mapping(value);
1466 assert(newInt.getBitWidth() == bitWidth);
1467 writeBits(data.data(), index * storageBitWidth, newInt);
1468 };
1469
1470 // Check for the splat case.
1471 if (attr.isSplat()) {
1472 if (bitWidth == 1) {
1473 // Handle the special encoding of splat of bool.
1474 data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
1475 } else {
1476 processElt(*attr.begin(), /*index=*/0);
1477 }
1478 return newArrayType;
1479 }
1480
1481 // Otherwise, process all of the element values.
1482 uint64_t elementIdx = 0;
1483 for (auto value : attr)
1484 processElt(value, elementIdx++);
1485 return newArrayType;
1486}
1487
1488DenseElementsAttr DenseFPElementsAttr::mapValues(
1489 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1490 llvm::SmallVector<char, 8> elementData;
1491 auto newArrayType =
1492 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1493
1494 return getRaw(newArrayType, elementData);
1495}
1496
1497/// Method for supporting type inquiry through isa, cast and dyn_cast.
1498bool DenseFPElementsAttr::classof(Attribute attr) {
1499 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val&: attr))
1500 return llvm::isa<FloatType>(denseAttr.getType().getElementType());
1501 return false;
1502}
1503
1504//===----------------------------------------------------------------------===//
1505// DenseIntElementsAttr
1506//===----------------------------------------------------------------------===//
1507
1508DenseElementsAttr DenseIntElementsAttr::mapValues(
1509 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1510 llvm::SmallVector<char, 8> elementData;
1511 auto newArrayType =
1512 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1513 return getRaw(newArrayType, elementData);
1514}
1515
1516/// Method for supporting type inquiry through isa, cast and dyn_cast.
1517bool DenseIntElementsAttr::classof(Attribute attr) {
1518 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val&: attr))
1519 return denseAttr.getType().getElementType().isIntOrIndex();
1520 return false;
1521}
1522
1523//===----------------------------------------------------------------------===//
1524// DenseResourceElementsAttr
1525//===----------------------------------------------------------------------===//
1526
1527DenseResourceElementsAttr
1528DenseResourceElementsAttr::get(ShapedType type,
1529 DenseResourceElementsHandle handle) {
1530 return Base::get(type.getContext(), type, handle);
1531}
1532
1533DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1534 StringRef blobName,
1535 AsmResourceBlob blob) {
1536 // Extract the builtin dialect resource manager from context and construct a
1537 // handle by inserting a new resource using the provided blob.
1538 auto &manager =
1539 DenseResourceElementsHandle::getManagerInterface(type.getContext());
1540 return get(type, manager.insert(blobName, std::move(blob)));
1541}
1542
1543//===----------------------------------------------------------------------===//
1544// DenseResourceElementsAttrBase
1545
1546namespace {
1547/// Instantiations of this class provide utilities for interacting with native
1548/// data types in the context of DenseResourceElementsAttr.
1549template <typename T>
1550struct DenseResourceAttrUtil;
1551template <size_t width, bool isSigned>
1552struct DenseResourceElementsAttrIntUtil {
1553 static bool checkElementType(Type eltType) {
1554 IntegerType type = llvm::dyn_cast<IntegerType>(eltType);
1555 if (!type || type.getWidth() != width)
1556 return false;
1557 return isSigned ? !type.isUnsigned() : !type.isSigned();
1558 }
1559};
1560template <>
1561struct DenseResourceAttrUtil<bool> {
1562 static bool checkElementType(Type eltType) {
1563 return eltType.isSignlessInteger(width: 1);
1564 }
1565};
1566template <>
1567struct DenseResourceAttrUtil<int8_t>
1568 : public DenseResourceElementsAttrIntUtil<8, true> {};
1569template <>
1570struct DenseResourceAttrUtil<uint8_t>
1571 : public DenseResourceElementsAttrIntUtil<8, false> {};
1572template <>
1573struct DenseResourceAttrUtil<int16_t>
1574 : public DenseResourceElementsAttrIntUtil<16, true> {};
1575template <>
1576struct DenseResourceAttrUtil<uint16_t>
1577 : public DenseResourceElementsAttrIntUtil<16, false> {};
1578template <>
1579struct DenseResourceAttrUtil<int32_t>
1580 : public DenseResourceElementsAttrIntUtil<32, true> {};
1581template <>
1582struct DenseResourceAttrUtil<uint32_t>
1583 : public DenseResourceElementsAttrIntUtil<32, false> {};
1584template <>
1585struct DenseResourceAttrUtil<int64_t>
1586 : public DenseResourceElementsAttrIntUtil<64, true> {};
1587template <>
1588struct DenseResourceAttrUtil<uint64_t>
1589 : public DenseResourceElementsAttrIntUtil<64, false> {};
1590template <>
1591struct DenseResourceAttrUtil<float> {
1592 static bool checkElementType(Type eltType) { return eltType.isF32(); }
1593};
1594template <>
1595struct DenseResourceAttrUtil<double> {
1596 static bool checkElementType(Type eltType) { return eltType.isF64(); }
1597};
1598} // namespace
1599
1600template <typename T>
1601DenseResourceElementsAttrBase<T>
1602DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1603 AsmResourceBlob blob) {
1604 // Check that the blob is in the form we were expecting.
1605 assert(blob.getDataAlignment() == alignof(T) &&
1606 "alignment mismatch between expected alignment and blob alignment");
1607 assert(((blob.getData().size() % sizeof(T)) == 0) &&
1608 "size mismatch between expected element width and blob size");
1609 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
1610 "invalid shape element type for provided type `T`");
1611 return llvm::cast<DenseResourceElementsAttrBase<T>>(
1612 DenseResourceElementsAttr::get(type, blobName, std::move(blob)));
1613}
1614
1615template <typename T>
1616std::optional<ArrayRef<T>>
1617DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1618 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1619 return blob->template getDataAs<T>();
1620 return std::nullopt;
1621}
1622
1623template <typename T>
1624bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1625 auto resourceAttr = llvm::dyn_cast<DenseResourceElementsAttr>(attr);
1626 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1627 resourceAttr.getElementType());
1628}
1629
1630namespace mlir {
1631namespace detail {
1632// Explicit instantiation for all the supported DenseResourceElementsAttr.
1633template class DenseResourceElementsAttrBase<bool>;
1634template class DenseResourceElementsAttrBase<int8_t>;
1635template class DenseResourceElementsAttrBase<int16_t>;
1636template class DenseResourceElementsAttrBase<int32_t>;
1637template class DenseResourceElementsAttrBase<int64_t>;
1638template class DenseResourceElementsAttrBase<uint8_t>;
1639template class DenseResourceElementsAttrBase<uint16_t>;
1640template class DenseResourceElementsAttrBase<uint32_t>;
1641template class DenseResourceElementsAttrBase<uint64_t>;
1642template class DenseResourceElementsAttrBase<float>;
1643template class DenseResourceElementsAttrBase<double>;
1644} // namespace detail
1645} // namespace mlir
1646
1647//===----------------------------------------------------------------------===//
1648// SparseElementsAttr
1649//===----------------------------------------------------------------------===//
1650
1651/// Get a zero APFloat for the given sparse attribute.
1652APFloat SparseElementsAttr::getZeroAPFloat() const {
1653 auto eltType = llvm::cast<FloatType>(getElementType());
1654 return APFloat(eltType.getFloatSemantics());
1655}
1656
1657/// Get a zero APInt for the given sparse attribute.
1658APInt SparseElementsAttr::getZeroAPInt() const {
1659 auto eltType = llvm::cast<IntegerType>(getElementType());
1660 return APInt::getZero(eltType.getWidth());
1661}
1662
1663/// Get a zero attribute for the given attribute type.
1664Attribute SparseElementsAttr::getZeroAttr() const {
1665 auto eltType = getElementType();
1666
1667 // Handle floating point elements.
1668 if (llvm::isa<FloatType>(eltType))
1669 return FloatAttr::get(eltType, 0);
1670
1671 // Handle complex elements.
1672 if (auto complexTy = llvm::dyn_cast<ComplexType>(eltType)) {
1673 auto eltType = complexTy.getElementType();
1674 Attribute zero;
1675 if (llvm::isa<FloatType>(eltType))
1676 zero = FloatAttr::get(eltType, 0);
1677 else // must be integer
1678 zero = IntegerAttr::get(eltType, 0);
1679 return ArrayAttr::get(complexTy.getContext(),
1680 ArrayRef<Attribute>{zero, zero});
1681 }
1682
1683 // Handle string type.
1684 if (llvm::isa<DenseStringElementsAttr>(getValues()))
1685 return StringAttr::get("", eltType);
1686
1687 // Otherwise, this is an integer.
1688 return IntegerAttr::get(eltType, 0);
1689}
1690
1691/// Flatten, and return, all of the sparse indices in this attribute in
1692/// row-major order.
1693std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1694 std::vector<ptrdiff_t> flatSparseIndices;
1695
1696 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1697 // as a 1-D index array.
1698 auto sparseIndices = getIndices();
1699 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1700 if (sparseIndices.isSplat()) {
1701 SmallVector<uint64_t, 8> indices(getType().getRank(),
1702 *sparseIndexValues.begin());
1703 flatSparseIndices.push_back(getFlattenedIndex(indices));
1704 return flatSparseIndices;
1705 }
1706
1707 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1708 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1709 size_t rank = getType().getRank();
1710 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1711 flatSparseIndices.push_back(getFlattenedIndex(
1712 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1713 return flatSparseIndices;
1714}
1715
1716LogicalResult
1717SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1718 ShapedType type, DenseIntElementsAttr sparseIndices,
1719 DenseElementsAttr values) {
1720 ShapedType valuesType = values.getType();
1721 if (valuesType.getRank() != 1)
1722 return emitError() << "expected 1-d tensor for sparse element values";
1723
1724 // Verify the indices and values shape.
1725 ShapedType indicesType = sparseIndices.getType();
1726 auto emitShapeError = [&]() {
1727 return emitError() << "expected shape ([" << type.getShape()
1728 << "]); inferred shape of indices literal (["
1729 << indicesType.getShape()
1730 << "]); inferred shape of values literal (["
1731 << valuesType.getShape() << "])";
1732 };
1733 // Verify indices shape.
1734 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1735 if (indicesRank == 2) {
1736 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1737 return emitShapeError();
1738 } else if (indicesRank != 1 || rank != 1) {
1739 return emitShapeError();
1740 }
1741 // Verify the values shape.
1742 int64_t numSparseIndices = indicesType.getDimSize(0);
1743 if (numSparseIndices != valuesType.getDimSize(0))
1744 return emitShapeError();
1745
1746 // Verify that the sparse indices are within the value shape.
1747 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1748 return emitError()
1749 << "sparse index #" << indexNum
1750 << " is not contained within the value shape, with index=[" << index
1751 << "], and type=" << type;
1752 };
1753
1754 // Handle the case where the index values are a splat.
1755 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1756 if (sparseIndices.isSplat()) {
1757 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1758 if (!ElementsAttr::isValidIndex(type, indices))
1759 return emitIndexError(0, indices);
1760 return success();
1761 }
1762
1763 // Otherwise, reinterpret each index as an ArrayRef.
1764 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1765 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1766 rank);
1767 if (!ElementsAttr::isValidIndex(type, index))
1768 return emitIndexError(i, index);
1769 }
1770
1771 return success();
1772}
1773
1774//===----------------------------------------------------------------------===//
1775// DistinctAttr
1776//===----------------------------------------------------------------------===//
1777
1778DistinctAttr DistinctAttr::create(Attribute referencedAttr) {
1779 return Base::get(referencedAttr.getContext(), referencedAttr);
1780}
1781
1782Attribute DistinctAttr::getReferencedAttr() const {
1783 return getImpl()->referencedAttr;
1784}
1785
1786//===----------------------------------------------------------------------===//
1787// Attribute Utilities
1788//===----------------------------------------------------------------------===//
1789
1790AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1791 int64_t offset,
1792 MLIRContext *context) {
1793 AffineExpr expr;
1794 unsigned nSymbols = 0;
1795
1796 // AffineExpr for offset.
1797 // Static case.
1798 if (!ShapedType::isDynamic(offset)) {
1799 auto cst = getAffineConstantExpr(constant: offset, context);
1800 expr = cst;
1801 } else {
1802 // Dynamic case, new symbol for the offset.
1803 auto sym = getAffineSymbolExpr(position: nSymbols++, context);
1804 expr = sym;
1805 }
1806
1807 // AffineExpr for strides.
1808 for (const auto &en : llvm::enumerate(First&: strides)) {
1809 auto dim = en.index();
1810 auto stride = en.value();
1811 assert(stride != 0 && "Invalid stride specification");
1812 auto d = getAffineDimExpr(position: dim, context);
1813 AffineExpr mult;
1814 // Static case.
1815 if (!ShapedType::isDynamic(stride))
1816 mult = getAffineConstantExpr(constant: stride, context);
1817 else
1818 // Dynamic case, new symbol for each new stride.
1819 mult = getAffineSymbolExpr(position: nSymbols++, context);
1820 expr = expr + d * mult;
1821 }
1822
1823 return AffineMap::get(dimCount: strides.size(), symbolCount: nSymbols, result: expr);
1824}
1825

source code of mlir/lib/IR/BuiltinAttributes.cpp