| 1 | //===- AttributeTest.cpp - Attribute unit tests ---------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/AsmState.h" |
| 10 | #include "mlir/IR/Builders.h" |
| 11 | #include "mlir/IR/BuiltinAttributes.h" |
| 12 | #include "mlir/IR/BuiltinTypes.h" |
| 13 | #include "gtest/gtest.h" |
| 14 | #include <optional> |
| 15 | |
| 16 | #include "../../test/lib/Dialect/Test/TestDialect.h" |
| 17 | |
| 18 | using namespace mlir; |
| 19 | using namespace mlir::detail; |
| 20 | |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | // DenseElementsAttr |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | |
| 25 | template <typename EltTy> |
| 26 | static void testSplat(Type eltType, const EltTy &splatElt) { |
| 27 | RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); |
| 28 | |
| 29 | // Check that the generated splat is the same for 1 element and N elements. |
| 30 | DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt); |
| 31 | EXPECT_TRUE(splat.isSplat()); |
| 32 | |
| 33 | auto detectedSplat = |
| 34 | DenseElementsAttr::get(shape, llvm::ArrayRef({splatElt, splatElt})); |
| 35 | EXPECT_EQ(detectedSplat, splat); |
| 36 | |
| 37 | for (auto newValue : detectedSplat.template getValues<EltTy>()) |
| 38 | EXPECT_TRUE(newValue == splatElt); |
| 39 | } |
| 40 | |
| 41 | namespace { |
| 42 | TEST(DenseSplatTest, BoolSplat) { |
| 43 | MLIRContext context; |
| 44 | IntegerType boolTy = IntegerType::get(&context, 1); |
| 45 | RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); |
| 46 | |
| 47 | // Check that splat is automatically detected for boolean values. |
| 48 | /// True. |
| 49 | DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); |
| 50 | EXPECT_TRUE(trueSplat.isSplat()); |
| 51 | /// False. |
| 52 | DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); |
| 53 | EXPECT_TRUE(falseSplat.isSplat()); |
| 54 | EXPECT_NE(falseSplat, trueSplat); |
| 55 | |
| 56 | /// Detect and handle splat within 8 elements (bool values are bit-packed). |
| 57 | /// True. |
| 58 | auto detectedSplat = DenseElementsAttr::get(shape, {true, true, true, true}); |
| 59 | EXPECT_EQ(detectedSplat, trueSplat); |
| 60 | /// False. |
| 61 | detectedSplat = DenseElementsAttr::get(shape, {false, false, false, false}); |
| 62 | EXPECT_EQ(detectedSplat, falseSplat); |
| 63 | } |
| 64 | TEST(DenseSplatTest, BoolSplatRawRoundtrip) { |
| 65 | MLIRContext context; |
| 66 | IntegerType boolTy = IntegerType::get(&context, 1); |
| 67 | RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy); |
| 68 | |
| 69 | // Check that splat booleans properly round trip via the raw API. |
| 70 | DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); |
| 71 | EXPECT_TRUE(trueSplat.isSplat()); |
| 72 | DenseElementsAttr trueSplatFromRaw = |
| 73 | DenseElementsAttr::getFromRawBuffer(shape, trueSplat.getRawData()); |
| 74 | EXPECT_TRUE(trueSplatFromRaw.isSplat()); |
| 75 | |
| 76 | EXPECT_EQ(trueSplat, trueSplatFromRaw); |
| 77 | } |
| 78 | |
| 79 | TEST(DenseSplatTest, BoolSplatSmall) { |
| 80 | MLIRContext context; |
| 81 | Builder builder(&context); |
| 82 | |
| 83 | // Check that splats that don't fill entire byte are handled properly. |
| 84 | auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); |
| 85 | std::vector<char> data{0b00001111}; |
| 86 | auto trueSplatFromRaw = |
| 87 | DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); |
| 88 | EXPECT_TRUE(trueSplatFromRaw.isSplat()); |
| 89 | DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); |
| 90 | EXPECT_EQ(trueSplat, trueSplatFromRaw); |
| 91 | } |
| 92 | |
| 93 | TEST(DenseSplatTest, LargeBoolSplat) { |
| 94 | constexpr int64_t boolCount = 56; |
| 95 | |
| 96 | MLIRContext context; |
| 97 | IntegerType boolTy = IntegerType::get(&context, 1); |
| 98 | RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy); |
| 99 | |
| 100 | // Check that splat is automatically detected for boolean values. |
| 101 | /// True. |
| 102 | DenseElementsAttr trueSplat = DenseElementsAttr::get(shape, true); |
| 103 | DenseElementsAttr falseSplat = DenseElementsAttr::get(shape, false); |
| 104 | EXPECT_TRUE(trueSplat.isSplat()); |
| 105 | EXPECT_TRUE(falseSplat.isSplat()); |
| 106 | |
| 107 | /// Detect that the large boolean arrays are properly splatted. |
| 108 | /// True. |
| 109 | SmallVector<bool, 64> trueValues(boolCount, true); |
| 110 | auto detectedSplat = DenseElementsAttr::get(shape, trueValues); |
| 111 | EXPECT_EQ(detectedSplat, trueSplat); |
| 112 | /// False. |
| 113 | SmallVector<bool, 64> falseValues(boolCount, false); |
| 114 | detectedSplat = DenseElementsAttr::get(shape, falseValues); |
| 115 | EXPECT_EQ(detectedSplat, falseSplat); |
| 116 | } |
| 117 | |
| 118 | TEST(DenseSplatTest, BoolNonSplat) { |
| 119 | MLIRContext context; |
| 120 | IntegerType boolTy = IntegerType::get(&context, 1); |
| 121 | RankedTensorType shape = RankedTensorType::get({6}, boolTy); |
| 122 | |
| 123 | // Check that we properly handle non-splat values. |
| 124 | DenseElementsAttr nonSplat = |
| 125 | DenseElementsAttr::get(shape, {false, false, true, false, false, true}); |
| 126 | EXPECT_FALSE(nonSplat.isSplat()); |
| 127 | } |
| 128 | |
| 129 | TEST(DenseSplatTest, OddIntSplat) { |
| 130 | // Test detecting a splat with an odd(non 8-bit) integer bitwidth. |
| 131 | MLIRContext context; |
| 132 | constexpr size_t intWidth = 19; |
| 133 | IntegerType intTy = IntegerType::get(&context, intWidth); |
| 134 | APInt value(intWidth, 10); |
| 135 | |
| 136 | testSplat(intTy, value); |
| 137 | } |
| 138 | |
| 139 | TEST(DenseSplatTest, Int32Splat) { |
| 140 | MLIRContext context; |
| 141 | IntegerType intTy = IntegerType::get(&context, 32); |
| 142 | int value = 64; |
| 143 | |
| 144 | testSplat(intTy, value); |
| 145 | } |
| 146 | |
| 147 | TEST(DenseSplatTest, IntAttrSplat) { |
| 148 | MLIRContext context; |
| 149 | IntegerType intTy = IntegerType::get(&context, 85); |
| 150 | Attribute value = IntegerAttr::get(intTy, 109); |
| 151 | |
| 152 | testSplat(intTy, value); |
| 153 | } |
| 154 | |
| 155 | TEST(DenseSplatTest, F32Splat) { |
| 156 | MLIRContext context; |
| 157 | FloatType floatTy = Float32Type::get(&context); |
| 158 | float value = 10.0; |
| 159 | |
| 160 | testSplat(floatTy, value); |
| 161 | } |
| 162 | |
| 163 | TEST(DenseSplatTest, F64Splat) { |
| 164 | MLIRContext context; |
| 165 | FloatType floatTy = Float64Type::get(&context); |
| 166 | double value = 10.0; |
| 167 | |
| 168 | testSplat(floatTy, APFloat(value)); |
| 169 | } |
| 170 | |
| 171 | TEST(DenseSplatTest, FloatAttrSplat) { |
| 172 | MLIRContext context; |
| 173 | FloatType floatTy = Float32Type::get(&context); |
| 174 | Attribute value = FloatAttr::get(floatTy, 10.0); |
| 175 | |
| 176 | testSplat(floatTy, value); |
| 177 | } |
| 178 | |
| 179 | TEST(DenseSplatTest, BF16Splat) { |
| 180 | MLIRContext context; |
| 181 | FloatType floatTy = BFloat16Type::get(&context); |
| 182 | Attribute value = FloatAttr::get(floatTy, 10.0); |
| 183 | |
| 184 | testSplat(floatTy, value); |
| 185 | } |
| 186 | |
| 187 | TEST(DenseSplatTest, StringSplat) { |
| 188 | MLIRContext context; |
| 189 | context.allowUnregisteredDialects(); |
| 190 | Type stringType = |
| 191 | OpaqueType::get(StringAttr::get(&context, "test" ), "string" ); |
| 192 | StringRef value = "test-string" ; |
| 193 | testSplat(eltType: stringType, splatElt: value); |
| 194 | } |
| 195 | |
| 196 | TEST(DenseSplatTest, StringAttrSplat) { |
| 197 | MLIRContext context; |
| 198 | context.allowUnregisteredDialects(); |
| 199 | Type stringType = |
| 200 | OpaqueType::get(StringAttr::get(&context, "test" ), "string" ); |
| 201 | Attribute stringAttr = StringAttr::get("test-string" , stringType); |
| 202 | testSplat(eltType: stringType, splatElt: stringAttr); |
| 203 | } |
| 204 | |
| 205 | TEST(DenseComplexTest, ComplexFloatSplat) { |
| 206 | MLIRContext context; |
| 207 | ComplexType complexType = ComplexType::get(Float32Type::get(&context)); |
| 208 | std::complex<float> value(10.0, 15.0); |
| 209 | testSplat(complexType, value); |
| 210 | } |
| 211 | |
| 212 | TEST(DenseComplexTest, ComplexIntSplat) { |
| 213 | MLIRContext context; |
| 214 | ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); |
| 215 | std::complex<int64_t> value(10, 15); |
| 216 | testSplat(complexType, value); |
| 217 | } |
| 218 | |
| 219 | TEST(DenseComplexTest, ComplexAPFloatSplat) { |
| 220 | MLIRContext context; |
| 221 | ComplexType complexType = ComplexType::get(Float32Type::get(&context)); |
| 222 | std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f)); |
| 223 | testSplat(complexType, value); |
| 224 | } |
| 225 | |
| 226 | TEST(DenseComplexTest, ComplexAPIntSplat) { |
| 227 | MLIRContext context; |
| 228 | ComplexType complexType = ComplexType::get(IntegerType::get(&context, 64)); |
| 229 | std::complex<APInt> value(APInt(64, 10), APInt(64, 15)); |
| 230 | testSplat(complexType, value); |
| 231 | } |
| 232 | |
| 233 | TEST(DenseScalarTest, ExtractZeroRankElement) { |
| 234 | MLIRContext context; |
| 235 | const int elementValue = 12; |
| 236 | IntegerType intTy = IntegerType::get(&context, 32); |
| 237 | Attribute value = IntegerAttr::get(intTy, elementValue); |
| 238 | RankedTensorType shape = RankedTensorType::get({}, intTy); |
| 239 | |
| 240 | auto attr = DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})); |
| 241 | EXPECT_TRUE(attr.getValues<Attribute>()[0] == value); |
| 242 | } |
| 243 | |
| 244 | TEST(DenseSplatMapValuesTest, I32ToTrue) { |
| 245 | MLIRContext context; |
| 246 | const int elementValue = 12; |
| 247 | IntegerType boolTy = IntegerType::get(&context, 1); |
| 248 | IntegerType intTy = IntegerType::get(&context, 32); |
| 249 | RankedTensorType shape = RankedTensorType::get({4}, intTy); |
| 250 | |
| 251 | auto attr = |
| 252 | DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) |
| 253 | .mapValues(boolTy, [](const APInt &x) { |
| 254 | return x.isZero() ? APInt::getZero(numBits: 1) : APInt::getAllOnes(numBits: 1); |
| 255 | }); |
| 256 | EXPECT_EQ(attr.getNumElements(), 4); |
| 257 | EXPECT_TRUE(attr.isSplat()); |
| 258 | EXPECT_TRUE(attr.getSplatValue<BoolAttr>().getValue()); |
| 259 | } |
| 260 | |
| 261 | TEST(DenseSplatMapValuesTest, I32ToFalse) { |
| 262 | MLIRContext context; |
| 263 | const int elementValue = 0; |
| 264 | IntegerType boolTy = IntegerType::get(&context, 1); |
| 265 | IntegerType intTy = IntegerType::get(&context, 32); |
| 266 | RankedTensorType shape = RankedTensorType::get({4}, intTy); |
| 267 | |
| 268 | auto attr = |
| 269 | DenseElementsAttr::get(shape, llvm::ArrayRef({elementValue})) |
| 270 | .mapValues(boolTy, [](const APInt &x) { |
| 271 | return x.isZero() ? APInt::getZero(numBits: 1) : APInt::getAllOnes(numBits: 1); |
| 272 | }); |
| 273 | EXPECT_EQ(attr.getNumElements(), 4); |
| 274 | EXPECT_TRUE(attr.isSplat()); |
| 275 | EXPECT_FALSE(attr.getSplatValue<BoolAttr>().getValue()); |
| 276 | } |
| 277 | } // namespace |
| 278 | |
| 279 | //===----------------------------------------------------------------------===// |
| 280 | // DenseResourceElementsAttr |
| 281 | //===----------------------------------------------------------------------===// |
| 282 | |
| 283 | template <typename AttrT, typename T> |
| 284 | static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data, |
| 285 | Type elementType) { |
| 286 | auto type = RankedTensorType::get(data.size(), elementType); |
| 287 | auto attr = AttrT::get(type, "resource" , |
| 288 | UnmanagedAsmResourceBlob::allocateInferAlign(data)); |
| 289 | |
| 290 | // Check that we can access and iterate the data properly. |
| 291 | std::optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef(); |
| 292 | EXPECT_TRUE(attrData.has_value()); |
| 293 | EXPECT_EQ(*attrData, data); |
| 294 | |
| 295 | // Check that we cast to this attribute when possible. |
| 296 | Attribute genericAttr = attr; |
| 297 | EXPECT_TRUE(isa<AttrT>(genericAttr)); |
| 298 | } |
| 299 | template <typename AttrT, typename T> |
| 300 | static void checkNativeIntAccess(Builder &builder, size_t intWidth) { |
| 301 | T data[] = {0, 1, 2}; |
| 302 | checkNativeAccess<AttrT, T>(builder.getContext(), llvm::ArrayRef(data), |
| 303 | builder.getIntegerType(intWidth)); |
| 304 | } |
| 305 | |
| 306 | namespace { |
| 307 | TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { |
| 308 | MLIRContext context; |
| 309 | Builder builder(&context); |
| 310 | |
| 311 | // Bool |
| 312 | bool boolData[] = {true, false, true}; |
| 313 | checkNativeAccess<DenseBoolResourceElementsAttr>( |
| 314 | &context, llvm::ArrayRef(boolData), builder.getI1Type()); |
| 315 | |
| 316 | // Unsigned integers |
| 317 | checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, intWidth: 8); |
| 318 | checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, intWidth: 16); |
| 319 | checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, intWidth: 32); |
| 320 | checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, intWidth: 64); |
| 321 | |
| 322 | // Signed integers |
| 323 | checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, intWidth: 8); |
| 324 | checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, intWidth: 16); |
| 325 | checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, intWidth: 32); |
| 326 | checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, intWidth: 64); |
| 327 | |
| 328 | // Float |
| 329 | float floatData[] = {0, 1, 2}; |
| 330 | checkNativeAccess<DenseF32ResourceElementsAttr>( |
| 331 | &context, llvm::ArrayRef(floatData), builder.getF32Type()); |
| 332 | |
| 333 | // Double |
| 334 | double doubleData[] = {0, 1, 2}; |
| 335 | checkNativeAccess<DenseF64ResourceElementsAttr>( |
| 336 | &context, llvm::ArrayRef(doubleData), builder.getF64Type()); |
| 337 | } |
| 338 | |
| 339 | TEST(DenseResourceElementsAttrTest, CheckNoCast) { |
| 340 | MLIRContext context; |
| 341 | Builder builder(&context); |
| 342 | |
| 343 | // Create a i32 attribute. |
| 344 | ArrayRef<uint32_t> data; |
| 345 | auto type = RankedTensorType::get(data.size(), builder.getI32Type()); |
| 346 | Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( |
| 347 | type, "resource" , UnmanagedAsmResourceBlob::allocateInferAlign(data)); |
| 348 | |
| 349 | EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); |
| 350 | EXPECT_FALSE(isa<DenseF32ResourceElementsAttr>(i32ResourceAttr)); |
| 351 | EXPECT_FALSE(isa<DenseBoolResourceElementsAttr>(i32ResourceAttr)); |
| 352 | } |
| 353 | |
| 354 | TEST(DenseResourceElementsAttrTest, CheckNotMutableAllocateAndCopy) { |
| 355 | MLIRContext context; |
| 356 | Builder builder(&context); |
| 357 | |
| 358 | // Create a i32 attribute. |
| 359 | std::vector<int32_t> data = {10, 20, 30}; |
| 360 | auto type = RankedTensorType::get(data.size(), builder.getI32Type()); |
| 361 | Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( |
| 362 | type, "resource" , |
| 363 | HeapAsmResourceBlob::allocateAndCopyInferAlign<int32_t>( |
| 364 | data, /*is_mutable=*/false)); |
| 365 | |
| 366 | EXPECT_TRUE(isa<DenseI32ResourceElementsAttr>(i32ResourceAttr)); |
| 367 | } |
| 368 | |
| 369 | TEST(DenseResourceElementsAttrTest, CheckInvalidData) { |
| 370 | MLIRContext context; |
| 371 | Builder builder(&context); |
| 372 | |
| 373 | // Create a bool attribute with data of the incorrect type. |
| 374 | ArrayRef<uint32_t> data; |
| 375 | auto type = RankedTensorType::get(data.size(), builder.getI32Type()); |
| 376 | EXPECT_DEBUG_DEATH( |
| 377 | { |
| 378 | DenseBoolResourceElementsAttr::get( |
| 379 | type, "resource" , |
| 380 | UnmanagedAsmResourceBlob::allocateInferAlign(data)); |
| 381 | }, |
| 382 | "alignment mismatch between expected alignment and blob alignment" ); |
| 383 | } |
| 384 | |
| 385 | TEST(DenseResourceElementsAttrTest, CheckInvalidType) { |
| 386 | MLIRContext context; |
| 387 | Builder builder(&context); |
| 388 | |
| 389 | // Create a bool attribute with incorrect type. |
| 390 | ArrayRef<bool> data; |
| 391 | auto type = RankedTensorType::get(data.size(), builder.getI32Type()); |
| 392 | EXPECT_DEBUG_DEATH( |
| 393 | { |
| 394 | DenseBoolResourceElementsAttr::get( |
| 395 | type, "resource" , |
| 396 | UnmanagedAsmResourceBlob::allocateInferAlign(data)); |
| 397 | }, |
| 398 | "invalid shape element type for provided type `T`" ); |
| 399 | } |
| 400 | } // namespace |
| 401 | |
| 402 | //===----------------------------------------------------------------------===// |
| 403 | // SparseElementsAttr |
| 404 | //===----------------------------------------------------------------------===// |
| 405 | |
| 406 | namespace { |
| 407 | TEST(SparseElementsAttrTest, GetZero) { |
| 408 | MLIRContext context; |
| 409 | context.allowUnregisteredDialects(); |
| 410 | |
| 411 | IntegerType intTy = IntegerType::get(&context, 32); |
| 412 | FloatType floatTy = Float32Type::get(&context); |
| 413 | Type stringTy = OpaqueType::get(StringAttr::get(&context, "test" ), "string" ); |
| 414 | |
| 415 | ShapedType tensorI32 = RankedTensorType::get({2, 2}, intTy); |
| 416 | ShapedType tensorF32 = RankedTensorType::get({2, 2}, floatTy); |
| 417 | ShapedType tensorString = RankedTensorType::get({2, 2}, stringTy); |
| 418 | |
| 419 | auto indicesType = |
| 420 | RankedTensorType::get({1, 2}, IntegerType::get(&context, 64)); |
| 421 | auto indices = |
| 422 | DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); |
| 423 | |
| 424 | RankedTensorType intValueTy = RankedTensorType::get({1}, intTy); |
| 425 | auto intValue = DenseIntElementsAttr::get(intValueTy, {1}); |
| 426 | |
| 427 | RankedTensorType floatValueTy = RankedTensorType::get({1}, floatTy); |
| 428 | auto floatValue = DenseFPElementsAttr::get(floatValueTy, {1.0f}); |
| 429 | |
| 430 | RankedTensorType stringValueTy = RankedTensorType::get({1}, stringTy); |
| 431 | auto stringValue = DenseElementsAttr::get(stringValueTy, {StringRef("foo" )}); |
| 432 | |
| 433 | auto sparseInt = SparseElementsAttr::get(tensorI32, indices, intValue); |
| 434 | auto sparseFloat = SparseElementsAttr::get(tensorF32, indices, floatValue); |
| 435 | auto sparseString = |
| 436 | SparseElementsAttr::get(tensorString, indices, stringValue); |
| 437 | |
| 438 | // Only index (0, 0) contains an element, others are supposed to return |
| 439 | // the zero/empty value. |
| 440 | auto zeroIntValue = |
| 441 | cast<IntegerAttr>(sparseInt.getValues<Attribute>()[{1, 1}]); |
| 442 | EXPECT_EQ(zeroIntValue.getInt(), 0); |
| 443 | EXPECT_TRUE(zeroIntValue.getType() == intTy); |
| 444 | |
| 445 | auto zeroFloatValue = |
| 446 | cast<FloatAttr>(sparseFloat.getValues<Attribute>()[{1, 1}]); |
| 447 | EXPECT_EQ(zeroFloatValue.getValueAsDouble(), 0.0f); |
| 448 | EXPECT_TRUE(zeroFloatValue.getType() == floatTy); |
| 449 | |
| 450 | auto zeroStringValue = |
| 451 | cast<StringAttr>(sparseString.getValues<Attribute>()[{1, 1}]); |
| 452 | EXPECT_TRUE(zeroStringValue.empty()); |
| 453 | EXPECT_TRUE(zeroStringValue.getType() == stringTy); |
| 454 | } |
| 455 | |
| 456 | //===----------------------------------------------------------------------===// |
| 457 | // SubElements |
| 458 | //===----------------------------------------------------------------------===// |
| 459 | |
| 460 | TEST(SubElementTest, Nested) { |
| 461 | MLIRContext context; |
| 462 | Builder builder(&context); |
| 463 | |
| 464 | BoolAttr trueAttr = builder.getBoolAttr(value: true); |
| 465 | BoolAttr falseAttr = builder.getBoolAttr(value: false); |
| 466 | ArrayAttr boolArrayAttr = |
| 467 | builder.getArrayAttr({trueAttr, falseAttr, trueAttr}); |
| 468 | StringAttr strAttr = builder.getStringAttr("array" ); |
| 469 | DictionaryAttr dictAttr = |
| 470 | builder.getDictionaryAttr(builder.getNamedAttr(strAttr, boolArrayAttr)); |
| 471 | |
| 472 | SmallVector<Attribute> subAttrs; |
| 473 | dictAttr.walk([&](Attribute attr) { subAttrs.push_back(Elt: attr); }); |
| 474 | // Note that trueAttr appears only once, identical subattributes are skipped. |
| 475 | EXPECT_EQ(llvm::ArrayRef(subAttrs), |
| 476 | ArrayRef<Attribute>( |
| 477 | {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); |
| 478 | } |
| 479 | |
| 480 | // Test how many times we call copy-ctor when building an attribute. |
| 481 | TEST(CopyCountAttr, CopyCount) { |
| 482 | MLIRContext context; |
| 483 | context.loadDialect<test::TestDialect>(); |
| 484 | |
| 485 | test::CopyCount::counter = 0; |
| 486 | test::CopyCount copyCount("hello" ); |
| 487 | test::TestCopyCountAttr::get(&context, std::move(copyCount)); |
| 488 | int counter1 = test::CopyCount::counter; |
| 489 | test::CopyCount::counter = 0; |
| 490 | test::TestCopyCountAttr::get(&context, std::move(copyCount)); |
| 491 | #ifndef NDEBUG |
| 492 | // One verification enabled only in assert-mode requires a copy. |
| 493 | EXPECT_EQ(counter1, 1); |
| 494 | EXPECT_EQ(test::CopyCount::counter, 1); |
| 495 | #else |
| 496 | EXPECT_EQ(counter1, 0); |
| 497 | EXPECT_EQ(test::CopyCount::counter, 0); |
| 498 | #endif |
| 499 | } |
| 500 | |
| 501 | // Test stripped printing using test dialect attribute. |
| 502 | TEST(CopyCountAttr, PrintStripped) { |
| 503 | MLIRContext context; |
| 504 | context.loadDialect<test::TestDialect>(); |
| 505 | // Doesn't matter which dialect attribute is used, just chose TestCopyCount |
| 506 | // given proximity. |
| 507 | test::CopyCount::counter = 0; |
| 508 | test::CopyCount copyCount("hello" ); |
| 509 | Attribute res = test::TestCopyCountAttr::get(&context, std::move(copyCount)); |
| 510 | |
| 511 | std::string str; |
| 512 | llvm::raw_string_ostream os(str); |
| 513 | os << "|" << res << "|" ; |
| 514 | res.printStripped(os&: os << "[" ); |
| 515 | os << "]" ; |
| 516 | EXPECT_EQ(str, "|#test.copy_count<hello>|[copy_count<hello>]" ); |
| 517 | } |
| 518 | |
| 519 | } // namespace |
| 520 | |