1use crate::*;
2use std::collections::{HashMap, HashSet, VecDeque};
3
4/// A wrapper around [Graph](struct.Graph.html) to assist diffing.
5pub struct DiffGraph<'a> {
6 pub(crate) graph: &'a Graph,
7 pub(crate) dist_start: HashMap<&'a str, usize>,
8 pub(crate) dist_end: HashMap<&'a str, usize>,
9}
10
11impl<'a> DiffGraph<'a> {
12 pub fn new(graph: &'a Graph) -> Self {
13 let adj_list = graph.adj_list();
14 let rev_adj_list = graph.rev_adj_list();
15 let start_nodes = Self::get_source_labels(&adj_list);
16 let end_nodes = Self::get_source_labels(&rev_adj_list);
17 DiffGraph {
18 graph,
19 dist_start: Self::bfs_shortest_dist(rev_adj_list, start_nodes),
20 dist_end: Self::bfs_shortest_dist(adj_list, end_nodes),
21 }
22 }
23
24 /// Calculate the shortest distance to the end from the given sources nodes using bfs.
25 fn bfs_shortest_dist(adj_list: AdjList<'a>, source: Vec<&'a str>) -> HashMap<&'a str, usize> {
26 let mut dist = HashMap::new();
27 for k in source.iter() {
28 dist.insert(*k, 0);
29 }
30 let mut visited = HashSet::new();
31 let mut queue: VecDeque<&str> = source.into();
32 while let Some(node) = queue.pop_front() {
33 let neighbours = adj_list.get(node).unwrap();
34 let curr_dist = *dist.get(&node).unwrap();
35
36 for neighbour in neighbours {
37 if !visited.contains(neighbour) {
38 dist.insert(neighbour, curr_dist + 1);
39 queue.push_back(neighbour);
40 visited.insert(neighbour);
41 }
42 }
43 }
44
45 dist
46 }
47
48 /// Get the source labels for a given adjacency list. The source labels will the
49 // TODO: This is sink labels, not source labels
50 fn get_source_labels(adj_list: &AdjList<'a>) -> Vec<&'a str> {
51 adj_list
52 .iter()
53 .filter(|(_, v)| v.is_empty())
54 .map(|(k, _)| *k)
55 .collect()
56 }
57}
58