1use petgraph::algo::dijkstra;
10use petgraph::visit::EdgeRef;
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::hash::Hash;
14
15use crate::base::{DiGraph, EdgeWeight, Graph, Node};
16use crate::error::{GraphError, Result};
17
18#[derive(Debug, Clone)]
20pub struct Path<N: Node + std::fmt::Debug, E: EdgeWeight> {
21 pub nodes: Vec<N>,
23 pub total_weight: E,
25}
26
27#[derive(Debug, Clone)]
29pub struct AStarResult<N: Node + std::fmt::Debug, E: EdgeWeight> {
30 pub path: Vec<N>,
32 pub cost: E,
34}
35
36#[derive(Clone)]
38struct AStarState<N: Node + std::fmt::Debug, E: EdgeWeight> {
39 node: N,
40 cost: E,
41 heuristic: E,
42 path: Vec<N>,
43}
44
45impl<N: Node + std::fmt::Debug, E: EdgeWeight> PartialEq for AStarState<N, E> {
46 fn eq(&self, other: &Self) -> bool {
47 self.node == other.node
48 }
49}
50
51impl<N: Node + std::fmt::Debug, E: EdgeWeight> Eq for AStarState<N, E> {}
52
53impl<N: Node + std::fmt::Debug, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd> Ord
54 for AStarState<N, E>
55{
56 fn cmp(&self, other: &Self) -> Ordering {
57 let self_total = self.cost + self.heuristic;
59 let other_total = other.cost + other.heuristic;
60 other_total
61 .partial_cmp(&self_total)
62 .unwrap_or(Ordering::Equal)
63 .then_with(|| {
64 other
65 .cost
66 .partial_cmp(&self.cost)
67 .unwrap_or(Ordering::Equal)
68 })
69 }
70}
71
72impl<N: Node + std::fmt::Debug, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd>
73 PartialOrd for AStarState<N, E>
74{
75 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76 Some(self.cmp(other))
77 }
78}
79
80#[deprecated(
104 since = "0.1.0-beta.2",
105 note = "Use `dijkstra_path` for future compatibility. This function will return PathResult in v1.0"
106)]
107#[allow(dead_code)]
108pub fn shortest_path<N, E, Ix>(
109 graph: &Graph<N, E, Ix>,
110 source: &N,
111 target: &N,
112) -> Result<Option<Path<N, E>>>
113where
114 N: Node + std::fmt::Debug,
115 E: EdgeWeight
116 + scirs2_core::numeric::Zero
117 + scirs2_core::numeric::One
118 + std::ops::Add<Output = E>
119 + PartialOrd
120 + std::marker::Copy
121 + std::fmt::Debug
122 + std::default::Default,
123 Ix: petgraph::graph::IndexType,
124{
125 if !graph.has_node(source) {
127 return Err(GraphError::InvalidGraph(format!(
128 "Source node {source:?} not found"
129 )));
130 }
131 if !graph.has_node(target) {
132 return Err(GraphError::InvalidGraph(format!(
133 "Target node {target:?} not found"
134 )));
135 }
136
137 let source_idx = graph
138 .inner()
139 .node_indices()
140 .find(|&idx| graph.inner()[idx] == *source)
141 .unwrap();
142 let target_idx = graph
143 .inner()
144 .node_indices()
145 .find(|&idx| graph.inner()[idx] == *target)
146 .unwrap();
147
148 let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
150
151 if !results.contains_key(&target_idx) {
153 return Ok(None);
154 }
155
156 let total_weight = results[&target_idx];
157
158 let mut path = Vec::new();
160 let mut current = target_idx;
161
162 path.push(graph.inner()[current].clone());
163
164 while current != source_idx {
166 let min_prev = graph
167 .inner()
168 .edges_directed(current, petgraph::Direction::Incoming)
169 .filter_map(|e| {
170 let from = e.source();
171 let edge_weight = *e.weight();
172
173 if let Some(from_dist) = results.get(&from) {
175 if *from_dist + edge_weight == results[¤t] {
177 return Some((from, *from_dist));
178 }
179 }
180 None
181 })
182 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
183
184 if let Some((prev, _)) = min_prev {
185 current = prev;
186 path.push(graph.inner()[current].clone());
187 } else {
188 return Err(GraphError::AlgorithmError(
190 "Failed to reconstruct path".to_string(),
191 ));
192 }
193 }
194
195 path.reverse();
197
198 Ok(Some(Path {
199 nodes: path,
200 total_weight,
201 }))
202}
203
204#[allow(dead_code)]
239pub fn dijkstra_path<N, E, Ix>(
240 graph: &Graph<N, E, Ix>,
241 source: &N,
242 target: &N,
243) -> Result<Option<Path<N, E>>>
244where
245 N: Node + std::fmt::Debug,
246 E: EdgeWeight
247 + scirs2_core::numeric::Zero
248 + scirs2_core::numeric::One
249 + std::ops::Add<Output = E>
250 + PartialOrd
251 + std::marker::Copy
252 + std::fmt::Debug
253 + std::default::Default,
254 Ix: petgraph::graph::IndexType,
255{
256 #[allow(deprecated)]
257 shortest_path(graph, source, target)
258}
259
260#[deprecated(
267 since = "0.1.0-beta.2",
268 note = "Use `dijkstra_path_digraph` for future compatibility. This function will return PathResult in v1.0"
269)]
270#[allow(dead_code)]
271pub fn shortest_path_digraph<N, E, Ix>(
272 graph: &DiGraph<N, E, Ix>,
273 source: &N,
274 target: &N,
275) -> Result<Option<Path<N, E>>>
276where
277 N: Node + std::fmt::Debug,
278 E: EdgeWeight
279 + scirs2_core::numeric::Zero
280 + scirs2_core::numeric::One
281 + std::ops::Add<Output = E>
282 + PartialOrd
283 + std::marker::Copy
284 + std::fmt::Debug
285 + std::default::Default,
286 Ix: petgraph::graph::IndexType,
287{
288 if !graph.has_node(source) {
291 return Err(GraphError::InvalidGraph(format!(
292 "Source node {source:?} not found"
293 )));
294 }
295 if !graph.has_node(target) {
296 return Err(GraphError::InvalidGraph(format!(
297 "Target node {target:?} not found"
298 )));
299 }
300
301 let source_idx = graph
302 .inner()
303 .node_indices()
304 .find(|&idx| graph.inner()[idx] == *source)
305 .unwrap();
306 let target_idx = graph
307 .inner()
308 .node_indices()
309 .find(|&idx| graph.inner()[idx] == *target)
310 .unwrap();
311
312 let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
314
315 if !results.contains_key(&target_idx) {
317 return Ok(None);
318 }
319
320 let total_weight = results[&target_idx];
321
322 let mut path = Vec::new();
324 let mut current = target_idx;
325
326 path.push(graph.inner()[current].clone());
327
328 while current != source_idx {
330 let min_prev = graph
331 .inner()
332 .edges_directed(current, petgraph::Direction::Incoming)
333 .filter_map(|e| {
334 let from = e.source();
335 let edge_weight = *e.weight();
336
337 if let Some(from_dist) = results.get(&from) {
339 if *from_dist + edge_weight == results[¤t] {
341 return Some((from, *from_dist));
342 }
343 }
344 None
345 })
346 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
347
348 if let Some((prev, _)) = min_prev {
349 current = prev;
350 path.push(graph.inner()[current].clone());
351 } else {
352 return Err(GraphError::AlgorithmError(
354 "Failed to reconstruct path".to_string(),
355 ));
356 }
357 }
358
359 path.reverse();
361
362 Ok(Some(Path {
363 nodes: path,
364 total_weight,
365 }))
366}
367
368#[allow(dead_code)]
389pub fn dijkstra_path_digraph<N, E, Ix>(
390 graph: &DiGraph<N, E, Ix>,
391 source: &N,
392 target: &N,
393) -> Result<Option<Path<N, E>>>
394where
395 N: Node + std::fmt::Debug,
396 E: EdgeWeight
397 + scirs2_core::numeric::Zero
398 + scirs2_core::numeric::One
399 + std::ops::Add<Output = E>
400 + PartialOrd
401 + std::marker::Copy
402 + std::fmt::Debug
403 + std::default::Default,
404 Ix: petgraph::graph::IndexType,
405{
406 #[allow(deprecated)]
407 shortest_path_digraph(graph, source, target)
408}
409
410#[allow(dead_code)]
433pub fn floyd_warshall<N, E, Ix>(
434 graph: &Graph<N, E, Ix>,
435) -> Result<scirs2_core::ndarray::Array2<f64>>
436where
437 N: Node + std::fmt::Debug,
438 E: EdgeWeight + Into<f64> + scirs2_core::numeric::Zero + Copy,
439 Ix: petgraph::graph::IndexType,
440{
441 let n = graph.node_count();
442
443 if n == 0 {
444 return Ok(scirs2_core::ndarray::Array2::zeros((0, 0)));
445 }
446
447 let mut dist = scirs2_core::ndarray::Array2::from_elem((n, n), f64::INFINITY);
449
450 for i in 0..n {
452 dist[[i, i]] = 0.0;
453 }
454
455 for edge in graph.inner().edge_references() {
457 let i = edge.source().index();
458 let j = edge.target().index();
459 let weight: f64 = (*edge.weight()).into();
460
461 dist[[i, j]] = weight;
462 dist[[j, i]] = weight;
464 }
465
466 for k in 0..n {
468 for i in 0..n {
469 for j in 0..n {
470 let alt = dist[[i, k]] + dist[[k, j]];
471 if alt < dist[[i, j]] {
472 dist[[i, j]] = alt;
473 }
474 }
475 }
476 }
477
478 Ok(dist)
479}
480
481#[allow(dead_code)]
483pub fn floyd_warshall_digraph<N, E, Ix>(
484 graph: &DiGraph<N, E, Ix>,
485) -> Result<scirs2_core::ndarray::Array2<f64>>
486where
487 N: Node + std::fmt::Debug,
488 E: EdgeWeight + Into<f64> + scirs2_core::numeric::Zero + Copy,
489 Ix: petgraph::graph::IndexType,
490{
491 let n = graph.node_count();
492
493 if n == 0 {
494 return Ok(scirs2_core::ndarray::Array2::zeros((0, 0)));
495 }
496
497 let mut dist = scirs2_core::ndarray::Array2::from_elem((n, n), f64::INFINITY);
499
500 for i in 0..n {
502 dist[[i, i]] = 0.0;
503 }
504
505 for edge in graph.inner().edge_references() {
507 let i = edge.source().index();
508 let j = edge.target().index();
509 let weight: f64 = (*edge.weight()).into();
510
511 dist[[i, j]] = weight;
512 }
513
514 for k in 0..n {
516 for i in 0..n {
517 for j in 0..n {
518 let alt = dist[[i, k]] + dist[[k, j]];
519 if alt < dist[[i, j]] {
520 dist[[i, j]] = alt;
521 }
522 }
523 }
524 }
525
526 Ok(dist)
527}
528
529#[allow(dead_code)]
554pub fn astar_search<N, E, Ix, H>(
555 graph: &Graph<N, E, Ix>,
556 source: &N,
557 target: &N,
558 heuristic: H,
559) -> Result<AStarResult<N, E>>
560where
561 N: Node + std::fmt::Debug + Clone + Hash + Eq,
562 E: EdgeWeight
563 + Clone
564 + std::ops::Add<Output = E>
565 + scirs2_core::numeric::Zero
566 + PartialOrd
567 + Copy,
568 Ix: petgraph::graph::IndexType,
569 H: Fn(&N) -> E,
570{
571 if !graph.contains_node(source) || !graph.contains_node(target) {
572 return Err(GraphError::node_not_found("node"));
573 }
574
575 let mut open_set = BinaryHeap::new();
576 let mut g_score: HashMap<N, E> = HashMap::new();
577 let mut came_from: HashMap<N, N> = HashMap::new();
578
579 g_score.insert(source.clone(), E::zero());
580
581 open_set.push(AStarState {
582 node: source.clone(),
583 cost: E::zero(),
584 heuristic: heuristic(source),
585 path: vec![source.clone()],
586 });
587
588 while let Some(current_state) = open_set.pop() {
589 let current = ¤t_state.node;
590
591 if current == target {
592 return Ok(AStarResult {
593 path: current_state.path,
594 cost: current_state.cost,
595 });
596 }
597
598 let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
599
600 if let Ok(neighbors) = graph.neighbors(current) {
601 for neighbor in neighbors {
602 if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
603 let tentative_g = current_g + edge_weight;
604
605 let current_neighbor_g = g_score.get(&neighbor);
606 if current_neighbor_g.is_none() || tentative_g < *current_neighbor_g.unwrap() {
607 came_from.insert(neighbor.clone(), current.clone());
608 g_score.insert(neighbor.clone(), tentative_g);
609
610 let mut new_path = current_state.path.clone();
611 new_path.push(neighbor.clone());
612
613 open_set.push(AStarState {
614 node: neighbor.clone(),
615 cost: tentative_g,
616 heuristic: heuristic(&neighbor),
617 path: new_path,
618 });
619 }
620 }
621 }
622 }
623 }
624
625 Err(GraphError::NoPath {
626 src_node: format!("{source:?}"),
627 target: format!("{target:?}"),
628 nodes: 0,
629 edges: 0,
630 })
631}
632
633#[allow(dead_code)]
635pub fn astar_search_digraph<N, E, Ix, H>(
636 graph: &DiGraph<N, E, Ix>,
637 source: &N,
638 target: &N,
639 heuristic: H,
640) -> Result<AStarResult<N, E>>
641where
642 N: Node + std::fmt::Debug + Clone + Hash + Eq,
643 E: EdgeWeight
644 + Clone
645 + std::ops::Add<Output = E>
646 + scirs2_core::numeric::Zero
647 + PartialOrd
648 + Copy,
649 Ix: petgraph::graph::IndexType,
650 H: Fn(&N) -> E,
651{
652 if !graph.contains_node(source) || !graph.contains_node(target) {
653 return Err(GraphError::node_not_found("node"));
654 }
655
656 let mut open_set = BinaryHeap::new();
657 let mut g_score: HashMap<N, E> = HashMap::new();
658 let mut came_from: HashMap<N, N> = HashMap::new();
659
660 g_score.insert(source.clone(), E::zero());
661
662 open_set.push(AStarState {
663 node: source.clone(),
664 cost: E::zero(),
665 heuristic: heuristic(source),
666 path: vec![source.clone()],
667 });
668
669 while let Some(current_state) = open_set.pop() {
670 let current = ¤t_state.node;
671
672 if current == target {
673 return Ok(AStarResult {
674 path: current_state.path,
675 cost: current_state.cost,
676 });
677 }
678
679 let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
680
681 if let Ok(successors) = graph.successors(current) {
682 for neighbor in successors {
683 if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
684 let tentative_g = current_g + edge_weight;
685
686 let current_neighbor_g = g_score.get(&neighbor);
687 if current_neighbor_g.is_none() || tentative_g < *current_neighbor_g.unwrap() {
688 came_from.insert(neighbor.clone(), current.clone());
689 g_score.insert(neighbor.clone(), tentative_g);
690
691 let mut new_path = current_state.path.clone();
692 new_path.push(neighbor.clone());
693
694 open_set.push(AStarState {
695 node: neighbor.clone(),
696 cost: tentative_g,
697 heuristic: heuristic(&neighbor),
698 path: new_path,
699 });
700 }
701 }
702 }
703 }
704 }
705
706 Err(GraphError::NoPath {
707 src_node: format!("{source:?}"),
708 target: format!("{target:?}"),
709 nodes: 0,
710 edges: 0,
711 })
712}
713
714#[allow(dead_code)]
719pub fn k_shortest_paths<N, E, Ix>(
720 graph: &Graph<N, E, Ix>,
721 source: &N,
722 target: &N,
723 k: usize,
724) -> Result<Vec<(f64, Vec<N>)>>
725where
726 N: Node + std::fmt::Debug + Clone + Hash + Eq + Ord,
727 E: EdgeWeight
728 + Into<f64>
729 + Clone
730 + scirs2_core::numeric::Zero
731 + scirs2_core::numeric::One
732 + std::ops::Add<Output = E>
733 + PartialOrd
734 + std::marker::Copy
735 + std::fmt::Debug
736 + std::default::Default,
737 Ix: petgraph::graph::IndexType,
738{
739 if k == 0 {
740 return Ok(vec![]);
741 }
742
743 if !graph.contains_node(source) || !graph.contains_node(target) {
745 return Err(GraphError::node_not_found("node"));
746 }
747
748 let mut paths = Vec::new();
749 let mut candidates = std::collections::BinaryHeap::new();
750
751 match dijkstra_path(graph, source, target) {
753 Ok(Some(path)) => {
754 let weight: f64 = path.total_weight.into();
755 paths.push((weight, path.nodes));
756 }
757 Ok(None) => return Ok(vec![]), Err(e) => return Err(e),
759 }
760
761 for i in 0..k - 1 {
763 if i >= paths.len() {
764 break;
765 }
766
767 let (_, prev_path) = &paths[i];
768
769 for j in 0..prev_path.len() - 1 {
771 let spur_node = &prev_path[j];
772 let root_path = &prev_path[..=j];
773
774 let mut removed_edges = Vec::new();
776
777 for (_, path) in &paths {
779 if path.len() > j && &path[..=j] == root_path && j + 1 < path.len() {
780 removed_edges.push((path[j].clone(), path[j + 1].clone()));
781 }
782 }
783
784 if let Ok((spur_weight, spur_path)) =
786 shortest_path_avoiding_edges(graph, spur_node, target, &removed_edges, root_path)
787 {
788 let mut total_weight = spur_weight;
790 for idx in 0..j {
791 if let Ok(edge_weight) = graph.edge_weight(&prev_path[idx], &prev_path[idx + 1])
792 {
793 let weight: f64 = edge_weight.into();
794 total_weight += weight;
795 }
796 }
797
798 let mut complete_path = root_path[..j].to_vec();
800 complete_path.extend(spur_path);
801
802 candidates.push((
804 std::cmp::Reverse(ordered_float::OrderedFloat(total_weight)),
805 complete_path.clone(),
806 ));
807 }
808 }
809 }
810
811 while paths.len() < k && !candidates.is_empty() {
813 let (std::cmp::Reverse(ordered_float::OrderedFloat(weight)), path) =
814 candidates.pop().unwrap();
815
816 let is_duplicate = paths.iter().any(|(_, p)| p == &path);
818 if !is_duplicate {
819 paths.push((weight, path));
820 }
821 }
822
823 Ok(paths)
824}
825
826#[allow(dead_code)]
828fn shortest_path_avoiding_edges<N, E, Ix>(
829 graph: &Graph<N, E, Ix>,
830 source: &N,
831 target: &N,
832 avoided_edges: &[(N, N)],
833 excluded_nodes: &[N],
834) -> Result<(f64, Vec<N>)>
835where
836 N: Node + std::fmt::Debug + Clone + Hash + Eq + Ord,
837 E: EdgeWeight + Into<f64>,
838 Ix: petgraph::graph::IndexType,
839{
840 use std::cmp::Reverse;
841
842 let mut distances: HashMap<N, f64> = HashMap::new();
843 let mut previous: HashMap<N, N> = HashMap::new();
844 let mut heap = BinaryHeap::new();
845
846 distances.insert(source.clone(), 0.0);
847 heap.push((Reverse(ordered_float::OrderedFloat(0.0)), source.clone()));
848
849 while let Some((Reverse(ordered_float::OrderedFloat(dist)), node)) = heap.pop() {
850 if &node == target {
851 let mut path = vec![target.clone()];
853 let mut current = target.clone();
854
855 while let Some(prev) = previous.get(¤t) {
856 path.push(prev.clone());
857 current = prev.clone();
858 }
859
860 path.reverse();
861 return Ok((dist, path));
862 }
863
864 if distances.get(&node).is_none_or(|&d| dist > d) {
865 continue;
866 }
867
868 if let Ok(neighbors) = graph.neighbors(&node) {
869 for neighbor in neighbors {
870 if avoided_edges.contains(&(node.clone(), neighbor.clone())) {
872 continue;
873 }
874
875 if &neighbor != source && &neighbor != target && excluded_nodes.contains(&neighbor)
877 {
878 continue;
879 }
880
881 if let Ok(edge_weight) = graph.edge_weight(&node, &neighbor) {
882 let weight: f64 = edge_weight.into();
883 let new_dist = dist + weight;
884
885 if new_dist < *distances.get(&neighbor).unwrap_or(&f64::INFINITY) {
886 distances.insert(neighbor.clone(), new_dist);
887 previous.insert(neighbor.clone(), node.clone());
888 heap.push((Reverse(ordered_float::OrderedFloat(new_dist)), neighbor));
889 }
890 }
891 }
892 }
893 }
894
895 Err(GraphError::NoPath {
896 src_node: format!("{source:?}"),
897 target: format!("{target:?}"),
898 nodes: 0,
899 edges: 0,
900 })
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906
907 #[test]
908 #[allow(deprecated)]
909 fn test_shortest_path() {
910 let mut graph: Graph<i32, f64> = Graph::new();
911
912 graph.add_edge(1, 2, 4.0).unwrap();
914 graph.add_edge(1, 3, 2.0).unwrap();
915 graph.add_edge(2, 3, 1.0).unwrap();
916 graph.add_edge(2, 4, 5.0).unwrap();
917 graph.add_edge(3, 4, 8.0).unwrap();
918
919 let path = shortest_path(&graph, &1, &4).unwrap().unwrap();
920
921 assert_eq!(path.total_weight, 8.0);
923 assert_eq!(path.nodes, vec![1, 3, 2, 4]);
924 }
925
926 #[test]
927 fn test_floyd_warshall() {
928 let mut graph: Graph<i32, f64> = Graph::new();
929
930 graph.add_edge(0, 1, 1.0).unwrap();
932 graph.add_edge(1, 2, 2.0).unwrap();
933 graph.add_edge(2, 0, 3.0).unwrap();
934
935 let distances = floyd_warshall(&graph).unwrap();
936
937 assert_eq!(distances[[0, 0]], 0.0);
939 assert_eq!(distances[[0, 1]], 1.0);
940 assert_eq!(distances[[0, 2]], 3.0); assert_eq!(distances[[1, 0]], 1.0); }
943
944 #[test]
945 fn test_astar_search() {
946 let mut graph: Graph<(i32, i32), f64> = Graph::new();
947
948 graph.add_edge((0, 0), (0, 1), 1.0).unwrap();
950 graph.add_edge((0, 1), (1, 1), 1.0).unwrap();
951 graph.add_edge((1, 1), (1, 0), 1.0).unwrap();
952 graph.add_edge((1, 0), (0, 0), 1.0).unwrap();
953
954 let heuristic = |&(x, y): &(i32, i32)| -> f64 { ((1 - x).abs() + (1 - y).abs()) as f64 };
956
957 let result = astar_search(&graph, &(0, 0), &(1, 1), heuristic);
958 let result = result.unwrap();
959 assert_eq!(result.cost, 2.0);
960 assert_eq!(result.path.len(), 3); }
962
963 #[test]
964 fn test_k_shortest_paths() {
965 let mut graph: Graph<char, f64> = Graph::new();
966
967 graph.add_edge('A', 'B', 2.0).unwrap();
969 graph.add_edge('B', 'D', 2.0).unwrap();
970 graph.add_edge('A', 'C', 1.0).unwrap();
971 graph.add_edge('C', 'D', 4.0).unwrap();
972 graph.add_edge('B', 'C', 1.0).unwrap();
973
974 let paths = k_shortest_paths(&graph, &'A', &'D', 3).unwrap();
975
976 assert!(paths.len() >= 2);
977 assert_eq!(paths[0].0, 4.0); assert_eq!(paths[0].1, vec!['A', 'B', 'D']);
979 }
980}