yagraphc/graph/
traits.rs

1//! Traits that correspond to the main functionalities of the crate.
2//!
3//! `Graph` is the main trait for working with general graph traversal, such as
4//! BFS and DFS.
5//!
6//! `ArithmeticallyWeightedGraph` is the main trait for working with path finding,
7//! such as Dijkstra's algorithm or A*. It is also intended to handle more general
8//! algorithms that rely on arithmetical weights in the future.
9use std::collections::BinaryHeap;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::VecDeque;
13use std::hash::Hash;
14
15use thiserror::Error;
16
17#[derive(Error, Debug)]
18#[error("node not found")]
19pub struct NodeNotFound;
20
21pub struct BfsIter<'a, T, W> {
22    pub(crate) queue: VecDeque<(T, usize)>,
23    pub(crate) visited: HashSet<T>,
24    pub(crate) graph: &'a dyn Traversable<T, W>,
25}
26
27impl<'a, T, W> Iterator for BfsIter<'a, T, W>
28where
29    T: Clone + Copy + Hash + PartialEq + Eq,
30    W: Clone + Copy,
31{
32    type Item = (T, usize);
33
34    fn next(&mut self) -> Option<Self::Item> {
35        while let Some((node, depth)) = self.queue.pop_front() {
36            if self.visited.contains(&node) {
37                continue;
38            }
39
40            self.visited.insert(node);
41
42            for (next, _) in self.graph.edges(&node) {
43                if self.visited.contains(&next) {
44                    continue;
45                } else {
46                    self.queue.push_back((next, depth + 1))
47                }
48            }
49
50            return Some((node, depth));
51        }
52        None
53    }
54}
55
56pub struct DfsIter<'a, T, W> {
57    pub(crate) queue: VecDeque<(T, usize)>,
58    pub(crate) visited: HashSet<T>,
59    pub(crate) graph: &'a dyn Traversable<T, W>,
60}
61
62impl<'a, T, W> Iterator for DfsIter<'a, T, W>
63where
64    T: Clone + Copy + Hash + PartialEq + Eq,
65    W: Clone + Copy,
66{
67    type Item = (T, usize);
68
69    fn next(&mut self) -> Option<Self::Item> {
70        while let Some((node, depth)) = self.queue.pop_front() {
71            if self.visited.contains(&node) {
72                continue;
73            }
74
75            self.visited.insert(node);
76
77            for (next, _) in self.graph.edges(&node) {
78                if self.visited.contains(&next) {
79                    continue;
80                } else {
81                    self.queue.push_front((next, depth + 1))
82                }
83            }
84
85            return Some((node, depth));
86        }
87
88        None
89    }
90}
91
92pub struct PostOrderDfsIter<T> {
93    pub(crate) queue: VecDeque<(T, usize)>,
94}
95
96impl<'a, T> PostOrderDfsIter<T>
97where
98    T: Clone + Copy + Hash + PartialEq + Eq,
99{
100    pub fn new<W>(graph: &'a dyn Traversable<T, W>, from: T) -> Self
101    where
102        W: Clone + Copy,
103    {
104        let mut queue = VecDeque::new();
105        let mut visited = HashSet::new();
106
107        dfs_post_order(graph, from, 0, &mut queue, &mut visited);
108
109        Self { queue }
110    }
111}
112
113fn dfs_post_order<T, W>(
114    graph: &dyn Traversable<T, W>,
115    node: T,
116    depth: usize,
117    queue: &mut VecDeque<(T, usize)>,
118    visited: &mut HashSet<T>,
119) where
120    T: Clone + Copy + Hash + PartialEq + Eq,
121    W: Clone + Copy,
122{
123    visited.insert(node);
124    for (next, _) in graph.edges(&node) {
125        if visited.contains(&next) {
126            continue;
127        }
128
129        dfs_post_order(graph, next, depth + 1, queue, visited);
130    }
131
132    queue.push_back((node, depth));
133}
134
135impl<T> Iterator for PostOrderDfsIter<T>
136where
137    T: Clone + Copy + Hash + PartialEq + Eq,
138{
139    type Item = (T, usize);
140
141    fn next(&mut self) -> Option<Self::Item> {
142        self.queue.pop_front()
143    }
144}
145
146pub struct NodeIter<'a, T> {
147    pub(crate) nodes_iter: std::collections::hash_set::Iter<'a, T>,
148}
149
150impl<T> Iterator for NodeIter<'_, T>
151where
152    T: Clone + Copy,
153{
154    type Item = T;
155    fn next(&mut self) -> Option<Self::Item> {
156        self.nodes_iter.next().copied()
157    }
158}
159
160pub enum EdgeIterType<'a, T, W> {
161    EdgeIter(EdgeIter<'a, T, W>),
162    EdgeIterVec(EdgeIterVec<'a, T, W>),
163}
164
165impl<'a, T, W> Iterator for EdgeIterType<'a, T, W>
166where
167    T: Clone + Copy,
168    W: Clone + Copy,
169{
170    type Item = (T, W);
171    fn next(&mut self) -> Option<Self::Item> {
172        match self {
173            EdgeIterType::EdgeIter(iter) => iter.next(),
174            EdgeIterType::EdgeIterVec(iter) => iter.next(),
175        }
176    }
177}
178
179pub struct EdgeIter<'a, T, W> {
180    pub(crate) edge_iter: std::collections::hash_map::Iter<'a, T, W>,
181}
182
183impl<T, W> Iterator for EdgeIter<'_, T, W>
184where
185    T: Clone + Copy,
186    W: Clone + Copy,
187{
188    type Item = (T, W);
189    fn next(&mut self) -> Option<Self::Item> {
190        self.edge_iter.next().map(copy_tuple)
191    }
192}
193
194pub struct EdgeIterVec<'a, T, W> {
195    pub(crate) edge_iter: core::slice::Iter<'a, (T, W)>,
196}
197
198impl<T, W> Iterator for EdgeIterVec<'_, T, W>
199where
200    T: Clone + Copy,
201    W: Clone + Copy,
202{
203    type Item = (T, W);
204    fn next(&mut self) -> Option<Self::Item> {
205        self.edge_iter.next().copied()
206    }
207}
208
209pub trait GraphBuilding<T, W>
210where
211    T: Clone + Copy + Eq + Hash + PartialEq,
212    W: Clone + Copy,
213{
214    fn add_edge(&mut self, from: T, to: T, weight: W);
215
216    fn add_node(&mut self, node: T) -> bool;
217
218    fn remove_edge(&mut self, from: T, to: T) -> Result<(), NodeNotFound>;
219
220    fn remove_node(&mut self, node: T) -> Result<(), NodeNotFound>;
221
222    fn has_edge(&self, from: T, to: T) -> bool;
223}
224
225pub trait Traversable<T, W>
226where
227    T: Clone + Copy + Eq + Hash + PartialEq,
228    W: Clone + Copy,
229{
230    /// Iterates over edges of the node as the target nodes and the edge weight.
231    ///
232    /// If the graph is undirected, should return the nodes that are connected to it by
233    /// an edge.
234    ///
235    /// If the graph is directed, should return the outbound nodes.
236    ///
237    /// # Examples
238    /// ```rust
239    /// use yagraphc::graph::{UnGraph, DiGraph};
240    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
241    ///
242    /// let mut graph = UnGraph::default();
243    ///
244    /// graph.add_edge(1, 2, ());
245    /// graph.add_edge(2, 3, ());
246    ///
247    /// let edges = graph.edges(&2);
248    ///
249    /// assert_eq!(edges.count(), 2);
250    ///
251    /// let mut graph = DiGraph::default();
252    ///
253    /// graph.add_edge(1, 2, ());
254    /// graph.add_edge(2, 3, ());
255    ///
256    /// let edges = graph.edges(&2);
257    ///
258    /// assert_eq!(edges.count(), 1);
259    fn edges(&self, n: &T) -> EdgeIterType<T, W>;
260
261    fn edge_weight(&self, from: T, to: T) -> Result<W, NodeNotFound>;
262
263    /// Iterates over inbound-edges of the node as the target nodes and the edge weight.
264    ///
265    /// If the graph is undirected, should return the nodes that are connected to it by
266    /// an edge. Thus, it is equivalent to `edges` in that case.
267    ///
268    /// If the graph is directed, should return the inbound nodes.
269    ///
270    /// # Examples
271    /// ```rust
272    /// use yagraphc::graph::{UnGraph, DiGraph};
273    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
274    ///
275    /// let mut graph = UnGraph::default();
276    ///
277    /// graph.add_edge(1, 2, ());
278    /// graph.add_edge(2, 3, ());
279    /// graph.add_edge(4, 2, ());
280    ///
281    /// let edges = graph.in_edges(&2);
282    ///
283    /// assert_eq!(edges.count(), 3);
284    ///
285    /// let mut graph = DiGraph::default();
286    ///
287    /// graph.add_edge(1, 2, ());
288    /// graph.add_edge(2, 3, ());
289    /// graph.add_edge(4, 2, ());
290    ///
291    /// let edges = graph.in_edges(&2);
292    ///
293    /// assert_eq!(edges.count(), 2);
294    fn in_edges(&self, n: &T) -> EdgeIterType<T, W>;
295
296    /// Returns an iterator over all nodes.
297    ///
298    /// # Examples
299    /// ```rust
300    /// use yagraphc::graph::UnGraph;
301    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
302    ///
303    /// let mut graph = UnGraph::default();
304    ///
305    /// graph.add_edge(1, 2, ());
306    /// graph.add_edge(2, 3, ());
307    ///
308    /// graph.add_node(4);
309    ///
310    /// assert_eq!(graph.nodes().count(), 4);
311    ///
312    /// graph.add_node(2);
313    /// assert_eq!(graph.nodes().count(), 4);
314    fn nodes(&self) -> NodeIter<T>;
315
316    /// Returns an iterator of nodes in breadth-first order.
317    ///
318    /// Iterator includes the depth at which the nodes were found. Nodes at the
319    /// same depth might be randomly shuffled.
320    ///
321    /// # Examples
322    /// ```
323    /// use yagraphc::graph::UnGraph;
324    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
325    ///
326    /// let mut graph = UnGraph::new();
327    ///
328    /// graph.add_edge(1, 2, ());
329    /// graph.add_edge(1, 3, ());
330    /// graph.add_edge(2, 4, ());
331    /// graph.add_edge(2, 5, ());
332    ///
333    /// let bfs = graph.bfs(1);
334    ///
335    /// let depths = bfs.map(|(_, depth)| depth).collect::<Vec<_>>();
336    ///
337    /// assert_eq!(depths, vec![0, 1, 1, 2, 2]);
338    fn bfs(&self, from: T) -> BfsIter<T, W>
339    where
340        Self: Sized,
341    {
342        let visited = HashSet::new();
343
344        let mut queue = VecDeque::new();
345        queue.push_front((from, 0));
346
347        BfsIter {
348            queue,
349            visited,
350            graph: self,
351        }
352    }
353
354    /// Returns an iterator of nodes in depth-first order, in pre-order.
355    ///
356    /// Iterator includes the depth at which the nodes were found. Order is not
357    /// deterministic.
358    ///
359    /// # Examples
360    /// ```
361    /// use yagraphc::graph::UnGraph;
362    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
363    ///
364    /// let mut graph = UnGraph::new();
365    ///
366    /// graph.add_edge(1, 2, ());
367    /// graph.add_edge(2, 3, ());
368    /// graph.add_edge(3, 4, ());
369    /// graph.add_edge(1, 5, ());
370    /// graph.add_edge(5, 6, ());
371    ///
372    /// let dfs = graph.dfs(1);
373    ///
374    /// let depths = dfs.map(|(node, _)| node).collect::<Vec<_>>();
375    ///
376    /// assert!(matches!(depths[..], [1, 2, 3, 4, 5, 6] | [1, 5, 6, 2, 3, 4]));
377    fn dfs(&self, from: T) -> DfsIter<T, W>
378    where
379        Self: Sized,
380    {
381        let visited = HashSet::new();
382
383        let mut queue = VecDeque::new();
384        queue.push_front((from, 0));
385
386        DfsIter {
387            queue,
388            visited,
389            graph: self,
390        }
391    }
392
393    /// Returns an iterator of nodes in depth-first order, in post-order.
394    ///
395    /// Iterator includes the depth at which the nodes were found. Order is not
396    /// deterministic.
397    ///
398    /// Currently implemented recursively. To be changed to a non-recursive
399    /// implemented at some point.
400    ///
401    /// # Examples
402    /// ```
403    /// use yagraphc::graph::UnGraph;
404    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
405    ///
406    /// let mut graph = UnGraph::new();
407    ///
408    /// graph.add_edge(1, 2, ());
409    /// graph.add_edge(2, 3, ());
410    /// graph.add_edge(3, 4, ());
411    /// graph.add_edge(1, 5, ());
412    /// graph.add_edge(5, 6, ());
413    ///
414    /// let dfs = graph.dfs_post_order(1);
415    ///
416    /// let depths = dfs.map(|(node, _)| node).collect::<Vec<_>>();
417    ///
418    /// assert!(matches!(depths[..], [6, 5, 4, 3, 2, 1] | [4, 3, 2, 6, 5, 1]));
419    // TODO: Implement post-order non-recursively.
420    fn dfs_post_order(&self, from: T) -> PostOrderDfsIter<T>
421    where
422        Self: Sized,
423    {
424        PostOrderDfsIter::new(self, from)
425    }
426
427    /// Finds path from `from` to `to` using BFS.
428    ///
429    /// Returns `None` if there is no path.
430    ///
431    /// # Examples
432    ///
433    /// ```rust
434    /// use yagraphc::graph::UnGraph;
435    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
436    ///
437    /// let mut graph = UnGraph::new();
438    ///
439    /// graph.add_edge(1, 2, ());
440    /// graph.add_edge(2, 3, ());
441    /// graph.add_edge(3, 4, ());
442    /// graph.add_edge(1, 5, ());
443    /// graph.add_edge(5, 6, ());
444    ///
445    /// let path = graph.find_path(1, 4);
446    ///
447    /// assert_eq!(path, Some(vec![1, 2, 3, 4]));
448    fn find_path(&self, from: T, to: T) -> Option<Vec<T>>
449    where
450        Self: Sized,
451    {
452        let mut visited = HashSet::new();
453        let mut pairs = HashMap::new();
454        let mut queue = VecDeque::new();
455
456        queue.push_back((from, from));
457
458        while let Some((prev, current)) = queue.pop_front() {
459            if visited.contains(&current) {
460                continue;
461            }
462            visited.insert(current);
463            pairs.insert(current, prev);
464
465            if current == to {
466                let mut node = current;
467
468                let mut path = Vec::new();
469                while node != from {
470                    path.push(node);
471
472                    node = pairs[&node];
473                }
474
475                path.push(from);
476
477                path.reverse();
478
479                return Some(path);
480            }
481
482            for (target, _) in self.edges(&current) {
483                if visited.contains(&target) {
484                    continue;
485                }
486
487                queue.push_back((current, target));
488            }
489        }
490
491        None
492    }
493
494    /// Finds path from `from` to `to` using BFS while filtering edges.
495    ///
496    /// Returns `None` if there is no path.
497    ///
498    /// For an undirected graph, strive to make predicate(x,y) == predicate(y,x).
499    /// As of now, the order is related to the exploration direction of bfs,
500    /// but it is advisable not to rely on direction concepts on undirected graphs.
501    ///
502    /// # Examples
503    ///
504    /// ```rust
505    /// use yagraphc::graph::UnGraph;
506    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
507    /// let mut graph = UnGraph::default();
508
509    /// graph.add_edge(1, 2, ());
510    /// graph.add_edge(2, 3, ());
511    /// graph.add_edge(3, 4, ());
512    /// graph.add_edge(4, 5, ());
513
514    /// graph.add_edge(1, 7, ());
515    /// graph.add_edge(7, 5, ());
516
517    /// let path = graph.find_path(1, 5).unwrap();
518
519    /// assert_eq!(path, vec![1, 7, 5]);
520
521    /// let path = graph
522    ///     .find_path_filter_edges(1, 5, |x, y| (x, y) != (1, 7))
523    ///     .unwrap();
524
525    /// assert_eq!(path, vec![1, 2, 3, 4, 5]);
526    fn find_path_filter_edges<G>(&self, from: T, to: T, predicate: G) -> Option<Vec<T>>
527    where
528        Self: Sized,
529        G: Fn(T, T) -> bool,
530    {
531        let mut visited = HashSet::new();
532        let mut pairs = HashMap::new();
533        let mut queue = VecDeque::new();
534
535        queue.push_back((from, from));
536
537        while let Some((prev, current)) = queue.pop_front() {
538            if visited.contains(&current) {
539                continue;
540            }
541            visited.insert(current);
542            pairs.insert(current, prev);
543
544            if current == to {
545                let mut node = current;
546
547                let mut path = Vec::new();
548                while node != from {
549                    path.push(node);
550
551                    node = pairs[&node];
552                }
553
554                path.push(from);
555
556                path.reverse();
557
558                return Some(path);
559            }
560
561            for (target, _) in self.edges(&current) {
562                if visited.contains(&target) || !predicate(current, target) {
563                    continue;
564                }
565
566                queue.push_back((current, target));
567            }
568        }
569
570        None
571    }
572
573    /// Returns a list of connected components of the graph.
574    ///
575    /// If being used in a directed graph, those are the strongly connected components,
576    /// computed using Kosaraju's algorithm.
577    ///
578    /// # Examples
579    ///
580    /// ```rust
581    /// use yagraphc::graph::{UnGraph, DiGraph};
582    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
583    ///
584    /// let mut graph = UnGraph::new();
585    ///
586    /// graph.add_edge(1, 2, ());
587    /// graph.add_edge(2, 3, ());
588    ///
589    /// graph.add_edge(4, 5, ());
590    /// graph.add_edge(5, 6, ());
591    /// graph.add_edge(6, 4, ());
592    ///
593    /// let components = graph.connected_components();
594    ///
595    /// assert_eq!(components.len(), 2);
596    ///
597    /// let mut graph = DiGraph::new();
598    ///
599    /// graph.add_edge(1, 2, ());
600    /// graph.add_edge(2, 3, ());
601    ///
602    /// graph.add_edge(4, 5, ());
603    /// graph.add_edge(5, 6, ());
604    /// graph.add_edge(6, 4, ());
605    ///
606    /// let components = graph.connected_components();
607    ///
608    /// assert_eq!(components.len(), 4);
609    fn connected_components(&self) -> Vec<Vec<T>>
610    where
611        Self: Sized,
612    {
613        let mut visited = HashSet::new();
614        let mut stack = Vec::new();
615
616        for node in self.nodes() {
617            if visited.contains(&node) {
618                continue;
619            }
620
621            for (inner_node, _) in self.dfs_post_order(node) {
622                visited.insert(inner_node);
623                stack.push(inner_node);
624            }
625        }
626
627        stack.reverse();
628
629        let mut visited = HashSet::new();
630        let mut components = Vec::new();
631
632        for node in stack {
633            if visited.contains(&node) {
634                continue;
635            }
636
637            let mut component = Vec::new();
638
639            let mut stack = Vec::new();
640            stack.push(node);
641
642            while let Some(node) = stack.pop() {
643                if visited.contains(&node) {
644                    continue;
645                }
646
647                component.push(node);
648                visited.insert(node);
649
650                for (inner_node, _) in self.in_edges(&node) {
651                    stack.push(inner_node);
652                }
653            }
654
655            components.push(component);
656        }
657
658        components
659    }
660}
661
662fn copy_tuple<T, W>(x: (&T, &W)) -> (T, W)
663where
664    T: Clone + Copy,
665    W: Clone + Copy,
666{
667    (*x.0, *x.1)
668}
669
670struct QueueEntry<T, W>
671where
672    T: Clone + Copy,
673    W: Ord + PartialOrd,
674{
675    pub node: T,
676    pub cur_cost: W,
677}
678
679impl<T, W> Ord for QueueEntry<T, W>
680where
681    T: Clone + Copy,
682    W: Ord + PartialOrd,
683{
684    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
685        other.cur_cost.cmp(&self.cur_cost)
686    }
687}
688
689impl<T, W> PartialOrd for QueueEntry<T, W>
690where
691    T: Clone + Copy,
692    W: Ord + PartialOrd,
693{
694    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
695        Some(self.cmp(other))
696    }
697}
698
699impl<T, W> PartialEq for QueueEntry<T, W>
700where
701    T: Clone + Copy,
702    W: Ord + PartialOrd,
703{
704    fn eq(&self, other: &Self) -> bool {
705        self.cur_cost.eq(&other.cur_cost)
706    }
707}
708
709impl<T, W> Eq for QueueEntry<T, W>
710where
711    T: Clone + Copy,
712    W: Ord + PartialOrd,
713{
714}
715
716pub trait ArithmeticallyWeightedGraph<T, W>
717where
718    T: Clone + Copy + Eq + Hash + PartialEq,
719    W: Clone
720        + Copy
721        + std::ops::Add<Output = W>
722        + std::ops::Sub<Output = W>
723        + PartialOrd
724        + Ord
725        + Default,
726{
727    /// Returns the shortest length among paths from `from` to `to`.
728    ///
729    /// # Examples
730    ///
731    /// ```rust
732    /// use yagraphc::graph::UnGraph;
733    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
734    /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
735    ///
736    /// let mut graph = UnGraph::new();
737    ///
738    /// graph.add_edge(1, 2, 1);
739    /// graph.add_edge(2, 3, 2);
740    /// graph.add_edge(3, 4, 3);
741    /// graph.add_edge(1, 4, 7);
742    ///
743    /// assert_eq!(graph.dijkstra(1, 4), Some(6));
744    /// assert_eq!(graph.dijkstra(1, 5), None);
745    fn dijkstra(&self, from: T, to: T) -> Option<W>
746    where
747        Self: Traversable<T, W>,
748    {
749        let mut visited = HashSet::new();
750        let mut distances = HashMap::new();
751
752        distances.insert(from, W::default());
753
754        let mut queue = BinaryHeap::new();
755        queue.push(QueueEntry {
756            node: from,
757            cur_cost: W::default(),
758        });
759
760        while let Some(QueueEntry {
761            node,
762            cur_cost: cur_dist,
763        }) = queue.pop()
764        {
765            if visited.contains(&node) {
766                continue;
767            }
768
769            if node == to {
770                return Some(cur_dist);
771            }
772
773            for (target, weight) in self.edges(&node) {
774                let mut distance = cur_dist + weight;
775
776                distances
777                    .entry(target)
778                    .and_modify(|dist| {
779                        let best_dist = (*dist).min(cur_dist + weight);
780
781                        distance = best_dist;
782                        *dist = best_dist;
783                    })
784                    .or_insert(cur_dist + weight);
785
786                if !visited.contains(&target) {
787                    queue.push(QueueEntry {
788                        node: target,
789                        cur_cost: distance,
790                    })
791                }
792            }
793
794            visited.insert(node);
795        }
796
797        None
798    }
799
800    /// Returns the shortest path among paths from `from` to `to`, together with its length.
801    ///
802    /// # Examples
803    ///
804    /// ```rust
805    /// use yagraphc::graph::UnGraph;
806    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
807    /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
808    ///
809    /// let mut graph = UnGraph::new();
810    ///
811    /// graph.add_edge(1, 2, 1);
812    /// graph.add_edge(2, 3, 2);
813    /// graph.add_edge(3, 4, 3);
814    /// graph.add_edge(1, 4, 7);
815    ///
816    /// assert_eq!(graph.dijkstra_with_path(1, 4).unwrap().0, vec![1, 2, 3, 4]);
817    /// assert_eq!(graph.dijkstra_with_path(1, 5), None);
818    fn dijkstra_with_path(&self, from: T, to: T) -> Option<(Vec<T>, W)>
819    where
820        Self: Traversable<T, W>,
821    {
822        let mut visited = HashSet::new();
823        let mut distances = HashMap::new();
824
825        distances.insert(from, (W::default(), from));
826
827        let mut queue = BinaryHeap::new();
828        queue.push(QueueEntry {
829            node: from,
830            cur_cost: W::default(),
831        });
832
833        while let Some(QueueEntry {
834            node,
835            cur_cost: cur_dist,
836        }) = queue.pop()
837        {
838            if visited.contains(&node) {
839                continue;
840            }
841
842            if node == to {
843                let mut path = Vec::new();
844
845                let mut node = node;
846
847                while node != from {
848                    path.push(node);
849
850                    node = distances[&node].1;
851                }
852
853                path.push(from);
854
855                path.reverse();
856
857                return Some((path, cur_dist));
858            }
859
860            for (target, weight) in self.edges(&node) {
861                distances
862                    .entry(target)
863                    .and_modify(|(dist, previous)| {
864                        if cur_dist + weight < *dist {
865                            *dist = cur_dist + weight;
866                            *previous = node;
867                        }
868                    })
869                    .or_insert((cur_dist + weight, node));
870
871                if !visited.contains(&target) {
872                    queue.push(QueueEntry {
873                        node: target,
874                        cur_cost: distances[&target].0,
875                    })
876                }
877            }
878
879            visited.insert(node);
880        }
881
882        None
883    }
884
885    /// Returns the shortest length among paths from `from` to `to` using A*.
886    ///
887    /// `heuristic` corresponds to the heuristic function of the A* algorithm.
888    ///
889    /// # Examples
890    ///
891    /// ```rust
892    /// use yagraphc::graph::UnGraph;
893    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
894    /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
895    ///
896    /// let mut graph = UnGraph::new();
897    ///
898    /// graph.add_edge(1, 2, 1);
899    /// graph.add_edge(2, 3, 2);
900    /// graph.add_edge(3, 4, 3);
901    /// graph.add_edge(1, 4, 7);
902    ///
903    /// assert_eq!(graph.a_star(1, 4, |_| 0), Some(6));
904    /// assert_eq!(graph.a_star(1, 5, |_| 0), None);
905    fn a_star<G>(&self, from: T, to: T, heuristic: G) -> Option<W>
906    where
907        Self: Traversable<T, W>,
908        G: Fn(T) -> W,
909    {
910        let mut visited = HashSet::new();
911        let mut distances = HashMap::new();
912
913        distances.insert(from, W::default());
914
915        let mut queue = BinaryHeap::new();
916        queue.push(QueueEntry {
917            node: from,
918            cur_cost: W::default() + heuristic(from),
919        });
920
921        while let Some(QueueEntry { node, .. }) = queue.pop() {
922            if visited.contains(&node) {
923                continue;
924            }
925
926            if node == to {
927                return Some(distances[&node]);
928            }
929
930            for (target, weight) in self.edges(&node) {
931                let mut distance = distances[&node] + weight;
932
933                distances
934                    .entry(target)
935                    .and_modify(|dist| {
936                        let best_dist = (*dist).min(distance);
937
938                        distance = best_dist;
939                        *dist = best_dist;
940                    })
941                    .or_insert(distance);
942
943                if !visited.contains(&target) {
944                    queue.push(QueueEntry {
945                        node: target,
946                        cur_cost: distance + heuristic(target),
947                    })
948                }
949            }
950
951            visited.insert(node);
952        }
953
954        None
955    }
956
957    /// Returns the shortest path from `from` to `to` using A*, together with its length.
958    ///
959    /// `heuristic` corresponds to the heuristic function of the A* algorithm.
960    ///
961    /// # Examples
962    ///
963    /// ```rust
964    /// use yagraphc::graph::UnGraph;
965    /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
966    /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
967    ///
968    /// let mut graph = UnGraph::new();
969    ///
970    /// graph.add_edge(1, 2, 1);
971    /// graph.add_edge(2, 3, 2);
972    /// graph.add_edge(3, 4, 3);
973    /// graph.add_edge(1, 4, 7);
974    ///
975    /// assert_eq!(graph.a_star_with_path(1, 4, |_| 0).unwrap().0, vec![1, 2, 3, 4]);
976    /// assert_eq!(graph.a_star_with_path(1, 5, |_| 0), None);
977    fn a_star_with_path<G>(&self, from: T, to: T, heuristic: G) -> Option<(Vec<T>, W)>
978    where
979        Self: Traversable<T, W>,
980        G: Fn(T) -> W,
981    {
982        let mut visited = HashSet::new();
983        let mut distances = HashMap::new();
984
985        distances.insert(from, (W::default(), from));
986
987        let mut queue = BinaryHeap::new();
988        queue.push(QueueEntry {
989            node: from,
990            cur_cost: W::default() + heuristic(from),
991        });
992
993        while let Some(QueueEntry { node, .. }) = queue.pop() {
994            if visited.contains(&node) {
995                continue;
996            }
997
998            if node == to {
999                let mut path = Vec::new();
1000
1001                let mut node = node;
1002
1003                while node != from {
1004                    path.push(node);
1005
1006                    node = distances[&node].1;
1007                }
1008
1009                path.push(from);
1010
1011                path.reverse();
1012
1013                return Some((path, distances[&node].0));
1014            }
1015
1016            for (target, weight) in self.edges(&node) {
1017                let mut distance = distances[&node].0 + weight;
1018
1019                distances
1020                    .entry(target)
1021                    .and_modify(|(dist, prev)| {
1022                        if distance < *dist {
1023                            *dist = distance;
1024                            *prev = node;
1025                        } else {
1026                            distance = *dist
1027                        }
1028                    })
1029                    .or_insert((distance, node));
1030
1031                if !visited.contains(&target) {
1032                    queue.push(QueueEntry {
1033                        node: target,
1034                        cur_cost: distance + heuristic(target),
1035                    })
1036                }
1037            }
1038
1039            visited.insert(node);
1040        }
1041
1042        None
1043    }
1044
1045    /// Runs the Edmonds-Karp algorithm on the graph to find max flow.
1046    ///
1047    /// Assumes the edge weights are the capacities.
1048    ///
1049    /// Returns a HashMap with the flow values for each edge.
1050    ///
1051    /// Please select a number type for W which allows for subtraction
1052    /// and negative values, otherwise there may be undefined behavior.
1053    ///
1054    /// # Examples
1055    ///
1056    /// ```rust
1057    /// use yagraphc::graph::UnGraph;
1058    /// use yagraphc::graph::traits::{ArithmeticallyWeightedGraph, GraphBuilding, Traversable};
1059    ///
1060    /// let mut graph = UnGraph::new();
1061    /// graph.add_edge(1, 2, 1000);
1062    /// graph.add_edge(2, 4, 1000);
1063    /// graph.add_edge(1, 3, 1000);
1064    /// graph.add_edge(3, 4, 1000);
1065    ///
1066    /// graph.add_edge(2, 3, 1);
1067    ///
1068    /// let flows = graph.edmonds_karp(1, 4);
1069    /// assert_eq!(*flows.get(&(1, 2)).unwrap(), 1000);
1070    /// assert_eq!(*flows.get(&(1, 3)).unwrap(), 1000);
1071    ///
1072    /// assert_eq!(*flows.get(&(2, 3)).unwrap_or(&0), 0);
1073    fn edmonds_karp(&self, source: T, sink: T) -> HashMap<(T, T), W>
1074    where
1075        Self: Traversable<T, W> + Sized,
1076    {
1077        let flows = HashMap::new();
1078
1079        let mut residual_obtension = ResidualNetwork { flows, graph: self };
1080
1081        while let Some(path) = self.find_path_filter_edges(source, sink, |x, y| {
1082            residual_obtension.get_residual_capacity(x, y) > W::default()
1083        }) {
1084            let residuals_in_path = path
1085                .iter()
1086                .zip(&path[1..])
1087                .map(|(&x, &y)| residual_obtension.get_residual_capacity(x, y));
1088
1089            let min_res = residuals_in_path.min().expect("Path should not be empty");
1090
1091            path.iter().zip(&path[1..]).for_each(|(&x, &y)| {
1092                residual_obtension
1093                    .flows
1094                    .entry((x, y))
1095                    .and_modify(|v| *v = *v + min_res)
1096                    .or_insert(min_res);
1097                residual_obtension
1098                    .flows
1099                    .entry((y, x))
1100                    .and_modify(|v| *v = *v - min_res)
1101                    .or_insert(W::default() - min_res);
1102            });
1103
1104            residual_obtension = ResidualNetwork {
1105                flows: residual_obtension.flows,
1106                graph: self,
1107            };
1108        }
1109
1110        residual_obtension.flows
1111    }
1112}
1113
1114struct ResidualNetwork<'a, T, W> {
1115    flows: HashMap<(T, T), W>,
1116    graph: &'a dyn Traversable<T, W>,
1117}
1118
1119impl<'a, T, W> ResidualNetwork<'a, T, W>
1120where
1121    T: Clone + Copy + Eq + Hash + PartialEq,
1122    W: Clone
1123        + Copy
1124        + std::ops::Add<Output = W>
1125        + std::ops::Sub<Output = W>
1126        + PartialOrd
1127        + Ord
1128        + Default,
1129{
1130    fn get_residual_capacity(&self, s: T, t: T) -> W {
1131        self.graph
1132            .edge_weight(s, t)
1133            .expect("Should only be considering existing edges")
1134            - self.flows.get(&(s, t)).copied().unwrap_or(W::default())
1135    }
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140    use super::*;
1141
1142    #[test]
1143    fn test_error() {
1144        let e = NodeNotFound;
1145        dbg!(&e);
1146    }
1147
1148    #[test]
1149    fn test_queue_entry() {
1150        let queue_entry1 = QueueEntry {
1151            node: 1,
1152            cur_cost: 12,
1153        };
1154        let queue_entry2 = QueueEntry {
1155            node: 4,
1156            cur_cost: 12,
1157        };
1158
1159        assert!(queue_entry1 == queue_entry2);
1160    }
1161}