| 1 | use crate::*; |
| 2 | use std::collections::{HashMap, HashSet, VecDeque}; |
| 3 | |
| 4 | /// A wrapper around [Graph](struct.Graph.html) to assist diffing. |
| 5 | pub struct DiffGraph<'a> { |
| 6 | pub(crate) graph: &'a Graph, |
| 7 | pub(crate) dist_start: HashMap<&'a str, usize>, |
| 8 | pub(crate) dist_end: HashMap<&'a str, usize>, |
| 9 | } |
| 10 | |
| 11 | impl<'a> DiffGraph<'a> { |
| 12 | pub fn new(graph: &'a Graph) -> Self { |
| 13 | let adj_list = graph.adj_list(); |
| 14 | let rev_adj_list = graph.rev_adj_list(); |
| 15 | let start_nodes = Self::get_source_labels(&adj_list); |
| 16 | let end_nodes = Self::get_source_labels(&rev_adj_list); |
| 17 | DiffGraph { |
| 18 | graph, |
| 19 | dist_start: Self::bfs_shortest_dist(rev_adj_list, start_nodes), |
| 20 | dist_end: Self::bfs_shortest_dist(adj_list, end_nodes), |
| 21 | } |
| 22 | } |
| 23 | |
| 24 | /// Calculate the shortest distance to the end from the given sources nodes using bfs. |
| 25 | fn bfs_shortest_dist(adj_list: AdjList<'a>, source: Vec<&'a str>) -> HashMap<&'a str, usize> { |
| 26 | let mut dist = HashMap::new(); |
| 27 | for k in source.iter() { |
| 28 | dist.insert(*k, 0); |
| 29 | } |
| 30 | let mut visited = HashSet::new(); |
| 31 | let mut queue: VecDeque<&str> = source.into(); |
| 32 | while let Some(node) = queue.pop_front() { |
| 33 | let neighbours = adj_list.get(node).unwrap(); |
| 34 | let curr_dist = *dist.get(&node).unwrap(); |
| 35 | |
| 36 | for neighbour in neighbours { |
| 37 | if !visited.contains(neighbour) { |
| 38 | dist.insert(neighbour, curr_dist + 1); |
| 39 | queue.push_back(neighbour); |
| 40 | visited.insert(neighbour); |
| 41 | } |
| 42 | } |
| 43 | } |
| 44 | |
| 45 | dist |
| 46 | } |
| 47 | |
| 48 | /// Get the source labels for a given adjacency list. The source labels will the |
| 49 | // TODO: This is sink labels, not source labels |
| 50 | fn get_source_labels(adj_list: &AdjList<'a>) -> Vec<&'a str> { |
| 51 | adj_list |
| 52 | .iter() |
| 53 | .filter(|(_, v)| v.is_empty()) |
| 54 | .map(|(k, _)| *k) |
| 55 | .collect() |
| 56 | } |
| 57 | } |
| 58 | |