1//===- AttributeTest.cpp - Attribute unit tests ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AsmState.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/BuiltinAttributes.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "gtest/gtest.h"
14#include <optional>
15
16#include "../../test/lib/Dialect/Test/TestDialect.h"
17
18using namespace mlir;
19using namespace mlir::detail;
20
21//===----------------------------------------------------------------------===//
22// DenseElementsAttr
23//===----------------------------------------------------------------------===//
24
25template <typename EltTy>
26static void testSplat(Type eltType, const EltTy &splatElt) {
27 RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
28
29 // Check that the generated splat is the same for 1 element and N elements.
30 DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
31 EXPECT_TRUE(splat.isSplat());
32
33 auto detectedSplat =
34 DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt}));
35 EXPECT_EQ(detectedSplat, splat);
36
37 for (auto newValue : detectedSplat.template getValues<EltTy>())
38 EXPECT_TRUE(newValue == splatElt);
39}
40
41namespace {
42TEST(DenseSplatTest, BoolSplat) {
43 MLIRContext context;
44 IntegerType boolTy = IntegerType::get(&context, 1);
45 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
46
47 // Check that splat is automatically detected for boolean values.
48 /// True.
49 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
50 EXPECT_TRUE(trueSplat.isSplat());
51 /// False.
52 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
53 EXPECT_TRUE(falseSplat.isSplat());
54 EXPECT_NE(falseSplat, trueSplat);
55
56 /// Detect and handle splat within 8 elements (bool values are bit-packed).
57 /// True.
58 auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true});
59 EXPECT_EQ(detectedSplat, trueSplat);
60 /// False.
61 detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false});
62 EXPECT_EQ(detectedSplat, falseSplat);
63}
64TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
65 MLIRContext context;
66 IntegerType boolTy = IntegerType::get(&context, 1);
67 RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
68
69 // Check that splat booleans properly round trip via the raw API.
70 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
71 EXPECT_TRUE(trueSplat.isSplat());
72 DenseElementsAttr trueSplatFromRaw =
73 DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData());
74 EXPECT_TRUE(trueSplatFromRaw.isSplat());
75
76 EXPECT_EQ(trueSplat, trueSplatFromRaw);
77}
78
79TEST(DenseSplatTest, BoolSplatSmall) {
80 MLIRContext context;
81 Builder builder(&context);
82
83 // Check that splats that don't fill entire byte are handled properly.
84 auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
85 std::vector<char> data{0b00001111};
86 auto trueSplatFromRaw =
87 DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
88 EXPECT_TRUE(trueSplatFromRaw.isSplat());
89 DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
90 EXPECT_EQ(trueSplat, trueSplatFromRaw);
91}
92
93TEST(DenseSplatTest, LargeBoolSplat) {
94 constexpr int64_t boolCount = 56;
95
96 MLIRContext context;
97 IntegerType boolTy = IntegerType::get(&context, 1);
98 RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
99
100 // Check that splat is automatically detected for boolean values.
101 /// True.
102 DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true);
103 DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false);
104 EXPECT_TRUE(trueSplat.isSplat());
105 EXPECT_TRUE(falseSplat.isSplat());
106
107 /// Detect that the large boolean arrays are properly splatted.
108 /// True.
109 SmallVector<bool, 64> trueValues(boolCount, true);
110 auto detectedSplat = DenseElementsAttr::get(shape, trueValues);
111 EXPECT_EQ(detectedSplat, trueSplat);
112 /// False.
113 SmallVector<bool, 64> falseValues(boolCount, false);
114 detectedSplat = DenseElementsAttr::get(shape, falseValues);
115 EXPECT_EQ(detectedSplat, falseSplat);
116}
117
118TEST(DenseSplatTest, BoolNonSplat) {
119 MLIRContext context;
120 IntegerType boolTy = IntegerType::get(&context, 1);
121 RankedTensorType shape = RankedTensorType::get({6}, boolTy);
122
123 // Check that we properly handle non-splat values.
124 DenseElementsAttr nonSplat =
125 DenseElementsAttr::get(shape, {false, false, true, false, false, true});
126 EXPECT_FALSE(nonSplat.isSplat());
127}
128
129TEST(DenseSplatTest, OddIntSplat) {
130 // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
131 MLIRContext context;
132 constexpr size_t intWidth = 19;
133 IntegerType intTy = IntegerType::get(&context, intWidth);
134 APInt value(intWidth, 10);
135
136 testSplat(intTy, value);
137}
138
139TEST(DenseSplatTest, Int32Splat) {
140 MLIRContext context;
141 IntegerType intTy = IntegerType::get(&context, 32);
142 int value = 64;
143
144 testSplat(intTy, value);
145}
146
147TEST(DenseSplatTest, IntAttrSplat) {
148 MLIRContext context;
149 IntegerType intTy = IntegerType::get(&context, 85);
150 Attribute value = IntegerAttr::get(intTy, 109);
151
152 testSplat(intTy, value);
153}
154
155TEST(DenseSplatTest, F32Splat) {
156 MLIRContext context;
157 FloatType floatTy = FloatType::getF32(ctx: &context);
158 float value = 10.0;
159
160 testSplat(eltType: floatTy, splatElt: value);
161}
162
163TEST(DenseSplatTest, F64Splat) {
164 MLIRContext context;
165 FloatType floatTy = FloatType::getF64(ctx: &context);
166 double value = 10.0;
167
168 testSplat(eltType: floatTy, splatElt: APFloat(value));
169}
170
171TEST(DenseSplatTest, FloatAttrSplat) {
172 MLIRContext context;
173 FloatType floatTy = FloatType::getF32(ctx: &context);
174 Attribute value = FloatAttr::get(floatTy, 10.0);
175
176 testSplat(eltType: floatTy, splatElt: value);
177}
178
179TEST(DenseSplatTest, BF16Splat) {
180 MLIRContext context;
181 FloatType floatTy = FloatType::getBF16(ctx: &context);
182 Attribute value = FloatAttr::get(floatTy, 10.0);
183
184 testSplat(eltType: floatTy, splatElt: value);
185}
186
187TEST(DenseSplatTest, StringSplat) {
188 MLIRContext context;
189 context.allowUnregisteredDialects();
190 Type stringType =
191 OpaqueType::get(StringAttr::get(&context, "test"), "string");
192 StringRef value = "test-string";
193 testSplat(eltType: stringType, splatElt: value);
194}
195
196TEST(DenseSplatTest, StringAttrSplat) {
197 MLIRContext context;
198 context.allowUnregisteredDialects();
199 Type stringType =
200 OpaqueType::get(StringAttr::get(&context, "test"), "string");
201 Attribute stringAttr = StringAttr::get("test-string", stringType);
202 testSplat(eltType: stringType, splatElt: stringAttr);
203}
204
205TEST(DenseComplexTest, ComplexFloatSplat) {
206 MLIRContext context;
207 ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
208 std::complex<float> value(10.0, 15.0);
209 testSplat(complexType, value);
210}
211
212TEST(DenseComplexTest, ComplexIntSplat) {
213 MLIRContext context;
214 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
215 std::complex<int64_t> value(10, 15);
216 testSplat(complexType, value);
217}
218
219TEST(DenseComplexTest, ComplexAPFloatSplat) {
220 MLIRContext context;
221 ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
222 std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
223 testSplat(complexType, value);
224}
225
226TEST(DenseComplexTest, ComplexAPIntSplat) {
227 MLIRContext context;
228 ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64));
229 std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
230 testSplat(complexType, value);
231}
232
233TEST(DenseScalarTest, ExtractZeroRankElement) {
234 MLIRContext context;
235 const int elementValue = 12;
236 IntegerType intTy = IntegerType::get(&context, 32);
237 Attribute value = IntegerAttr::get(intTy, elementValue);
238 RankedTensorType shape = RankedTensorType::get({}, intTy);
239
240 auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}));
241 EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
242}
243
244TEST(DenseSplatMapValuesTest, I32ToTrue) {
245 MLIRContext context;
246 const int elementValue = 12;
247 IntegerType boolTy = IntegerType::get(&context, 1);
248 IntegerType intTy = IntegerType::get(&context, 32);
249 RankedTensorType shape = RankedTensorType::get({4}, intTy);
250
251 auto attr =
252 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
253 .mapValues(boolTy, [](const APInt &x) {
254 return x.isZero() ? APInt::getZero(numBits: 1) : APInt::getAllOnes(numBits: 1);
255 });
256 EXPECT_EQ(attr.getNumElements(), 4);
257 EXPECT_TRUE(attr.isSplat());
258 EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue());
259}
260
261TEST(DenseSplatMapValuesTest, I32ToFalse) {
262 MLIRContext context;
263 const int elementValue = 0;
264 IntegerType boolTy = IntegerType::get(&context, 1);
265 IntegerType intTy = IntegerType::get(&context, 32);
266 RankedTensorType shape = RankedTensorType::get({4}, intTy);
267
268 auto attr =
269 DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue}))
270 .mapValues(boolTy, [](const APInt &x) {
271 return x.isZero() ? APInt::getZero(numBits: 1) : APInt::getAllOnes(numBits: 1);
272 });
273 EXPECT_EQ(attr.getNumElements(), 4);
274 EXPECT_TRUE(attr.isSplat());
275 EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue());
276}
277} // namespace
278
279//===----------------------------------------------------------------------===//
280// DenseResourceElementsAttr
281//===----------------------------------------------------------------------===//
282
283template <typename AttrT, typename T>
284static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
285 Type elementType) {
286 auto type = RankedTensorType::get(data.size(), elementType);
287 auto attr = AttrT::get(type, "resource",
288 UnmanagedAsmResourceBlob::allocateInferAlign(data));
289
290 // Check that we can access and iterate the data properly.
291 std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
292 EXPECT_TRUE(attrData.has_value());
293 EXPECT_EQ(*attrData, data);
294
295 // Check that we cast to this attribute when possible.
296 Attribute genericAttr = attr;
297 EXPECT_TRUE(isa<AttrT>(genericAttr));
298}
299template <typename AttrT, typename T>
300static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
301 T data[] = {0, 1, 2};
302 checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data),
303 builder.getIntegerType(intWidth));
304}
305
306namespace {
307TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
308 MLIRContext context;
309 Builder builder(&context);
310
311 // Bool
312 bool boolData[] = {true, false, true};
313 checkNativeAccess<DenseBoolResourceElementsAttr>(
314 &context, llvm::ArrayRef(boolData), builder.getI1Type());
315
316 // Unsigned integers
317 checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, intWidth: 8);
318 checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, intWidth: 16);
319 checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, intWidth: 32);
320 checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, intWidth: 64);
321
322 // Signed integers
323 checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, intWidth: 8);
324 checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, intWidth: 16);
325 checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, intWidth: 32);
326 checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, intWidth: 64);
327
328 // Float
329 float floatData[] = {0, 1, 2};
330 checkNativeAccess<DenseF32ResourceElementsAttr>(
331 ctx: &context, data: llvm::ArrayRef(floatData), elementType: builder.getF32Type());
332
333 // Double
334 double doubleData[] = {0, 1, 2};
335 checkNativeAccess<DenseF64ResourceElementsAttr>(
336 ctx: &context, data: llvm::ArrayRef(doubleData), elementType: builder.getF64Type());
337}
338
339TEST(DenseResourceElementsAttrTest, CheckNoCast) {
340 MLIRContext context;
341 Builder builder(&context);
342
343 // Create a i32 attribute.
344 ArrayRef<uint32_t> data;
345 auto type = RankedTensorType::get(data.size(), builder.getI32Type());
346 Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
347 type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
348
349 EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr));
350 EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr));
351 EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr));
352}
353
354TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
355 MLIRContext context;
356 Builder builder(&context);
357
358 // Create a bool attribute with data of the incorrect type.
359 ArrayRef<uint32_t> data;
360 auto type = RankedTensorType::get(data.size(), builder.getI32Type());
361 EXPECT_DEBUG_DEATH(
362 {
363 DenseBoolResourceElementsAttr::get(
364 type, "resource",
365 UnmanagedAsmResourceBlob::allocateInferAlign(data));
366 },
367 "alignment mismatch between expected alignment and blob alignment");
368}
369
370TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
371 MLIRContext context;
372 Builder builder(&context);
373
374 // Create a bool attribute with incorrect type.
375 ArrayRef<bool> data;
376 auto type = RankedTensorType::get(data.size(), builder.getI32Type());
377 EXPECT_DEBUG_DEATH(
378 {
379 DenseBoolResourceElementsAttr::get(
380 type, "resource",
381 UnmanagedAsmResourceBlob::allocateInferAlign(data));
382 },
383 "invalid shape element type for provided type `T`");
384}
385} // namespace
386
387//===----------------------------------------------------------------------===//
388// SparseElementsAttr
389//===----------------------------------------------------------------------===//
390
391namespace {
392TEST(SparseElementsAttrTest, GetZero) {
393 MLIRContext context;
394 context.allowUnregisteredDialects();
395
396 IntegerType intTy = IntegerType::get(&context, 32);
397 FloatType floatTy = FloatType::getF32(ctx: &context);
398 Type stringTy = OpaqueType::get(StringAttr::get(&context, "test"), "string");
399
400 ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy);
401 ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy);
402 ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy);
403
404 auto indicesType =
405 RankedTensorType::get({1, 2}, IntegerType::get(&context, 64));
406 auto indices =
407 DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
408
409 RankedTensorType intValueTy = RankedTensorType::get({1}, intTy);
410 auto intValue = DenseIntElementsAttr::get(intValueTy, {1});
411
412 RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy);
413 auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f});
414
415 RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy);
416 auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo")});
417
418 auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue);
419 auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue);
420 auto sparseString =
421 SparseElementsAttr::get(tensorString, indices, stringValue);
422
423 // Only index (0, 0) contains an element, others are supposed to return
424 // the zero/empty value.
425 auto zeroIntValue =
426 cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]);
427 EXPECT_EQ(zeroIntValue.getInt(), 0);
428 EXPECT_TRUE(zeroIntValue.getType() == intTy);
429
430 auto zeroFloatValue =
431 cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]);
432 EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f);
433 EXPECT_TRUE(zeroFloatValue.getType() == floatTy);
434
435 auto zeroStringValue =
436 cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]);
437 EXPECT_TRUE(zeroStringValue.empty());
438 EXPECT_TRUE(zeroStringValue.getType() == stringTy);
439}
440
441//===----------------------------------------------------------------------===//
442// SubElements
443//===----------------------------------------------------------------------===//
444
445TEST(SubElementTest, Nested) {
446 MLIRContext context;
447 Builder builder(&context);
448
449 BoolAttr trueAttr = builder.getBoolAttr(value: true);
450 BoolAttr falseAttr = builder.getBoolAttr(value: false);
451 ArrayAttr boolArrayAttr =
452 builder.getArrayAttr({trueAttr, falseAttr, trueAttr});
453 StringAttr strAttr = builder.getStringAttr("array");
454 DictionaryAttr dictAttr =
455 builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr));
456
457 SmallVector<Attribute> subAttrs;
458 dictAttr.walk([&](Attribute attr) { subAttrs.push_back(Elt: attr); });
459 // Note that trueAttr appears only once, identical subattributes are skipped.
460 EXPECT_EQ(llvm::ArrayRef(subAttrs),
461 ArrayRef<Attribute>(
462 {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
463}
464
465// Test how many times we call copy-ctor when building an attribute.
466TEST(CopyCountAttr, CopyCount) {
467 MLIRContext context;
468 context.loadDialect<test::TestDialect>();
469
470 test::CopyCount::counter = 0;
471 test::CopyCount copyCount("hello");
472 test::TestCopyCountAttr::get(&context, std::move(copyCount));
473 int counter1 = test::CopyCount::counter;
474 test::CopyCount::counter = 0;
475 test::TestCopyCountAttr::get(&context, std::move(copyCount));
476#ifndef NDEBUG
477 // One verification enabled only in assert-mode requires a copy.
478 EXPECT_EQ(counter1, 1);
479 EXPECT_EQ(test::CopyCount::counter, 1);
480#else
481 EXPECT_EQ(counter1, 0);
482 EXPECT_EQ(test::CopyCount::counter, 0);
483#endif
484}
485
486// Test stripped printing using test dialect attribute.
487TEST(CopyCountAttr, PrintStripped) {
488 MLIRContext context;
489 context.loadDialect<test::TestDialect>();
490 // Doesn't matter which dialect attribute is used, just chose TestCopyCount
491 // given proximity.
492 test::CopyCount::counter = 0;
493 test::CopyCount copyCount("hello");
494 Attribute res = test::TestCopyCountAttr::get(&context, std::move(copyCount));
495
496 std::string str;
497 llvm::raw_string_ostream os(str);
498 os << "|" << res << "|";
499 res.printStripped(os&: os << "[");
500 os << "]";
501 EXPECT_EQ(os.str(), "|#test.copy_count<hello>|[copy_count<hello>]");
502}
503
504} // namespace
505

source code of mlir/unittests/IR/AttributeTest.cpp