1//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributes.h"
10#include "AttributeDetail.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinDialect.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/DialectResourceBlobManager.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Types.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/Support/Debug.h"
22#include "llvm/Support/Endian.h"
23#include <optional>
24
25#define DEBUG_TYPE "builtinattributes"
26
27using namespace mlir;
28using namespace mlir::detail;
29
30//===----------------------------------------------------------------------===//
31/// Tablegen Attribute Definitions
32//===----------------------------------------------------------------------===//
33
34#define GET_ATTRDEF_CLASSES
35#include "mlir/IR/BuiltinAttributes.cpp.inc"
36
37//===----------------------------------------------------------------------===//
38// BuiltinDialect
39//===----------------------------------------------------------------------===//
40
41void BuiltinDialect::registerAttributes() {
42 addAttributes<
43#define GET_ATTRDEF_LIST
44#include "mlir/IR/BuiltinAttributes.cpp.inc"
45 >();
46 addAttributes<DistinctAttr>();
47}
48
49//===----------------------------------------------------------------------===//
50// DictionaryAttr
51//===----------------------------------------------------------------------===//
52
53/// Helper function that does either an in place sort or sorts from source array
54/// into destination. If inPlace then storage is both the source and the
55/// destination, else value is the source and storage destination. Returns
56/// whether source was sorted.
57template <bool inPlace>
58static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
59 SmallVectorImpl<NamedAttribute> &storage) {
60 // Specialize for the common case.
61 switch (value.size()) {
62 case 0:
63 // Zero already sorted.
64 if (!inPlace)
65 storage.clear();
66 break;
67 case 1:
68 // One already sorted but may need to be copied.
69 if (!inPlace)
70 storage.assign(IL: {value[0]});
71 break;
72 case 2: {
73 bool isSorted = value[0] < value[1];
74 if (inPlace) {
75 if (!isSorted)
76 std::swap(a&: storage[0], b&: storage[1]);
77 } else if (isSorted) {
78 storage.assign(IL: {value[0], value[1]});
79 } else {
80 storage.assign(IL: {value[1], value[0]});
81 }
82 return !isSorted;
83 }
84 default:
85 if (!inPlace)
86 storage.assign(in_start: value.begin(), in_end: value.end());
87 // Check to see they are sorted already.
88 bool isSorted = llvm::is_sorted(Range&: value);
89 // If not, do a general sort.
90 if (!isSorted)
91 llvm::array_pod_sort(Start: storage.begin(), End: storage.end());
92 return !isSorted;
93 }
94 return false;
95}
96
97/// Returns an entry with a duplicate name from the given sorted array of named
98/// attributes. Returns std::nullopt if all elements have unique names.
99static std::optional<NamedAttribute>
100findDuplicateElement(ArrayRef<NamedAttribute> value) {
101 const std::optional<NamedAttribute> none{std::nullopt};
102 if (value.size() < 2)
103 return none;
104
105 if (value.size() == 2)
106 return value[0].getName() == value[1].getName() ? value[0] : none;
107
108 const auto *it = std::adjacent_find(first: value.begin(), last: value.end(),
109 binary_pred: [](NamedAttribute l, NamedAttribute r) {
110 return l.getName() == r.getName();
111 });
112 return it != value.end() ? *it : none;
113}
114
115bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
116 SmallVectorImpl<NamedAttribute> &storage) {
117 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
118 assert(!findDuplicateElement(storage) &&
119 "DictionaryAttr element names must be unique");
120 return isSorted;
121}
122
123bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
124 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(value: array, storage&: array);
125 assert(!findDuplicateElement(array) &&
126 "DictionaryAttr element names must be unique");
127 return isSorted;
128}
129
130std::optional<NamedAttribute>
131DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
132 bool isSorted) {
133 if (!isSorted)
134 dictionaryAttrSort</*inPlace=*/true>(value: array, storage&: array);
135 return findDuplicateElement(value: array);
136}
137
138DictionaryAttr DictionaryAttr::get(MLIRContext *context,
139 ArrayRef<NamedAttribute> value) {
140 if (value.empty())
141 return DictionaryAttr::getEmpty(context);
142
143 // We need to sort the element list to canonicalize it.
144 SmallVector<NamedAttribute, 8> storage;
145 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
146 value = storage;
147 assert(!findDuplicateElement(value) &&
148 "DictionaryAttr element names must be unique");
149 return Base::get(ctx: context, args&: value);
150}
151/// Construct a dictionary with an array of values that is known to already be
152/// sorted by name and uniqued.
153DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
154 ArrayRef<NamedAttribute> value) {
155 if (value.empty())
156 return DictionaryAttr::getEmpty(context);
157 // Ensure that the attribute elements are unique and sorted.
158 assert(llvm::is_sorted(
159 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&
160 "expected attribute values to be sorted");
161 assert(!findDuplicateElement(value) &&
162 "DictionaryAttr element names must be unique");
163 return Base::get(ctx: context, args&: value);
164}
165
166/// Return the specified attribute if present, null otherwise.
167Attribute DictionaryAttr::get(StringRef name) const {
168 auto it = impl::findAttrSorted(first: begin(), last: end(), name);
169 return it.second ? it.first->getValue() : Attribute();
170}
171Attribute DictionaryAttr::get(StringAttr name) const {
172 auto it = impl::findAttrSorted(first: begin(), last: end(), name);
173 return it.second ? it.first->getValue() : Attribute();
174}
175
176/// Return the specified named attribute if present, std::nullopt otherwise.
177std::optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
178 auto it = impl::findAttrSorted(first: begin(), last: end(), name);
179 return it.second ? *it.first : std::optional<NamedAttribute>();
180}
181std::optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
182 auto it = impl::findAttrSorted(first: begin(), last: end(), name);
183 return it.second ? *it.first : std::optional<NamedAttribute>();
184}
185
186/// Return whether the specified attribute is present.
187bool DictionaryAttr::contains(StringRef name) const {
188 return impl::findAttrSorted(first: begin(), last: end(), name).second;
189}
190bool DictionaryAttr::contains(StringAttr name) const {
191 return impl::findAttrSorted(first: begin(), last: end(), name).second;
192}
193
194DictionaryAttr::iterator DictionaryAttr::begin() const {
195 return getValue().begin();
196}
197DictionaryAttr::iterator DictionaryAttr::end() const {
198 return getValue().end();
199}
200size_t DictionaryAttr::size() const { return getValue().size(); }
201
202DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
203 return Base::get(ctx: context, args: ArrayRef<NamedAttribute>());
204}
205
206//===----------------------------------------------------------------------===//
207// StridedLayoutAttr
208//===----------------------------------------------------------------------===//
209
210/// Prints a strided layout attribute.
211void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
212 auto printIntOrQuestion = [&](int64_t value) {
213 if (ShapedType::isDynamic(dValue: value))
214 os << "?";
215 else
216 os << value;
217 };
218
219 os << "strided<[";
220 llvm::interleaveComma(c: getStrides(), os, each_fn: printIntOrQuestion);
221 os << "]";
222
223 if (getOffset() != 0) {
224 os << ", offset: ";
225 printIntOrQuestion(getOffset());
226 }
227 os << ">";
228}
229
230/// Returns true if this layout is static, i.e. the strides and offset all have
231/// a known value > 0.
232bool StridedLayoutAttr::hasStaticLayout() const {
233 return ShapedType::isStatic(dValue: getOffset()) &&
234 ShapedType::isStaticShape(dSizes: getStrides());
235}
236
237/// Returns the strided layout as an affine map.
238AffineMap StridedLayoutAttr::getAffineMap() const {
239 return makeStridedLinearLayoutMap(strides: getStrides(), offset: getOffset(), context: getContext());
240}
241
242/// Checks that the type-agnostic strided layout invariants are satisfied.
243LogicalResult
244StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
245 int64_t offset, ArrayRef<int64_t> strides) {
246 return success();
247}
248
249/// Checks that the type-specific strided layout invariants are satisfied.
250LogicalResult StridedLayoutAttr::verifyLayout(
251 ArrayRef<int64_t> shape,
252 function_ref<InFlightDiagnostic()> emitError) const {
253 if (shape.size() != getStrides().size())
254 return emitError() << "expected the number of strides to match the rank";
255
256 return success();
257}
258
259LogicalResult
260StridedLayoutAttr::getStridesAndOffset(ArrayRef<int64_t>,
261 SmallVectorImpl<int64_t> &strides,
262 int64_t &offset) const {
263 llvm::append_range(C&: strides, R: getStrides());
264 offset = getOffset();
265 return success();
266}
267
268//===----------------------------------------------------------------------===//
269// StringAttr
270//===----------------------------------------------------------------------===//
271
272StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
273 return Base::get(ctx: context, args: "", args: NoneType::get(context));
274}
275
276/// Twine support for StringAttr.
277StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
278 // Fast-path empty twine.
279 if (twine.isTriviallyEmpty())
280 return get(context);
281 SmallVector<char, 32> tempStr;
282 return Base::get(ctx: context, args: twine.toStringRef(Out&: tempStr), args: NoneType::get(context));
283}
284
285/// Twine support for StringAttr.
286StringAttr StringAttr::get(const Twine &twine, Type type) {
287 SmallVector<char, 32> tempStr;
288 return Base::get(ctx: type.getContext(), args: twine.toStringRef(Out&: tempStr), args&: type);
289}
290
291StringRef StringAttr::getValue() const { return getImpl()->value; }
292
293Type StringAttr::getType() const { return getImpl()->type; }
294
295Dialect *StringAttr::getReferencedDialect() const {
296 return getImpl()->referencedDialect;
297}
298
299//===----------------------------------------------------------------------===//
300// FloatAttr
301//===----------------------------------------------------------------------===//
302
303double FloatAttr::getValueAsDouble() const {
304 return getValueAsDouble(val: getValue());
305}
306double FloatAttr::getValueAsDouble(APFloat value) {
307 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
308 bool losesInfo = false;
309 value.convert(ToSemantics: APFloat::IEEEdouble(), RM: APFloat::rmNearestTiesToEven,
310 losesInfo: &losesInfo);
311 }
312 return value.convertToDouble();
313}
314
315LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
316 Type type, APFloat value) {
317 // Verify that the type is correct.
318 if (!llvm::isa<FloatType>(Val: type))
319 return emitError() << "expected floating point type";
320
321 // Verify that the type semantics match that of the value.
322 if (&llvm::cast<FloatType>(Val&: type).getFloatSemantics() !=
323 &value.getSemantics()) {
324 return emitError()
325 << "FloatAttr type doesn't match the type implied by its value";
326 }
327 return success();
328}
329
330//===----------------------------------------------------------------------===//
331// SymbolRefAttr
332//===----------------------------------------------------------------------===//
333
334SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
335 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
336 return get(rootReference: StringAttr::get(context: ctx, twine: value), nestedReferences: nestedRefs);
337}
338
339FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
340 return llvm::cast<FlatSymbolRefAttr>(Val: get(ctx, value, nestedRefs: {}));
341}
342
343FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
344 return llvm::cast<FlatSymbolRefAttr>(Val: get(rootReference: value, nestedReferences: {}));
345}
346
347FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
348 auto symName =
349 symbol->getAttrOfType<StringAttr>(name: SymbolTable::getSymbolAttrName());
350 assert(symName && "value does not have a valid symbol name");
351 return SymbolRefAttr::get(value: symName);
352}
353
354StringAttr SymbolRefAttr::getLeafReference() const {
355 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
356 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
357}
358
359//===----------------------------------------------------------------------===//
360// IntegerAttr
361//===----------------------------------------------------------------------===//
362
363int64_t IntegerAttr::getInt() const {
364 assert((getType().isIndex() || getType().isSignlessInteger()) &&
365 "must be signless integer");
366 return getValue().getSExtValue();
367}
368
369int64_t IntegerAttr::getSInt() const {
370 assert(getType().isSignedInteger() && "must be signed integer");
371 return getValue().getSExtValue();
372}
373
374uint64_t IntegerAttr::getUInt() const {
375 assert(getType().isUnsignedInteger() && "must be unsigned integer");
376 return getValue().getZExtValue();
377}
378
379/// Return the value as an APSInt which carries the signed from the type of
380/// the attribute. This traps on signless integers types!
381APSInt IntegerAttr::getAPSInt() const {
382 assert(!getType().isSignlessInteger() &&
383 "Signless integers don't carry a sign for APSInt");
384 return APSInt(getValue(), getType().isUnsignedInteger());
385}
386
387LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
388 Type type, APInt value) {
389 if (IntegerType integerType = llvm::dyn_cast<IntegerType>(Val&: type)) {
390 if (integerType.getWidth() != value.getBitWidth())
391 return emitError() << "integer type bit width (" << integerType.getWidth()
392 << ") doesn't match value bit width ("
393 << value.getBitWidth() << ")";
394 return success();
395 }
396 if (llvm::isa<IndexType>(Val: type)) {
397 if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
398 return emitError()
399 << "value bit width (" << value.getBitWidth()
400 << ") doesn't match index type internal storage bit width ("
401 << IndexType::kInternalStorageBitWidth << ")";
402 return success();
403 }
404 return emitError() << "expected integer or index type";
405}
406
407BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
408 auto attr = Base::get(ctx: type.getContext(), args&: type, args: APInt(/*numBits=*/1, value));
409 return llvm::cast<BoolAttr>(Val&: attr);
410}
411
412//===----------------------------------------------------------------------===//
413// BoolAttr
414//===----------------------------------------------------------------------===//
415
416bool BoolAttr::getValue() const {
417 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
418 return storage->value.getBoolValue();
419}
420
421bool BoolAttr::classof(Attribute attr) {
422 IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(Val&: attr);
423 return intAttr && intAttr.getType().isSignlessInteger(width: 1);
424}
425
426//===----------------------------------------------------------------------===//
427// OpaqueAttr
428//===----------------------------------------------------------------------===//
429
430LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
431 StringAttr dialect, StringRef attrData,
432 Type type) {
433 if (!Dialect::isValidNamespace(str: dialect.strref()))
434 return emitError() << "invalid dialect namespace '" << dialect << "'";
435
436 // Check that the dialect is actually registered.
437 MLIRContext *context = dialect.getContext();
438 if (!context->allowsUnregisteredDialects() &&
439 !context->getLoadedDialect(name: dialect.strref())) {
440 return emitError()
441 << "#" << dialect << "<\"" << attrData << "\"> : " << type
442 << " attribute created with unregistered dialect. If this is "
443 "intended, please call allowUnregisteredDialects() on the "
444 "MLIRContext, or use -allow-unregistered-dialect with "
445 "the MLIR opt tool used";
446 }
447
448 return success();
449}
450
451//===----------------------------------------------------------------------===//
452// DenseElementsAttr Utilities
453//===----------------------------------------------------------------------===//
454
455const char DenseIntOrFPElementsAttrStorage::kSplatTrue = ~0;
456const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
457
458/// Get the bitwidth of a dense element type within the buffer.
459/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
460static size_t getDenseElementStorageWidth(size_t origWidth) {
461 return origWidth == 1 ? origWidth : llvm::alignTo<8>(Value: origWidth);
462}
463static size_t getDenseElementStorageWidth(Type elementType) {
464 return getDenseElementStorageWidth(origWidth: getDenseElementBitWidth(eltType: elementType));
465}
466
467/// Set a bit to a specific value.
468static void setBit(char *rawData, size_t bitPos, bool value) {
469 if (value)
470 rawData[bitPos / CHAR_BIT] |= (1 << (bitPos % CHAR_BIT));
471 else
472 rawData[bitPos / CHAR_BIT] &= ~(1 << (bitPos % CHAR_BIT));
473}
474
475/// Return the value of the specified bit.
476static bool getBit(const char *rawData, size_t bitPos) {
477 return (rawData[bitPos / CHAR_BIT] & (1 << (bitPos % CHAR_BIT))) != 0;
478}
479
480/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
481/// BE format.
482static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
483 char *result) {
484 assert(llvm::endianness::native == llvm::endianness::big);
485 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
486
487 // Copy the words filled with data.
488 // For example, when `value` has 2 words, the first word is filled with data.
489 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
490 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
491 std::copy_n(first: reinterpret_cast<const char *>(value.getRawData()),
492 n: numFilledWords, result: result);
493 // Convert last word of APInt to LE format and store it in char
494 // array(`valueLE`).
495 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
496 size_t lastWordPos = numFilledWords;
497 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
498 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
499 inRawData: reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
500 outRawData: valueLE.begin(), elementBitWidth: APInt::APINT_BITS_PER_WORD, numElements: 1);
501 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
502 // and store it in `result`.
503 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
504 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
505 inRawData: valueLE.begin(), outRawData: result + lastWordPos,
506 elementBitWidth: (numBytes - lastWordPos) * CHAR_BIT, numElements: 1);
507}
508
509/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
510/// format.
511static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
512 APInt &result) {
513 assert(llvm::endianness::native == llvm::endianness::big);
514 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes);
515
516 // Copy the data that fills the word of `result` from `inArray`.
517 // For example, when `result` has 2 words, the first word will be filled with
518 // data. So, the first 8 bytes are copied from `inArray` here.
519 // `inArray` (10 bytes, BE): |abcdefgh|ij|
520 // ==> `result` (2 words, BE): |abcdefgh|--------|
521 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
522 std::copy_n(
523 first: inArray, n: numFilledWords,
524 result: const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
525
526 // Convert array data which will be last word of `result` to LE format, and
527 // store it in char array(`inArrayLE`).
528 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
529 size_t lastWordPos = numFilledWords;
530 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
531 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
532 inRawData: inArray + lastWordPos, outRawData: inArrayLE.begin(),
533 elementBitWidth: (numBytes - lastWordPos) * CHAR_BIT, numElements: 1);
534
535 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
536 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
537 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
538 inRawData: inArrayLE.begin(),
539 outRawData: const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
540 lastWordPos,
541 elementBitWidth: APInt::APINT_BITS_PER_WORD, numElements: 1);
542}
543
544/// Writes value to the bit position `bitPos` in array `rawData`.
545static void writeBits(char *rawData, size_t bitPos, APInt value) {
546 size_t bitWidth = value.getBitWidth();
547
548 // If the bitwidth is 1 we just toggle the specific bit.
549 if (bitWidth == 1)
550 return setBit(rawData, bitPos, value: value.isOne());
551
552 // Otherwise, the bit position is guaranteed to be byte aligned.
553 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
554 if (llvm::endianness::native == llvm::endianness::big) {
555 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
556 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
557 // work correctly in BE format.
558 // ex. `value` (2 words including 10 bytes)
559 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
560 copyAPIntToArrayForBEmachine(value, numBytes: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
561 result: rawData + (bitPos / CHAR_BIT));
562 } else {
563 std::copy_n(first: reinterpret_cast<const char *>(value.getRawData()),
564 n: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
565 result: rawData + (bitPos / CHAR_BIT));
566 }
567}
568
569/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
570/// `rawData`.
571static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
572 // Handle a boolean bit position.
573 if (bitWidth == 1)
574 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
575
576 // Otherwise, the bit position must be 8-bit aligned.
577 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned");
578 APInt result(bitWidth, 0);
579 if (llvm::endianness::native == llvm::endianness::big) {
580 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
581 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
582 // work correctly in BE format.
583 // ex. `result` (2 words including 10 bytes)
584 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
585 copyArrayToAPIntForBEmachine(inArray: rawData + (bitPos / CHAR_BIT),
586 numBytes: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT), result);
587 } else {
588 std::copy_n(first: rawData + (bitPos / CHAR_BIT),
589 n: llvm::divideCeil(Numerator: bitWidth, CHAR_BIT),
590 result: const_cast<char *>(
591 reinterpret_cast<const char *>(result.getRawData())));
592 }
593 return result;
594}
595
596/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
597/// the same element count as 'type'.
598template <typename Values>
599static bool hasSameNumElementsOrSplat(ShapedType type, const Values &values) {
600 return (values.size() == 1) ||
601 (type.getNumElements() == static_cast<int64_t>(values.size()));
602}
603
604//===----------------------------------------------------------------------===//
605// DenseElementsAttr Iterators
606//===----------------------------------------------------------------------===//
607
608//===----------------------------------------------------------------------===//
609// AttributeElementIterator
610//===----------------------------------------------------------------------===//
611
612DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
613 DenseElementsAttr attr, size_t index)
614 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
615 Attribute, Attribute, Attribute>(
616 attr.getAsOpaquePointer(), index) {}
617
618Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
619 auto owner = llvm::cast<DenseElementsAttr>(Val: getFromOpaquePointer(ptr: base));
620 Type eltTy = owner.getElementType();
621 if (llvm::dyn_cast<IntegerType>(Val&: eltTy))
622 return IntegerAttr::get(type: eltTy, value: *IntElementIterator(owner, index));
623 if (llvm::isa<IndexType>(Val: eltTy))
624 return IntegerAttr::get(type: eltTy, value: *IntElementIterator(owner, index));
625 if (auto floatEltTy = llvm::dyn_cast<FloatType>(Val&: eltTy)) {
626 IntElementIterator intIt(owner, index);
627 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
628 return FloatAttr::get(type: eltTy, value: *floatIt);
629 }
630 if (auto complexTy = llvm::dyn_cast<ComplexType>(Val&: eltTy)) {
631 auto complexEltTy = complexTy.getElementType();
632 ComplexIntElementIterator complexIntIt(owner, index);
633 if (llvm::isa<IntegerType>(Val: complexEltTy)) {
634 auto value = *complexIntIt;
635 auto real = IntegerAttr::get(type: complexEltTy, value: value.real());
636 auto imag = IntegerAttr::get(type: complexEltTy, value: value.imag());
637 return ArrayAttr::get(context: complexTy.getContext(),
638 value: ArrayRef<Attribute>{real, imag});
639 }
640
641 ComplexFloatElementIterator complexFloatIt(
642 llvm::cast<FloatType>(Val&: complexEltTy).getFloatSemantics(), complexIntIt);
643 auto value = *complexFloatIt;
644 auto real = FloatAttr::get(type: complexEltTy, value: value.real());
645 auto imag = FloatAttr::get(type: complexEltTy, value: value.imag());
646 return ArrayAttr::get(context: complexTy.getContext(),
647 value: ArrayRef<Attribute>{real, imag});
648 }
649 if (llvm::isa<DenseStringElementsAttr>(Val: owner)) {
650 ArrayRef<StringRef> vals = owner.getRawStringData();
651 return StringAttr::get(twine: owner.isSplat() ? vals.front() : vals[index], type: eltTy);
652 }
653 llvm_unreachable("unexpected element type");
654}
655
656//===----------------------------------------------------------------------===//
657// BoolElementIterator
658//===----------------------------------------------------------------------===//
659
660DenseElementsAttr::BoolElementIterator::BoolElementIterator(
661 DenseElementsAttr attr, size_t dataIndex)
662 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
663 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
664
665bool DenseElementsAttr::BoolElementIterator::operator*() const {
666 return getBit(rawData: getData(), bitPos: getDataIndex());
667}
668
669//===----------------------------------------------------------------------===//
670// IntElementIterator
671//===----------------------------------------------------------------------===//
672
673DenseElementsAttr::IntElementIterator::IntElementIterator(
674 DenseElementsAttr attr, size_t dataIndex)
675 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
676 attr.getRawData().data(), attr.isSplat(), dataIndex),
677 bitWidth(getDenseElementBitWidth(eltType: attr.getElementType())) {}
678
679APInt DenseElementsAttr::IntElementIterator::operator*() const {
680 return readBits(rawData: getData(),
681 bitPos: getDataIndex() * getDenseElementStorageWidth(origWidth: bitWidth),
682 bitWidth);
683}
684
685//===----------------------------------------------------------------------===//
686// ComplexIntElementIterator
687//===----------------------------------------------------------------------===//
688
689DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
690 DenseElementsAttr attr, size_t dataIndex)
691 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
692 std::complex<APInt>, std::complex<APInt>,
693 std::complex<APInt>>(
694 attr.getRawData().data(), attr.isSplat(), dataIndex) {
695 auto complexType = llvm::cast<ComplexType>(Val: attr.getElementType());
696 bitWidth = getDenseElementBitWidth(eltType: complexType.getElementType());
697}
698
699std::complex<APInt>
700DenseElementsAttr::ComplexIntElementIterator::operator*() const {
701 size_t storageWidth = getDenseElementStorageWidth(origWidth: bitWidth);
702 size_t offset = getDataIndex() * storageWidth * 2;
703 return {readBits(rawData: getData(), bitPos: offset, bitWidth),
704 readBits(rawData: getData(), bitPos: offset + storageWidth, bitWidth)};
705}
706
707//===----------------------------------------------------------------------===//
708// DenseArrayAttr
709//===----------------------------------------------------------------------===//
710
711LogicalResult
712DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
713 Type elementType, int64_t size, ArrayRef<char> rawData) {
714 if (!elementType.isIntOrIndexOrFloat())
715 return emitError() << "expected integer or floating point element type";
716 int64_t dataSize = rawData.size();
717 int64_t elementSize =
718 llvm::divideCeil(Numerator: elementType.getIntOrFloatBitWidth(), CHAR_BIT);
719 if (size * elementSize != dataSize) {
720 return emitError() << "expected data size (" << size << " elements, "
721 << elementSize
722 << " bytes each) does not match: " << dataSize
723 << " bytes";
724 }
725 return success();
726}
727
728namespace {
729/// Instantiations of this class provide utilities for interacting with native
730/// data types in the context of DenseArrayAttr.
731template <size_t width,
732 IntegerType::SignednessSemantics signedness = IntegerType::Signless>
733struct DenseArrayAttrIntUtil {
734 static bool checkElementType(Type eltType) {
735 auto type = llvm::dyn_cast<IntegerType>(Val&: eltType);
736 if (!type || type.getWidth() != width)
737 return false;
738 return type.getSignedness() == signedness;
739 }
740
741 static Type getElementType(MLIRContext *ctx) {
742 return IntegerType::get(context: ctx, width, signedness);
743 }
744
745 template <typename T>
746 static void printElement(raw_ostream &os, T value) {
747 os << value;
748 }
749
750 template <typename T>
751 static ParseResult parseElement(AsmParser &parser, T &value) {
752 return parser.parseInteger(value);
753 }
754};
755template <typename T>
756struct DenseArrayAttrUtil;
757
758/// Specialization for boolean elements to print 'true' and 'false' literals for
759/// elements.
760template <>
761struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
762 static void printElement(raw_ostream &os, bool value) {
763 os << (value ? "true" : "false");
764 }
765};
766
767/// Specialization for 8-bit integers to ensure values are printed as integers
768/// and not characters.
769template <>
770struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
771 static void printElement(raw_ostream &os, int8_t value) {
772 os << static_cast<int>(value);
773 }
774};
775template <>
776struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
777template <>
778struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
779template <>
780struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
781
782/// Specialization for 32-bit floats.
783template <>
784struct DenseArrayAttrUtil<float> {
785 static bool checkElementType(Type eltType) { return eltType.isF32(); }
786 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(context: ctx); }
787 static void printElement(raw_ostream &os, float value) { os << value; }
788
789 /// Parse a double and cast it to a float.
790 static ParseResult parseElement(AsmParser &parser, float &value) {
791 double doubleVal;
792 if (parser.parseFloat(result&: doubleVal))
793 return failure();
794 value = doubleVal;
795 return success();
796 }
797};
798
799/// Specialization for 64-bit floats.
800template <>
801struct DenseArrayAttrUtil<double> {
802 static bool checkElementType(Type eltType) { return eltType.isF64(); }
803 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(context: ctx); }
804 static void printElement(raw_ostream &os, float value) { os << value; }
805 static ParseResult parseElement(AsmParser &parser, double &value) {
806 return parser.parseFloat(result&: value);
807 }
808};
809} // namespace
810
811template <typename T>
812void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
813 print(printer.getStream());
814}
815
816template <typename T>
817void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
818 llvm::interleaveComma(asArrayRef(), os, [&](T value) {
819 DenseArrayAttrUtil<T>::printElement(os, value);
820 });
821}
822
823template <typename T>
824void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
825 os << "[";
826 printWithoutBraces(os);
827 os << "]";
828}
829
830/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
831template <typename T>
832Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
833 Type odsType) {
834 SmallVector<T> data;
835 if (failed(parser.parseCommaSeparatedList([&]() {
836 T value;
837 if (DenseArrayAttrUtil<T>::parseElement(parser, value))
838 return failure();
839 data.push_back(value);
840 return success();
841 })))
842 return {};
843 return get(context: parser.getContext(), content: data);
844}
845
846/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
847template <typename T>
848Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
849 if (parser.parseLSquare())
850 return {};
851 // Handle empty list case.
852 if (succeeded(Result: parser.parseOptionalRSquare()))
853 return get(context: parser.getContext(), content: {});
854 Attribute result = parseWithoutBraces(parser, odsType);
855 if (parser.parseRSquare())
856 return {};
857 return result;
858}
859
860/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
861template <typename T>
862DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
863 ArrayRef<char> raw = getRawData();
864 assert((raw.size() % sizeof(T)) == 0);
865 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
866 raw.size() / sizeof(T));
867}
868
869/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
870template <typename T>
871DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
872 ArrayRef<T> content) {
873 Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
874 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
875 content.size() * sizeof(T));
876 return llvm::cast<DenseArrayAttrImpl<T>>(
877 Base::get(context, elementType, content.size(), rawArray));
878}
879
880template <typename T>
881bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
882 if (auto denseArray = llvm::dyn_cast<DenseArrayAttr>(Val&: attr))
883 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
884 return false;
885}
886
887namespace mlir {
888namespace detail {
889// Explicit instantiation for all the supported DenseArrayAttr.
890template class DenseArrayAttrImpl<bool>;
891template class DenseArrayAttrImpl<int8_t>;
892template class DenseArrayAttrImpl<int16_t>;
893template class DenseArrayAttrImpl<int32_t>;
894template class DenseArrayAttrImpl<int64_t>;
895template class DenseArrayAttrImpl<float>;
896template class DenseArrayAttrImpl<double>;
897} // namespace detail
898} // namespace mlir
899
900//===----------------------------------------------------------------------===//
901// DenseElementsAttr
902//===----------------------------------------------------------------------===//
903
904/// Method for support type inquiry through isa, cast and dyn_cast.
905bool DenseElementsAttr::classof(Attribute attr) {
906 return llvm::isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(Val: attr);
907}
908
909DenseElementsAttr DenseElementsAttr::get(ShapedType type,
910 ArrayRef<Attribute> values) {
911 assert(hasSameNumElementsOrSplat(type, values));
912
913 Type eltType = type.getElementType();
914
915 // Take care complex type case first.
916 if (auto complexType = llvm::dyn_cast<ComplexType>(Val&: eltType)) {
917 if (complexType.getElementType().isIntOrIndex()) {
918 SmallVector<std::complex<APInt>> complexValues;
919 complexValues.reserve(N: values.size());
920 for (Attribute attr : values) {
921 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
922 auto arrayAttr = llvm::cast<ArrayAttr>(Val&: attr);
923 assert(arrayAttr.size() == 2 && "expected 2 element for complex");
924 auto attr0 = arrayAttr[0];
925 auto attr1 = arrayAttr[1];
926 complexValues.push_back(
927 Elt: std::complex<APInt>(llvm::cast<IntegerAttr>(Val&: attr0).getValue(),
928 llvm::cast<IntegerAttr>(Val&: attr1).getValue()));
929 }
930 return DenseElementsAttr::get(type, values: complexValues);
931 }
932 // Must be float.
933 SmallVector<std::complex<APFloat>> complexValues;
934 complexValues.reserve(N: values.size());
935 for (Attribute attr : values) {
936 assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex");
937 auto arrayAttr = llvm::cast<ArrayAttr>(Val&: attr);
938 assert(arrayAttr.size() == 2 && "expected 2 element for complex");
939 auto attr0 = arrayAttr[0];
940 auto attr1 = arrayAttr[1];
941 complexValues.push_back(
942 Elt: std::complex<APFloat>(llvm::cast<FloatAttr>(Val&: attr0).getValue(),
943 llvm::cast<FloatAttr>(Val&: attr1).getValue()));
944 }
945 return DenseElementsAttr::get(type, values: complexValues);
946 }
947
948 // If the element type is not based on int/float/index, assume it is a string
949 // type.
950 if (!eltType.isIntOrIndexOrFloat()) {
951 SmallVector<StringRef, 8> stringValues;
952 stringValues.reserve(N: values.size());
953 for (Attribute attr : values) {
954 assert(llvm::isa<StringAttr>(attr) &&
955 "expected string value for non integer/index/float element");
956 stringValues.push_back(Elt: llvm::cast<StringAttr>(Val&: attr).getValue());
957 }
958 return get(type, values: stringValues);
959 }
960
961 // Otherwise, get the raw storage width to use for the allocation.
962 size_t bitWidth = getDenseElementBitWidth(eltType);
963 size_t storageBitWidth = getDenseElementStorageWidth(origWidth: bitWidth);
964
965 // Compress the attribute values into a character buffer.
966 SmallVector<char, 8> data(
967 llvm::divideCeil(Numerator: storageBitWidth * values.size(), CHAR_BIT));
968 APInt intVal;
969 for (unsigned i = 0, e = values.size(); i < e; ++i) {
970 if (auto floatAttr = llvm::dyn_cast<FloatAttr>(Val: values[i])) {
971 assert(floatAttr.getType() == eltType &&
972 "expected float attribute type to equal element type");
973 intVal = floatAttr.getValue().bitcastToAPInt();
974 } else {
975 auto intAttr = llvm::cast<IntegerAttr>(Val: values[i]);
976 assert(intAttr.getType() == eltType &&
977 "expected integer attribute type to equal element type");
978 intVal = intAttr.getValue();
979 }
980
981 assert(intVal.getBitWidth() == bitWidth &&
982 "expected value to have same bitwidth as element type");
983 writeBits(rawData: data.data(), bitPos: i * storageBitWidth, value: intVal);
984 }
985
986 // Handle the special encoding of splat of bool.
987 if (values.size() == 1 && eltType.isInteger(width: 1))
988 data[0] = data[0] ? -1 : 0;
989
990 return DenseIntOrFPElementsAttr::getRaw(type, data);
991}
992
993DenseElementsAttr DenseElementsAttr::get(ShapedType type,
994 ArrayRef<bool> values) {
995 assert(hasSameNumElementsOrSplat(type, values));
996 assert(type.getElementType().isInteger(1));
997
998 SmallVector<char> buff(llvm::divideCeil(Numerator: values.size(), CHAR_BIT));
999
1000 if (!values.empty()) {
1001 bool isSplat = true;
1002 bool firstValue = values[0];
1003 for (int i = 0, e = values.size(); i != e; ++i) {
1004 isSplat &= values[i] == firstValue;
1005 setBit(rawData: buff.data(), bitPos: i, value: values[i]);
1006 }
1007
1008 // Splat of bool is encoded as a byte with all-ones in it.
1009 if (isSplat) {
1010 buff.resize(N: 1);
1011 buff[0] = values[0] ? -1 : 0;
1012 }
1013 }
1014
1015 return DenseIntOrFPElementsAttr::getRaw(type, data: buff);
1016}
1017
1018DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1019 ArrayRef<StringRef> values) {
1020 assert(!type.getElementType().isIntOrFloat());
1021 return DenseStringElementsAttr::get(type, values);
1022}
1023
1024/// Constructs a dense integer elements attribute from an array of APInt
1025/// values. Each APInt value is expected to have the same bitwidth as the
1026/// element type of 'type'.
1027DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1028 ArrayRef<APInt> values) {
1029 assert(type.getElementType().isIntOrIndex());
1030 assert(hasSameNumElementsOrSplat(type, values));
1031 size_t storageBitWidth = getDenseElementStorageWidth(elementType: type.getElementType());
1032 return DenseIntOrFPElementsAttr::getRaw(type, storageWidth: storageBitWidth, values);
1033}
1034DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1035 ArrayRef<std::complex<APInt>> values) {
1036 ComplexType complex = llvm::cast<ComplexType>(Val: type.getElementType());
1037 assert(llvm::isa<IntegerType>(complex.getElementType()));
1038 assert(hasSameNumElementsOrSplat(type, values));
1039 size_t storageBitWidth = getDenseElementStorageWidth(elementType: complex) / 2;
1040 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1041 values.size() * 2);
1042 return DenseIntOrFPElementsAttr::getRaw(type, storageWidth: storageBitWidth, values: intVals);
1043}
1044
1045// Constructs a dense float elements attribute from an array of APFloat
1046// values. Each APFloat value is expected to have the same bitwidth as the
1047// element type of 'type'.
1048DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1049 ArrayRef<APFloat> values) {
1050 assert(llvm::isa<FloatType>(type.getElementType()));
1051 assert(hasSameNumElementsOrSplat(type, values));
1052 size_t storageBitWidth = getDenseElementStorageWidth(elementType: type.getElementType());
1053 return DenseIntOrFPElementsAttr::getRaw(type, storageWidth: storageBitWidth, values);
1054}
1055DenseElementsAttr
1056DenseElementsAttr::get(ShapedType type,
1057 ArrayRef<std::complex<APFloat>> values) {
1058 ComplexType complex = llvm::cast<ComplexType>(Val: type.getElementType());
1059 assert(llvm::isa<FloatType>(complex.getElementType()));
1060 assert(hasSameNumElementsOrSplat(type, values));
1061 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1062 values.size() * 2);
1063 size_t storageBitWidth = getDenseElementStorageWidth(elementType: complex) / 2;
1064 return DenseIntOrFPElementsAttr::getRaw(type, storageWidth: storageBitWidth, values: apVals);
1065}
1066
1067/// Construct a dense elements attribute from a raw buffer representing the
1068/// data for this attribute. Users should generally not use this methods as
1069/// the expected buffer format may not be a form the user expects.
1070DenseElementsAttr
1071DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1072 return DenseIntOrFPElementsAttr::getRaw(type, data: rawBuffer);
1073}
1074
1075/// Returns true if the given buffer is a valid raw buffer for the given type.
1076bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1077 ArrayRef<char> rawBuffer,
1078 bool &detectedSplat) {
1079 size_t storageWidth = getDenseElementStorageWidth(elementType: type.getElementType());
1080 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT;
1081 int64_t numElements = type.getNumElements();
1082
1083 // The initializer is always a splat if the result type has a single element.
1084 detectedSplat = numElements == 1;
1085
1086 // Storage width of 1 is special as it is packed by the bit.
1087 if (storageWidth == 1) {
1088 // Check for a splat, or a buffer equal to the number of elements which
1089 // consists of either all 0's or all 1's.
1090 if (rawBuffer.size() == 1) {
1091 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1092 if (rawByte == 0 || rawByte == 0xff) {
1093 detectedSplat = true;
1094 return true;
1095 }
1096 }
1097
1098 // This is a valid non-splat buffer if it has the right size.
1099 return rawBufferWidth == llvm::alignTo<8>(Value: numElements);
1100 }
1101
1102 // All other types are 8-bit aligned, so we can just check the buffer width
1103 // to know if only a single initializer element was passed in.
1104 if (rawBufferWidth == storageWidth) {
1105 detectedSplat = true;
1106 return true;
1107 }
1108
1109 // The raw buffer is valid if it has the right size.
1110 return rawBufferWidth == storageWidth * numElements;
1111}
1112
1113/// Check the information for a C++ data type, check if this type is valid for
1114/// the current attribute. This method is used to verify specific type
1115/// invariants that the templatized 'getValues' method cannot.
1116static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1117 bool isSigned) {
1118 // Make sure that the data element size is the same as the type element width.
1119 auto denseEltBitWidth = getDenseElementBitWidth(eltType: type);
1120 auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT);
1121 if (denseEltBitWidth != dataSize) {
1122 LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width "
1123 << denseEltBitWidth << " to match data size "
1124 << dataSize << " for type " << type << "\n");
1125 return false;
1126 }
1127
1128 // Check that the element type is either float or integer or index.
1129 if (!isInt) {
1130 bool valid = llvm::isa<FloatType>(Val: type);
1131 if (!valid)
1132 LLVM_DEBUG(llvm::dbgs()
1133 << "expected float type when isInt is false, but found "
1134 << type << "\n");
1135 return valid;
1136 }
1137 if (type.isIndex())
1138 return true;
1139
1140 auto intType = llvm::dyn_cast<IntegerType>(Val&: type);
1141 if (!intType) {
1142 LLVM_DEBUG(llvm::dbgs()
1143 << "expected integer type when isInt is true, but found " << type
1144 << "\n");
1145 return false;
1146 }
1147
1148 // Make sure signedness semantics is consistent.
1149 if (intType.isSignless())
1150 return true;
1151
1152 bool valid = intType.isSigned() == isSigned;
1153 if (!valid)
1154 LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned
1155 << " to match type " << type << "\n");
1156 return valid;
1157}
1158
1159/// Defaults down the subclass implementation.
1160DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1161 ArrayRef<char> data,
1162 int64_t dataEltSize,
1163 bool isInt, bool isSigned) {
1164 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1165 isSigned);
1166}
1167DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1168 ArrayRef<char> data,
1169 int64_t dataEltSize,
1170 bool isInt,
1171 bool isSigned) {
1172 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1173 isInt, isSigned);
1174}
1175
1176bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1177 bool isSigned) const {
1178 return ::isValidIntOrFloat(type: getElementType(), dataEltSize, isInt, isSigned);
1179}
1180bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1181 bool isSigned) const {
1182 return ::isValidIntOrFloat(
1183 type: llvm::cast<ComplexType>(Val: getElementType()).getElementType(),
1184 dataEltSize: dataEltSize / 2, isInt, isSigned);
1185}
1186
1187/// Returns true if this attribute corresponds to a splat, i.e. if all element
1188/// values are the same.
1189bool DenseElementsAttr::isSplat() const {
1190 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1191}
1192
1193/// Return if the given complex type has an integer element type.
1194static bool isComplexOfIntType(Type type) {
1195 return llvm::isa<IntegerType>(Val: llvm::cast<ComplexType>(Val&: type).getElementType());
1196}
1197
1198auto DenseElementsAttr::tryGetComplexIntValues() const
1199 -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
1200 if (!isComplexOfIntType(type: getElementType()))
1201 return failure();
1202 return iterator_range_impl<ComplexIntElementIterator>(
1203 getType(), ComplexIntElementIterator(*this, 0),
1204 ComplexIntElementIterator(*this, getNumElements()));
1205}
1206
1207auto DenseElementsAttr::tryGetFloatValues() const
1208 -> FailureOr<iterator_range_impl<FloatElementIterator>> {
1209 auto eltTy = llvm::dyn_cast<FloatType>(Val: getElementType());
1210 if (!eltTy)
1211 return failure();
1212 const auto &elementSemantics = eltTy.getFloatSemantics();
1213 return iterator_range_impl<FloatElementIterator>(
1214 getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1215 FloatElementIterator(elementSemantics, raw_int_end()));
1216}
1217
1218auto DenseElementsAttr::tryGetComplexFloatValues() const
1219 -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
1220 auto complexTy = llvm::dyn_cast<ComplexType>(Val: getElementType());
1221 if (!complexTy)
1222 return failure();
1223 auto eltTy = llvm::dyn_cast<FloatType>(Val: complexTy.getElementType());
1224 if (!eltTy)
1225 return failure();
1226 const auto &semantics = eltTy.getFloatSemantics();
1227 return iterator_range_impl<ComplexFloatElementIterator>(
1228 getType(), {semantics, {*this, 0}},
1229 {semantics, {*this, static_cast<size_t>(getNumElements())}});
1230}
1231
1232/// Return the raw storage data held by this attribute.
1233ArrayRef<char> DenseElementsAttr::getRawData() const {
1234 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1235}
1236
1237ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1238 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1239}
1240
1241/// Return a new DenseElementsAttr that has the same data as the current
1242/// attribute, but has been reshaped to 'newType'. The new type must have the
1243/// same total number of elements as well as element type.
1244DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1245 ShapedType curType = getType();
1246 if (curType == newType)
1247 return *this;
1248
1249 assert(newType.getElementType() == curType.getElementType() &&
1250 "expected the same element type");
1251 assert(newType.getNumElements() == curType.getNumElements() &&
1252 "expected the same number of elements");
1253 return DenseIntOrFPElementsAttr::getRaw(type: newType, data: getRawData());
1254}
1255
1256DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1257 assert(isSplat() && "expected a splat type");
1258
1259 ShapedType curType = getType();
1260 if (curType == newType)
1261 return *this;
1262
1263 assert(newType.getElementType() == curType.getElementType() &&
1264 "expected the same element type");
1265 return DenseIntOrFPElementsAttr::getRaw(type: newType, data: getRawData());
1266}
1267
1268/// Return a new DenseElementsAttr that has the same data as the current
1269/// attribute, but has bitcast elements such that it is now 'newType'. The new
1270/// type must have the same shape and element types of the same bitwidth as the
1271/// current type.
1272DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1273 ShapedType curType = getType();
1274 Type curElType = curType.getElementType();
1275 if (curElType == newElType)
1276 return *this;
1277
1278 assert(getDenseElementBitWidth(newElType) ==
1279 getDenseElementBitWidth(curElType) &&
1280 "expected element types with the same bitwidth");
1281 return DenseIntOrFPElementsAttr::getRaw(type: curType.clone(elementType: newElType),
1282 data: getRawData());
1283}
1284
1285DenseElementsAttr
1286DenseElementsAttr::mapValues(Type newElementType,
1287 function_ref<APInt(const APInt &)> mapping) const {
1288 return llvm::cast<DenseIntElementsAttr>(Val: *this).mapValues(newElementType,
1289 mapping);
1290}
1291
1292DenseElementsAttr DenseElementsAttr::mapValues(
1293 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1294 return llvm::cast<DenseFPElementsAttr>(Val: *this).mapValues(newElementType,
1295 mapping);
1296}
1297
1298ShapedType DenseElementsAttr::getType() const {
1299 return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1300}
1301
1302Type DenseElementsAttr::getElementType() const {
1303 return getType().getElementType();
1304}
1305
1306int64_t DenseElementsAttr::getNumElements() const {
1307 return getType().getNumElements();
1308}
1309
1310//===----------------------------------------------------------------------===//
1311// DenseIntOrFPElementsAttr
1312//===----------------------------------------------------------------------===//
1313
1314/// Utility method to write a range of APInt values to a buffer.
1315template <typename APRangeT>
1316static void writeAPIntsToBuffer(size_t storageWidth,
1317 SmallVectorImpl<char> &data,
1318 APRangeT &&values) {
1319 size_t numValues = llvm::size(values);
1320 data.resize(N: llvm::divideCeil(Numerator: storageWidth * numValues, CHAR_BIT));
1321 size_t offset = 0;
1322 for (auto it = values.begin(), e = values.end(); it != e;
1323 ++it, offset += storageWidth) {
1324 assert((*it).getBitWidth() <= storageWidth);
1325 writeBits(data.data(), offset, *it);
1326 }
1327
1328 // Handle the special encoding of splat of a boolean.
1329 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1330 data[0] = data[0] ? -1 : 0;
1331}
1332
1333/// Constructs a dense elements attribute from an array of raw APFloat values.
1334/// Each APFloat value is expected to have the same bitwidth as the element
1335/// type of 'type'. 'type' must be a vector or tensor with static shape.
1336DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1337 size_t storageWidth,
1338 ArrayRef<APFloat> values) {
1339 SmallVector<char> data;
1340 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1341 writeAPIntsToBuffer(storageWidth, data, values: llvm::map_range(C&: values, F: unwrapFloat));
1342 return DenseIntOrFPElementsAttr::getRaw(type, data);
1343}
1344
1345/// Constructs a dense elements attribute from an array of raw APInt values.
1346/// Each APInt value is expected to have the same bitwidth as the element type
1347/// of 'type'.
1348DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1349 size_t storageWidth,
1350 ArrayRef<APInt> values) {
1351 SmallVector<char> data;
1352 writeAPIntsToBuffer(storageWidth, data, values);
1353 return DenseIntOrFPElementsAttr::getRaw(type, data);
1354}
1355
1356DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1357 ArrayRef<char> data) {
1358 assert(type.hasStaticShape() && "type must have static shape");
1359 bool isSplat = false;
1360 bool isValid = isValidRawBuffer(type, rawBuffer: data, detectedSplat&: isSplat);
1361 assert(isValid);
1362 (void)isValid;
1363 return Base::get(ctx: type.getContext(), args&: type, args&: data, args&: isSplat);
1364}
1365
1366/// Overload of the raw 'get' method that asserts that the given type is of
1367/// complex type. This method is used to verify type invariants that the
1368/// templatized 'get' method cannot.
1369DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1370 ArrayRef<char> data,
1371 int64_t dataEltSize,
1372 bool isInt,
1373 bool isSigned) {
1374 assert(::isValidIntOrFloat(
1375 llvm::cast<ComplexType>(type.getElementType()).getElementType(),
1376 dataEltSize / 2, isInt, isSigned) &&
1377 "Try re-running with -debug-only=builtinattributes");
1378
1379 int64_t numElements = data.size() / dataEltSize;
1380 (void)numElements;
1381 assert(numElements == 1 || numElements == type.getNumElements());
1382 return getRaw(type, data);
1383}
1384
1385/// Overload of the 'getRaw' method that asserts that the given type is of
1386/// integer type. This method is used to verify type invariants that the
1387/// templatized 'get' method cannot.
1388DenseElementsAttr
1389DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1390 int64_t dataEltSize, bool isInt,
1391 bool isSigned) {
1392 assert(::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt,
1393 isSigned) &&
1394 "Try re-running with -debug-only=builtinattributes");
1395
1396 int64_t numElements = data.size() / dataEltSize;
1397 assert(numElements == 1 || numElements == type.getNumElements());
1398 (void)numElements;
1399 return getRaw(type, data);
1400}
1401
1402void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1403 const char *inRawData, char *outRawData, size_t elementBitWidth,
1404 size_t numElements) {
1405 using llvm::support::ulittle16_t;
1406 using llvm::support::ulittle32_t;
1407 using llvm::support::ulittle64_t;
1408
1409 assert(llvm::endianness::native == llvm::endianness::big);
1410 // NOLINT to avoid warning message about replacing by static_assert()
1411
1412 // Following std::copy_n always converts endianness on BE machine.
1413 switch (elementBitWidth) {
1414 case 16: {
1415 const ulittle16_t *inRawDataPos =
1416 reinterpret_cast<const ulittle16_t *>(inRawData);
1417 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1418 std::copy_n(first: inRawDataPos, n: numElements, result: outDataPos);
1419 break;
1420 }
1421 case 32: {
1422 const ulittle32_t *inRawDataPos =
1423 reinterpret_cast<const ulittle32_t *>(inRawData);
1424 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1425 std::copy_n(first: inRawDataPos, n: numElements, result: outDataPos);
1426 break;
1427 }
1428 case 64: {
1429 const ulittle64_t *inRawDataPos =
1430 reinterpret_cast<const ulittle64_t *>(inRawData);
1431 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1432 std::copy_n(first: inRawDataPos, n: numElements, result: outDataPos);
1433 break;
1434 }
1435 default: {
1436 size_t nBytes = elementBitWidth / CHAR_BIT;
1437 for (size_t i = 0; i < nBytes; i++)
1438 std::copy_n(first: inRawData + (nBytes - 1 - i), n: 1, result: outRawData + i);
1439 break;
1440 }
1441 }
1442}
1443
1444void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1445 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1446 ShapedType type) {
1447 size_t numElements = type.getNumElements();
1448 Type elementType = type.getElementType();
1449 if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(Val&: elementType)) {
1450 elementType = complexTy.getElementType();
1451 numElements = numElements * 2;
1452 }
1453 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1454 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&
1455 inRawData.size() <= outRawData.size());
1456 if (elementBitWidth <= CHAR_BIT)
1457 std::memcpy(dest: outRawData.begin(), src: inRawData.begin(), n: inRawData.size());
1458 else
1459 convertEndianOfCharForBEmachine(inRawData: inRawData.begin(), outRawData: outRawData.begin(),
1460 elementBitWidth, numElements);
1461}
1462
1463//===----------------------------------------------------------------------===//
1464// DenseFPElementsAttr
1465//===----------------------------------------------------------------------===//
1466
1467template <typename Fn, typename Attr>
1468static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1469 Type newElementType,
1470 llvm::SmallVectorImpl<char> &data) {
1471 size_t bitWidth = getDenseElementBitWidth(eltType: newElementType);
1472 size_t storageBitWidth = getDenseElementStorageWidth(origWidth: bitWidth);
1473
1474 ShapedType newArrayType = inType.cloneWith(shape: inType.getShape(), elementType: newElementType);
1475
1476 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1477 data.resize(N: llvm::divideCeil(Numerator: storageBitWidth * numRawElements, CHAR_BIT));
1478
1479 // Functor used to process a single element value of the attribute.
1480 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1481 auto newInt = mapping(value);
1482 assert(newInt.getBitWidth() == bitWidth);
1483 writeBits(data.data(), index * storageBitWidth, newInt);
1484 };
1485
1486 // Check for the splat case.
1487 if (attr.isSplat()) {
1488 if (bitWidth == 1) {
1489 // Handle the special encoding of splat of bool.
1490 data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
1491 } else {
1492 processElt(*attr.begin(), /*index=*/0);
1493 }
1494 return newArrayType;
1495 }
1496
1497 // Otherwise, process all of the element values.
1498 uint64_t elementIdx = 0;
1499 for (auto value : attr)
1500 processElt(value, elementIdx++);
1501 return newArrayType;
1502}
1503
1504DenseElementsAttr DenseFPElementsAttr::mapValues(
1505 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1506 llvm::SmallVector<char, 8> elementData;
1507 auto newArrayType =
1508 mappingHelper(mapping, attr: *this, inType: getType(), newElementType, data&: elementData);
1509
1510 return getRaw(type: newArrayType, data: elementData);
1511}
1512
1513/// Method for supporting type inquiry through isa, cast and dyn_cast.
1514bool DenseFPElementsAttr::classof(Attribute attr) {
1515 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val&: attr))
1516 return llvm::isa<FloatType>(Val: denseAttr.getType().getElementType());
1517 return false;
1518}
1519
1520//===----------------------------------------------------------------------===//
1521// DenseIntElementsAttr
1522//===----------------------------------------------------------------------===//
1523
1524DenseElementsAttr DenseIntElementsAttr::mapValues(
1525 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1526 llvm::SmallVector<char, 8> elementData;
1527 auto newArrayType =
1528 mappingHelper(mapping, attr: *this, inType: getType(), newElementType, data&: elementData);
1529 return getRaw(type: newArrayType, data: elementData);
1530}
1531
1532/// Method for supporting type inquiry through isa, cast and dyn_cast.
1533bool DenseIntElementsAttr::classof(Attribute attr) {
1534 if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(Val&: attr))
1535 return denseAttr.getType().getElementType().isIntOrIndex();
1536 return false;
1537}
1538
1539//===----------------------------------------------------------------------===//
1540// DenseResourceElementsAttr
1541//===----------------------------------------------------------------------===//
1542
1543DenseResourceElementsAttr
1544DenseResourceElementsAttr::get(ShapedType type,
1545 DenseResourceElementsHandle handle) {
1546 return Base::get(ctx: type.getContext(), args&: type, args&: handle);
1547}
1548
1549DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1550 StringRef blobName,
1551 AsmResourceBlob blob) {
1552 // Extract the builtin dialect resource manager from context and construct a
1553 // handle by inserting a new resource using the provided blob.
1554 auto &manager =
1555 DenseResourceElementsHandle::getManagerInterface(ctx: type.getContext());
1556 return get(type, handle: manager.insert(name: blobName, blob: std::move(blob)));
1557}
1558
1559ArrayRef<char> DenseResourceElementsAttr::getData() {
1560 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1561 return blob->getDataAs<char>();
1562 return {};
1563}
1564
1565//===----------------------------------------------------------------------===//
1566// DenseResourceElementsAttrBase
1567//===----------------------------------------------------------------------===//
1568
1569namespace {
1570/// Instantiations of this class provide utilities for interacting with native
1571/// data types in the context of DenseResourceElementsAttr.
1572template <typename T>
1573struct DenseResourceAttrUtil;
1574template <size_t width, bool isSigned>
1575struct DenseResourceElementsAttrIntUtil {
1576 static bool checkElementType(Type eltType) {
1577 IntegerType type = llvm::dyn_cast<IntegerType>(Val&: eltType);
1578 if (!type || type.getWidth() != width)
1579 return false;
1580 return isSigned ? !type.isUnsigned() : !type.isSigned();
1581 }
1582};
1583template <>
1584struct DenseResourceAttrUtil<bool> {
1585 static bool checkElementType(Type eltType) {
1586 return eltType.isSignlessInteger(width: 1);
1587 }
1588};
1589template <>
1590struct DenseResourceAttrUtil<int8_t>
1591 : public DenseResourceElementsAttrIntUtil<8, true> {};
1592template <>
1593struct DenseResourceAttrUtil<uint8_t>
1594 : public DenseResourceElementsAttrIntUtil<8, false> {};
1595template <>
1596struct DenseResourceAttrUtil<int16_t>
1597 : public DenseResourceElementsAttrIntUtil<16, true> {};
1598template <>
1599struct DenseResourceAttrUtil<uint16_t>
1600 : public DenseResourceElementsAttrIntUtil<16, false> {};
1601template <>
1602struct DenseResourceAttrUtil<int32_t>
1603 : public DenseResourceElementsAttrIntUtil<32, true> {};
1604template <>
1605struct DenseResourceAttrUtil<uint32_t>
1606 : public DenseResourceElementsAttrIntUtil<32, false> {};
1607template <>
1608struct DenseResourceAttrUtil<int64_t>
1609 : public DenseResourceElementsAttrIntUtil<64, true> {};
1610template <>
1611struct DenseResourceAttrUtil<uint64_t>
1612 : public DenseResourceElementsAttrIntUtil<64, false> {};
1613template <>
1614struct DenseResourceAttrUtil<float> {
1615 static bool checkElementType(Type eltType) { return eltType.isF32(); }
1616};
1617template <>
1618struct DenseResourceAttrUtil<double> {
1619 static bool checkElementType(Type eltType) { return eltType.isF64(); }
1620};
1621} // namespace
1622
1623template <typename T>
1624DenseResourceElementsAttrBase<T>
1625DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1626 AsmResourceBlob blob) {
1627 // Check that the blob is in the form we were expecting.
1628 assert(blob.getDataAlignment() == alignof(T) &&
1629 "alignment mismatch between expected alignment and blob alignment");
1630 assert(((blob.getData().size() % sizeof(T)) == 0) &&
1631 "size mismatch between expected element width and blob size");
1632 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
1633 "invalid shape element type for provided type `T`");
1634 return llvm::cast<DenseResourceElementsAttrBase<T>>(
1635 DenseResourceElementsAttr::get(type, blobName, blob: std::move(blob)));
1636}
1637
1638template <typename T>
1639std::optional<ArrayRef<T>>
1640DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1641 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1642 return blob->template getDataAs<T>();
1643 return std::nullopt;
1644}
1645
1646template <typename T>
1647bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1648 auto resourceAttr = llvm::dyn_cast<DenseResourceElementsAttr>(Val&: attr);
1649 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1650 resourceAttr.getElementType());
1651}
1652
1653namespace mlir {
1654namespace detail {
1655// Explicit instantiation for all the supported DenseResourceElementsAttr.
1656template class DenseResourceElementsAttrBase<bool>;
1657template class DenseResourceElementsAttrBase<int8_t>;
1658template class DenseResourceElementsAttrBase<int16_t>;
1659template class DenseResourceElementsAttrBase<int32_t>;
1660template class DenseResourceElementsAttrBase<int64_t>;
1661template class DenseResourceElementsAttrBase<uint8_t>;
1662template class DenseResourceElementsAttrBase<uint16_t>;
1663template class DenseResourceElementsAttrBase<uint32_t>;
1664template class DenseResourceElementsAttrBase<uint64_t>;
1665template class DenseResourceElementsAttrBase<float>;
1666template class DenseResourceElementsAttrBase<double>;
1667} // namespace detail
1668} // namespace mlir
1669
1670//===----------------------------------------------------------------------===//
1671// SparseElementsAttr
1672//===----------------------------------------------------------------------===//
1673
1674/// Get a zero APFloat for the given sparse attribute.
1675APFloat SparseElementsAttr::getZeroAPFloat() const {
1676 auto eltType = llvm::cast<FloatType>(Val: getElementType());
1677 return APFloat(eltType.getFloatSemantics());
1678}
1679
1680/// Get a zero APInt for the given sparse attribute.
1681APInt SparseElementsAttr::getZeroAPInt() const {
1682 auto eltType = llvm::cast<IntegerType>(Val: getElementType());
1683 return APInt::getZero(numBits: eltType.getWidth());
1684}
1685
1686/// Get a zero attribute for the given attribute type.
1687Attribute SparseElementsAttr::getZeroAttr() const {
1688 auto eltType = getElementType();
1689
1690 // Handle floating point elements.
1691 if (llvm::isa<FloatType>(Val: eltType))
1692 return FloatAttr::get(type: eltType, value: 0);
1693
1694 // Handle complex elements.
1695 if (auto complexTy = llvm::dyn_cast<ComplexType>(Val&: eltType)) {
1696 auto eltType = complexTy.getElementType();
1697 Attribute zero;
1698 if (llvm::isa<FloatType>(Val: eltType))
1699 zero = FloatAttr::get(type: eltType, value: 0);
1700 else // must be integer
1701 zero = IntegerAttr::get(type: eltType, value: 0);
1702 return ArrayAttr::get(context: complexTy.getContext(),
1703 value: ArrayRef<Attribute>{zero, zero});
1704 }
1705
1706 // Handle string type.
1707 if (llvm::isa<DenseStringElementsAttr>(Val: getValues()))
1708 return StringAttr::get(twine: "", type: eltType);
1709
1710 // Otherwise, this is an integer.
1711 return IntegerAttr::get(type: eltType, value: 0);
1712}
1713
1714/// Flatten, and return, all of the sparse indices in this attribute in
1715/// row-major order.
1716SmallVector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1717 SmallVector<ptrdiff_t> flatSparseIndices;
1718
1719 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1720 // as a 1-D index array.
1721 auto sparseIndices = getIndices();
1722 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1723 if (sparseIndices.isSplat()) {
1724 SmallVector<uint64_t, 8> indices(getType().getRank(),
1725 *sparseIndexValues.begin());
1726 flatSparseIndices.push_back(Elt: getFlattenedIndex(index: indices));
1727 return flatSparseIndices;
1728 }
1729
1730 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1731 auto numSparseIndices = sparseIndices.getType().getDimSize(idx: 0);
1732 size_t rank = getType().getRank();
1733 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1734 flatSparseIndices.push_back(Elt: getFlattenedIndex(
1735 index: {&*std::next(x: sparseIndexValues.begin(), n: i * rank), rank}));
1736 return flatSparseIndices;
1737}
1738
1739LogicalResult
1740SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1741 ShapedType type, DenseIntElementsAttr sparseIndices,
1742 DenseElementsAttr values) {
1743 ShapedType valuesType = values.getType();
1744 if (valuesType.getRank() != 1)
1745 return emitError() << "expected 1-d tensor for sparse element values";
1746
1747 // Verify the indices and values shape.
1748 ShapedType indicesType = sparseIndices.getType();
1749 auto emitShapeError = [&]() {
1750 return emitError() << "expected shape ([" << type.getShape()
1751 << "]); inferred shape of indices literal (["
1752 << indicesType.getShape()
1753 << "]); inferred shape of values literal (["
1754 << valuesType.getShape() << "])";
1755 };
1756 // Verify indices shape.
1757 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1758 if (indicesRank == 2) {
1759 if (indicesType.getDimSize(idx: 1) != static_cast<int64_t>(rank))
1760 return emitShapeError();
1761 } else if (indicesRank != 1 || rank != 1) {
1762 return emitShapeError();
1763 }
1764 // Verify the values shape.
1765 int64_t numSparseIndices = indicesType.getDimSize(idx: 0);
1766 if (numSparseIndices != valuesType.getDimSize(idx: 0))
1767 return emitShapeError();
1768
1769 // Verify that the sparse indices are within the value shape.
1770 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1771 return emitError()
1772 << "sparse index #" << indexNum
1773 << " is not contained within the value shape, with index=[" << index
1774 << "], and type=" << type;
1775 };
1776
1777 // Handle the case where the index values are a splat.
1778 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1779 if (sparseIndices.isSplat()) {
1780 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1781 if (!ElementsAttr::isValidIndex(type, index: indices))
1782 return emitIndexError(0, indices);
1783 return success();
1784 }
1785
1786 // Otherwise, reinterpret each index as an ArrayRef.
1787 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1788 ArrayRef<uint64_t> index(&*std::next(x: sparseIndexValues.begin(), n: i * rank),
1789 rank);
1790 if (!ElementsAttr::isValidIndex(type, index))
1791 return emitIndexError(i, index);
1792 }
1793
1794 return success();
1795}
1796
1797//===----------------------------------------------------------------------===//
1798// DistinctAttr
1799//===----------------------------------------------------------------------===//
1800
1801DistinctAttr DistinctAttr::create(Attribute referencedAttr) {
1802 return Base::get(ctx: referencedAttr.getContext(), args&: referencedAttr);
1803}
1804
1805Attribute DistinctAttr::getReferencedAttr() const {
1806 return getImpl()->referencedAttr;
1807}
1808
1809//===----------------------------------------------------------------------===//
1810// Attribute Utilities
1811//===----------------------------------------------------------------------===//
1812
1813AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1814 int64_t offset,
1815 MLIRContext *context) {
1816 AffineExpr expr;
1817 unsigned nSymbols = 0;
1818
1819 // AffineExpr for offset.
1820 // Static case.
1821 if (ShapedType::isStatic(dValue: offset)) {
1822 auto cst = getAffineConstantExpr(constant: offset, context);
1823 expr = cst;
1824 } else {
1825 // Dynamic case, new symbol for the offset.
1826 auto sym = getAffineSymbolExpr(position: nSymbols++, context);
1827 expr = sym;
1828 }
1829
1830 // AffineExpr for strides.
1831 for (const auto &en : llvm::enumerate(First&: strides)) {
1832 auto dim = en.index();
1833 auto stride = en.value();
1834 auto d = getAffineDimExpr(position: dim, context);
1835 AffineExpr mult;
1836 // Static case.
1837 if (ShapedType::isStatic(dValue: stride))
1838 mult = getAffineConstantExpr(constant: stride, context);
1839 else
1840 // Dynamic case, new symbol for each new stride.
1841 mult = getAffineSymbolExpr(position: nSymbols++, context);
1842 expr = expr + d * mult;
1843 }
1844
1845 return AffineMap::get(dimCount: strides.size(), symbolCount: nSymbols, result: expr);
1846}
1847

source code of mlir/lib/IR/BuiltinAttributes.cpp