| 1 | use crate::*; | 
| 2 | use std::collections::{HashMap, HashSet, VecDeque}; | 
|---|
| 3 |  | 
|---|
| 4 | /// A wrapper around [Graph](struct.Graph.html) to assist diffing. | 
|---|
| 5 | pub struct DiffGraph<'a> { | 
|---|
| 6 | pub(crate) graph: &'a Graph, | 
|---|
| 7 | pub(crate) dist_start: HashMap<&'a str, usize>, | 
|---|
| 8 | pub(crate) dist_end: HashMap<&'a str, usize>, | 
|---|
| 9 | } | 
|---|
| 10 |  | 
|---|
| 11 | impl<'a> DiffGraph<'a> { | 
|---|
| 12 | pub fn new(graph: &'a Graph) -> Self { | 
|---|
| 13 | let adj_list = graph.adj_list(); | 
|---|
| 14 | let rev_adj_list = graph.rev_adj_list(); | 
|---|
| 15 | let start_nodes = Self::get_source_labels(&adj_list); | 
|---|
| 16 | let end_nodes = Self::get_source_labels(&rev_adj_list); | 
|---|
| 17 | DiffGraph { | 
|---|
| 18 | graph, | 
|---|
| 19 | dist_start: Self::bfs_shortest_dist(rev_adj_list, start_nodes), | 
|---|
| 20 | dist_end: Self::bfs_shortest_dist(adj_list, end_nodes), | 
|---|
| 21 | } | 
|---|
| 22 | } | 
|---|
| 23 |  | 
|---|
| 24 | /// Calculate the shortest distance to the end from the given sources nodes using bfs. | 
|---|
| 25 | fn bfs_shortest_dist(adj_list: AdjList<'a>, source: Vec<&'a str>) -> HashMap<&'a str, usize> { | 
|---|
| 26 | let mut dist = HashMap::new(); | 
|---|
| 27 | for k in source.iter() { | 
|---|
| 28 | dist.insert(*k, 0); | 
|---|
| 29 | } | 
|---|
| 30 | let mut visited = HashSet::new(); | 
|---|
| 31 | let mut queue: VecDeque<&str> = source.into(); | 
|---|
| 32 | while let Some(node) = queue.pop_front() { | 
|---|
| 33 | let neighbours = adj_list.get(node).unwrap(); | 
|---|
| 34 | let curr_dist = *dist.get(&node).unwrap(); | 
|---|
| 35 |  | 
|---|
| 36 | for neighbour in neighbours { | 
|---|
| 37 | if !visited.contains(neighbour) { | 
|---|
| 38 | dist.insert(neighbour, curr_dist + 1); | 
|---|
| 39 | queue.push_back(neighbour); | 
|---|
| 40 | visited.insert(neighbour); | 
|---|
| 41 | } | 
|---|
| 42 | } | 
|---|
| 43 | } | 
|---|
| 44 |  | 
|---|
| 45 | dist | 
|---|
| 46 | } | 
|---|
| 47 |  | 
|---|
| 48 | /// Get the source labels for a given adjacency list. The source labels will the | 
|---|
| 49 | // TODO: This is sink labels, not source labels | 
|---|
| 50 | fn get_source_labels(adj_list: &AdjList<'a>) -> Vec<&'a str> { | 
|---|
| 51 | adj_list | 
|---|
| 52 | .iter() | 
|---|
| 53 | .filter(|(_, v)| v.is_empty()) | 
|---|
| 54 | .map(|(k, _)| *k) | 
|---|
| 55 | .collect() | 
|---|
| 56 | } | 
|---|
| 57 | } | 
|---|
| 58 |  | 
|---|