1use petgraph::algo::dijkstra;
10use petgraph::visit::EdgeRef;
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::hash::Hash;
14
15use crate::base::{DiGraph, EdgeWeight, Graph, Node};
16use crate::error::{GraphError, Result};
17
18#[derive(Debug, Clone)]
20pub struct Path<N: Node + std::fmt::Debug, E: EdgeWeight> {
21 pub nodes: Vec<N>,
23 pub total_weight: E,
25}
26
27#[derive(Debug, Clone)]
29pub struct AStarResult<N: Node + std::fmt::Debug, E: EdgeWeight> {
30 pub path: Vec<N>,
32 pub cost: E,
34}
35
36#[derive(Clone)]
38struct AStarState<N: Node + std::fmt::Debug, E: EdgeWeight> {
39 node: N,
40 cost: E,
41 heuristic: E,
42 path: Vec<N>,
43}
44
45impl<N: Node + std::fmt::Debug, E: EdgeWeight> PartialEq for AStarState<N, E> {
46 fn eq(&self, other: &Self) -> bool {
47 self.node == other.node
48 }
49}
50
51impl<N: Node + std::fmt::Debug, E: EdgeWeight> Eq for AStarState<N, E> {}
52
53impl<N: Node + std::fmt::Debug, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd> Ord
54 for AStarState<N, E>
55{
56 fn cmp(&self, other: &Self) -> Ordering {
57 let self_total = self.cost + self.heuristic;
59 let other_total = other.cost + other.heuristic;
60 other_total
61 .partial_cmp(&self_total)
62 .unwrap_or(Ordering::Equal)
63 .then_with(|| {
64 other
65 .cost
66 .partial_cmp(&self.cost)
67 .unwrap_or(Ordering::Equal)
68 })
69 }
70}
71
72impl<N: Node + std::fmt::Debug, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd>
73 PartialOrd for AStarState<N, E>
74{
75 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76 Some(self.cmp(other))
77 }
78}
79
80#[deprecated(
104 since = "0.1.0-beta.2",
105 note = "Use `dijkstra_path` for future compatibility. This function will return PathResult in v1.0"
106)]
107#[allow(dead_code)]
108pub fn shortest_path<N, E, Ix>(
109 graph: &Graph<N, E, Ix>,
110 source: &N,
111 target: &N,
112) -> Result<Option<Path<N, E>>>
113where
114 N: Node + std::fmt::Debug,
115 E: EdgeWeight
116 + num_traits::Zero
117 + num_traits::One
118 + std::ops::Add<Output = E>
119 + PartialOrd
120 + std::marker::Copy
121 + std::fmt::Debug
122 + std::default::Default,
123 Ix: petgraph::graph::IndexType,
124{
125 if !graph.has_node(source) {
127 return Err(GraphError::InvalidGraph(format!(
128 "Source node {source:?} not found"
129 )));
130 }
131 if !graph.has_node(target) {
132 return Err(GraphError::InvalidGraph(format!(
133 "Target node {target:?} not found"
134 )));
135 }
136
137 let source_idx = graph
138 .inner()
139 .node_indices()
140 .find(|&idx| graph.inner()[idx] == *source)
141 .unwrap();
142 let target_idx = graph
143 .inner()
144 .node_indices()
145 .find(|&idx| graph.inner()[idx] == *target)
146 .unwrap();
147
148 let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
150
151 if !results.contains_key(&target_idx) {
153 return Ok(None);
154 }
155
156 let total_weight = results[&target_idx];
157
158 let mut path = Vec::new();
160 let mut current = target_idx;
161
162 path.push(graph.inner()[current].clone());
163
164 while current != source_idx {
166 let min_prev = graph
167 .inner()
168 .edges_directed(current, petgraph::Direction::Incoming)
169 .filter_map(|e| {
170 let from = e.source();
171 let edge_weight = *e.weight();
172
173 if let Some(from_dist) = results.get(&from) {
175 if *from_dist + edge_weight == results[¤t] {
177 return Some((from, *from_dist));
178 }
179 }
180 None
181 })
182 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
183
184 if let Some((prev, _)) = min_prev {
185 current = prev;
186 path.push(graph.inner()[current].clone());
187 } else {
188 return Err(GraphError::AlgorithmError(
190 "Failed to reconstruct path".to_string(),
191 ));
192 }
193 }
194
195 path.reverse();
197
198 Ok(Some(Path {
199 nodes: path,
200 total_weight,
201 }))
202}
203
204#[allow(dead_code)]
239pub fn dijkstra_path<N, E, Ix>(
240 graph: &Graph<N, E, Ix>,
241 source: &N,
242 target: &N,
243) -> Result<Option<Path<N, E>>>
244where
245 N: Node + std::fmt::Debug,
246 E: EdgeWeight
247 + num_traits::Zero
248 + num_traits::One
249 + std::ops::Add<Output = E>
250 + PartialOrd
251 + std::marker::Copy
252 + std::fmt::Debug
253 + std::default::Default,
254 Ix: petgraph::graph::IndexType,
255{
256 #[allow(deprecated)]
257 shortest_path(graph, source, target)
258}
259
260#[deprecated(
267 since = "0.1.0-beta.2",
268 note = "Use `dijkstra_path_digraph` for future compatibility. This function will return PathResult in v1.0"
269)]
270#[allow(dead_code)]
271pub fn shortest_path_digraph<N, E, Ix>(
272 graph: &DiGraph<N, E, Ix>,
273 source: &N,
274 target: &N,
275) -> Result<Option<Path<N, E>>>
276where
277 N: Node + std::fmt::Debug,
278 E: EdgeWeight
279 + num_traits::Zero
280 + num_traits::One
281 + std::ops::Add<Output = E>
282 + PartialOrd
283 + std::marker::Copy
284 + std::fmt::Debug
285 + std::default::Default,
286 Ix: petgraph::graph::IndexType,
287{
288 if !graph.has_node(source) {
291 return Err(GraphError::InvalidGraph(format!(
292 "Source node {source:?} not found"
293 )));
294 }
295 if !graph.has_node(target) {
296 return Err(GraphError::InvalidGraph(format!(
297 "Target node {target:?} not found"
298 )));
299 }
300
301 let source_idx = graph
302 .inner()
303 .node_indices()
304 .find(|&idx| graph.inner()[idx] == *source)
305 .unwrap();
306 let target_idx = graph
307 .inner()
308 .node_indices()
309 .find(|&idx| graph.inner()[idx] == *target)
310 .unwrap();
311
312 let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
314
315 if !results.contains_key(&target_idx) {
317 return Ok(None);
318 }
319
320 let total_weight = results[&target_idx];
321
322 let mut path = Vec::new();
324 let mut current = target_idx;
325
326 path.push(graph.inner()[current].clone());
327
328 while current != source_idx {
330 let min_prev = graph
331 .inner()
332 .edges_directed(current, petgraph::Direction::Incoming)
333 .filter_map(|e| {
334 let from = e.source();
335 let edge_weight = *e.weight();
336
337 if let Some(from_dist) = results.get(&from) {
339 if *from_dist + edge_weight == results[¤t] {
341 return Some((from, *from_dist));
342 }
343 }
344 None
345 })
346 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
347
348 if let Some((prev, _)) = min_prev {
349 current = prev;
350 path.push(graph.inner()[current].clone());
351 } else {
352 return Err(GraphError::AlgorithmError(
354 "Failed to reconstruct path".to_string(),
355 ));
356 }
357 }
358
359 path.reverse();
361
362 Ok(Some(Path {
363 nodes: path,
364 total_weight,
365 }))
366}
367
368#[allow(dead_code)]
389pub fn dijkstra_path_digraph<N, E, Ix>(
390 graph: &DiGraph<N, E, Ix>,
391 source: &N,
392 target: &N,
393) -> Result<Option<Path<N, E>>>
394where
395 N: Node + std::fmt::Debug,
396 E: EdgeWeight
397 + num_traits::Zero
398 + num_traits::One
399 + std::ops::Add<Output = E>
400 + PartialOrd
401 + std::marker::Copy
402 + std::fmt::Debug
403 + std::default::Default,
404 Ix: petgraph::graph::IndexType,
405{
406 #[allow(deprecated)]
407 shortest_path_digraph(graph, source, target)
408}
409
410#[allow(dead_code)]
433pub fn floyd_warshall<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<ndarray::Array2<f64>>
434where
435 N: Node + std::fmt::Debug,
436 E: EdgeWeight + Into<f64> + num_traits::Zero + Copy,
437 Ix: petgraph::graph::IndexType,
438{
439 let n = graph.node_count();
440
441 if n == 0 {
442 return Ok(ndarray::Array2::zeros((0, 0)));
443 }
444
445 let mut dist = ndarray::Array2::from_elem((n, n), f64::INFINITY);
447
448 for i in 0..n {
450 dist[[i, i]] = 0.0;
451 }
452
453 for edge in graph.inner().edge_references() {
455 let i = edge.source().index();
456 let j = edge.target().index();
457 let weight: f64 = (*edge.weight()).into();
458
459 dist[[i, j]] = weight;
460 dist[[j, i]] = weight;
462 }
463
464 for k in 0..n {
466 for i in 0..n {
467 for j in 0..n {
468 let alt = dist[[i, k]] + dist[[k, j]];
469 if alt < dist[[i, j]] {
470 dist[[i, j]] = alt;
471 }
472 }
473 }
474 }
475
476 Ok(dist)
477}
478
479#[allow(dead_code)]
481pub fn floyd_warshall_digraph<N, E, Ix>(graph: &DiGraph<N, E, Ix>) -> Result<ndarray::Array2<f64>>
482where
483 N: Node + std::fmt::Debug,
484 E: EdgeWeight + Into<f64> + num_traits::Zero + Copy,
485 Ix: petgraph::graph::IndexType,
486{
487 let n = graph.node_count();
488
489 if n == 0 {
490 return Ok(ndarray::Array2::zeros((0, 0)));
491 }
492
493 let mut dist = ndarray::Array2::from_elem((n, n), f64::INFINITY);
495
496 for i in 0..n {
498 dist[[i, i]] = 0.0;
499 }
500
501 for edge in graph.inner().edge_references() {
503 let i = edge.source().index();
504 let j = edge.target().index();
505 let weight: f64 = (*edge.weight()).into();
506
507 dist[[i, j]] = weight;
508 }
509
510 for k in 0..n {
512 for i in 0..n {
513 for j in 0..n {
514 let alt = dist[[i, k]] + dist[[k, j]];
515 if alt < dist[[i, j]] {
516 dist[[i, j]] = alt;
517 }
518 }
519 }
520 }
521
522 Ok(dist)
523}
524
525#[allow(dead_code)]
550pub fn astar_search<N, E, Ix, H>(
551 graph: &Graph<N, E, Ix>,
552 source: &N,
553 target: &N,
554 heuristic: H,
555) -> Result<AStarResult<N, E>>
556where
557 N: Node + std::fmt::Debug + Clone + Hash + Eq,
558 E: EdgeWeight + Clone + std::ops::Add<Output = E> + num_traits::Zero + PartialOrd + Copy,
559 Ix: petgraph::graph::IndexType,
560 H: Fn(&N) -> E,
561{
562 if !graph.contains_node(source) || !graph.contains_node(target) {
563 return Err(GraphError::node_not_found("node"));
564 }
565
566 let mut open_set = BinaryHeap::new();
567 let mut g_score: HashMap<N, E> = HashMap::new();
568 let mut came_from: HashMap<N, N> = HashMap::new();
569
570 g_score.insert(source.clone(), E::zero());
571
572 open_set.push(AStarState {
573 node: source.clone(),
574 cost: E::zero(),
575 heuristic: heuristic(source),
576 path: vec![source.clone()],
577 });
578
579 while let Some(current_state) = open_set.pop() {
580 let current = ¤t_state.node;
581
582 if current == target {
583 return Ok(AStarResult {
584 path: current_state.path,
585 cost: current_state.cost,
586 });
587 }
588
589 let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
590
591 if let Ok(neighbors) = graph.neighbors(current) {
592 for neighbor in neighbors {
593 if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
594 let tentative_g = current_g + edge_weight;
595
596 let current_neighbor_g = g_score.get(&neighbor);
597 if current_neighbor_g.is_none() || tentative_g < *current_neighbor_g.unwrap() {
598 came_from.insert(neighbor.clone(), current.clone());
599 g_score.insert(neighbor.clone(), tentative_g);
600
601 let mut new_path = current_state.path.clone();
602 new_path.push(neighbor.clone());
603
604 open_set.push(AStarState {
605 node: neighbor.clone(),
606 cost: tentative_g,
607 heuristic: heuristic(&neighbor),
608 path: new_path,
609 });
610 }
611 }
612 }
613 }
614 }
615
616 Err(GraphError::NoPath {
617 src_node: format!("{source:?}"),
618 target: format!("{target:?}"),
619 nodes: 0,
620 edges: 0,
621 })
622}
623
624#[allow(dead_code)]
626pub fn astar_search_digraph<N, E, Ix, H>(
627 graph: &DiGraph<N, E, Ix>,
628 source: &N,
629 target: &N,
630 heuristic: H,
631) -> Result<AStarResult<N, E>>
632where
633 N: Node + std::fmt::Debug + Clone + Hash + Eq,
634 E: EdgeWeight + Clone + std::ops::Add<Output = E> + num_traits::Zero + PartialOrd + Copy,
635 Ix: petgraph::graph::IndexType,
636 H: Fn(&N) -> E,
637{
638 if !graph.contains_node(source) || !graph.contains_node(target) {
639 return Err(GraphError::node_not_found("node"));
640 }
641
642 let mut open_set = BinaryHeap::new();
643 let mut g_score: HashMap<N, E> = HashMap::new();
644 let mut came_from: HashMap<N, N> = HashMap::new();
645
646 g_score.insert(source.clone(), E::zero());
647
648 open_set.push(AStarState {
649 node: source.clone(),
650 cost: E::zero(),
651 heuristic: heuristic(source),
652 path: vec![source.clone()],
653 });
654
655 while let Some(current_state) = open_set.pop() {
656 let current = ¤t_state.node;
657
658 if current == target {
659 return Ok(AStarResult {
660 path: current_state.path,
661 cost: current_state.cost,
662 });
663 }
664
665 let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
666
667 if let Ok(successors) = graph.successors(current) {
668 for neighbor in successors {
669 if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
670 let tentative_g = current_g + edge_weight;
671
672 let current_neighbor_g = g_score.get(&neighbor);
673 if current_neighbor_g.is_none() || tentative_g < *current_neighbor_g.unwrap() {
674 came_from.insert(neighbor.clone(), current.clone());
675 g_score.insert(neighbor.clone(), tentative_g);
676
677 let mut new_path = current_state.path.clone();
678 new_path.push(neighbor.clone());
679
680 open_set.push(AStarState {
681 node: neighbor.clone(),
682 cost: tentative_g,
683 heuristic: heuristic(&neighbor),
684 path: new_path,
685 });
686 }
687 }
688 }
689 }
690 }
691
692 Err(GraphError::NoPath {
693 src_node: format!("{source:?}"),
694 target: format!("{target:?}"),
695 nodes: 0,
696 edges: 0,
697 })
698}
699
700#[allow(dead_code)]
705pub fn k_shortest_paths<N, E, Ix>(
706 graph: &Graph<N, E, Ix>,
707 source: &N,
708 target: &N,
709 k: usize,
710) -> Result<Vec<(f64, Vec<N>)>>
711where
712 N: Node + std::fmt::Debug + Clone + Hash + Eq + Ord,
713 E: EdgeWeight
714 + Into<f64>
715 + Clone
716 + num_traits::Zero
717 + num_traits::One
718 + std::ops::Add<Output = E>
719 + PartialOrd
720 + std::marker::Copy
721 + std::fmt::Debug
722 + std::default::Default,
723 Ix: petgraph::graph::IndexType,
724{
725 if k == 0 {
726 return Ok(vec![]);
727 }
728
729 if !graph.contains_node(source) || !graph.contains_node(target) {
731 return Err(GraphError::node_not_found("node"));
732 }
733
734 let mut paths = Vec::new();
735 let mut candidates = std::collections::BinaryHeap::new();
736
737 match dijkstra_path(graph, source, target) {
739 Ok(Some(path)) => {
740 let weight: f64 = path.total_weight.into();
741 paths.push((weight, path.nodes));
742 }
743 Ok(None) => return Ok(vec![]), Err(e) => return Err(e),
745 }
746
747 for i in 0..k - 1 {
749 if i >= paths.len() {
750 break;
751 }
752
753 let (_, prev_path) = &paths[i];
754
755 for j in 0..prev_path.len() - 1 {
757 let spur_node = &prev_path[j];
758 let root_path = &prev_path[..=j];
759
760 let mut removed_edges = Vec::new();
762
763 for (_, path) in &paths {
765 if path.len() > j && &path[..=j] == root_path && j + 1 < path.len() {
766 removed_edges.push((path[j].clone(), path[j + 1].clone()));
767 }
768 }
769
770 if let Ok((spur_weight, spur_path)) =
772 shortest_path_avoiding_edges(graph, spur_node, target, &removed_edges, root_path)
773 {
774 let mut total_weight = spur_weight;
776 for idx in 0..j {
777 if let Ok(edge_weight) = graph.edge_weight(&prev_path[idx], &prev_path[idx + 1])
778 {
779 let weight: f64 = edge_weight.into();
780 total_weight += weight;
781 }
782 }
783
784 let mut complete_path = root_path[..j].to_vec();
786 complete_path.extend(spur_path);
787
788 candidates.push((
790 std::cmp::Reverse(ordered_float::OrderedFloat(total_weight)),
791 complete_path.clone(),
792 ));
793 }
794 }
795 }
796
797 while paths.len() < k && !candidates.is_empty() {
799 let (std::cmp::Reverse(ordered_float::OrderedFloat(weight)), path) =
800 candidates.pop().unwrap();
801
802 let is_duplicate = paths.iter().any(|(_, p)| p == &path);
804 if !is_duplicate {
805 paths.push((weight, path));
806 }
807 }
808
809 Ok(paths)
810}
811
812#[allow(dead_code)]
814fn shortest_path_avoiding_edges<N, E, Ix>(
815 graph: &Graph<N, E, Ix>,
816 source: &N,
817 target: &N,
818 avoided_edges: &[(N, N)],
819 excluded_nodes: &[N],
820) -> Result<(f64, Vec<N>)>
821where
822 N: Node + std::fmt::Debug + Clone + Hash + Eq + Ord,
823 E: EdgeWeight + Into<f64>,
824 Ix: petgraph::graph::IndexType,
825{
826 use std::cmp::Reverse;
827
828 let mut distances: HashMap<N, f64> = HashMap::new();
829 let mut previous: HashMap<N, N> = HashMap::new();
830 let mut heap = BinaryHeap::new();
831
832 distances.insert(source.clone(), 0.0);
833 heap.push((Reverse(ordered_float::OrderedFloat(0.0)), source.clone()));
834
835 while let Some((Reverse(ordered_float::OrderedFloat(dist)), node)) = heap.pop() {
836 if &node == target {
837 let mut path = vec![target.clone()];
839 let mut current = target.clone();
840
841 while let Some(prev) = previous.get(¤t) {
842 path.push(prev.clone());
843 current = prev.clone();
844 }
845
846 path.reverse();
847 return Ok((dist, path));
848 }
849
850 if distances.get(&node).is_none_or(|&d| dist > d) {
851 continue;
852 }
853
854 if let Ok(neighbors) = graph.neighbors(&node) {
855 for neighbor in neighbors {
856 if avoided_edges.contains(&(node.clone(), neighbor.clone())) {
858 continue;
859 }
860
861 if &neighbor != source && &neighbor != target && excluded_nodes.contains(&neighbor)
863 {
864 continue;
865 }
866
867 if let Ok(edge_weight) = graph.edge_weight(&node, &neighbor) {
868 let weight: f64 = edge_weight.into();
869 let new_dist = dist + weight;
870
871 if new_dist < *distances.get(&neighbor).unwrap_or(&f64::INFINITY) {
872 distances.insert(neighbor.clone(), new_dist);
873 previous.insert(neighbor.clone(), node.clone());
874 heap.push((Reverse(ordered_float::OrderedFloat(new_dist)), neighbor));
875 }
876 }
877 }
878 }
879 }
880
881 Err(GraphError::NoPath {
882 src_node: format!("{source:?}"),
883 target: format!("{target:?}"),
884 nodes: 0,
885 edges: 0,
886 })
887}
888
889#[cfg(test)]
890mod tests {
891 use super::*;
892
893 #[test]
894 #[allow(deprecated)]
895 fn test_shortest_path() {
896 let mut graph: Graph<i32, f64> = Graph::new();
897
898 graph.add_edge(1, 2, 4.0).unwrap();
900 graph.add_edge(1, 3, 2.0).unwrap();
901 graph.add_edge(2, 3, 1.0).unwrap();
902 graph.add_edge(2, 4, 5.0).unwrap();
903 graph.add_edge(3, 4, 8.0).unwrap();
904
905 let path = shortest_path(&graph, &1, &4).unwrap().unwrap();
906
907 assert_eq!(path.total_weight, 8.0);
909 assert_eq!(path.nodes, vec![1, 3, 2, 4]);
910 }
911
912 #[test]
913 fn test_floyd_warshall() {
914 let mut graph: Graph<i32, f64> = Graph::new();
915
916 graph.add_edge(0, 1, 1.0).unwrap();
918 graph.add_edge(1, 2, 2.0).unwrap();
919 graph.add_edge(2, 0, 3.0).unwrap();
920
921 let distances = floyd_warshall(&graph).unwrap();
922
923 assert_eq!(distances[[0, 0]], 0.0);
925 assert_eq!(distances[[0, 1]], 1.0);
926 assert_eq!(distances[[0, 2]], 3.0); assert_eq!(distances[[1, 0]], 1.0); }
929
930 #[test]
931 fn test_astar_search() {
932 let mut graph: Graph<(i32, i32), f64> = Graph::new();
933
934 graph.add_edge((0, 0), (0, 1), 1.0).unwrap();
936 graph.add_edge((0, 1), (1, 1), 1.0).unwrap();
937 graph.add_edge((1, 1), (1, 0), 1.0).unwrap();
938 graph.add_edge((1, 0), (0, 0), 1.0).unwrap();
939
940 let heuristic = |&(x, y): &(i32, i32)| -> f64 { ((1 - x).abs() + (1 - y).abs()) as f64 };
942
943 let result = astar_search(&graph, &(0, 0), &(1, 1), heuristic);
944 let result = result.unwrap();
945 assert_eq!(result.cost, 2.0);
946 assert_eq!(result.path.len(), 3); }
948
949 #[test]
950 fn test_k_shortest_paths() {
951 let mut graph: Graph<char, f64> = Graph::new();
952
953 graph.add_edge('A', 'B', 2.0).unwrap();
955 graph.add_edge('B', 'D', 2.0).unwrap();
956 graph.add_edge('A', 'C', 1.0).unwrap();
957 graph.add_edge('C', 'D', 4.0).unwrap();
958 graph.add_edge('B', 'C', 1.0).unwrap();
959
960 let paths = k_shortest_paths(&graph, &'A', &'D', 3).unwrap();
961
962 assert!(paths.len() >= 2);
963 assert_eq!(paths[0].0, 4.0); assert_eq!(paths[0].1, vec!['A', 'B', 'D']);
965 }
966}