yagraphc/graph/traits.rs
1//! Traits that correspond to the main functionalities of the crate.
2//!
3//! `Graph` is the main trait for working with general graph traversal, such as
4//! BFS and DFS.
5//!
6//! `ArithmeticallyWeightedGraph` is the main trait for working with path finding,
7//! such as Dijkstra's algorithm or A*. It is also intended to handle more general
8//! algorithms that rely on arithmetical weights in the future.
9use std::collections::BinaryHeap;
10use std::collections::HashMap;
11use std::collections::HashSet;
12use std::collections::VecDeque;
13use std::hash::Hash;
14
15use thiserror::Error;
16
17#[derive(Error, Debug)]
18#[error("node not found")]
19pub struct NodeNotFound;
20
21pub struct BfsIter<'a, T, W> {
22 pub(crate) queue: VecDeque<(T, usize)>,
23 pub(crate) visited: HashSet<T>,
24 pub(crate) graph: &'a dyn Traversable<T, W>,
25}
26
27impl<'a, T, W> Iterator for BfsIter<'a, T, W>
28where
29 T: Clone + Copy + Hash + PartialEq + Eq,
30 W: Clone + Copy,
31{
32 type Item = (T, usize);
33
34 fn next(&mut self) -> Option<Self::Item> {
35 while let Some((node, depth)) = self.queue.pop_front() {
36 if self.visited.contains(&node) {
37 continue;
38 }
39
40 self.visited.insert(node);
41
42 for (next, _) in self.graph.edges(&node) {
43 if self.visited.contains(&next) {
44 continue;
45 } else {
46 self.queue.push_back((next, depth + 1))
47 }
48 }
49
50 return Some((node, depth));
51 }
52 None
53 }
54}
55
56pub struct DfsIter<'a, T, W> {
57 pub(crate) queue: VecDeque<(T, usize)>,
58 pub(crate) visited: HashSet<T>,
59 pub(crate) graph: &'a dyn Traversable<T, W>,
60}
61
62impl<'a, T, W> Iterator for DfsIter<'a, T, W>
63where
64 T: Clone + Copy + Hash + PartialEq + Eq,
65 W: Clone + Copy,
66{
67 type Item = (T, usize);
68
69 fn next(&mut self) -> Option<Self::Item> {
70 while let Some((node, depth)) = self.queue.pop_front() {
71 if self.visited.contains(&node) {
72 continue;
73 }
74
75 self.visited.insert(node);
76
77 for (next, _) in self.graph.edges(&node) {
78 if self.visited.contains(&next) {
79 continue;
80 } else {
81 self.queue.push_front((next, depth + 1))
82 }
83 }
84
85 return Some((node, depth));
86 }
87
88 None
89 }
90}
91
92pub struct PostOrderDfsIter<T> {
93 pub(crate) queue: VecDeque<(T, usize)>,
94}
95
96impl<'a, T> PostOrderDfsIter<T>
97where
98 T: Clone + Copy + Hash + PartialEq + Eq,
99{
100 pub fn new<W>(graph: &'a dyn Traversable<T, W>, from: T) -> Self
101 where
102 W: Clone + Copy,
103 {
104 let mut queue = VecDeque::new();
105 let mut visited = HashSet::new();
106
107 dfs_post_order(graph, from, 0, &mut queue, &mut visited);
108
109 Self { queue }
110 }
111}
112
113fn dfs_post_order<T, W>(
114 graph: &dyn Traversable<T, W>,
115 node: T,
116 depth: usize,
117 queue: &mut VecDeque<(T, usize)>,
118 visited: &mut HashSet<T>,
119) where
120 T: Clone + Copy + Hash + PartialEq + Eq,
121 W: Clone + Copy,
122{
123 visited.insert(node);
124 for (next, _) in graph.edges(&node) {
125 if visited.contains(&next) {
126 continue;
127 }
128
129 dfs_post_order(graph, next, depth + 1, queue, visited);
130 }
131
132 queue.push_back((node, depth));
133}
134
135impl<T> Iterator for PostOrderDfsIter<T>
136where
137 T: Clone + Copy + Hash + PartialEq + Eq,
138{
139 type Item = (T, usize);
140
141 fn next(&mut self) -> Option<Self::Item> {
142 self.queue.pop_front()
143 }
144}
145
146pub struct NodeIter<'a, T> {
147 pub(crate) nodes_iter: std::collections::hash_set::Iter<'a, T>,
148}
149
150impl<T> Iterator for NodeIter<'_, T>
151where
152 T: Clone + Copy,
153{
154 type Item = T;
155 fn next(&mut self) -> Option<Self::Item> {
156 self.nodes_iter.next().copied()
157 }
158}
159
160pub enum EdgeIterType<'a, T, W> {
161 EdgeIter(EdgeIter<'a, T, W>),
162 EdgeIterVec(EdgeIterVec<'a, T, W>),
163}
164
165impl<'a, T, W> Iterator for EdgeIterType<'a, T, W>
166where
167 T: Clone + Copy,
168 W: Clone + Copy,
169{
170 type Item = (T, W);
171 fn next(&mut self) -> Option<Self::Item> {
172 match self {
173 EdgeIterType::EdgeIter(iter) => iter.next(),
174 EdgeIterType::EdgeIterVec(iter) => iter.next(),
175 }
176 }
177}
178
179pub struct EdgeIter<'a, T, W> {
180 pub(crate) edge_iter: std::collections::hash_map::Iter<'a, T, W>,
181}
182
183impl<T, W> Iterator for EdgeIter<'_, T, W>
184where
185 T: Clone + Copy,
186 W: Clone + Copy,
187{
188 type Item = (T, W);
189 fn next(&mut self) -> Option<Self::Item> {
190 self.edge_iter.next().map(copy_tuple)
191 }
192}
193
194pub struct EdgeIterVec<'a, T, W> {
195 pub(crate) edge_iter: core::slice::Iter<'a, (T, W)>,
196}
197
198impl<T, W> Iterator for EdgeIterVec<'_, T, W>
199where
200 T: Clone + Copy,
201 W: Clone + Copy,
202{
203 type Item = (T, W);
204 fn next(&mut self) -> Option<Self::Item> {
205 self.edge_iter.next().copied()
206 }
207}
208
209pub trait GraphBuilding<T, W>
210where
211 T: Clone + Copy + Eq + Hash + PartialEq,
212 W: Clone + Copy,
213{
214 fn add_edge(&mut self, from: T, to: T, weight: W);
215
216 fn add_node(&mut self, node: T) -> bool;
217
218 fn remove_edge(&mut self, from: T, to: T) -> Result<(), NodeNotFound>;
219
220 fn remove_node(&mut self, node: T) -> Result<(), NodeNotFound>;
221
222 fn has_edge(&self, from: T, to: T) -> bool;
223}
224
225pub trait Traversable<T, W>
226where
227 T: Clone + Copy + Eq + Hash + PartialEq,
228 W: Clone + Copy,
229{
230 /// Iterates over edges of the node as the target nodes and the edge weight.
231 ///
232 /// If the graph is undirected, should return the nodes that are connected to it by
233 /// an edge.
234 ///
235 /// If the graph is directed, should return the outbound nodes.
236 ///
237 /// # Examples
238 /// ```rust
239 /// use yagraphc::graph::{UnGraph, DiGraph};
240 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
241 ///
242 /// let mut graph = UnGraph::default();
243 ///
244 /// graph.add_edge(1, 2, ());
245 /// graph.add_edge(2, 3, ());
246 ///
247 /// let edges = graph.edges(&2);
248 ///
249 /// assert_eq!(edges.count(), 2);
250 ///
251 /// let mut graph = DiGraph::default();
252 ///
253 /// graph.add_edge(1, 2, ());
254 /// graph.add_edge(2, 3, ());
255 ///
256 /// let edges = graph.edges(&2);
257 ///
258 /// assert_eq!(edges.count(), 1);
259 fn edges(&self, n: &T) -> EdgeIterType<T, W>;
260
261 fn edge_weight(&self, from: T, to: T) -> Result<W, NodeNotFound>;
262
263 /// Iterates over inbound-edges of the node as the target nodes and the edge weight.
264 ///
265 /// If the graph is undirected, should return the nodes that are connected to it by
266 /// an edge. Thus, it is equivalent to `edges` in that case.
267 ///
268 /// If the graph is directed, should return the inbound nodes.
269 ///
270 /// # Examples
271 /// ```rust
272 /// use yagraphc::graph::{UnGraph, DiGraph};
273 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
274 ///
275 /// let mut graph = UnGraph::default();
276 ///
277 /// graph.add_edge(1, 2, ());
278 /// graph.add_edge(2, 3, ());
279 /// graph.add_edge(4, 2, ());
280 ///
281 /// let edges = graph.in_edges(&2);
282 ///
283 /// assert_eq!(edges.count(), 3);
284 ///
285 /// let mut graph = DiGraph::default();
286 ///
287 /// graph.add_edge(1, 2, ());
288 /// graph.add_edge(2, 3, ());
289 /// graph.add_edge(4, 2, ());
290 ///
291 /// let edges = graph.in_edges(&2);
292 ///
293 /// assert_eq!(edges.count(), 2);
294 fn in_edges(&self, n: &T) -> EdgeIterType<T, W>;
295
296 /// Returns an iterator over all nodes.
297 ///
298 /// # Examples
299 /// ```rust
300 /// use yagraphc::graph::UnGraph;
301 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
302 ///
303 /// let mut graph = UnGraph::default();
304 ///
305 /// graph.add_edge(1, 2, ());
306 /// graph.add_edge(2, 3, ());
307 ///
308 /// graph.add_node(4);
309 ///
310 /// assert_eq!(graph.nodes().count(), 4);
311 ///
312 /// graph.add_node(2);
313 /// assert_eq!(graph.nodes().count(), 4);
314 fn nodes(&self) -> NodeIter<T>;
315
316 /// Returns an iterator of nodes in breadth-first order.
317 ///
318 /// Iterator includes the depth at which the nodes were found. Nodes at the
319 /// same depth might be randomly shuffled.
320 ///
321 /// # Examples
322 /// ```
323 /// use yagraphc::graph::UnGraph;
324 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
325 ///
326 /// let mut graph = UnGraph::new();
327 ///
328 /// graph.add_edge(1, 2, ());
329 /// graph.add_edge(1, 3, ());
330 /// graph.add_edge(2, 4, ());
331 /// graph.add_edge(2, 5, ());
332 ///
333 /// let bfs = graph.bfs(1);
334 ///
335 /// let depths = bfs.map(|(_, depth)| depth).collect::<Vec<_>>();
336 ///
337 /// assert_eq!(depths, vec![0, 1, 1, 2, 2]);
338 fn bfs(&self, from: T) -> BfsIter<T, W>
339 where
340 Self: Sized,
341 {
342 let visited = HashSet::new();
343
344 let mut queue = VecDeque::new();
345 queue.push_front((from, 0));
346
347 BfsIter {
348 queue,
349 visited,
350 graph: self,
351 }
352 }
353
354 /// Returns an iterator of nodes in depth-first order, in pre-order.
355 ///
356 /// Iterator includes the depth at which the nodes were found. Order is not
357 /// deterministic.
358 ///
359 /// # Examples
360 /// ```
361 /// use yagraphc::graph::UnGraph;
362 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
363 ///
364 /// let mut graph = UnGraph::new();
365 ///
366 /// graph.add_edge(1, 2, ());
367 /// graph.add_edge(2, 3, ());
368 /// graph.add_edge(3, 4, ());
369 /// graph.add_edge(1, 5, ());
370 /// graph.add_edge(5, 6, ());
371 ///
372 /// let dfs = graph.dfs(1);
373 ///
374 /// let depths = dfs.map(|(node, _)| node).collect::<Vec<_>>();
375 ///
376 /// assert!(matches!(depths[..], [1, 2, 3, 4, 5, 6] | [1, 5, 6, 2, 3, 4]));
377 fn dfs(&self, from: T) -> DfsIter<T, W>
378 where
379 Self: Sized,
380 {
381 let visited = HashSet::new();
382
383 let mut queue = VecDeque::new();
384 queue.push_front((from, 0));
385
386 DfsIter {
387 queue,
388 visited,
389 graph: self,
390 }
391 }
392
393 /// Returns an iterator of nodes in depth-first order, in post-order.
394 ///
395 /// Iterator includes the depth at which the nodes were found. Order is not
396 /// deterministic.
397 ///
398 /// Currently implemented recursively. To be changed to a non-recursive
399 /// implemented at some point.
400 ///
401 /// # Examples
402 /// ```
403 /// use yagraphc::graph::UnGraph;
404 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
405 ///
406 /// let mut graph = UnGraph::new();
407 ///
408 /// graph.add_edge(1, 2, ());
409 /// graph.add_edge(2, 3, ());
410 /// graph.add_edge(3, 4, ());
411 /// graph.add_edge(1, 5, ());
412 /// graph.add_edge(5, 6, ());
413 ///
414 /// let dfs = graph.dfs_post_order(1);
415 ///
416 /// let depths = dfs.map(|(node, _)| node).collect::<Vec<_>>();
417 ///
418 /// assert!(matches!(depths[..], [6, 5, 4, 3, 2, 1] | [4, 3, 2, 6, 5, 1]));
419 // TODO: Implement post-order non-recursively.
420 fn dfs_post_order(&self, from: T) -> PostOrderDfsIter<T>
421 where
422 Self: Sized,
423 {
424 PostOrderDfsIter::new(self, from)
425 }
426
427 /// Finds path from `from` to `to` using BFS.
428 ///
429 /// Returns `None` if there is no path.
430 ///
431 /// # Examples
432 ///
433 /// ```rust
434 /// use yagraphc::graph::UnGraph;
435 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
436 ///
437 /// let mut graph = UnGraph::new();
438 ///
439 /// graph.add_edge(1, 2, ());
440 /// graph.add_edge(2, 3, ());
441 /// graph.add_edge(3, 4, ());
442 /// graph.add_edge(1, 5, ());
443 /// graph.add_edge(5, 6, ());
444 ///
445 /// let path = graph.find_path(1, 4);
446 ///
447 /// assert_eq!(path, Some(vec![1, 2, 3, 4]));
448 fn find_path(&self, from: T, to: T) -> Option<Vec<T>>
449 where
450 Self: Sized,
451 {
452 let mut visited = HashSet::new();
453 let mut pairs = HashMap::new();
454 let mut queue = VecDeque::new();
455
456 queue.push_back((from, from));
457
458 while let Some((prev, current)) = queue.pop_front() {
459 if visited.contains(¤t) {
460 continue;
461 }
462 visited.insert(current);
463 pairs.insert(current, prev);
464
465 if current == to {
466 let mut node = current;
467
468 let mut path = Vec::new();
469 while node != from {
470 path.push(node);
471
472 node = pairs[&node];
473 }
474
475 path.push(from);
476
477 path.reverse();
478
479 return Some(path);
480 }
481
482 for (target, _) in self.edges(¤t) {
483 if visited.contains(&target) {
484 continue;
485 }
486
487 queue.push_back((current, target));
488 }
489 }
490
491 None
492 }
493
494 /// Finds path from `from` to `to` using BFS while filtering edges.
495 ///
496 /// Returns `None` if there is no path.
497 ///
498 /// For an undirected graph, strive to make predicate(x,y) == predicate(y,x).
499 /// As of now, the order is related to the exploration direction of bfs,
500 /// but it is advisable not to rely on direction concepts on undirected graphs.
501 ///
502 /// # Examples
503 ///
504 /// ```rust
505 /// use yagraphc::graph::UnGraph;
506 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
507 /// let mut graph = UnGraph::default();
508
509 /// graph.add_edge(1, 2, ());
510 /// graph.add_edge(2, 3, ());
511 /// graph.add_edge(3, 4, ());
512 /// graph.add_edge(4, 5, ());
513
514 /// graph.add_edge(1, 7, ());
515 /// graph.add_edge(7, 5, ());
516
517 /// let path = graph.find_path(1, 5).unwrap();
518
519 /// assert_eq!(path, vec![1, 7, 5]);
520
521 /// let path = graph
522 /// .find_path_filter_edges(1, 5, |x, y| (x, y) != (1, 7))
523 /// .unwrap();
524
525 /// assert_eq!(path, vec![1, 2, 3, 4, 5]);
526 fn find_path_filter_edges<G>(&self, from: T, to: T, predicate: G) -> Option<Vec<T>>
527 where
528 Self: Sized,
529 G: Fn(T, T) -> bool,
530 {
531 let mut visited = HashSet::new();
532 let mut pairs = HashMap::new();
533 let mut queue = VecDeque::new();
534
535 queue.push_back((from, from));
536
537 while let Some((prev, current)) = queue.pop_front() {
538 if visited.contains(¤t) {
539 continue;
540 }
541 visited.insert(current);
542 pairs.insert(current, prev);
543
544 if current == to {
545 let mut node = current;
546
547 let mut path = Vec::new();
548 while node != from {
549 path.push(node);
550
551 node = pairs[&node];
552 }
553
554 path.push(from);
555
556 path.reverse();
557
558 return Some(path);
559 }
560
561 for (target, _) in self.edges(¤t) {
562 if visited.contains(&target) || !predicate(current, target) {
563 continue;
564 }
565
566 queue.push_back((current, target));
567 }
568 }
569
570 None
571 }
572
573 /// Returns a list of connected components of the graph.
574 ///
575 /// If being used in a directed graph, those are the strongly connected components,
576 /// computed using Kosaraju's algorithm.
577 ///
578 /// # Examples
579 ///
580 /// ```rust
581 /// use yagraphc::graph::{UnGraph, DiGraph};
582 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
583 ///
584 /// let mut graph = UnGraph::new();
585 ///
586 /// graph.add_edge(1, 2, ());
587 /// graph.add_edge(2, 3, ());
588 ///
589 /// graph.add_edge(4, 5, ());
590 /// graph.add_edge(5, 6, ());
591 /// graph.add_edge(6, 4, ());
592 ///
593 /// let components = graph.connected_components();
594 ///
595 /// assert_eq!(components.len(), 2);
596 ///
597 /// let mut graph = DiGraph::new();
598 ///
599 /// graph.add_edge(1, 2, ());
600 /// graph.add_edge(2, 3, ());
601 ///
602 /// graph.add_edge(4, 5, ());
603 /// graph.add_edge(5, 6, ());
604 /// graph.add_edge(6, 4, ());
605 ///
606 /// let components = graph.connected_components();
607 ///
608 /// assert_eq!(components.len(), 4);
609 fn connected_components(&self) -> Vec<Vec<T>>
610 where
611 Self: Sized,
612 {
613 let mut visited = HashSet::new();
614 let mut stack = Vec::new();
615
616 for node in self.nodes() {
617 if visited.contains(&node) {
618 continue;
619 }
620
621 for (inner_node, _) in self.dfs_post_order(node) {
622 visited.insert(inner_node);
623 stack.push(inner_node);
624 }
625 }
626
627 stack.reverse();
628
629 let mut visited = HashSet::new();
630 let mut components = Vec::new();
631
632 for node in stack {
633 if visited.contains(&node) {
634 continue;
635 }
636
637 let mut component = Vec::new();
638
639 let mut stack = Vec::new();
640 stack.push(node);
641
642 while let Some(node) = stack.pop() {
643 if visited.contains(&node) {
644 continue;
645 }
646
647 component.push(node);
648 visited.insert(node);
649
650 for (inner_node, _) in self.in_edges(&node) {
651 stack.push(inner_node);
652 }
653 }
654
655 components.push(component);
656 }
657
658 components
659 }
660}
661
662fn copy_tuple<T, W>(x: (&T, &W)) -> (T, W)
663where
664 T: Clone + Copy,
665 W: Clone + Copy,
666{
667 (*x.0, *x.1)
668}
669
670struct QueueEntry<T, W>
671where
672 T: Clone + Copy,
673 W: Ord + PartialOrd,
674{
675 pub node: T,
676 pub cur_cost: W,
677}
678
679impl<T, W> Ord for QueueEntry<T, W>
680where
681 T: Clone + Copy,
682 W: Ord + PartialOrd,
683{
684 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
685 other.cur_cost.cmp(&self.cur_cost)
686 }
687}
688
689impl<T, W> PartialOrd for QueueEntry<T, W>
690where
691 T: Clone + Copy,
692 W: Ord + PartialOrd,
693{
694 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
695 Some(self.cmp(other))
696 }
697}
698
699impl<T, W> PartialEq for QueueEntry<T, W>
700where
701 T: Clone + Copy,
702 W: Ord + PartialOrd,
703{
704 fn eq(&self, other: &Self) -> bool {
705 self.cur_cost.eq(&other.cur_cost)
706 }
707}
708
709impl<T, W> Eq for QueueEntry<T, W>
710where
711 T: Clone + Copy,
712 W: Ord + PartialOrd,
713{
714}
715
716pub trait ArithmeticallyWeightedGraph<T, W>
717where
718 T: Clone + Copy + Eq + Hash + PartialEq,
719 W: Clone
720 + Copy
721 + std::ops::Add<Output = W>
722 + std::ops::Sub<Output = W>
723 + PartialOrd
724 + Ord
725 + Default,
726{
727 /// Returns the shortest length among paths from `from` to `to`.
728 ///
729 /// # Examples
730 ///
731 /// ```rust
732 /// use yagraphc::graph::UnGraph;
733 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
734 /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
735 ///
736 /// let mut graph = UnGraph::new();
737 ///
738 /// graph.add_edge(1, 2, 1);
739 /// graph.add_edge(2, 3, 2);
740 /// graph.add_edge(3, 4, 3);
741 /// graph.add_edge(1, 4, 7);
742 ///
743 /// assert_eq!(graph.dijkstra(1, 4), Some(6));
744 /// assert_eq!(graph.dijkstra(1, 5), None);
745 fn dijkstra(&self, from: T, to: T) -> Option<W>
746 where
747 Self: Traversable<T, W>,
748 {
749 let mut visited = HashSet::new();
750 let mut distances = HashMap::new();
751
752 distances.insert(from, W::default());
753
754 let mut queue = BinaryHeap::new();
755 queue.push(QueueEntry {
756 node: from,
757 cur_cost: W::default(),
758 });
759
760 while let Some(QueueEntry {
761 node,
762 cur_cost: cur_dist,
763 }) = queue.pop()
764 {
765 if visited.contains(&node) {
766 continue;
767 }
768
769 if node == to {
770 return Some(cur_dist);
771 }
772
773 for (target, weight) in self.edges(&node) {
774 let mut distance = cur_dist + weight;
775
776 distances
777 .entry(target)
778 .and_modify(|dist| {
779 let best_dist = (*dist).min(cur_dist + weight);
780
781 distance = best_dist;
782 *dist = best_dist;
783 })
784 .or_insert(cur_dist + weight);
785
786 if !visited.contains(&target) {
787 queue.push(QueueEntry {
788 node: target,
789 cur_cost: distance,
790 })
791 }
792 }
793
794 visited.insert(node);
795 }
796
797 None
798 }
799
800 /// Returns the shortest path among paths from `from` to `to`, together with its length.
801 ///
802 /// # Examples
803 ///
804 /// ```rust
805 /// use yagraphc::graph::UnGraph;
806 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
807 /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
808 ///
809 /// let mut graph = UnGraph::new();
810 ///
811 /// graph.add_edge(1, 2, 1);
812 /// graph.add_edge(2, 3, 2);
813 /// graph.add_edge(3, 4, 3);
814 /// graph.add_edge(1, 4, 7);
815 ///
816 /// assert_eq!(graph.dijkstra_with_path(1, 4).unwrap().0, vec![1, 2, 3, 4]);
817 /// assert_eq!(graph.dijkstra_with_path(1, 5), None);
818 fn dijkstra_with_path(&self, from: T, to: T) -> Option<(Vec<T>, W)>
819 where
820 Self: Traversable<T, W>,
821 {
822 let mut visited = HashSet::new();
823 let mut distances = HashMap::new();
824
825 distances.insert(from, (W::default(), from));
826
827 let mut queue = BinaryHeap::new();
828 queue.push(QueueEntry {
829 node: from,
830 cur_cost: W::default(),
831 });
832
833 while let Some(QueueEntry {
834 node,
835 cur_cost: cur_dist,
836 }) = queue.pop()
837 {
838 if visited.contains(&node) {
839 continue;
840 }
841
842 if node == to {
843 let mut path = Vec::new();
844
845 let mut node = node;
846
847 while node != from {
848 path.push(node);
849
850 node = distances[&node].1;
851 }
852
853 path.push(from);
854
855 path.reverse();
856
857 return Some((path, cur_dist));
858 }
859
860 for (target, weight) in self.edges(&node) {
861 distances
862 .entry(target)
863 .and_modify(|(dist, previous)| {
864 if cur_dist + weight < *dist {
865 *dist = cur_dist + weight;
866 *previous = node;
867 }
868 })
869 .or_insert((cur_dist + weight, node));
870
871 if !visited.contains(&target) {
872 queue.push(QueueEntry {
873 node: target,
874 cur_cost: distances[&target].0,
875 })
876 }
877 }
878
879 visited.insert(node);
880 }
881
882 None
883 }
884
885 /// Returns the shortest length among paths from `from` to `to` using A*.
886 ///
887 /// `heuristic` corresponds to the heuristic function of the A* algorithm.
888 ///
889 /// # Examples
890 ///
891 /// ```rust
892 /// use yagraphc::graph::UnGraph;
893 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
894 /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
895 ///
896 /// let mut graph = UnGraph::new();
897 ///
898 /// graph.add_edge(1, 2, 1);
899 /// graph.add_edge(2, 3, 2);
900 /// graph.add_edge(3, 4, 3);
901 /// graph.add_edge(1, 4, 7);
902 ///
903 /// assert_eq!(graph.a_star(1, 4, |_| 0), Some(6));
904 /// assert_eq!(graph.a_star(1, 5, |_| 0), None);
905 fn a_star<G>(&self, from: T, to: T, heuristic: G) -> Option<W>
906 where
907 Self: Traversable<T, W>,
908 G: Fn(T) -> W,
909 {
910 let mut visited = HashSet::new();
911 let mut distances = HashMap::new();
912
913 distances.insert(from, W::default());
914
915 let mut queue = BinaryHeap::new();
916 queue.push(QueueEntry {
917 node: from,
918 cur_cost: W::default() + heuristic(from),
919 });
920
921 while let Some(QueueEntry { node, .. }) = queue.pop() {
922 if visited.contains(&node) {
923 continue;
924 }
925
926 if node == to {
927 return Some(distances[&node]);
928 }
929
930 for (target, weight) in self.edges(&node) {
931 let mut distance = distances[&node] + weight;
932
933 distances
934 .entry(target)
935 .and_modify(|dist| {
936 let best_dist = (*dist).min(distance);
937
938 distance = best_dist;
939 *dist = best_dist;
940 })
941 .or_insert(distance);
942
943 if !visited.contains(&target) {
944 queue.push(QueueEntry {
945 node: target,
946 cur_cost: distance + heuristic(target),
947 })
948 }
949 }
950
951 visited.insert(node);
952 }
953
954 None
955 }
956
957 /// Returns the shortest path from `from` to `to` using A*, together with its length.
958 ///
959 /// `heuristic` corresponds to the heuristic function of the A* algorithm.
960 ///
961 /// # Examples
962 ///
963 /// ```rust
964 /// use yagraphc::graph::UnGraph;
965 /// use yagraphc::graph::traits::{GraphBuilding, Traversable};
966 /// use yagraphc::graph::traits::ArithmeticallyWeightedGraph;
967 ///
968 /// let mut graph = UnGraph::new();
969 ///
970 /// graph.add_edge(1, 2, 1);
971 /// graph.add_edge(2, 3, 2);
972 /// graph.add_edge(3, 4, 3);
973 /// graph.add_edge(1, 4, 7);
974 ///
975 /// assert_eq!(graph.a_star_with_path(1, 4, |_| 0).unwrap().0, vec![1, 2, 3, 4]);
976 /// assert_eq!(graph.a_star_with_path(1, 5, |_| 0), None);
977 fn a_star_with_path<G>(&self, from: T, to: T, heuristic: G) -> Option<(Vec<T>, W)>
978 where
979 Self: Traversable<T, W>,
980 G: Fn(T) -> W,
981 {
982 let mut visited = HashSet::new();
983 let mut distances = HashMap::new();
984
985 distances.insert(from, (W::default(), from));
986
987 let mut queue = BinaryHeap::new();
988 queue.push(QueueEntry {
989 node: from,
990 cur_cost: W::default() + heuristic(from),
991 });
992
993 while let Some(QueueEntry { node, .. }) = queue.pop() {
994 if visited.contains(&node) {
995 continue;
996 }
997
998 if node == to {
999 let mut path = Vec::new();
1000
1001 let mut node = node;
1002
1003 while node != from {
1004 path.push(node);
1005
1006 node = distances[&node].1;
1007 }
1008
1009 path.push(from);
1010
1011 path.reverse();
1012
1013 return Some((path, distances[&node].0));
1014 }
1015
1016 for (target, weight) in self.edges(&node) {
1017 let mut distance = distances[&node].0 + weight;
1018
1019 distances
1020 .entry(target)
1021 .and_modify(|(dist, prev)| {
1022 if distance < *dist {
1023 *dist = distance;
1024 *prev = node;
1025 } else {
1026 distance = *dist
1027 }
1028 })
1029 .or_insert((distance, node));
1030
1031 if !visited.contains(&target) {
1032 queue.push(QueueEntry {
1033 node: target,
1034 cur_cost: distance + heuristic(target),
1035 })
1036 }
1037 }
1038
1039 visited.insert(node);
1040 }
1041
1042 None
1043 }
1044
1045 /// Runs the Edmonds-Karp algorithm on the graph to find max flow.
1046 ///
1047 /// Assumes the edge weights are the capacities.
1048 ///
1049 /// Returns a HashMap with the flow values for each edge.
1050 ///
1051 /// Please select a number type for W which allows for subtraction
1052 /// and negative values, otherwise there may be undefined behavior.
1053 ///
1054 /// # Examples
1055 ///
1056 /// ```rust
1057 /// use yagraphc::graph::UnGraph;
1058 /// use yagraphc::graph::traits::{ArithmeticallyWeightedGraph, GraphBuilding, Traversable};
1059 ///
1060 /// let mut graph = UnGraph::new();
1061 /// graph.add_edge(1, 2, 1000);
1062 /// graph.add_edge(2, 4, 1000);
1063 /// graph.add_edge(1, 3, 1000);
1064 /// graph.add_edge(3, 4, 1000);
1065 ///
1066 /// graph.add_edge(2, 3, 1);
1067 ///
1068 /// let flows = graph.edmonds_karp(1, 4);
1069 /// assert_eq!(*flows.get(&(1, 2)).unwrap(), 1000);
1070 /// assert_eq!(*flows.get(&(1, 3)).unwrap(), 1000);
1071 ///
1072 /// assert_eq!(*flows.get(&(2, 3)).unwrap_or(&0), 0);
1073 fn edmonds_karp(&self, source: T, sink: T) -> HashMap<(T, T), W>
1074 where
1075 Self: Traversable<T, W> + Sized,
1076 {
1077 let flows = HashMap::new();
1078
1079 let mut residual_obtension = ResidualNetwork { flows, graph: self };
1080
1081 while let Some(path) = self.find_path_filter_edges(source, sink, |x, y| {
1082 residual_obtension.get_residual_capacity(x, y) > W::default()
1083 }) {
1084 let residuals_in_path = path
1085 .iter()
1086 .zip(&path[1..])
1087 .map(|(&x, &y)| residual_obtension.get_residual_capacity(x, y));
1088
1089 let min_res = residuals_in_path.min().expect("Path should not be empty");
1090
1091 path.iter().zip(&path[1..]).for_each(|(&x, &y)| {
1092 residual_obtension
1093 .flows
1094 .entry((x, y))
1095 .and_modify(|v| *v = *v + min_res)
1096 .or_insert(min_res);
1097 residual_obtension
1098 .flows
1099 .entry((y, x))
1100 .and_modify(|v| *v = *v - min_res)
1101 .or_insert(W::default() - min_res);
1102 });
1103
1104 residual_obtension = ResidualNetwork {
1105 flows: residual_obtension.flows,
1106 graph: self,
1107 };
1108 }
1109
1110 residual_obtension.flows
1111 }
1112}
1113
1114struct ResidualNetwork<'a, T, W> {
1115 flows: HashMap<(T, T), W>,
1116 graph: &'a dyn Traversable<T, W>,
1117}
1118
1119impl<'a, T, W> ResidualNetwork<'a, T, W>
1120where
1121 T: Clone + Copy + Eq + Hash + PartialEq,
1122 W: Clone
1123 + Copy
1124 + std::ops::Add<Output = W>
1125 + std::ops::Sub<Output = W>
1126 + PartialOrd
1127 + Ord
1128 + Default,
1129{
1130 fn get_residual_capacity(&self, s: T, t: T) -> W {
1131 self.graph
1132 .edge_weight(s, t)
1133 .expect("Should only be considering existing edges")
1134 - self.flows.get(&(s, t)).copied().unwrap_or(W::default())
1135 }
1136}
1137
1138#[cfg(test)]
1139mod tests {
1140 use super::*;
1141
1142 #[test]
1143 fn test_error() {
1144 let e = NodeNotFound;
1145 dbg!(&e);
1146 }
1147
1148 #[test]
1149 fn test_queue_entry() {
1150 let queue_entry1 = QueueEntry {
1151 node: 1,
1152 cur_cost: 12,
1153 };
1154 let queue_entry2 = QueueEntry {
1155 node: 4,
1156 cur_cost: 12,
1157 };
1158
1159 assert!(queue_entry1 == queue_entry2);
1160 }
1161}