| 1 | //===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/IR/AttrTypeSubElements.h" |
| 10 | #include "mlir/IR/Builders.h" |
| 11 | #include "mlir/IR/BuiltinAttributes.h" |
| 12 | #include "mlir/IR/BuiltinTypes.h" |
| 13 | #include "gtest/gtest.h" |
| 14 | |
| 15 | using namespace mlir; |
| 16 | |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | // CyclicAttrTypeReplacer |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | |
| 21 | TEST(CyclicAttrTypeReplacerTest, testNoRecursion) { |
| 22 | MLIRContext ctx; |
| 23 | |
| 24 | CyclicAttrTypeReplacer replacer; |
| 25 | replacer.addReplacement(callback: [&](BoolAttr b) { |
| 26 | return StringAttr::get(&ctx, b.getValue() ? "true" : "false" ); |
| 27 | }); |
| 28 | |
| 29 | EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)), |
| 30 | StringAttr::get(&ctx, "true" )); |
| 31 | EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)), |
| 32 | StringAttr::get(&ctx, "false" )); |
| 33 | EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), |
| 34 | mlir::UnitAttr::get(&ctx)); |
| 35 | } |
| 36 | |
| 37 | TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) { |
| 38 | MLIRContext ctx; |
| 39 | Builder b(&ctx); |
| 40 | |
| 41 | CyclicAttrTypeReplacer replacer; |
| 42 | // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ... |
| 43 | replacer.addReplacement([&](IntegerAttr attr) { |
| 44 | return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3)); |
| 45 | }); |
| 46 | // The first repeat of any integer attr is pruned into a unit attr. |
| 47 | replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); }); |
| 48 | |
| 49 | // No recursion case. |
| 50 | EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), |
| 51 | mlir::UnitAttr::get(&ctx)); |
| 52 | // Starting at 0. |
| 53 | EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx)); |
| 54 | // Starting at 2. |
| 55 | EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx)); |
| 56 | } |
| 57 | |
| 58 | //===----------------------------------------------------------------------===// |
| 59 | // CyclicAttrTypeReplacerTest: ChainRecursion |
| 60 | //===----------------------------------------------------------------------===// |
| 61 | |
| 62 | class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test { |
| 63 | public: |
| 64 | CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) { |
| 65 | // IntegerType<width = N> |
| 66 | // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>. |
| 67 | // This will create a chain of infinite length without recursion pruning. |
| 68 | replacer.addReplacement([&](mlir::IntegerType intType) { |
| 69 | ++invokeCount; |
| 70 | return b.getFunctionType( |
| 71 | {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)}); |
| 72 | }); |
| 73 | } |
| 74 | |
| 75 | void setBaseCase(std::optional<unsigned> pruneAt) { |
| 76 | replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { |
| 77 | return (!pruneAt || intType.getWidth() == *pruneAt) |
| 78 | ? std::make_optional(b.getIndexType()) |
| 79 | : std::nullopt; |
| 80 | }); |
| 81 | } |
| 82 | |
| 83 | Type getFunctionTypeChain(unsigned N) { |
| 84 | Type type = b.getIndexType(); |
| 85 | for (unsigned i = 0; i < N; i++) |
| 86 | type = b.getFunctionType({}, type); |
| 87 | return type; |
| 88 | }; |
| 89 | |
| 90 | MLIRContext ctx; |
| 91 | Builder b; |
| 92 | CyclicAttrTypeReplacer replacer; |
| 93 | int invokeCount = 0; |
| 94 | }; |
| 95 | |
| 96 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) { |
| 97 | setBaseCase(std::nullopt); |
| 98 | |
| 99 | // No recursion case. |
| 100 | EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); |
| 101 | EXPECT_EQ(invokeCount, 0); |
| 102 | |
| 103 | // Starting at 0. Cycle length is 3. |
| 104 | invokeCount = 0; |
| 105 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
| 106 | getFunctionTypeChain(3)); |
| 107 | EXPECT_EQ(invokeCount, 3); |
| 108 | |
| 109 | // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. |
| 110 | invokeCount = 0; |
| 111 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
| 112 | getFunctionTypeChain(5)); |
| 113 | EXPECT_EQ(invokeCount, 2); |
| 114 | } |
| 115 | |
| 116 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) { |
| 117 | setBaseCase(std::nullopt); |
| 118 | |
| 119 | // Starting at 1. Cycle length is 3. |
| 120 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
| 121 | getFunctionTypeChain(3)); |
| 122 | EXPECT_EQ(invokeCount, 3); |
| 123 | } |
| 124 | |
| 125 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) { |
| 126 | setBaseCase(0); |
| 127 | |
| 128 | // Starting at 0. Cycle length is 3. |
| 129 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
| 130 | getFunctionTypeChain(3)); |
| 131 | EXPECT_EQ(invokeCount, 3); |
| 132 | } |
| 133 | |
| 134 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) { |
| 135 | setBaseCase(0); |
| 136 | |
| 137 | // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). |
| 138 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
| 139 | getFunctionTypeChain(5)); |
| 140 | EXPECT_EQ(invokeCount, 5); |
| 141 | } |
| 142 | |
| 143 | //===----------------------------------------------------------------------===// |
| 144 | // CyclicAttrTypeReplacerTest: BranchingRecusion |
| 145 | //===----------------------------------------------------------------------===// |
| 146 | |
| 147 | class CyclicAttrTypeReplacerBranchingRecusionPruningTest |
| 148 | : public ::testing::Test { |
| 149 | public: |
| 150 | CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) { |
| 151 | // IntegerType<width = N> |
| 152 | // ==> FunctionType< |
| 153 | // IntegerType< width = (N+1) % 3> => |
| 154 | // IntegerType< width = (N+1) % 3>>. |
| 155 | // This will create a binary tree of infinite depth without pruning. |
| 156 | replacer.addReplacement([&](mlir::IntegerType intType) { |
| 157 | ++invokeCount; |
| 158 | Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3); |
| 159 | return b.getFunctionType({child}, {child}); |
| 160 | }); |
| 161 | } |
| 162 | |
| 163 | void setBaseCase(std::optional<unsigned> pruneAt) { |
| 164 | replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { |
| 165 | return (!pruneAt || intType.getWidth() == *pruneAt) |
| 166 | ? std::make_optional(b.getIndexType()) |
| 167 | : std::nullopt; |
| 168 | }); |
| 169 | } |
| 170 | |
| 171 | Type getFunctionTypeTree(unsigned N) { |
| 172 | Type type = b.getIndexType(); |
| 173 | for (unsigned i = 0; i < N; i++) |
| 174 | type = b.getFunctionType(type, type); |
| 175 | return type; |
| 176 | }; |
| 177 | |
| 178 | MLIRContext ctx; |
| 179 | Builder b; |
| 180 | CyclicAttrTypeReplacer replacer; |
| 181 | int invokeCount = 0; |
| 182 | }; |
| 183 | |
| 184 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) { |
| 185 | setBaseCase(std::nullopt); |
| 186 | |
| 187 | // No recursion case. |
| 188 | EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); |
| 189 | EXPECT_EQ(invokeCount, 0); |
| 190 | |
| 191 | // Starting at 0. Cycle length is 3. |
| 192 | invokeCount = 0; |
| 193 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
| 194 | getFunctionTypeTree(3)); |
| 195 | // Since both branches are identical, this should incur linear invocations |
| 196 | // of the replacement function instead of exponential. |
| 197 | EXPECT_EQ(invokeCount, 3); |
| 198 | |
| 199 | // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. |
| 200 | invokeCount = 0; |
| 201 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
| 202 | getFunctionTypeTree(5)); |
| 203 | EXPECT_EQ(invokeCount, 2); |
| 204 | } |
| 205 | |
| 206 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) { |
| 207 | setBaseCase(std::nullopt); |
| 208 | |
| 209 | // Starting at 1. Cycle length is 3. |
| 210 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
| 211 | getFunctionTypeTree(3)); |
| 212 | EXPECT_EQ(invokeCount, 3); |
| 213 | } |
| 214 | |
| 215 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) { |
| 216 | setBaseCase(0); |
| 217 | |
| 218 | // Starting at 0. Cycle length is 3. |
| 219 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
| 220 | getFunctionTypeTree(3)); |
| 221 | EXPECT_EQ(invokeCount, 3); |
| 222 | } |
| 223 | |
| 224 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) { |
| 225 | setBaseCase(0); |
| 226 | |
| 227 | // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). |
| 228 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
| 229 | getFunctionTypeTree(5)); |
| 230 | EXPECT_EQ(invokeCount, 5); |
| 231 | } |
| 232 | |