1//===- CyclicReplacerCacheTest.cpp ----------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Support/CyclicReplacerCache.h"
10#include "mlir/Support/LLVM.h"
11#include "llvm/ADT/SetVector.h"
12#include "gmock/gmock.h"
13#include <map>
14#include <set>
15
16using namespace mlir;
17
18TEST(CachedCyclicReplacerTest, testNoRecursion) {
19 CachedCyclicReplacer<int, bool> replacer(
20 /*replacer=*/[](int n) { return static_cast<bool>(n); },
21 /*cycleBreaker=*/[](int n) { return std::nullopt; });
22
23 EXPECT_EQ(replacer(3), true);
24 EXPECT_EQ(replacer(0), false);
25}
26
27TEST(CachedCyclicReplacerTest, testInPlaceRecursionPruneAnywhere) {
28 // Replacer cycles through ints 0 -> 1 -> 2 -> 0 -> ...
29 std::optional<CachedCyclicReplacer<int, int>> replacer;
30 replacer.emplace(
31 /*replacer=*/args: [&](int n) { return (*replacer)((n + 1) % 3); },
32 /*cycleBreaker=*/args: [&](int n) { return -1; });
33
34 // Starting at 0.
35 EXPECT_EQ((*replacer)(0), -1);
36 // Starting at 2.
37 EXPECT_EQ((*replacer)(2), -1);
38}
39
40//===----------------------------------------------------------------------===//
41// CachedCyclicReplacer: ChainRecursion
42//===----------------------------------------------------------------------===//
43
44/// This set of tests uses a replacer function that maps ints into vectors of
45/// ints.
46///
47/// The replacement result for input `n` is the replacement result of `(n+1)%3`
48/// appended with an element `42`. Theoretically, this will produce an
49/// infinitely long vector. The cycle-breaker function prunes this infinite
50/// recursion in the replacer logic by returning an empty vector upon the first
51/// re-occurrence of an input value.
52namespace {
53class CachedCyclicReplacerChainRecursionPruningTest : public ::testing::Test {
54public:
55 // N ==> (N+1) % 3
56 // This will create a chain of infinite length without recursion pruning.
57 CachedCyclicReplacerChainRecursionPruningTest()
58 : replacer(
59 [&](int n) {
60 ++invokeCount;
61 std::vector<int> result = replacer((n + 1) % 3);
62 result.push_back(x: 42);
63 return result;
64 },
65 [&](int n) -> std::optional<std::vector<int>> {
66 return baseCase.value_or(u&: n) == n
67 ? std::make_optional(t: std::vector<int>{})
68 : std::nullopt;
69 }) {}
70
71 std::vector<int> getChain(unsigned N) { return std::vector<int>(N, 42); };
72
73 CachedCyclicReplacer<int, std::vector<int>> replacer;
74 int invokeCount = 0;
75 std::optional<int> baseCase = std::nullopt;
76};
77} // namespace
78
79TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere0) {
80 // Starting at 0. Cycle length is 3.
81 EXPECT_EQ(replacer(0), getChain(3));
82 EXPECT_EQ(invokeCount, 3);
83
84 // Starting at 1. Cycle length is 5 now because of a cached replacement at 0.
85 invokeCount = 0;
86 EXPECT_EQ(replacer(1), getChain(5));
87 EXPECT_EQ(invokeCount, 2);
88
89 // Starting at 2. Cycle length is 4. Entire result is cached.
90 invokeCount = 0;
91 EXPECT_EQ(replacer(2), getChain(4));
92 EXPECT_EQ(invokeCount, 0);
93}
94
95TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneAnywhere1) {
96 // Starting at 1. Cycle length is 3.
97 EXPECT_EQ(replacer(1), getChain(3));
98 EXPECT_EQ(invokeCount, 3);
99}
100
101TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific0) {
102 baseCase = 0;
103
104 // Starting at 0. Cycle length is 3.
105 EXPECT_EQ(replacer(0), getChain(3));
106 EXPECT_EQ(invokeCount, 3);
107}
108
109TEST_F(CachedCyclicReplacerChainRecursionPruningTest, testPruneSpecific1) {
110 baseCase = 0;
111
112 // Starting at 1. Cycle length is 5 (1 -> 2 -> 0 -> 1 -> 2 -> Prune).
113 EXPECT_EQ(replacer(1), getChain(5));
114 EXPECT_EQ(invokeCount, 5);
115
116 // Starting at 0. Cycle length is 3. Entire result is cached.
117 invokeCount = 0;
118 EXPECT_EQ(replacer(0), getChain(3));
119 EXPECT_EQ(invokeCount, 0);
120}
121
122//===----------------------------------------------------------------------===//
123// CachedCyclicReplacer: GraphReplacement
124//===----------------------------------------------------------------------===//
125
126/// This set of tests uses a replacer function that maps from cyclic graphs to
127/// trees, pruning out cycles in the process.
128///
129/// It consists of two helper classes:
130/// - Graph
131/// - A directed graph where nodes are non-negative integers.
132/// - PrunedGraph
133/// - A Graph where edges that used to cause cycles are now represented with
134/// an indirection (a recursionId).
135namespace {
136class CachedCyclicReplacerGraphReplacement : public ::testing::Test {
137public:
138 /// A directed graph where nodes are non-negative integers.
139 struct Graph {
140 using Node = int64_t;
141
142 /// Use ordered containers for deterministic output.
143 /// Nodes without outgoing edges are considered nonexistent.
144 std::map<Node, std::set<Node>> edges;
145
146 void addEdge(Node src, Node sink) { edges[src].insert(x: sink); }
147
148 bool isCyclic() const {
149 DenseSet<Node> visited;
150 for (Node root : llvm::make_first_range(c: edges)) {
151 if (visited.contains(V: root))
152 continue;
153
154 SetVector<Node> path;
155 SmallVector<Node> workstack;
156 workstack.push_back(Elt: root);
157 while (!workstack.empty()) {
158 Node curr = workstack.back();
159 workstack.pop_back();
160
161 if (curr < 0) {
162 // A negative node signals the end of processing all of this node's
163 // children. Remove self from path.
164 assert(path.back() == -curr && "internal inconsistency");
165 path.pop_back();
166 continue;
167 }
168
169 if (path.contains(key: curr))
170 return true;
171
172 visited.insert(V: curr);
173 auto edgesIter = edges.find(x: curr);
174 if (edgesIter == edges.end() || edgesIter->second.empty())
175 continue;
176
177 path.insert(X: curr);
178 // Push negative node to signify recursion return.
179 workstack.push_back(Elt: -curr);
180 workstack.insert(I: workstack.end(), From: edgesIter->second.begin(),
181 To: edgesIter->second.end());
182 }
183 }
184 return false;
185 }
186
187 /// Deterministic output for testing.
188 std::string serialize() const {
189 std::ostringstream oss;
190 for (const auto &[src, neighbors] : edges) {
191 oss << src << ":";
192 for (Graph::Node neighbor : neighbors)
193 oss << " " << neighbor;
194 oss << "\n";
195 }
196 return oss.str();
197 }
198 };
199
200 /// A Graph where edges that used to cause cycles (back-edges) are now
201 /// represented with an indirection (a recursionId).
202 ///
203 /// In addition to each node having an integer ID, each node also tracks the
204 /// original integer ID it had in the original graph. This way for every
205 /// back-edge, we can represent it as pointing to a new instance of the
206 /// original node. Then we mark the original node and the new instance with
207 /// a new unique recursionId to indicate that they're supposed to be the same
208 /// node.
209 struct PrunedGraph {
210 using Node = Graph::Node;
211 struct NodeInfo {
212 Graph::Node originalId;
213 /// A negative recursive index means not recursive. Otherwise nodes with
214 /// the same originalId & recursionId are the same node in the original
215 /// graph.
216 int64_t recursionId;
217 };
218
219 /// Add a regular non-recursive-self node.
220 Node addNode(Graph::Node originalId, int64_t recursionIndex = -1) {
221 Node id = nextConnectionId++;
222 info[id] = {.originalId: originalId, .recursionId: recursionIndex};
223 return id;
224 }
225 /// Add a recursive-self-node, i.e. a duplicate of the original node that is
226 /// meant to represent an indirection to it.
227 std::pair<Node, int64_t> addRecursiveSelfNode(Graph::Node originalId) {
228 auto node = addNode(originalId, recursionIndex: nextRecursionId);
229 return {node, nextRecursionId++};
230 }
231 void addEdge(Node src, Node sink) { connections.addEdge(src, sink); }
232
233 /// Deterministic output for testing.
234 std::string serialize() const {
235 std::ostringstream oss;
236 oss << "nodes\n";
237 for (const auto &[nodeId, nodeInfo] : info) {
238 oss << nodeId << ": n" << nodeInfo.originalId;
239 if (nodeInfo.recursionId >= 0)
240 oss << '<' << nodeInfo.recursionId << '>';
241 oss << "\n";
242 }
243 oss << "edges\n";
244 oss << connections.serialize();
245 return oss.str();
246 }
247
248 bool isCyclic() const { return connections.isCyclic(); }
249
250 private:
251 Graph connections;
252 int64_t nextRecursionId = 0;
253 int64_t nextConnectionId = 0;
254 /// Use ordered map for deterministic output.
255 std::map<Graph::Node, NodeInfo> info;
256 };
257
258 PrunedGraph breakCycles(const Graph &input) {
259 assert(input.isCyclic() && "input graph is not cyclic");
260
261 PrunedGraph output;
262
263 DenseMap<Graph::Node, int64_t> recMap;
264 auto cycleBreaker = [&](Graph::Node inNode) -> std::optional<Graph::Node> {
265 auto [node, recId] = output.addRecursiveSelfNode(originalId: inNode);
266 recMap[inNode] = recId;
267 return node;
268 };
269
270 CyclicReplacerCache<Graph::Node, Graph::Node> cache(cycleBreaker);
271
272 std::function<Graph::Node(Graph::Node)> replaceNode =
273 [&](Graph::Node inNode) {
274 auto cacheEntry = cache.lookupOrInit(element: inNode);
275 if (std::optional<Graph::Node> result = cacheEntry.get())
276 return *result;
277
278 // Recursively replace its neighbors.
279 SmallVector<Graph::Node> neighbors;
280 if (auto it = input.edges.find(x: inNode); it != input.edges.end())
281 neighbors = SmallVector<Graph::Node>(
282 llvm::map_range(C: it->second, F: replaceNode));
283
284 // Create a new node in the output graph.
285 int64_t recursionIndex =
286 cacheEntry.wasRepeated() ? recMap.lookup(Val: inNode) : -1;
287 Graph::Node result = output.addNode(originalId: inNode, recursionIndex);
288
289 for (Graph::Node neighbor : neighbors)
290 output.addEdge(src: result, sink: neighbor);
291
292 cacheEntry.resolve(result);
293 return result;
294 };
295
296 /// Translate starting from each node.
297 for (Graph::Node root : llvm::make_first_range(c: input.edges))
298 replaceNode(root);
299
300 return output;
301 }
302
303 /// Helper for serialization tests that allow putting comments in the
304 /// serialized format. Every line that begins with a `;` is considered a
305 /// comment. The entire line, incl. the terminating `\n` is removed.
306 std::string trimComments(StringRef input) {
307 std::ostringstream oss;
308 bool isNewLine = false;
309 bool isComment = false;
310 for (char c : input) {
311 // Lines beginning with ';' are comments.
312 if (isNewLine && c == ';')
313 isComment = true;
314
315 if (!isComment)
316 oss << c;
317
318 if (c == '\n') {
319 isNewLine = true;
320 isComment = false;
321 }
322 }
323 return oss.str();
324 }
325};
326} // namespace
327
328TEST_F(CachedCyclicReplacerGraphReplacement, testSingleLoop) {
329 // 0 -> 1 -> 2
330 // ^ |
331 // +---------+
332 Graph input = {.edges: {{0, {1}}, {1, {2}}, {2, {0}}}};
333 PrunedGraph output = breakCycles(input);
334 ASSERT_FALSE(output.isCyclic()) << output.serialize();
335 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
336; root 0
3370: n0<0>
3381: n2
3392: n1
3403: n0<0>
341; root 1
3424: n2
343; root 2
3445: n1
345edges
3461: 0
3472: 1
3483: 2
3494: 3
3505: 4
351)"));
352}
353
354TEST_F(CachedCyclicReplacerGraphReplacement, testDualLoop) {
355 // +----> 1 -----+
356 // | v
357 // 0 <---------- 3
358 // | ^
359 // +----> 2 -----+
360 //
361 // Two loops:
362 // 0 -> 1 -> 3 -> 0
363 // 0 -> 2 -> 3 -> 0
364 Graph input = {.edges: {{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0}}}};
365 PrunedGraph output = breakCycles(input);
366 ASSERT_FALSE(output.isCyclic()) << output.serialize();
367 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
368; root 0
3690: n0<0>
3701: n3
3712: n1
3723: n2
3734: n0<0>
374; root 1
3755: n3
3766: n1
377; root 2
3787: n2
379edges
3801: 0
3812: 1
3823: 1
3834: 2 3
3845: 4
3856: 5
3867: 5
387)"));
388}
389
390TEST_F(CachedCyclicReplacerGraphReplacement, testNestedLoops) {
391 // +----> 1 -----+
392 // | ^ v
393 // 0 <----+----- 2
394 //
395 // Two nested loops:
396 // 0 -> 1 -> 2 -> 0
397 // 1 -> 2 -> 1
398 Graph input = {.edges: {{0, {1}}, {1, {2}}, {2, {0, 1}}}};
399 PrunedGraph output = breakCycles(input);
400 ASSERT_FALSE(output.isCyclic()) << output.serialize();
401 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
402; root 0
4030: n0<0>
4041: n1<1>
4052: n2
4063: n1<1>
4074: n0<0>
408; root 1
4095: n1<2>
4106: n2
4117: n1<2>
412; root 2
4138: n2
414edges
4152: 0 1
4163: 2
4174: 3
4186: 4 5
4197: 6
4208: 4 7
421)"));
422}
423
424TEST_F(CachedCyclicReplacerGraphReplacement, testDualNestedLoops) {
425 // +----> 1 -----+
426 // | ^ v
427 // 0 <----+----- 3
428 // | v ^
429 // +----> 2 -----+
430 //
431 // Two sets of nested loops:
432 // 0 -> 1 -> 3 -> 0
433 // 1 -> 3 -> 1
434 // 0 -> 2 -> 3 -> 0
435 // 2 -> 3 -> 2
436 Graph input = {.edges: {{0, {1, 2}}, {1, {3}}, {2, {3}}, {3, {0, 1, 2}}}};
437 PrunedGraph output = breakCycles(input);
438 ASSERT_FALSE(output.isCyclic()) << output.serialize();
439 EXPECT_EQ(output.serialize(), trimComments(R"(nodes
440; root 0
4410: n0<0>
4421: n1<1>
4432: n3<2>
4443: n2
4454: n3<2>
4465: n1<1>
4476: n2<3>
4487: n3
4498: n2<3>
4509: n0<0>
451; root 1
45210: n1<4>
45311: n3<5>
45412: n2
45513: n3<5>
45614: n1<4>
457; root 2
45815: n2<6>
45916: n3
46017: n2<6>
461; root 3
46218: n3
463edges
464; root 0
4653: 2
4664: 0 1 3
4675: 4
4687: 0 5 6
4698: 7
4709: 5 8
471; root 1
47212: 11
47313: 9 10 12
47414: 13
475; root 2
47616: 9 14 15
47717: 16
478; root 3
47918: 9 14 17
480)"));
481}
482

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/unittests/Support/CyclicReplacerCacheTest.cpp