| 1 | //===- RootOrdering.h - Optimal root ordering ------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file contains definition for a cost graph over candidate roots and |
| 10 | // an implementation of an algorithm to determine the optimal ordering over |
| 11 | // these roots. Each edge in this graph indicates that the target root can be |
| 12 | // connected (via a chain of positions) to the source root, and their cost |
| 13 | // indicates the estimated cost of such traversal. The optimal root ordering |
| 14 | // is then formulated as that of finding a spanning arborescence (i.e., a |
| 15 | // directed spanning tree) of minimal weight. |
| 16 | // |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | |
| 19 | #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ |
| 20 | #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ |
| 21 | |
| 22 | #include "mlir/IR/Value.h" |
| 23 | #include "llvm/ADT/DenseMap.h" |
| 24 | #include "llvm/ADT/SmallVector.h" |
| 25 | #include <functional> |
| 26 | #include <vector> |
| 27 | |
| 28 | namespace mlir { |
| 29 | namespace pdl_to_pdl_interp { |
| 30 | |
| 31 | /// The information associated with an edge in the cost graph. Each node in |
| 32 | /// the cost graph corresponds to a candidate root detected in the pdl.pattern, |
| 33 | /// and each edge in the cost graph corresponds to connecting the two candidate |
| 34 | /// roots via a chain of operations. The cost of an edge is the smallest number |
| 35 | /// of upward traversals required to go from the source to the target root, and |
| 36 | /// the connector is a `Value` in the intersection of the two subtrees rooted at |
| 37 | /// the source and target root that results in that smallest number of upward |
| 38 | /// traversals. Consider the following pattern with 3 roots op3, op4, and op5: |
| 39 | /// |
| 40 | /// argA ---> op1 ---> op2 ---> op3 ---> res3 |
| 41 | /// ^ ^ |
| 42 | /// | | |
| 43 | /// argB argC |
| 44 | /// | | |
| 45 | /// v v |
| 46 | /// res4 <--- op4 op5 ---> res5 |
| 47 | /// ^ ^ |
| 48 | /// | | |
| 49 | /// op6 op7 |
| 50 | /// |
| 51 | /// The cost of the edge op3 -> op4 is 1 (the upward traversal argB -> op4), |
| 52 | /// with argB being the connector `Value` and similarly for op3 -> op5 (cost 1, |
| 53 | /// connector argC). The cost of the edge op4 -> op3 is 3 (upward traversals |
| 54 | /// argB -> op1 -> op2 -> op3, connector argB), while the cost of edge op5 -> |
| 55 | /// op3 is 2 (uwpard traversals argC -> op2 -> op3). There are no edges between |
| 56 | /// op4 and op5 in the cost graph, because the subtrees rooted at these two |
| 57 | /// roots do not intersect. It is easy to see that the optimal root for this |
| 58 | /// pattern is op3, resulting in the spanning arborescence op3 -> {op4, op5}. |
| 59 | struct RootOrderingEntry { |
| 60 | /// The depth of the connector `Value` w.r.t. the target root. |
| 61 | /// |
| 62 | /// This is a pair where the first value is the additive cost (the depth of |
| 63 | /// the connector), and the second value is a priority for breaking ties |
| 64 | /// (with 0 being the highest). Typically, the priority is a unique edge ID. |
| 65 | std::pair<unsigned, unsigned> cost; |
| 66 | |
| 67 | /// The connector value in the intersection of the two subtrees rooted at |
| 68 | /// the source and target root that results in that smallest depth w.r.t. |
| 69 | /// the target root. |
| 70 | Value connector; |
| 71 | }; |
| 72 | |
| 73 | /// A directed graph representing the cost of ordering the roots in the |
| 74 | /// predicate tree. It is represented as an adjacency map, where the outer map |
| 75 | /// is indexed by the target node, and the inner map is indexed by the source |
| 76 | /// node. Each edge is associated with a cost and the underlying connector |
| 77 | /// value. |
| 78 | using RootOrderingGraph = DenseMap<Value, DenseMap<Value, RootOrderingEntry>>; |
| 79 | |
| 80 | /// The optimal branching algorithm solver. This solver accepts a graph and the |
| 81 | /// root in its constructor, and is invoked via the solve() member function. |
| 82 | /// This is a direct implementation of the Edmonds' algorithm, see |
| 83 | /// https://en.wikipedia.org/wiki/Edmonds%27_algorithm. The worst-case |
| 84 | /// computational complexity of this algorithm is O(N^3), for a single root. |
| 85 | /// The PDL-to-PDLInterp lowering calls this N times (once for each candidate |
| 86 | /// root), so the overall complexity root ordering is O(N^4). If needed, this |
| 87 | /// could be reduced to O(N^3) with a more efficient algorithm. However, note |
| 88 | /// that the underlying implementation is very efficient, and N in our |
| 89 | /// instances tends to be very small (<10). |
| 90 | class OptimalBranching { |
| 91 | public: |
| 92 | /// A list of edges (child, parent). |
| 93 | using EdgeList = std::vector<std::pair<Value, Value>>; |
| 94 | |
| 95 | /// Constructs the solver for the given graph and root value. |
| 96 | OptimalBranching(RootOrderingGraph graph, Value root); |
| 97 | |
| 98 | /// Runs the Edmonds' algorithm for the current `graph`, returning the total |
| 99 | /// cost of the minimum-weight spanning arborescence (sum of the edge costs). |
| 100 | /// This function first determines the optimal local choice of the parents |
| 101 | /// and stores this choice in the `parents` mapping. If this choice results |
| 102 | /// in an acyclic graph, the function returns immediately. Otherwise, it |
| 103 | /// takes an arbitrary cycle, contracts it, and recurses on the new graph |
| 104 | /// (which is guaranteed to have fewer nodes than we began with). After we |
| 105 | /// return from recursion, we redirect the edges to/from the contracted node, |
| 106 | /// so the `parents` map contains a valid solution for the current graph. |
| 107 | unsigned solve(); |
| 108 | |
| 109 | /// Returns the computed parent map. This is the unique predecessor for each |
| 110 | /// node (root) in the optimal branching. |
| 111 | const DenseMap<Value, Value> &getRootOrderingParents() const { |
| 112 | return parents; |
| 113 | } |
| 114 | |
| 115 | /// Returns the computed edges as visited in the preorder traversal. |
| 116 | /// The specified array determines the order for breaking any ties. |
| 117 | EdgeList preOrderTraversal(ArrayRef<Value> nodes) const; |
| 118 | |
| 119 | private: |
| 120 | /// The graph whose optimal branching we wish to determine. |
| 121 | RootOrderingGraph graph; |
| 122 | |
| 123 | /// The root of the optimal branching. |
| 124 | Value root; |
| 125 | |
| 126 | /// The computed parent mapping. This is the unique predecessor for each node |
| 127 | /// in the optimal branching. The keys of this map correspond to the keys of |
| 128 | /// the outer map of the input graph, and each value is one of the keys of |
| 129 | /// the inner map for this node. Also used as an intermediate (possibly |
| 130 | /// cyclical) result in the optimal branching algorithm. |
| 131 | DenseMap<Value, Value> parents; |
| 132 | }; |
| 133 | |
| 134 | } // namespace pdl_to_pdl_interp |
| 135 | } // namespace mlir |
| 136 | |
| 137 | #endif // MLIR_CONVERSION_PDLTOPDLINTERP_ROOTORDERING_H_ |
| 138 | |