1 | //===- RootOrderingTest.cpp - unit tests for optimal branching ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v[1].0 with LLVM |
4 | // Exceptions. See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/IR/Builders.h" |
12 | #include "mlir/IR/MLIRContext.h" |
13 | #include "gtest/gtest.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace mlir::arith; |
17 | using namespace mlir::pdl_to_pdl_interp; |
18 | |
19 | namespace { |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Test Fixture |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | /// The test fixture for constructing root ordering tests and verifying results. |
26 | /// This fixture constructs the test values v. The test populates the graph |
27 | /// with the desired costs and then calls check(), passing the expected optimal |
28 | /// cost and the list of edges in the preorder traversal of the optimal |
29 | /// branching. |
30 | class RootOrderingTest : public ::testing::Test { |
31 | protected: |
32 | RootOrderingTest() { |
33 | context.loadDialect<ArithDialect>(); |
34 | createValues(); |
35 | } |
36 | |
37 | /// Creates the test values. These values simply act as vertices / vertex IDs |
38 | /// in the cost graph, rather than being a part of an IR. |
39 | void createValues() { |
40 | OpBuilder builder(&context); |
41 | builder.setInsertionPointToStart(&block); |
42 | for (int i = 0; i < 4; ++i) |
43 | // Ops will be deleted when `block` is destroyed. |
44 | v[i] = builder.create<ConstantIntOp>(location: builder.getUnknownLoc(), args&: i, args: 32); |
45 | } |
46 | |
47 | /// Checks that optimal branching on graph has the given cost and |
48 | /// its preorder traversal results in the specified edges. |
49 | void check(unsigned cost, const OptimalBranching::EdgeList &edges) { |
50 | OptimalBranching opt(graph, v[0]); |
51 | EXPECT_EQ(opt.solve(), cost); |
52 | EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges); |
53 | for (std::pair<Value, Value> edge : edges) |
54 | EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second); |
55 | } |
56 | |
57 | protected: |
58 | /// The context for creating the values. |
59 | MLIRContext context; |
60 | |
61 | /// Block holding all the operations. |
62 | Block block; |
63 | |
64 | /// Values used in the graph definition. We always use leading `n` values. |
65 | Value v[4]; |
66 | |
67 | /// The graph being tested on. |
68 | RootOrderingGraph graph; |
69 | }; |
70 | |
71 | //===----------------------------------------------------------------------===// |
72 | // Simple 3-node graphs |
73 | //===----------------------------------------------------------------------===// |
74 | |
75 | TEST_F(RootOrderingTest, simpleA) { |
76 | graph[v[1]][v[0]].cost = {1, 10}; |
77 | graph[v[2]][v[0]].cost = {1, 11}; |
78 | graph[v[1]][v[2]].cost = {2, 12}; |
79 | graph[v[2]][v[1]].cost = {2, 13}; |
80 | check(cost: 2, edges: {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}}); |
81 | } |
82 | |
83 | TEST_F(RootOrderingTest, simpleB) { |
84 | graph[v[1]][v[0]].cost = {1, 10}; |
85 | graph[v[2]][v[0]].cost = {2, 11}; |
86 | graph[v[1]][v[2]].cost = {1, 12}; |
87 | graph[v[2]][v[1]].cost = {1, 13}; |
88 | check(cost: 2, edges: {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); |
89 | } |
90 | |
91 | TEST_F(RootOrderingTest, simpleC) { |
92 | graph[v[1]][v[0]].cost = {2, 10}; |
93 | graph[v[2]][v[0]].cost = {2, 11}; |
94 | graph[v[1]][v[2]].cost = {1, 12}; |
95 | graph[v[2]][v[1]].cost = {1, 13}; |
96 | check(cost: 3, edges: {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}}); |
97 | } |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // Graph for testing contraction |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | TEST_F(RootOrderingTest, contraction) { |
104 | graph[v[1]][v[0]].cost = {10, 0}; |
105 | graph[v[2]][v[0]].cost = {5, 0}; |
106 | graph[v[2]][v[1]].cost = {1, 0}; |
107 | graph[v[3]][v[2]].cost = {2, 0}; |
108 | graph[v[1]][v[3]].cost = {3, 0}; |
109 | check(cost: 10, edges: {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}}); |
110 | } |
111 | |
112 | } // namespace |
113 | |