1//===- RootOrdering.cpp - Optimal root ordering ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// An implementation of Edmonds' optimal branching algorithm. This is a
10// directed analogue of the minimum spanning tree problem for a given root.
11//
12//===----------------------------------------------------------------------===//
13
14#include "RootOrdering.h"
15
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/DenseSet.h"
18#include "llvm/ADT/SmallVector.h"
19#include <queue>
20#include <utility>
21
22using namespace mlir;
23using namespace mlir::pdl_to_pdl_interp;
24
25/// Returns the cycle implied by the specified parent relation, starting at the
26/// given node.
27static SmallVector<Value> getCycle(const DenseMap<Value, Value> &parents,
28 Value rep) {
29 SmallVector<Value> cycle;
30 Value node = rep;
31 do {
32 cycle.push_back(Elt: node);
33 node = parents.lookup(Val: node);
34 assert(node && "got an empty value in the cycle");
35 } while (node != rep);
36 return cycle;
37}
38
39/// Contracts the specified cycle in the given graph in-place.
40/// The parentsCost map specifies, for each node in the cycle, the lowest cost
41/// among the edges entering that node. Then, the nodes in the cycle C are
42/// replaced with a single node v_C (the first node in the cycle). All edges
43/// (u, v) entering the cycle, v \in C, are replaced with a single edge
44/// (u, v_C) with an appropriately chosen cost, and the selected node v is
45/// marked in the output map actualTarget[u]. All edges (u, v) leaving the
46/// cycle, u \in C, are replaced with a single edge (v_C, v), and the selected
47/// node u is marked in the ouptut map actualSource[v].
48static void contract(RootOrderingGraph &graph, ArrayRef<Value> cycle,
49 const DenseMap<Value, unsigned> &parentDepths,
50 DenseMap<Value, Value> &actualSource,
51 DenseMap<Value, Value> &actualTarget) {
52 Value rep = cycle.front();
53 DenseSet<Value> cycleSet(cycle.begin(), cycle.end());
54
55 // Now, contract the cycle, marking the actual sources and targets.
56 DenseMap<Value, RootOrderingEntry> repEntries;
57 for (auto outer = graph.begin(), e = graph.end(); outer != e; ++outer) {
58 Value target = outer->first;
59 if (cycleSet.contains(V: target)) {
60 // Target in the cycle => edges incoming to the cycle or within the cycle.
61 unsigned parentDepth = parentDepths.lookup(Val: target);
62 for (const auto &inner : outer->second) {
63 Value source = inner.first;
64 // Ignore edges within the cycle.
65 if (cycleSet.contains(V: source))
66 continue;
67
68 // Edge incoming to the cycle.
69 std::pair<unsigned, unsigned> cost = inner.second.cost;
70 assert(parentDepth <= cost.first && "invalid parent depth");
71
72 // Subtract the cost of the parent within the cycle from the cost of
73 // the edge incoming to the cycle. This update ensures that the cost
74 // of the minimum-weight spanning arborescence of the entire graph is
75 // the cost of arborescence for the contracted graph plus the cost of
76 // the cycle, no matter which edge in the cycle we choose to drop.
77 cost.first -= parentDepth;
78 auto it = repEntries.find(Val: source);
79 if (it == repEntries.end() || it->second.cost > cost) {
80 actualTarget[source] = target;
81 // Do not bother populating the connector (the connector is only
82 // relevant for the final traversal, not for the optimal branching).
83 repEntries[source].cost = cost;
84 }
85 }
86 // Erase the node in the cycle.
87 graph.erase(I: outer);
88 } else {
89 // Target not in cycle => edges going away from or unrelated to the cycle.
90 DenseMap<Value, RootOrderingEntry> &entries = outer->second;
91 Value bestSource;
92 std::pair<unsigned, unsigned> bestCost;
93 auto inner = entries.begin(), innerE = entries.end();
94 while (inner != innerE) {
95 Value source = inner->first;
96 if (cycleSet.contains(V: source)) {
97 // Going-away edge => get its cost and erase it.
98 if (!bestSource || bestCost > inner->second.cost) {
99 bestSource = source;
100 bestCost = inner->second.cost;
101 }
102 entries.erase(I: inner++);
103 } else {
104 ++inner;
105 }
106 }
107
108 // There were going-away edges, contract them.
109 if (bestSource) {
110 entries[rep].cost = bestCost;
111 actualSource[target] = bestSource;
112 }
113 }
114 }
115
116 // Store the edges to the representative.
117 graph[rep] = std::move(repEntries);
118}
119
120OptimalBranching::OptimalBranching(RootOrderingGraph graph, Value root)
121 : graph(std::move(graph)), root(root) {}
122
123unsigned OptimalBranching::solve() {
124 // Initialize the parents and total cost.
125 parents.clear();
126 parents[root] = Value();
127 unsigned totalCost = 0;
128
129 // A map that stores the cost of the optimal local choice for each node
130 // in a directed cycle. This map is cleared every time we seed the search.
131 DenseMap<Value, unsigned> parentDepths;
132 parentDepths.reserve(NumEntries: graph.size());
133
134 // Determine if the optimal local choice results in an acyclic graph. This is
135 // done by computing the optimal local choice and traversing up the computed
136 // parents. On success, `parents` will contain the parent of each node.
137 for (const auto &outer : graph) {
138 Value node = outer.first;
139 if (parents.count(Val: node)) // already visited
140 continue;
141
142 // Follow the trail of best sources until we reach an already visited node.
143 // The code will assert if we cannot reach an already visited node, i.e.,
144 // the graph is not strongly connected.
145 parentDepths.clear();
146 do {
147 auto it = graph.find(Val: node);
148 assert(it != graph.end() && "the graph is not strongly connected");
149
150 // Find the best local parent, taking into account both the depth and the
151 // tie breaking rules.
152 Value &bestSource = parents[node];
153 std::pair<unsigned, unsigned> bestCost;
154 for (const auto &inner : it->second) {
155 const RootOrderingEntry &entry = inner.second;
156 if (!bestSource /* initial */ || bestCost > entry.cost) {
157 bestSource = inner.first;
158 bestCost = entry.cost;
159 }
160 }
161 assert(bestSource && "the graph is not strongly connected");
162 parentDepths[node] = bestCost.first;
163 node = bestSource;
164 totalCost += bestCost.first;
165 } while (!parents.count(Val: node));
166
167 // If we reached a non-root node, we have a cycle.
168 if (parentDepths.count(Val: node)) {
169 // Determine the cycle starting at the representative node.
170 SmallVector<Value> cycle = getCycle(parents, rep: node);
171
172 // The following maps disambiguate the source / target of the edges
173 // going out of / into the cycle.
174 DenseMap<Value, Value> actualSource, actualTarget;
175
176 // Contract the cycle and recurse.
177 contract(graph, cycle, parentDepths, actualSource, actualTarget);
178 totalCost = solve();
179
180 // Redirect the going-away edges.
181 for (auto &p : parents)
182 if (p.second == node)
183 // The parent is the node representating the cycle; replace it
184 // with the actual (best) source in the cycle.
185 p.second = actualSource.lookup(Val: p.first);
186
187 // Redirect the unique incoming edge and copy the cycle.
188 Value parent = parents.lookup(Val: node);
189 Value entry = actualTarget.lookup(Val: parent);
190 cycle.push_back(Elt: node); // complete the cycle
191 for (size_t i = 0, e = cycle.size() - 1; i < e; ++i) {
192 totalCost += parentDepths.lookup(Val: cycle[i]);
193 if (cycle[i] == entry)
194 parents[cycle[i]] = parent; // break the cycle
195 else
196 parents[cycle[i]] = cycle[i + 1];
197 }
198
199 // `parents` has a complete solution.
200 break;
201 }
202 }
203
204 return totalCost;
205}
206
207OptimalBranching::EdgeList
208OptimalBranching::preOrderTraversal(ArrayRef<Value> nodes) const {
209 // Invert the parent mapping.
210 DenseMap<Value, std::vector<Value>> children;
211 for (Value node : nodes) {
212 if (node != root) {
213 Value parent = parents.lookup(Val: node);
214 assert(parent && "invalid parent");
215 children[parent].push_back(x: node);
216 }
217 }
218
219 // The result which simultaneously acts as a queue.
220 EdgeList result;
221 result.reserve(n: nodes.size());
222 result.emplace_back(args: root, args: Value());
223
224 // Perform a BFS, pushing into the queue.
225 for (size_t i = 0; i < result.size(); ++i) {
226 Value node = result[i].first;
227 for (Value child : children[node])
228 result.emplace_back(args&: child, args&: node);
229 }
230
231 return result;
232}
233

source code of mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp