1 | use crate::*; |
2 | use std::collections::{HashMap, HashSet, VecDeque}; |
3 | |
4 | /// A wrapper around [Graph](struct.Graph.html) to assist diffing. |
5 | pub struct DiffGraph<'a> { |
6 | pub(crate) graph: &'a Graph, |
7 | pub(crate) dist_start: HashMap<&'a str, usize>, |
8 | pub(crate) dist_end: HashMap<&'a str, usize>, |
9 | } |
10 | |
11 | impl<'a> DiffGraph<'a> { |
12 | pub fn new(graph: &'a Graph) -> Self { |
13 | let adj_list = graph.adj_list(); |
14 | let rev_adj_list = graph.rev_adj_list(); |
15 | let start_nodes = Self::get_source_labels(&adj_list); |
16 | let end_nodes = Self::get_source_labels(&rev_adj_list); |
17 | DiffGraph { |
18 | graph, |
19 | dist_start: Self::bfs_shortest_dist(rev_adj_list, start_nodes), |
20 | dist_end: Self::bfs_shortest_dist(adj_list, end_nodes), |
21 | } |
22 | } |
23 | |
24 | /// Calculate the shortest distance to the end from the given sources nodes using bfs. |
25 | fn bfs_shortest_dist(adj_list: AdjList<'a>, source: Vec<&'a str>) -> HashMap<&'a str, usize> { |
26 | let mut dist = HashMap::new(); |
27 | for k in source.iter() { |
28 | dist.insert(*k, 0); |
29 | } |
30 | let mut visited = HashSet::new(); |
31 | let mut queue: VecDeque<&str> = source.into(); |
32 | while let Some(node) = queue.pop_front() { |
33 | let neighbours = adj_list.get(node).unwrap(); |
34 | let curr_dist = *dist.get(&node).unwrap(); |
35 | |
36 | for neighbour in neighbours { |
37 | if !visited.contains(neighbour) { |
38 | dist.insert(neighbour, curr_dist + 1); |
39 | queue.push_back(neighbour); |
40 | visited.insert(neighbour); |
41 | } |
42 | } |
43 | } |
44 | |
45 | dist |
46 | } |
47 | |
48 | /// Get the source labels for a given adjacency list. The source labels will the |
49 | // TODO: This is sink labels, not source labels |
50 | fn get_source_labels(adj_list: &AdjList<'a>) -> Vec<&'a str> { |
51 | adj_list |
52 | .iter() |
53 | .filter(|(_, v)| v.is_empty()) |
54 | .map(|(k, _)| *k) |
55 | .collect() |
56 | } |
57 | } |
58 | |