1//===- AttrTypeReplacerTest.cpp - Sub-element replacer unit tests ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/AttrTypeSubElements.h"
10#include "mlir/IR/Builders.h"
11#include "mlir/IR/BuiltinAttributes.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "gtest/gtest.h"
14
15using namespace mlir;
16
17//===----------------------------------------------------------------------===//
18// CyclicAttrTypeReplacer
19//===----------------------------------------------------------------------===//
20
21TEST(CyclicAttrTypeReplacerTest, testNoRecursion) {
22 MLIRContext ctx;
23
24 CyclicAttrTypeReplacer replacer;
25 replacer.addReplacement(callback: [&](BoolAttr b) {
26 return StringAttr::get(&ctx, b.getValue() ? "true" : "false");
27 });
28
29 EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, true)),
30 StringAttr::get(&ctx, "true"));
31 EXPECT_EQ(replacer.replace(BoolAttr::get(&ctx, false)),
32 StringAttr::get(&ctx, "false"));
33 EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
34 mlir::UnitAttr::get(&ctx));
35}
36
37TEST(CyclicAttrTypeReplacerTest, testInPlaceRecursionPruneAnywhere) {
38 MLIRContext ctx;
39 Builder b(&ctx);
40
41 CyclicAttrTypeReplacer replacer;
42 // Replacer cycles through integer attrs 0 -> 1 -> 2 -> 0 -> ...
43 replacer.addReplacement([&](IntegerAttr attr) {
44 return replacer.replace(b.getI8IntegerAttr((attr.getInt() + 1) % 3));
45 });
46 // The first repeat of any integer attr is pruned into a unit attr.
47 replacer.addCycleBreaker([&](IntegerAttr attr) { return b.getUnitAttr(); });
48
49 // No recursion case.
50 EXPECT_EQ(replacer.replace(mlir::UnitAttr::get(&ctx)),
51 mlir::UnitAttr::get(&ctx));
52 // Starting at 0.
53 EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(0)), mlir::UnitAttr::get(&ctx));
54 // Starting at 2.
55 EXPECT_EQ(replacer.replace(b.getI8IntegerAttr(2)), mlir::UnitAttr::get(&ctx));
56}
57
58//===----------------------------------------------------------------------===//
59// CyclicAttrTypeReplacerTest: ChainRecursion
60//===----------------------------------------------------------------------===//
61
62class CyclicAttrTypeReplacerChainRecursionPruningTest : public ::testing::Test {
63public:
64 CyclicAttrTypeReplacerChainRecursionPruningTest() : b(&ctx) {
65 // IntegerType<width = N>
66 // ==> FunctionType<() => IntegerType< width = (N+1) % 3>>.
67 // This will create a chain of infinite length without recursion pruning.
68 replacer.addReplacement([&](mlir::IntegerType intType) {
69 ++invokeCount;
70 return b.getFunctionType(
71 {}, {mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3)});
72 });
73 }
74
75 void setBaseCase(std::optional<unsigned> pruneAt) {
76 replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
77 return (!pruneAt || intType.getWidth() == *pruneAt)
78 ? std::make_optional(b.getIndexType())
79 : std::nullopt;
80 });
81 }
82
83 Type getFunctionTypeChain(unsigned N) {
84 Type type = b.getIndexType();
85 for (unsigned i = 0; i < N; i++)
86 type = b.getFunctionType({}, type);
87 return type;
88 };
89
90 MLIRContext ctx;
91 Builder b;
92 CyclicAttrTypeReplacer replacer;
93 int invokeCount = 0;
94};
95
96TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere0) {
97 setBaseCase(std::nullopt);
98
99 // No recursion case.
100 EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
101 EXPECT_EQ(invokeCount, 0);
102
103 // Starting at 0. Cycle length is 3.
104 invokeCount = 0;
105 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
106 getFunctionTypeChain(3));
107 EXPECT_EQ(invokeCount, 3);
108
109 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
110 invokeCount = 0;
111 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
112 getFunctionTypeChain(5));
113 EXPECT_EQ(invokeCount, 2);
114}
115
116TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneAnywhere1) {
117 setBaseCase(std::nullopt);
118
119 // Starting at 1. Cycle length is 3.
120 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
121 getFunctionTypeChain(3));
122 EXPECT_EQ(invokeCount, 3);
123}
124
125TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific0) {
126 setBaseCase(0);
127
128 // Starting at 0. Cycle length is 3.
129 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
130 getFunctionTypeChain(3));
131 EXPECT_EQ(invokeCount, 3);
132}
133
134TEST_F(CyclicAttrTypeReplacerChainRecursionPruningTest, testPruneSpecific1) {
135 setBaseCase(0);
136
137 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
138 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
139 getFunctionTypeChain(5));
140 EXPECT_EQ(invokeCount, 5);
141}
142
143//===----------------------------------------------------------------------===//
144// CyclicAttrTypeReplacerTest: BranchingRecusion
145//===----------------------------------------------------------------------===//
146
147class CyclicAttrTypeReplacerBranchingRecusionPruningTest
148 : public ::testing::Test {
149public:
150 CyclicAttrTypeReplacerBranchingRecusionPruningTest() : b(&ctx) {
151 // IntegerType<width = N>
152 // ==> FunctionType<
153 // IntegerType< width = (N+1) % 3> =>
154 // IntegerType< width = (N+1) % 3>>.
155 // This will create a binary tree of infinite depth without pruning.
156 replacer.addReplacement([&](mlir::IntegerType intType) {
157 ++invokeCount;
158 Type child = mlir::IntegerType::get(&ctx, (intType.getWidth() + 1) % 3);
159 return b.getFunctionType({child}, {child});
160 });
161 }
162
163 void setBaseCase(std::optional<unsigned> pruneAt) {
164 replacer.addCycleBreaker([&, pruneAt](mlir::IntegerType intType) {
165 return (!pruneAt || intType.getWidth() == *pruneAt)
166 ? std::make_optional(b.getIndexType())
167 : std::nullopt;
168 });
169 }
170
171 Type getFunctionTypeTree(unsigned N) {
172 Type type = b.getIndexType();
173 for (unsigned i = 0; i < N; i++)
174 type = b.getFunctionType(type, type);
175 return type;
176 };
177
178 MLIRContext ctx;
179 Builder b;
180 CyclicAttrTypeReplacer replacer;
181 int invokeCount = 0;
182};
183
184TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere0) {
185 setBaseCase(std::nullopt);
186
187 // No recursion case.
188 EXPECT_EQ(replacer.replace(b.getIndexType()), b.getIndexType());
189 EXPECT_EQ(invokeCount, 0);
190
191 // Starting at 0. Cycle length is 3.
192 invokeCount = 0;
193 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
194 getFunctionTypeTree(3));
195 // Since both branches are identical, this should incur linear invocations
196 // of the replacement function instead of exponential.
197 EXPECT_EQ(invokeCount, 3);
198
199 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
200 invokeCount = 0;
201 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
202 getFunctionTypeTree(5));
203 EXPECT_EQ(invokeCount, 2);
204}
205
206TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneAnywhere1) {
207 setBaseCase(std::nullopt);
208
209 // Starting at 1. Cycle length is 3.
210 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
211 getFunctionTypeTree(3));
212 EXPECT_EQ(invokeCount, 3);
213}
214
215TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific0) {
216 setBaseCase(0);
217
218 // Starting at 0. Cycle length is 3.
219 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 0)),
220 getFunctionTypeTree(3));
221 EXPECT_EQ(invokeCount, 3);
222}
223
224TEST_F(CyclicAttrTypeReplacerBranchingRecusionPruningTest, testPruneSpecific1) {
225 setBaseCase(0);
226
227 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
228 EXPECT_EQ(replacer.replace(mlir::IntegerType::get(&ctx, 1)),
229 getFunctionTypeTree(5));
230 EXPECT_EQ(invokeCount, 5);
231}
232

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/unittests/IR/AttrTypeReplacerTest.cpp