1//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributes.h"
10#include "AttributeDetail.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinDialect.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/DialectResourceBlobManager.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Types.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Debug.h"
24#include "llvm/Support/Endian.h"
25#include <optional>
26
27#define DEBUG_TYPE "builtinattributes"
28
29using namespace mlir;
30using namespace mlir::detail;
31
32//===----------------------------------------------------------------------===//
33/// Tablegen Attribute Definitions
34//===----------------------------------------------------------------------===//
35
36#define GET_ATTRDEF_CLASSES
37#include "mlir/IR/BuiltinAttributes.cpp.inc"
38
39//===----------------------------------------------------------------------===//
40// BuiltinDialect
41//===----------------------------------------------------------------------===//
42
43void BuiltinDialect::registerAttributes() {
44 addAttributes<
45#define GET_ATTRDEF_LIST
46#include "mlir/IR/BuiltinAttributes.cpp.inc"
47 >();
48 addAttributes<DistinctAttr>();
49}
50
51//===----------------------------------------------------------------------===//
52// DictionaryAttr
53//===----------------------------------------------------------------------===//
54
55/// Helper function that does either an in place sort or sorts from source array
56/// into destination. If inPlace then storage is both the source and the
57/// destination, else value is the source and storage destination. Returns
58/// whether source was sorted.
59template <bool inPlace>
60static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
61 SmallVectorImpl<NamedAttribute> &storage) {
62 // Specialize for the common case.
63 switch (value.size()) {
64 case 0:
65 // Zero already sorted.
66 if (!inPlace)
67 storage.clear();
68 break;
69 case 1:
70 // One already sorted but may need to be copied.
71 if (!inPlace)
72 storage.assign(IL: {value[0]});
73 break;
74 case 2: {
75 bool isSorted = value[0] < value[1];
76 if (inPlace) {
77 if (!isSorted)
78 std::swap(storage[0], storage[1]);
79 } else if (isSorted) {
80 storage.assign(IL: {value[0], value[1]});
81 } else {
82 storage.assign(IL: {value[1], value[0]});
83 }
84 return !isSorted;
85 }
86 default:
87 if (!inPlace)
88 storage.assign(value.begin(), value.end());
89 // Check to see they are sorted already.
90 bool isSorted = llvm::is_sorted(value);
91 // If not, do a general sort.
92 if (!isSorted)
93 llvm::array_pod_sort(storage.begin(), storage.end());
94 return !isSorted;
95 }
96 return false;
97}
98
99/// Returns an entry with a duplicate name from the given sorted array of named
100/// attributes. Returns std::nullopt if all elements have unique names.
101static std::optional<NamedAttribute>
102findDuplicateElement(ArrayRef<NamedAttribute> value) {
103 const std::optional<NamedAttribute> none{std::nullopt};
104 if (value.size() < 2)
105 return none;
106
107 if (value.size() == 2)
108 return value[0].getName() == value[1].getName() ? value[0] : none;
109
110 const auto *it = std::adjacent_find(first: value.begin(), last: value.end(),
111 binary_pred: [](NamedAttribute l, NamedAttribute r) {
112 return l.getName() == r.getName();
113 });
114 return it != value.end() ? *it : none;
115}
116
117bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
118 SmallVectorImpl<NamedAttribute> &storage) {
119 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
120 assert(!findDuplicateElement(storage) &&
121 "DictionaryAttr element names must be unique");
122 return isSorted;
123}
124
125bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
126 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
127 assert(!findDuplicateElement(array) &&
128 "DictionaryAttr element names must be unique");
129 return isSorted;
130}
131
132std::optional<NamedAttribute>
133DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
134 bool isSorted) {
135 if (!isSorted)
136 dictionaryAttrSort</*inPlace=*/true>(array, array);
137 return findDuplicateElement(array);
138}
139
140DictionaryAttr DictionaryAttr::get(MLIRContext *context,
141 ArrayRef<NamedAttribute> value) {
142 if (value.empty())
143 return DictionaryAttr::getEmpty(context);
144
145 // We need to sort the element list to canonicalize it.
146 SmallVector<NamedAttribute, 8> storage;
147 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
148 value = storage;
149 assert(!findDuplicateElement(value) &&
150 "DictionaryAttr element names must be unique");
151 return Base::get(context, value);
152}
153/// Construct a dictionary with an array of values that is known to already be
154/// sorted by name and uniqued.
155DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
156 ArrayRef<NamedAttribute> value) {
157 if (value.empty())
158 return DictionaryAttr::getEmpty(context);
159 // Ensure that the attribute elements are unique and sorted.
160 assert(llvm::is_sorted(
161 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
162 "expected attribute values to be sorted");
163 assert(!findDuplicateElement(value) &&
164 "DictionaryAttr element names must be unique");
165 return Base::get(context, value);
166}
167
168/// Return the specified attribute if present, null otherwise.
169Attribute DictionaryAttr::get(StringRef name) const {
170 auto it = impl::findAttrSorted(begin(), end(), name);
171 return it.second ? it.first->getValue() : Attribute();
172}
173Attribute DictionaryAttr::get(StringAttr name) const {
174 auto it = impl::findAttrSorted(begin(), end(), name);
175 return it.second ? it.first->getValue() : Attribute();
176}
177
178/// Return the specified named attribute if present, std::nullopt otherwise.
179std::optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
180 auto it = impl::findAttrSorted(begin(), end(), name);
181 return it.second ? *it.first : std::optional<NamedAttribute>();
182}
183std::optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
184 auto it = impl::findAttrSorted(begin(), end(), name);
185 return it.second ? *it.first : std::optional<NamedAttribute>();
186}
187
188/// Return whether the specified attribute is present.
189bool DictionaryAttr::contains(StringRef name) const {
190 return impl::findAttrSorted(begin(), end(), name).second;
191}
192bool DictionaryAttr::contains(StringAttr name) const {
193 return impl::findAttrSorted(begin(), end(), name).second;
194}
195
196DictionaryAttr::iterator DictionaryAttr::begin() const {
197 return getValue().begin();
198}
199DictionaryAttr::iterator DictionaryAttr::end() const {
200 return getValue().end();
201}
202size_t DictionaryAttr::size() const { return getValue().size(); }
203
204DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
205 return Base::get(context, ArrayRef<NamedAttribute>());
206}
207
208//===----------------------------------------------------------------------===//
209// StridedLayoutAttr
210//===----------------------------------------------------------------------===//
211
212/// Prints a strided layout attribute.
213void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
214 auto printIntOrQuestion = [&](int64_t value) {
215 if (ShapedType::isDynamic(value))
216 os << "?";
217 else
218 os << value;
219 };
220
221 os << "strided<[";
222 llvm::interleaveComma(getStrides(), os, printIntOrQuestion);
223 os << "]";
224
225 if (getOffset() != 0) {
226 os << ", offset: ";
227 printIntOrQuestion(getOffset());
228 }
229 os << ">";
230}
231
232/// Returns true if this layout is static, i.e. the strides and offset all have
233/// a known value > 0.
234bool StridedLayoutAttr::hasStaticLayout() const {
235 return !ShapedType::isDynamic(getOffset()) &&
236 !ShapedType::isDynamicShape(getStrides());
237}
238
239/// Returns the strided layout as an affine map.
240AffineMap StridedLayoutAttr::getAffineMap() const {
241 return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
242}
243
244/// Checks that the type-agnostic strided layout invariants are satisfied.
245LogicalResult
246StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
247 int64_t offset, ArrayRef<int64_t> strides) {
248 return success();
249}
250
251/// Checks that the type-specific strided layout invariants are satisfied.
252LogicalResult StridedLayoutAttr::verifyLayout(
253 ArrayRef<int64_t> shape,
254 function_ref<InFlightDiagnostic()> emitError) const {
255 if (shape.size() != getStrides().size())
256 return emitError() << "expected the number of strides to match the rank";
257
258 return success();
259}
260
261LogicalResult
262StridedLayoutAttr::getStridesAndOffset(ArrayRef<int64_t>,
263 SmallVectorImpl<int64_t> &strides,
264 int64_t &offset) const {
265 llvm::append_range(strides, getStrides());
266 offset = getOffset();
267 return success();
268}
269
270//===----------------------------------------------------------------------===//
271// StringAttr
272//===----------------------------------------------------------------------===//
273
274StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
275 return Base::get(context, "", NoneType::get(context));
276}
277
278/// Twine support for StringAttr.
279StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
280 // Fast-path empty twine.
281 if (twine.isTriviallyEmpty())
282 return get(context);
283 SmallVector<char, 32> tempStr;
284 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
285}
286
287/// Twine support for StringAttr.
288StringAttr StringAttr::get(const Twine &twine, Type type) {
289 SmallVector<char, 32> tempStr;
290 return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
291}
292
293StringRef StringAttr::getValue() const { return getImpl()->value; }
294
295Type StringAttr::getType() const { return getImpl()->type; }
296
297Dialect *StringAttr::getReferencedDialect() const {
298 return getImpl()->referencedDialect;
299}
300
301//===----------------------------------------------------------------------===//
302// FloatAttr
303//===----------------------------------------------------------------------===//
304
305double FloatAttr::getValueAsDouble() const {
306 return getValueAsDouble(getValue());
307}
308double FloatAttr::getValueAsDouble(APFloat value) {
309 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
310 bool losesInfo = false;
311 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
312 &losesInfo);
313 }
314 return value.convertToDouble();
315}
316
317LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
318 Type type, APFloat value) {
319 // Verify that the type is correct.
320 if (!llvm::isa<FloatType>(type))
321 return emitError() << "expected floating point type";
322
323 // Verify that the type semantics match that of the value.
324 if (&llvm::cast<FloatType>(type).getFloatSemantics() !=
325 &value.getSemantics()) {
326 return emitError()
327 << "FloatAttr type doesn't match the type implied by its value";
328 }
329 return success();
330}
331
332//===----------------------------------------------------------------------===//
333// SymbolRefAttr
334//===----------------------------------------------------------------------===//
335
336SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
337 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
338 return get(StringAttr::get(ctx, value), nestedRefs);
339}
340
341FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
342 return llvm::cast<FlatSymbolRefAttr>(get(ctx, value, {}));
343}
344
345FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
346 return llvm::cast<FlatSymbolRefAttr>(get(value, {}));
347}
348
349FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
350 auto symName =
351 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
352 assert(symName && "value does not have a valid symbol name");
353 return SymbolRefAttr::get(symName);
354}
355
356StringAttr SymbolRefAttr::getLeafReference() const {
357 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
358 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
359}
360
361//===----------------------------------------------------------------------===//
362// IntegerAttr
363//===----------------------------------------------------------------------===//
364
365int64_t IntegerAttr::getInt() const {
366 assert((getType().isIndex() || getType().isSignlessInteger()) &&
367 "must be signless integer");
368 return getValue().getSExtValue();
369}
370
371int64_t IntegerAttr::getSInt() const {
372 assert(getType().isSignedInteger() && "must be signed integer");
373 return getValue().getSExtValue();
374}
375
376uint64_t IntegerAttr::getUInt() const {
377 assert(getType().isUnsignedInteger() && "must be unsigned integer");
378 return getValue().getZExtValue();
379}
380
381/// Return the value as an APSInt which carries the signed from the type of
382/// the attribute. This traps on signless integers types!
383APSInt IntegerAttr::getAPSInt() const {
384 assert(!getType().isSignlessInteger() &&
385 "Signless integers don't carry a sign for APSInt");
386 return APSInt(getValue(), getType().isUnsignedInteger());
387}
388
389LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
390 Type type, APInt value) {
391 if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
392 if (integerType.getWidth() != value.getBitWidth())
393 return emitError() << "integer type bit width (" << integerType.getWidth()
394 << ") doesn't match value bit width ("
395 << value.getBitWidth() << ")";
396 return success();
397 }
398 if (llvm::isa<IndexType>(type)) {
399 if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
400 return emitError()
401 << "value bit width (" << value.getBitWidth()
402 << ") doesn't match index type internal storage bit width ("
403 << IndexType::kInternalStorageBitWidth << ")";
404 return success();
405 }
406 return emitError() << "expected integer or index type";
407}
408
409BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
410 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
411 return llvm::cast<BoolAttr>(attr);
412}
413
414//===----------------------------------------------------------------------===//
415// BoolAttr
416//===----------------------------------------------------------------------===//
417
418bool BoolAttr::getValue() const {
419 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
420 return storage->value.getBoolValue();
421}
422
423bool BoolAttr::classof(Attribute attr) {
424 IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr);
425 return intAttr && intAttr.getType().isSignlessInteger(1);
426}
427
428//===----------------------------------------------------------------------===//
429// OpaqueAttr
430//===----------------------------------------------------------------------===//
431
432LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
433 StringAttr dialect, StringRef attrData,
434 Type type) {
435 if (!Dialect::isValidNamespace(dialect.strref()))
436 return emitError() << "invalid dialect namespace '" << dialect << "'";
437
438 // Check that the dialect is actually registered.
439 MLIRContext *context = dialect.getContext();
440 if (!context->allowsUnregisteredDialects() &&
441 !context->getLoadedDialect(dialect.strref())) {
442 return emitError()
443 << "#" << dialect << "<\"" << attrData << "\"> : " << type
444 << " attribute created with unregistered dialect. If this is "
445 "intended, please call allowUnregisteredDialects() on the "
446 "MLIRContext, or use -allow-unregistered-dialect with "
447 "the MLIR opt tool used";
448 }
449
450 return success();
451}
452
453//===----------------------------------------------------------------------===//
454// DenseElementsAttr Utilities
455//===----------------------------------------------------------------------===//
456
457const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0;
458const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
459
460/// Get the bitwidth of a dense element type within the buffer.
461/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
462static size_t getDenseElementStorageWidth(size_t origWidth) {
463 return origWidth == 1 ? origWidth : llvm::alignTo<8>(Value: origWidth);
464}
465static size_t getDenseElementStorageWidth(Type elementType) {
466 return getDenseElementStorageWidth(origWidth: getDenseElementBitWidth(eltType: elementType));
467}
468
469/// Set a bit to a specific value.
470static void setBit(char *rawData, size_t bitPos, bool value) {
471 if (value)
472 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
473 else
474 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
475}
476
477/// Return the value of the specified bit.
478static bool getBit(const char *rawData, size_t bitPos) {
479 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
480}
481
482/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
483/// BE format.
484static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
485 char *result) {
486 assert(llvm::endianness::native == llvm::endianness::big);
487 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
488
489 // Copy the words filled with data.
490 // For example, when `value` has 2 words, the first word is filled with data.
491 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
492 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
493 std::copy_n(first: reinterpret_cast<const char *>(value.getRawData()),
494 n: numFilledWords, result: result);
495 // Convert last word of APInt to LE format and store it in char
496 // array(`valueLE`).
497 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
498 size_t lastWordPos = numFilledWords;
499 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
500 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
501 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
502 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
503 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
504 // and store it in `result`.
505 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
506 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
507 valueLE.begin(), result + lastWordPos,
508 (numBytes - lastWordPos) * CHAR_BIT, 1);
509}
510
511/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
512/// format.
513static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
514 APInt &result) {
515 assert(llvm::endianness::native == llvm::endianness::big);
516 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
517
518 // Copy the data that fills the word of `result` from `inArray`.
519 // For example, when `result` has 2 words, the first word will be filled with
520 // data. So, the first 8 bytes are copied from `inArray` here.
521 // `inArray` (10 bytes, BE): |abcdefgh|ij|
522 // ==> `result` (2 words, BE): |abcdefgh|--------|
523 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
524 std::copy_n(
525 first: inArray, n: numFilledWords,
526 result: const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
527
528 // Convert array data which will be last word of `result` to LE format, and
529 // store it in char array(`inArrayLE`).
530 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
531 size_t lastWordPos = numFilledWords;
532 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
533 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
534 inArray + lastWordPos, inArrayLE.begin(),
535 (numBytes - lastWordPos) * CHAR_BIT, 1);
536
537 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
538 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
539 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
540 inArrayLE.begin(),
541 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
542 lastWordPos,
543 APInt::APINT_BITS_PER_WORD, 1);
544}
545
546/// Writes value to the bit position `bitPos` in array `rawData`.
547static void writeBits(char *rawData, size_t bitPos, APInt value) {
548 size_t bitWidth = value.getBitWidth();
549
550 // If the bitwidth is 1 we just toggle the specific bit.
551 if (bitWidth == 1)
552 return setBit(rawData, bitPos, value: value.isOne());
553
554 // Otherwise, the bit position is guaranteed to be byte aligned.
555 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
556 if (llvm::endianness::native == llvm::endianness::big) {
557 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
558 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
559 // work correctly in BE format.
560 // ex. `value` (2 words including 10 bytes)
561 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
562 copyAPIntToArrayForBEmachine(value, numBytes: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
563 result: rawData + (bitPos / CHAR_BIT));
564 } else {
565 std::copy_n(first: reinterpret_cast<const char *>(value.getRawData()),
566 n: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
567 result: rawData + (bitPos / CHAR_BIT));
568 }
569}
570
571/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
572/// `rawData`.
573static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
574 // Handle a boolean bit position.
575 if (bitWidth == 1)
576 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
577
578 // Otherwise, the bit position must be 8-bit aligned.
579 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
580 APInt result(bitWidth, 0);
581 if (llvm::endianness::native == llvm::endianness::big) {
582 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
583 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
584 // work correctly in BE format.
585 // ex. `result` (2 words including 10 bytes)
586 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
587 copyArrayToAPIntForBEmachine(inArray: rawData + (bitPos / CHAR_BIT),
588 numBytes: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT), result);
589 } else {
590 std::copy_n(first: rawData + (bitPos / CHAR_BIT),
591 n: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
592 result: const_cast<char *>(
593 reinterpret_cast<const char *>(result.getRawData())));
594 }
595 return result;
596}
597
598/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
599/// the same element count as 'type'.
600template <typename Values>
601static bool hasSameNumElementsOrSplat(ShapedType type, const Values &values) {
602 return (values.size() == 1) ||
603 (type.getNumElements() == static_cast<int64_t>(values.size()));
604}
605
606//===----------------------------------------------------------------------===//
607// DenseElementsAttr Iterators
608//===----------------------------------------------------------------------===//
609
610//===----------------------------------------------------------------------===//
611// AttributeElementIterator
612//===----------------------------------------------------------------------===//
613
614DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
615 DenseElementsAttr attr, size_t index)
616 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
617 Attribute, Attribute, Attribute>(
618 attr.getAsOpaquePointer(), index) {}
619
620Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
621 auto owner = llvm::cast<DenseElementsAttr>(Val: getFromOpaquePointer(ptr: base));
622 Type eltTy = owner.getElementType();
623 if (llvm::dyn_cast<IntegerType>(eltTy))
624 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
625 if (llvm::isa<IndexType>(eltTy))
626 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
627 if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) {
628 IntElementIterator intIt(owner, index);
629 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
630 return FloatAttr::get(eltTy, *floatIt);
631 }
632 if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) {
633 auto complexEltTy = complexTy.getElementType();
634 ComplexIntElementIterator complexIntIt(owner, index);
635 if (llvm::isa<IntegerType>(complexEltTy)) {
636 auto value = *complexIntIt;
637 auto real = IntegerAttr::get(complexEltTy, value.real());
638 auto imag = IntegerAttr::get(complexEltTy, value.imag());
639 return ArrayAttr::get(complexTy.getContext(),
640 ArrayRef<Attribute>{real, imag});
641 }
642
643 ComplexFloatElementIterator complexFloatIt(
644 llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt);
645 auto value = *complexFloatIt;
646 auto real = FloatAttr::get(complexEltTy, value.real());
647 auto imag = FloatAttr::get(complexEltTy, value.imag());
648 return ArrayAttr::get(complexTy.getContext(),
649 ArrayRef<Attribute>{real, imag});
650 }
651 if (llvm::isa<DenseStringElementsAttr>(owner)) {
652 ArrayRef<StringRef> vals = owner.getRawStringData();
653 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
654 }
655 llvm_unreachable("unexpected element type");
656}
657
658//===----------------------------------------------------------------------===//
659// BoolElementIterator
660//===----------------------------------------------------------------------===//
661
662DenseElementsAttr::BoolElementIterator::BoolElementIterator(
663 DenseElementsAttr attr, size_t dataIndex)
664 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
665 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
666
667bool DenseElementsAttr::BoolElementIterator::operator*() const {
668 return getBit(rawData: getData(), bitPos: getDataIndex());
669}
670
671//===----------------------------------------------------------------------===//
672// IntElementIterator
673//===----------------------------------------------------------------------===//
674
675DenseElementsAttr::IntElementIterator::IntElementIterator(
676 DenseElementsAttr attr, size_t dataIndex)
677 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
678 attr.getRawData().data(), attr.isSplat(), dataIndex),
679 bitWidth(getDenseElementBitWidth(eltType: attr.getElementType())) {}
680
681APInt DenseElementsAttr::IntElementIterator::operator*() const {
682 return readBits(rawData: getData(),
683 bitPos: getDataIndex() * getDenseElementStorageWidth(origWidth: bitWidth),
684 bitWidth);
685}
686
687//===----------------------------------------------------------------------===//
688// ComplexIntElementIterator
689//===----------------------------------------------------------------------===//
690
691DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
692 DenseElementsAttr attr, size_t dataIndex)
693 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
694 std::complex<APInt>, std::complex<APInt>,
695 std::complex<APInt>>(
696 attr.getRawData().data(), attr.isSplat(), dataIndex) {
697 auto complexType = llvm::cast<ComplexType>(attr.getElementType());
698 bitWidth = getDenseElementBitWidth(complexType.getElementType());
699}
700
701std::complex<APInt>
702DenseElementsAttr::ComplexIntElementIterator::operator*() const {
703 size_t storageWidth = getDenseElementStorageWidth(origWidth: bitWidth);
704 size_t offset = getDataIndex() * storageWidth * 2;
705 return {readBits(rawData: getData(), bitPos: offset, bitWidth),
706 readBits(rawData: getData(), bitPos: offset + storageWidth, bitWidth)};
707}
708
709//===----------------------------------------------------------------------===//
710// DenseArrayAttr
711//===----------------------------------------------------------------------===//
712
713LogicalResult
714DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
715 Type elementType, int64_t size, ArrayRef<char> rawData) {
716 if (!elementType.isIntOrIndexOrFloat())
717 return emitError() << "expected integer or floating point element type";
718 int64_t dataSize = rawData.size();
719 int64_t elementSize =
720 llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT);
721 if (size * elementSize != dataSize) {
722 return emitError() << "expected data size (" << size << " elements, "
723 << elementSize
724 << " bytes each) does not match: " << dataSize
725 << " bytes";
726 }
727 return success();
728}
729
730namespace {
731/// Instantiations of this class provide utilities for interacting with native
732/// data types in the context of DenseArrayAttr.
733template <size_t width,
734 IntegerType::SignednessSemantics signedness = IntegerType::Signless>
735struct DenseArrayAttrIntUtil {
736 static bool checkElementType(Type eltType) {
737 auto type = llvm::dyn_cast<IntegerType>(eltType);
738 if (!type || type.getWidth() != width)
739 return false;
740 return type.getSignedness() == signedness;
741 }
742
743 static Type getElementType(MLIRContext *ctx) {
744 return IntegerType::get(ctx, width, signedness);
745 }
746
747 template <typename T>
748 static void printElement(raw_ostream &os, T value) {
749 os << value;
750 }
751
752 template <typename T>
753 static ParseResult parseElement(AsmParser &parser, T &value) {
754 return parser.parseInteger(value);
755 }
756};
757template <typename T>
758struct DenseArrayAttrUtil;
759
760/// Specialization for boolean elements to print 'true' and 'false' literals for
761/// elements.
762template <>
763struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
764 static void printElement(raw_ostream &os, bool value) {
765 os << (value ? "true" : "false");
766 }
767};
768
769/// Specialization for 8-bit integers to ensure values are printed as integers
770/// and not characters.
771template <>
772struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
773 static void printElement(raw_ostream &os, int8_t value) {
774 os << static_cast<int>(value);
775 }
776};
777template <>
778struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
779template <>
780struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
781template <>
782struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
783
784/// Specialization for 32-bit floats.
785template <>
786struct DenseArrayAttrUtil<float> {
787 static bool checkElementType(Type eltType) { return eltType.isF32(); }
788 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
789 static void printElement(raw_ostream &os, float value) { os << value; }
790
791 /// Parse a double and cast it to a float.
792 static ParseResult parseElement(AsmParser &parser, float &value) {
793 double doubleVal;
794 if (parser.parseFloat(result&: doubleVal))
795 return failure();
796 value = doubleVal;
797 return success();
798 }
799};
800
801/// Specialization for 64-bit floats.
802template <>
803struct DenseArrayAttrUtil<double> {
804 static bool checkElementType(Type eltType) { return eltType.isF64(); }
805 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
806 static void printElement(raw_ostream &os, float value) { os << value; }
807 static ParseResult parseElement(AsmParser &parser, double &value) {
808 return parser.parseFloat(result&: value);
809 }
810};
811} // namespace
812
813template <typename T>
814void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
815 print(os&: printer.getStream());
816}
817
818template <typename T>
819void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
820 llvm::interleaveComma(asArrayRef(), os, [&](T value) {
821 DenseArrayAttrUtil<T>::printElement(os, value);
822 });
823}
824
825template <typename T>
826void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
827 os << "[";
828 printWithoutBraces(os);
829 os << "]";
830}
831
832/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
833template <typename T>
834Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
835 Type odsType) {
836 SmallVector<T> data;
837 if (failed(parser.parseCommaSeparatedList([&]() {
838 T value;
839 if (DenseArrayAttrUtil<T>::parseElement(parser, value))
840 return failure();
841 data.push_back(value);
842 return success();
843 })))
844 return {};
845 return get(context: parser.getContext(), content: data);
846}
847
848/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
849template <typename T>
850Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
851 if (parser.parseLSquare())
852 return {};
853 // Handle empty list case.
854 if (succeeded(Result: parser.parseOptionalRSquare()))
855 return get(context: parser.getContext(), content: {});
856 Attribute result = parseWithoutBraces(parser, odsType);
857 if (parser.parseRSquare())
858 return {};
859 return result;
860}
861
862/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
863template <typename T>
864DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
865 ArrayRef<char> raw = getRawData();
866 assert((raw.size() % sizeof(T)) == 0);
867 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
868 raw.size() / sizeof(T));
869}
870
871/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
872template <typename T>
873DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
874 ArrayRef<T> content) {
875 Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
876 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
877 content.size() * sizeof(T));
878 return llvm::cast<DenseArrayAttrImpl<T>>(
879 Base::get(context, elementType, content.size(), rawArray));
880}
881
882template <typename T>
883bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
884 if (auto denseArray = llvm::dyn_cast<DenseArrayAttr>(attr))
885 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
886 return false;
887}
888
889namespace mlir {
890namespace detail {
891// Explicit instantiation for all the supported DenseArrayAttr.
892template class DenseArrayAttrImpl<bool>;
893template class DenseArrayAttrImpl<int8_t>;
894template class DenseArrayAttrImpl<int16_t>;
895template class DenseArrayAttrImpl<int32_t>;
896template class DenseArrayAttrImpl<int64_t>;
897template class DenseArrayAttrImpl<float>;
898template class DenseArrayAttrImpl<double>;
899} // namespace detail
900} // namespace mlir
901
902//===----------------------------------------------------------------------===//
903// DenseElementsAttr
904//===----------------------------------------------------------------------===//
905
906/// Method for support type inquiry through isa, cast and dyn_cast.
907bool DenseElementsAttr::classof(Attribute attr) {
908 return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(attr);
909}
910
911DenseElementsAttr DenseElementsAttr::get(ShapedType type,
912 ArrayRef<Attribute> values) {
913 assert(hasSameNumElementsOrSplat(type, values));
914
915 Type eltType = type.getElementType();
916
917 // Take care complex type case first.
918 if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) {
919 if (complexType.getElementType().isIntOrIndex()) {
920 SmallVector<std::complex<APInt>> complexValues;
921 complexValues.reserve(N: values.size());
922 for (Attribute attr : values) {
923 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
924 auto arrayAttr = llvm::cast<ArrayAttr>(attr);
925 assert(arrayAttr.size() == 2 && "expected 2 element for complex");
926 auto attr0 = arrayAttr[0];
927 auto attr1 = arrayAttr[1];
928 complexValues.push_back(
929 std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(),
930 llvm::cast<IntegerAttr>(attr1).getValue()));
931 }
932 return DenseElementsAttr::get(type, complexValues);
933 }
934 // Must be float.
935 SmallVector<std::complex<APFloat>> complexValues;
936 complexValues.reserve(N: values.size());
937 for (Attribute attr : values) {
938 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
939 auto arrayAttr = llvm::cast<ArrayAttr>(attr);
940 assert(arrayAttr.size() == 2 && "expected 2 element for complex");
941 auto attr0 = arrayAttr[0];
942 auto attr1 = arrayAttr[1];
943 complexValues.push_back(
944 std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(),
945 llvm::cast<FloatAttr>(attr1).getValue()));
946 }
947 return DenseElementsAttr::get(type, complexValues);
948 }
949
950 // If the element type is not based on int/float/index, assume it is a string
951 // type.
952 if (!eltType.isIntOrIndexOrFloat()) {
953 SmallVector<StringRef, 8> stringValues;
954 stringValues.reserve(N: values.size());
955 for (Attribute attr : values) {
956 assert(llvm::isa<StringAttr>(attr) &&
957 "expected string value for non integer/index/float element");
958 stringValues.push_back(Elt: llvm::cast<StringAttr>(attr).getValue());
959 }
960 return get(type, stringValues);
961 }
962
963 // Otherwise, get the raw storage width to use for the allocation.
964 size_t bitWidth = getDenseElementBitWidth(eltType);
965 size_t storageBitWidth = getDenseElementStorageWidth(origWidth: bitWidth);
966
967 // Compress the attribute values into a character buffer.
968 SmallVector<char, 8> data(
969 llvm::divideCeil(Numerator: storageBitWidth * values.size(), CHAR_BIT));
970 APInt intVal;
971 for (unsigned i = 0, e = values.size(); i < e; ++i) {
972 if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) {
973 assert(floatAttr.getType() == eltType &&
974 "expected float attribute type to equal element type");
975 intVal = floatAttr.getValue().bitcastToAPInt();
976 } else {
977 auto intAttr = llvm::cast<IntegerAttr>(values[i]);
978 assert(intAttr.getType() == eltType &&
979 "expected integer attribute type to equal element type");
980 intVal = intAttr.getValue();
981 }
982
983 assert(intVal.getBitWidth() == bitWidth &&
984 "expected value to have same bitwidth as element type");
985 writeBits(rawData: data.data(), bitPos: i * storageBitWidth, value: intVal);
986 }
987
988 // Handle the special encoding of splat of bool.
989 if (values.size() == 1 && eltType.isInteger(width: 1))
990 data[0] = data[0] ? -1 : 0;
991
992 return DenseIntOrFPElementsAttr::getRaw(type, data);
993}
994
995DenseElementsAttr DenseElementsAttr::get(ShapedType type,
996 ArrayRef<bool> values) {
997 assert(hasSameNumElementsOrSplat(type, values));
998 assert(type.getElementType().isInteger(1));
999
1000 SmallVector<char> buff(llvm::divideCeil(Numerator: values.size(), CHAR_BIT));
1001
1002 if (!values.empty()) {
1003 bool isSplat = true;
1004 bool firstValue = values[0];
1005 for (int i = 0, e = values.size(); i != e; ++i) {
1006 isSplat &= values[i] == firstValue;
1007 setBit(rawData: buff.data(), bitPos: i, value: values[i]);
1008 }
1009
1010 // Splat of bool is encoded as a byte with all-ones in it.
1011 if (isSplat) {
1012 buff.resize(N: 1);
1013 buff[0] = values[0] ? -1 : 0;
1014 }
1015 }
1016
1017 return DenseIntOrFPElementsAttr::getRaw(type, buff);
1018}
1019
1020DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1021 ArrayRef<StringRef> values) {
1022 assert(!type.getElementType().isIntOrFloat());
1023 return DenseStringElementsAttr::get(type, values);
1024}
1025
1026/// Constructs a dense integer elements attribute from an array of APInt
1027/// values. Each APInt value is expected to have the same bitwidth as the
1028/// element type of 'type'.
1029DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1030 ArrayRef<APInt> values) {
1031 assert(type.getElementType().isIntOrIndex());
1032 assert(hasSameNumElementsOrSplat(type, values));
1033 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1034 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1035}
1036DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1037 ArrayRef<std::complex<APInt>> values) {
1038 ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
1039 assert(llvm::isa<IntegerType>(complex.getElementType()));
1040 assert(hasSameNumElementsOrSplat(type, values));
1041 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1042 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1043 values.size() * 2);
1044 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1045}
1046
1047// Constructs a dense float elements attribute from an array of APFloat
1048// values. Each APFloat value is expected to have the same bitwidth as the
1049// element type of 'type'.
1050DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1051 ArrayRef<APFloat> values) {
1052 assert(llvm::isa<FloatType>(type.getElementType()));
1053 assert(hasSameNumElementsOrSplat(type, values));
1054 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1055 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1056}
1057DenseElementsAttr
1058DenseElementsAttr::get(ShapedType type,
1059 ArrayRef<std::complex<APFloat>> values) {
1060 ComplexType complex = llvm::cast<ComplexType>(type.getElementType());
1061 assert(llvm::isa<FloatType>(complex.getElementType()));
1062 assert(hasSameNumElementsOrSplat(type, values));
1063 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1064 values.size() * 2);
1065 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1066 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1067}
1068
1069/// Construct a dense elements attribute from a raw buffer representing the
1070/// data for this attribute. Users should generally not use this methods as
1071/// the expected buffer format may not be a form the user expects.
1072DenseElementsAttr
1073DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1074 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1075}
1076
1077/// Returns true if the given buffer is a valid raw buffer for the given type.
1078bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1079 ArrayRef<char> rawBuffer,
1080 bool &detectedSplat) {
1081 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1082 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1083 int64_t numElements = type.getNumElements();
1084
1085 // The initializer is always a splat if the result type has a single element.
1086 detectedSplat = numElements == 1;
1087
1088 // Storage width of 1 is special as it is packed by the bit.
1089 if (storageWidth == 1) {
1090 // Check for a splat, or a buffer equal to the number of elements which
1091 // consists of either all 0's or all 1's.
1092 if (rawBuffer.size() == 1) {
1093 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1094 if (rawByte == 0 || rawByte == 0xff) {
1095 detectedSplat = true;
1096 return true;
1097 }
1098 }
1099
1100 // This is a valid non-splat buffer if it has the right size.
1101 return rawBufferWidth == llvm::alignTo<8>(Value: numElements);
1102 }
1103
1104 // All other types are 8-bit aligned, so we can just check the buffer width
1105 // to know if only a single initializer element was passed in.
1106 if (rawBufferWidth == storageWidth) {
1107 detectedSplat = true;
1108 return true;
1109 }
1110
1111 // The raw buffer is valid if it has the right size.
1112 return rawBufferWidth == storageWidth * numElements;
1113}
1114
1115/// Check the information for a C++ data type, check if this type is valid for
1116/// the current attribute. This method is used to verify specific type
1117/// invariants that the templatized 'getValues' method cannot.
1118static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1119 bool isSigned) {
1120 // Make sure that the data element size is the same as the type element width.
1121 auto denseEltBitWidth = getDenseElementBitWidth(eltType: type);
1122 auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT);
1123 if (denseEltBitWidth != dataSize) {
1124 LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
1125 << denseEltBitWidth << " to match data size "
1126 << dataSize << " for type " << type << "\n");
1127 return false;
1128 }
1129
1130 // Check that the element type is either float or integer or index.
1131 if (!isInt) {
1132 bool valid = llvm::isa<FloatType>(Val: type);
1133 if (!valid)
1134 LLVM_DEBUG(llvm::dbgs()
1135 << "expected float type when isInt is false, but found "
1136 << type << "\n");
1137 return valid;
1138 }
1139 if (type.isIndex())
1140 return true;
1141
1142 auto intType = llvm::dyn_cast<IntegerType>(type);
1143 if (!intType) {
1144 LLVM_DEBUG(llvm::dbgs()
1145 << "expected integer type when isInt is true, but found " << type
1146 << "\n");
1147 return false;
1148 }
1149
1150 // Make sure signedness semantics is consistent.
1151 if (intType.isSignless())
1152 return true;
1153
1154 bool valid = intType.isSigned() == isSigned;
1155 if (!valid)
1156 LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
1157 << " to match type " << type << "\n");
1158 return valid;
1159}
1160
1161/// Defaults down the subclass implementation.
1162DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1163 ArrayRef<char> data,
1164 int64_t dataEltSize,
1165 bool isInt, bool isSigned) {
1166 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1167 isSigned);
1168}
1169DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1170 ArrayRef<char> data,
1171 int64_t dataEltSize,
1172 bool isInt,
1173 bool isSigned) {
1174 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1175 isInt, isSigned);
1176}
1177
1178bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1179 bool isSigned) const {
1180 return ::isValidIntOrFloat(type: getElementType(), dataEltSize, isInt, isSigned);
1181}
1182bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1183 bool isSigned) const {
1184 return ::isValidIntOrFloat(
1185 llvm::cast<ComplexType>(getElementType()).getElementType(),
1186 dataEltSize / 2, isInt, isSigned);
1187}
1188
1189/// Returns true if this attribute corresponds to a splat, i.e. if all element
1190/// values are the same.
1191bool DenseElementsAttr::isSplat() const {
1192 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1193}
1194
1195/// Return if the given complex type has an integer element type.
1196static bool isComplexOfIntType(Type type) {
1197 return llvm::isa<IntegerType>(llvm::cast<ComplexType>(type).getElementType());
1198}
1199
1200auto DenseElementsAttr::tryGetComplexIntValues() const
1201 -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
1202 if (!isComplexOfIntType(type: getElementType()))
1203 return failure();
1204 return iterator_range_impl<ComplexIntElementIterator>(
1205 getType(), ComplexIntElementIterator(*this, 0),
1206 ComplexIntElementIterator(*this, getNumElements()));
1207}
1208
1209auto DenseElementsAttr::tryGetFloatValues() const
1210 -> FailureOr<iterator_range_impl<FloatElementIterator>> {
1211 auto eltTy = llvm::dyn_cast<FloatType>(getElementType());
1212 if (!eltTy)
1213 return failure();
1214 const auto &elementSemantics = eltTy.getFloatSemantics();
1215 return iterator_range_impl<FloatElementIterator>(
1216 getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1217 FloatElementIterator(elementSemantics, raw_int_end()));
1218}
1219
1220auto DenseElementsAttr::tryGetComplexFloatValues() const
1221 -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
1222 auto complexTy = llvm::dyn_cast<ComplexType>(getElementType());
1223 if (!complexTy)
1224 return failure();
1225 auto eltTy = llvm::dyn_cast<FloatType>(complexTy.getElementType());
1226 if (!eltTy)
1227 return failure();
1228 const auto &semantics = eltTy.getFloatSemantics();
1229 return iterator_range_impl<ComplexFloatElementIterator>(
1230 getType(), {semantics, {*this, 0}},
1231 {semantics, {*this, static_cast<size_t>(getNumElements())}});
1232}
1233
1234/// Return the raw storage data held by this attribute.
1235ArrayRef<char> DenseElementsAttr::getRawData() const {
1236 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1237}
1238
1239ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1240 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1241}
1242
1243/// Return a new DenseElementsAttr that has the same data as the current
1244/// attribute, but has been reshaped to 'newType'. The new type must have the
1245/// same total number of elements as well as element type.
1246DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1247 ShapedType curType = getType();
1248 if (curType == newType)
1249 return *this;
1250
1251 assert(newType.getElementType() == curType.getElementType() &&
1252 "expected the same element type");
1253 assert(newType.getNumElements() == curType.getNumElements() &&
1254 "expected the same number of elements");
1255 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1256}
1257
1258DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1259 assert(isSplat() && "expected a splat type");
1260
1261 ShapedType curType = getType();
1262 if (curType == newType)
1263 return *this;
1264
1265 assert(newType.getElementType() == curType.getElementType() &&
1266 "expected the same element type");
1267 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1268}
1269
1270/// Return a new DenseElementsAttr that has the same data as the current
1271/// attribute, but has bitcast elements such that it is now 'newType'. The new
1272/// type must have the same shape and element types of the same bitwidth as the
1273/// current type.
1274DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1275 ShapedType curType = getType();
1276 Type curElType = curType.getElementType();
1277 if (curElType == newElType)
1278 return *this;
1279
1280 assert(getDenseElementBitWidth(newElType) ==
1281 getDenseElementBitWidth(curElType) &&
1282 "expected element types with the same bitwidth");
1283 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1284 getRawData());
1285}
1286
1287DenseElementsAttr
1288DenseElementsAttr::mapValues(Type newElementType,
1289 function_ref<APInt(const APInt &)> mapping) const {
1290 return llvm::cast<DenseIntElementsAttr>(Val: *this).mapValues(newElementType,
1291 mapping);
1292}
1293
1294DenseElementsAttr DenseElementsAttr::mapValues(
1295 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1296 return llvm::cast<DenseFPElementsAttr>(Val: *this).mapValues(newElementType,
1297 mapping);
1298}
1299
1300ShapedType DenseElementsAttr::getType() const {
1301 return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1302}
1303
1304Type DenseElementsAttr::getElementType() const {
1305 return getType().getElementType();
1306}
1307
1308int64_t DenseElementsAttr::getNumElements() const {
1309 return getType().getNumElements();
1310}
1311
1312//===----------------------------------------------------------------------===//
1313// DenseIntOrFPElementsAttr
1314//===----------------------------------------------------------------------===//
1315
1316/// Utility method to write a range of APInt values to a buffer.
1317template <typename APRangeT>
1318static void writeAPIntsToBuffer(size_t storageWidth,
1319 SmallVectorImpl<char> &data,
1320 APRangeT &&values) {
1321 size_t numValues = llvm::size(values);
1322 data.resize(N: llvm::divideCeil(Numerator: storageWidth * numValues, CHAR_BIT));
1323 size_t offset = 0;
1324 for (auto it = values.begin(), e = values.end(); it != e;
1325 ++it, offset += storageWidth) {
1326 assert((*it).getBitWidth() <= storageWidth);
1327 writeBits(data.data(), offset, *it);
1328 }
1329
1330 // Handle the special encoding of splat of a boolean.
1331 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1332 data[0] = data[0] ? -1 : 0;
1333}
1334
1335/// Constructs a dense elements attribute from an array of raw APFloat values.
1336/// Each APFloat value is expected to have the same bitwidth as the element
1337/// type of 'type'. 'type' must be a vector or tensor with static shape.
1338DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1339 size_t storageWidth,
1340 ArrayRef<APFloat> values) {
1341 SmallVector<char> data;
1342 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1343 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1344 return DenseIntOrFPElementsAttr::getRaw(type, data);
1345}
1346
1347/// Constructs a dense elements attribute from an array of raw APInt values.
1348/// Each APInt value is expected to have the same bitwidth as the element type
1349/// of 'type'.
1350DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1351 size_t storageWidth,
1352 ArrayRef<APInt> values) {
1353 SmallVector<char> data;
1354 writeAPIntsToBuffer(storageWidth, data, values);
1355 return DenseIntOrFPElementsAttr::getRaw(type, data);
1356}
1357
1358DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1359 ArrayRef<char> data) {
1360 assert(type.hasStaticShape() && "type must have static shape");
1361 bool isSplat = false;
1362 bool isValid = isValidRawBuffer(type, data, isSplat);
1363 assert(isValid);
1364 (void)isValid;
1365 return Base::get(type.getContext(), type, data, isSplat);
1366}
1367
1368/// Overload of the raw 'get' method that asserts that the given type is of
1369/// complex type. This method is used to verify type invariants that the
1370/// templatized 'get' method cannot.
1371DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1372 ArrayRef<char> data,
1373 int64_t dataEltSize,
1374 bool isInt,
1375 bool isSigned) {
1376 assert(::isValidIntOrFloat(
1377 llvm::cast<ComplexType>(type.getElementType()).getElementType(),
1378 dataEltSize / 2, isInt, isSigned) &&
1379 "Try re-running with -debug-only=builtinattributes");
1380
1381 int64_t numElements = data.size() / dataEltSize;
1382 (void)numElements;
1383 assert(numElements == 1 || numElements == type.getNumElements());
1384 return getRaw(type, data);
1385}
1386
1387/// Overload of the 'getRaw' method that asserts that the given type is of
1388/// integer type. This method is used to verify type invariants that the
1389/// templatized 'get' method cannot.
1390DenseElementsAttr
1391DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1392 int64_t dataEltSize, bool isInt,
1393 bool isSigned) {
1394 assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt,
1395 isSigned) &&
1396 "Try re-running with -debug-only=builtinattributes");
1397
1398 int64_t numElements = data.size() / dataEltSize;
1399 assert(numElements == 1 || numElements == type.getNumElements());
1400 (void)numElements;
1401 return getRaw(type, data);
1402}
1403
1404void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1405 const char *inRawData, char *outRawData, size_t elementBitWidth,
1406 size_t numElements) {
1407 using llvm::support::ulittle16_t;
1408 using llvm::support::ulittle32_t;
1409 using llvm::support::ulittle64_t;
1410
1411 assert(llvm::endianness::native == llvm::endianness::big);
1412 // NOLINT to avoid warning message about replacing by static_assert()
1413
1414 // Following std::copy_n always converts endianness on BE machine.
1415 switch (elementBitWidth) {
1416 case 16: {
1417 const ulittle16_t *inRawDataPos =
1418 reinterpret_cast<const ulittle16_t *>(inRawData);
1419 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1420 std::copy_n(inRawDataPos, numElements, outDataPos);
1421 break;
1422 }
1423 case 32: {
1424 const ulittle32_t *inRawDataPos =
1425 reinterpret_cast<const ulittle32_t *>(inRawData);
1426 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1427 std::copy_n(inRawDataPos, numElements, outDataPos);
1428 break;
1429 }
1430 case 64: {
1431 const ulittle64_t *inRawDataPos =
1432 reinterpret_cast<const ulittle64_t *>(inRawData);
1433 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1434 std::copy_n(inRawDataPos, numElements, outDataPos);
1435 break;
1436 }
1437 default: {
1438 size_t nBytes = elementBitWidth / CHAR_BIT;
1439 for (size_t i = 0; i < nBytes; i++)
1440 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1441 break;
1442 }
1443 }
1444}
1445
1446void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1447 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1448 ShapedType type) {
1449 size_t numElements = type.getNumElements();
1450 Type elementType = type.getElementType();
1451 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
1452 elementType = complexTy.getElementType();
1453 numElements = numElements * 2;
1454 }
1455 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1456 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1457 inRawData.size() <= outRawData.size());
1458 if (elementBitWidth <= CHAR_BIT)
1459 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1460 else
1461 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1462 elementBitWidth, numElements);
1463}
1464
1465//===----------------------------------------------------------------------===//
1466// DenseFPElementsAttr
1467//===----------------------------------------------------------------------===//
1468
1469template <typename Fn, typename Attr>
1470static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1471 Type newElementType,
1472 llvm::SmallVectorImpl<char> &data) {
1473 size_t bitWidth = getDenseElementBitWidth(eltType: newElementType);
1474 size_t storageBitWidth = getDenseElementStorageWidth(origWidth: bitWidth);
1475
1476 ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType);
1477
1478 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1479 data.resize(N: llvm::divideCeil(Numerator: storageBitWidth * numRawElements, CHAR_BIT));
1480
1481 // Functor used to process a single element value of the attribute.
1482 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1483 auto newInt = mapping(value);
1484 assert(newInt.getBitWidth() == bitWidth);
1485 writeBits(data.data(), index * storageBitWidth, newInt);
1486 };
1487
1488 // Check for the splat case.
1489 if (attr.isSplat()) {
1490 if (bitWidth == 1) {
1491 // Handle the special encoding of splat of bool.
1492 data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
1493 } else {
1494 processElt(*attr.begin(), /*index=*/0);
1495 }
1496 return newArrayType;
1497 }
1498
1499 // Otherwise, process all of the element values.
1500 uint64_t elementIdx = 0;
1501 for (auto value : attr)
1502 processElt(value, elementIdx++);
1503 return newArrayType;
1504}
1505
1506DenseElementsAttr DenseFPElementsAttr::mapValues(
1507 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1508 llvm::SmallVector<char, 8> elementData;
1509 auto newArrayType =
1510 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1511
1512 return getRaw(newArrayType, elementData);
1513}
1514
1515/// Method for supporting type inquiry through isa, cast and dyn_cast.
1516bool DenseFPElementsAttr::classof(Attribute attr) {
1517 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val&: attr))
1518 return llvm::isa<FloatType>(denseAttr.getType().getElementType());
1519 return false;
1520}
1521
1522//===----------------------------------------------------------------------===//
1523// DenseIntElementsAttr
1524//===----------------------------------------------------------------------===//
1525
1526DenseElementsAttr DenseIntElementsAttr::mapValues(
1527 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1528 llvm::SmallVector<char, 8> elementData;
1529 auto newArrayType =
1530 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1531 return getRaw(newArrayType, elementData);
1532}
1533
1534/// Method for supporting type inquiry through isa, cast and dyn_cast.
1535bool DenseIntElementsAttr::classof(Attribute attr) {
1536 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val&: attr))
1537 return denseAttr.getType().getElementType().isIntOrIndex();
1538 return false;
1539}
1540
1541//===----------------------------------------------------------------------===//
1542// DenseResourceElementsAttr
1543//===----------------------------------------------------------------------===//
1544
1545DenseResourceElementsAttr
1546DenseResourceElementsAttr::get(ShapedType type,
1547 DenseResourceElementsHandle handle) {
1548 return Base::get(type.getContext(), type, handle);
1549}
1550
1551DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1552 StringRef blobName,
1553 AsmResourceBlob blob) {
1554 // Extract the builtin dialect resource manager from context and construct a
1555 // handle by inserting a new resource using the provided blob.
1556 auto &manager =
1557 DenseResourceElementsHandle::getManagerInterface(type.getContext());
1558 return get(type, manager.insert(blobName, std::move(blob)));
1559}
1560
1561ArrayRef<char> DenseResourceElementsAttr::getData() {
1562 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1563 return blob->getDataAs<char>();
1564 return {};
1565}
1566
1567//===----------------------------------------------------------------------===//
1568// DenseResourceElementsAttrBase
1569//===----------------------------------------------------------------------===//
1570
1571namespace {
1572/// Instantiations of this class provide utilities for interacting with native
1573/// data types in the context of DenseResourceElementsAttr.
1574template <typename T>
1575struct DenseResourceAttrUtil;
1576template <size_t width, bool isSigned>
1577struct DenseResourceElementsAttrIntUtil {
1578 static bool checkElementType(Type eltType) {
1579 IntegerType type = llvm::dyn_cast<IntegerType>(eltType);
1580 if (!type || type.getWidth() != width)
1581 return false;
1582 return isSigned ? !type.isUnsigned() : !type.isSigned();
1583 }
1584};
1585template <>
1586struct DenseResourceAttrUtil<bool> {
1587 static bool checkElementType(Type eltType) {
1588 return eltType.isSignlessInteger(width: 1);
1589 }
1590};
1591template <>
1592struct DenseResourceAttrUtil<int8_t>
1593 : public DenseResourceElementsAttrIntUtil<8, true> {};
1594template <>
1595struct DenseResourceAttrUtil<uint8_t>
1596 : public DenseResourceElementsAttrIntUtil<8, false> {};
1597template <>
1598struct DenseResourceAttrUtil<int16_t>
1599 : public DenseResourceElementsAttrIntUtil<16, true> {};
1600template <>
1601struct DenseResourceAttrUtil<uint16_t>
1602 : public DenseResourceElementsAttrIntUtil<16, false> {};
1603template <>
1604struct DenseResourceAttrUtil<int32_t>
1605 : public DenseResourceElementsAttrIntUtil<32, true> {};
1606template <>
1607struct DenseResourceAttrUtil<uint32_t>
1608 : public DenseResourceElementsAttrIntUtil<32, false> {};
1609template <>
1610struct DenseResourceAttrUtil<int64_t>
1611 : public DenseResourceElementsAttrIntUtil<64, true> {};
1612template <>
1613struct DenseResourceAttrUtil<uint64_t>
1614 : public DenseResourceElementsAttrIntUtil<64, false> {};
1615template <>
1616struct DenseResourceAttrUtil<float> {
1617 static bool checkElementType(Type eltType) { return eltType.isF32(); }
1618};
1619template <>
1620struct DenseResourceAttrUtil<double> {
1621 static bool checkElementType(Type eltType) { return eltType.isF64(); }
1622};
1623} // namespace
1624
1625template <typename T>
1626DenseResourceElementsAttrBase<T>
1627DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1628 AsmResourceBlob blob) {
1629 // Check that the blob is in the form we were expecting.
1630 assert(blob.getDataAlignment() == alignof(T) &&
1631 "alignment mismatch between expected alignment and blob alignment");
1632 assert(((blob.getData().size() % sizeof(T)) == 0) &&
1633 "size mismatch between expected element width and blob size");
1634 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
1635 "invalid shape element type for provided type `T`");
1636 return llvm::cast<DenseResourceElementsAttrBase<T>>(
1637 DenseResourceElementsAttr::get(type, blobName, std::move(blob)));
1638}
1639
1640template <typename T>
1641std::optional<ArrayRef<T>>
1642DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1643 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1644 return blob->template getDataAs<T>();
1645 return std::nullopt;
1646}
1647
1648template <typename T>
1649bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1650 auto resourceAttr = llvm::dyn_cast<DenseResourceElementsAttr>(attr);
1651 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1652 resourceAttr.getElementType());
1653}
1654
1655namespace mlir {
1656namespace detail {
1657// Explicit instantiation for all the supported DenseResourceElementsAttr.
1658template class DenseResourceElementsAttrBase<bool>;
1659template class DenseResourceElementsAttrBase<int8_t>;
1660template class DenseResourceElementsAttrBase<int16_t>;
1661template class DenseResourceElementsAttrBase<int32_t>;
1662template class DenseResourceElementsAttrBase<int64_t>;
1663template class DenseResourceElementsAttrBase<uint8_t>;
1664template class DenseResourceElementsAttrBase<uint16_t>;
1665template class DenseResourceElementsAttrBase<uint32_t>;
1666template class DenseResourceElementsAttrBase<uint64_t>;
1667template class DenseResourceElementsAttrBase<float>;
1668template class DenseResourceElementsAttrBase<double>;
1669} // namespace detail
1670} // namespace mlir
1671
1672//===----------------------------------------------------------------------===//
1673// SparseElementsAttr
1674//===----------------------------------------------------------------------===//
1675
1676/// Get a zero APFloat for the given sparse attribute.
1677APFloat SparseElementsAttr::getZeroAPFloat() const {
1678 auto eltType = llvm::cast<FloatType>(getElementType());
1679 return APFloat(eltType.getFloatSemantics());
1680}
1681
1682/// Get a zero APInt for the given sparse attribute.
1683APInt SparseElementsAttr::getZeroAPInt() const {
1684 auto eltType = llvm::cast<IntegerType>(getElementType());
1685 return APInt::getZero(eltType.getWidth());
1686}
1687
1688/// Get a zero attribute for the given attribute type.
1689Attribute SparseElementsAttr::getZeroAttr() const {
1690 auto eltType = getElementType();
1691
1692 // Handle floating point elements.
1693 if (llvm::isa<FloatType>(eltType))
1694 return FloatAttr::get(eltType, 0);
1695
1696 // Handle complex elements.
1697 if (auto complexTy = llvm::dyn_cast<ComplexType>(eltType)) {
1698 auto eltType = complexTy.getElementType();
1699 Attribute zero;
1700 if (llvm::isa<FloatType>(eltType))
1701 zero = FloatAttr::get(eltType, 0);
1702 else // must be integer
1703 zero = IntegerAttr::get(eltType, 0);
1704 return ArrayAttr::get(complexTy.getContext(),
1705 ArrayRef<Attribute>{zero, zero});
1706 }
1707
1708 // Handle string type.
1709 if (llvm::isa<DenseStringElementsAttr>(getValues()))
1710 return StringAttr::get("", eltType);
1711
1712 // Otherwise, this is an integer.
1713 return IntegerAttr::get(eltType, 0);
1714}
1715
1716/// Flatten, and return, all of the sparse indices in this attribute in
1717/// row-major order.
1718SmallVector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1719 SmallVector<ptrdiff_t> flatSparseIndices;
1720
1721 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1722 // as a 1-D index array.
1723 auto sparseIndices = getIndices();
1724 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1725 if (sparseIndices.isSplat()) {
1726 SmallVector<uint64_t, 8> indices(getType().getRank(),
1727 *sparseIndexValues.begin());
1728 flatSparseIndices.push_back(getFlattenedIndex(indices));
1729 return flatSparseIndices;
1730 }
1731
1732 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1733 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1734 size_t rank = getType().getRank();
1735 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1736 flatSparseIndices.push_back(getFlattenedIndex(
1737 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1738 return flatSparseIndices;
1739}
1740
1741LogicalResult
1742SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1743 ShapedType type, DenseIntElementsAttr sparseIndices,
1744 DenseElementsAttr values) {
1745 ShapedType valuesType = values.getType();
1746 if (valuesType.getRank() != 1)
1747 return emitError() << "expected 1-d tensor for sparse element values";
1748
1749 // Verify the indices and values shape.
1750 ShapedType indicesType = sparseIndices.getType();
1751 auto emitShapeError = [&]() {
1752 return emitError() << "expected shape ([" << type.getShape()
1753 << "]); inferred shape of indices literal (["
1754 << indicesType.getShape()
1755 << "]); inferred shape of values literal (["
1756 << valuesType.getShape() << "])";
1757 };
1758 // Verify indices shape.
1759 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1760 if (indicesRank == 2) {
1761 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1762 return emitShapeError();
1763 } else if (indicesRank != 1 || rank != 1) {
1764 return emitShapeError();
1765 }
1766 // Verify the values shape.
1767 int64_t numSparseIndices = indicesType.getDimSize(0);
1768 if (numSparseIndices != valuesType.getDimSize(0))
1769 return emitShapeError();
1770
1771 // Verify that the sparse indices are within the value shape.
1772 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1773 return emitError()
1774 << "sparse index #" << indexNum
1775 << " is not contained within the value shape, with index=[" << index
1776 << "], and type=" << type;
1777 };
1778
1779 // Handle the case where the index values are a splat.
1780 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1781 if (sparseIndices.isSplat()) {
1782 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1783 if (!ElementsAttr::isValidIndex(type, indices))
1784 return emitIndexError(0, indices);
1785 return success();
1786 }
1787
1788 // Otherwise, reinterpret each index as an ArrayRef.
1789 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1790 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1791 rank);
1792 if (!ElementsAttr::isValidIndex(type, index))
1793 return emitIndexError(i, index);
1794 }
1795
1796 return success();
1797}
1798
1799//===----------------------------------------------------------------------===//
1800// DistinctAttr
1801//===----------------------------------------------------------------------===//
1802
1803DistinctAttr DistinctAttr::create(Attribute referencedAttr) {
1804 return Base::get(referencedAttr.getContext(), referencedAttr);
1805}
1806
1807Attribute DistinctAttr::getReferencedAttr() const {
1808 return getImpl()->referencedAttr;
1809}
1810
1811//===----------------------------------------------------------------------===//
1812// Attribute Utilities
1813//===----------------------------------------------------------------------===//
1814
1815AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1816 int64_t offset,
1817 MLIRContext *context) {
1818 AffineExpr expr;
1819 unsigned nSymbols = 0;
1820
1821 // AffineExpr for offset.
1822 // Static case.
1823 if (!ShapedType::isDynamic(offset)) {
1824 auto cst = getAffineConstantExpr(constant: offset, context);
1825 expr = cst;
1826 } else {
1827 // Dynamic case, new symbol for the offset.
1828 auto sym = getAffineSymbolExpr(position: nSymbols++, context);
1829 expr = sym;
1830 }
1831
1832 // AffineExpr for strides.
1833 for (const auto &en : llvm::enumerate(First&: strides)) {
1834 auto dim = en.index();
1835 auto stride = en.value();
1836 auto d = getAffineDimExpr(position: dim, context);
1837 AffineExpr mult;
1838 // Static case.
1839 if (!ShapedType::isDynamic(stride))
1840 mult = getAffineConstantExpr(constant: stride, context);
1841 else
1842 // Dynamic case, new symbol for each new stride.
1843 mult = getAffineSymbolExpr(position: nSymbols++, context);
1844 expr = expr + d * mult;
1845 }
1846
1847 return AffineMap::get(dimCount: strides.size(), symbolCount: nSymbols, result: expr);
1848}
1849

source code of mlir/lib/IR/BuiltinAttributes.cpp