1//===- RootOrderingTest.cpp - unit tests for optimal branching ------------===//
2//
3// Part of the LLVM Project, under the Apache License v[1].0 with LLVM
4// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "../lib/Conversion/PDLToPDLInterp/RootOrdering.h"
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/MLIRContext.h"
13#include "gtest/gtest.h"
14
15using namespace mlir;
16using namespace mlir::arith;
17using namespace mlir::pdl_to_pdl_interp;
18
19namespace {
20
21//===----------------------------------------------------------------------===//
22// Test Fixture
23//===----------------------------------------------------------------------===//
24
25/// The test fixture for constructing root ordering tests and verifying results.
26/// This fixture constructs the test values v. The test populates the graph
27/// with the desired costs and then calls check(), passing the expected optimal
28/// cost and the list of edges in the preorder traversal of the optimal
29/// branching.
30class RootOrderingTest : public ::testing::Test {
31protected:
32 RootOrderingTest() {
33 context.loadDialect<ArithDialect>();
34 createValues();
35 }
36
37 /// Creates the test values. These values simply act as vertices / vertex IDs
38 /// in the cost graph, rather than being a part of an IR.
39 void createValues() {
40 OpBuilder builder(&context);
41 builder.setInsertionPointToStart(&block);
42 for (int i = 0; i < 4; ++i)
43 // Ops will be deleted when `block` is destroyed.
44 v[i] = builder.create<ConstantIntOp>(location: builder.getUnknownLoc(), args&: i, args: 32);
45 }
46
47 /// Checks that optimal branching on graph has the given cost and
48 /// its preorder traversal results in the specified edges.
49 void check(unsigned cost, const OptimalBranching::EdgeList &edges) {
50 OptimalBranching opt(graph, v[0]);
51 EXPECT_EQ(opt.solve(), cost);
52 EXPECT_EQ(opt.preOrderTraversal({v, v + edges.size()}), edges);
53 for (std::pair<Value, Value> edge : edges)
54 EXPECT_EQ(opt.getRootOrderingParents().lookup(edge.first), edge.second);
55 }
56
57protected:
58 /// The context for creating the values.
59 MLIRContext context;
60
61 /// Block holding all the operations.
62 Block block;
63
64 /// Values used in the graph definition. We always use leading `n` values.
65 Value v[4];
66
67 /// The graph being tested on.
68 RootOrderingGraph graph;
69};
70
71//===----------------------------------------------------------------------===//
72// Simple 3-node graphs
73//===----------------------------------------------------------------------===//
74
75TEST_F(RootOrderingTest, simpleA) {
76 graph[v[1]][v[0]].cost = {1, 10};
77 graph[v[2]][v[0]].cost = {1, 11};
78 graph[v[1]][v[2]].cost = {2, 12};
79 graph[v[2]][v[1]].cost = {2, 13};
80 check(cost: 2, edges: {{v[0], {}}, {v[1], v[0]}, {v[2], v[0]}});
81}
82
83TEST_F(RootOrderingTest, simpleB) {
84 graph[v[1]][v[0]].cost = {1, 10};
85 graph[v[2]][v[0]].cost = {2, 11};
86 graph[v[1]][v[2]].cost = {1, 12};
87 graph[v[2]][v[1]].cost = {1, 13};
88 check(cost: 2, edges: {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
89}
90
91TEST_F(RootOrderingTest, simpleC) {
92 graph[v[1]][v[0]].cost = {2, 10};
93 graph[v[2]][v[0]].cost = {2, 11};
94 graph[v[1]][v[2]].cost = {1, 12};
95 graph[v[2]][v[1]].cost = {1, 13};
96 check(cost: 3, edges: {{v[0], {}}, {v[1], v[0]}, {v[2], v[1]}});
97}
98
99//===----------------------------------------------------------------------===//
100// Graph for testing contraction
101//===----------------------------------------------------------------------===//
102
103TEST_F(RootOrderingTest, contraction) {
104 graph[v[1]][v[0]].cost = {10, 0};
105 graph[v[2]][v[0]].cost = {5, 0};
106 graph[v[2]][v[1]].cost = {1, 0};
107 graph[v[3]][v[2]].cost = {2, 0};
108 graph[v[1]][v[3]].cost = {3, 0};
109 check(cost: 10, edges: {{v[0], {}}, {v[2], v[0]}, {v[3], v[2]}, {v[1], v[3]}});
110}
111
112} // namespace
113

source code of mlir/unittests/Conversion/PDLToPDLInterp/RootOrderingTest.cpp