1 | //===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/AttrTypeSubElements.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/BuiltinAttributes.h" |
12 | #include "mlir/IR/BuiltinTypes.h" |
13 | #include "gtest/gtest.h" |
14 | |
15 | using namespace mlir; |
16 | |
17 | //===----------------------------------------------------------------------===// |
18 | // CyclicAttrTypeReplacer |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | TEST(CyclicAttrTypeReplacerTest, testNoRecursion) { |
22 | MLIRContext ctx; |
23 | |
24 | CyclicAttrTypeReplacer replacer; |
25 | replacer.addReplacement(callback: [&](BoolAttr b) { |
26 | return StringAttr::get(&ctx, b.getValue() ? "true" : "false" ); |
27 | }); |
28 | |
29 | EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)), |
30 | StringAttr::get(&ctx, "true" )); |
31 | EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)), |
32 | StringAttr::get(&ctx, "false" )); |
33 | EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), |
34 | mlir::UnitAttr::get(&ctx)); |
35 | } |
36 | |
37 | TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) { |
38 | MLIRContext ctx; |
39 | Builder b(&ctx); |
40 | |
41 | CyclicAttrTypeReplacer replacer; |
42 | // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ... |
43 | replacer.addReplacement([&](IntegerAttr attr) { |
44 | return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3)); |
45 | }); |
46 | // The first repeat of any integer attr is pruned into a unit attr. |
47 | replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); }); |
48 | |
49 | // No recursion case. |
50 | EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)), |
51 | mlir::UnitAttr::get(&ctx)); |
52 | // Starting at 0. |
53 | EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx)); |
54 | // Starting at 2. |
55 | EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx)); |
56 | } |
57 | |
58 | //===----------------------------------------------------------------------===// |
59 | // CyclicAttrTypeReplacerTest: ChainRecursion |
60 | //===----------------------------------------------------------------------===// |
61 | |
62 | class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test { |
63 | public: |
64 | CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) { |
65 | // IntegerType<width = N> |
66 | // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>. |
67 | // This will create a chain of infinite length without recursion pruning. |
68 | replacer.addReplacement([&](mlir::IntegerType intType) { |
69 | ++invokeCount; |
70 | return b.getFunctionType( |
71 | {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)}); |
72 | }); |
73 | } |
74 | |
75 | void setBaseCase(std::optional<unsigned> pruneAt) { |
76 | replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { |
77 | return (!pruneAt || intType.getWidth() == *pruneAt) |
78 | ? std::make_optional(b.getIndexType()) |
79 | : std::nullopt; |
80 | }); |
81 | } |
82 | |
83 | Type getFunctionTypeChain(unsigned N) { |
84 | Type type = b.getIndexType(); |
85 | for (unsigned i = 0; i < N; i++) |
86 | type = b.getFunctionType({}, type); |
87 | return type; |
88 | }; |
89 | |
90 | MLIRContext ctx; |
91 | Builder b; |
92 | CyclicAttrTypeReplacer replacer; |
93 | int invokeCount = 0; |
94 | }; |
95 | |
96 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) { |
97 | setBaseCase(std::nullopt); |
98 | |
99 | // No recursion case. |
100 | EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); |
101 | EXPECT_EQ(invokeCount, 0); |
102 | |
103 | // Starting at 0. Cycle length is 3. |
104 | invokeCount = 0; |
105 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
106 | getFunctionTypeChain(3)); |
107 | EXPECT_EQ(invokeCount, 3); |
108 | |
109 | // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. |
110 | invokeCount = 0; |
111 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
112 | getFunctionTypeChain(5)); |
113 | EXPECT_EQ(invokeCount, 2); |
114 | } |
115 | |
116 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) { |
117 | setBaseCase(std::nullopt); |
118 | |
119 | // Starting at 1. Cycle length is 3. |
120 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
121 | getFunctionTypeChain(3)); |
122 | EXPECT_EQ(invokeCount, 3); |
123 | } |
124 | |
125 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) { |
126 | setBaseCase(0); |
127 | |
128 | // Starting at 0. Cycle length is 3. |
129 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
130 | getFunctionTypeChain(3)); |
131 | EXPECT_EQ(invokeCount, 3); |
132 | } |
133 | |
134 | TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) { |
135 | setBaseCase(0); |
136 | |
137 | // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). |
138 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
139 | getFunctionTypeChain(5)); |
140 | EXPECT_EQ(invokeCount, 5); |
141 | } |
142 | |
143 | //===----------------------------------------------------------------------===// |
144 | // CyclicAttrTypeReplacerTest: BranchingRecusion |
145 | //===----------------------------------------------------------------------===// |
146 | |
147 | class CyclicAttrTypeReplacerBranchingRecusionPruningTest |
148 | : public ::testing::Test { |
149 | public: |
150 | CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) { |
151 | // IntegerType<width = N> |
152 | // ==> FunctionType< |
153 | // IntegerType< width = (N+1) % 3> => |
154 | // IntegerType< width = (N+1) % 3>>. |
155 | // This will create a binary tree of infinite depth without pruning. |
156 | replacer.addReplacement([&](mlir::IntegerType intType) { |
157 | ++invokeCount; |
158 | Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3); |
159 | return b.getFunctionType({child}, {child}); |
160 | }); |
161 | } |
162 | |
163 | void setBaseCase(std::optional<unsigned> pruneAt) { |
164 | replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) { |
165 | return (!pruneAt || intType.getWidth() == *pruneAt) |
166 | ? std::make_optional(b.getIndexType()) |
167 | : std::nullopt; |
168 | }); |
169 | } |
170 | |
171 | Type getFunctionTypeTree(unsigned N) { |
172 | Type type = b.getIndexType(); |
173 | for (unsigned i = 0; i < N; i++) |
174 | type = b.getFunctionType(type, type); |
175 | return type; |
176 | }; |
177 | |
178 | MLIRContext ctx; |
179 | Builder b; |
180 | CyclicAttrTypeReplacer replacer; |
181 | int invokeCount = 0; |
182 | }; |
183 | |
184 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) { |
185 | setBaseCase(std::nullopt); |
186 | |
187 | // No recursion case. |
188 | EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType()); |
189 | EXPECT_EQ(invokeCount, 0); |
190 | |
191 | // Starting at 0. Cycle length is 3. |
192 | invokeCount = 0; |
193 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
194 | getFunctionTypeTree(3)); |
195 | // Since both branches are identical, this should incur linear invocations |
196 | // of the replacement function instead of exponential. |
197 | EXPECT_EQ(invokeCount, 3); |
198 | |
199 | // Starting at 1. Cycle length is 5 now because of a cached replacement at 0. |
200 | invokeCount = 0; |
201 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
202 | getFunctionTypeTree(5)); |
203 | EXPECT_EQ(invokeCount, 2); |
204 | } |
205 | |
206 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) { |
207 | setBaseCase(std::nullopt); |
208 | |
209 | // Starting at 1. Cycle length is 3. |
210 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
211 | getFunctionTypeTree(3)); |
212 | EXPECT_EQ(invokeCount, 3); |
213 | } |
214 | |
215 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) { |
216 | setBaseCase(0); |
217 | |
218 | // Starting at 0. Cycle length is 3. |
219 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)), |
220 | getFunctionTypeTree(3)); |
221 | EXPECT_EQ(invokeCount, 3); |
222 | } |
223 | |
224 | TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) { |
225 | setBaseCase(0); |
226 | |
227 | // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune). |
228 | EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)), |
229 | getFunctionTypeTree(5)); |
230 | EXPECT_EQ(invokeCount, 5); |
231 | } |
232 | |