1use petgraph::algo::dijkstra;
10use petgraph::visit::EdgeRef;
11use std::cmp::Ordering;
12use std::collections::{BinaryHeap, HashMap};
13use std::hash::Hash;
14
15use crate::base::{DiGraph, EdgeWeight, Graph, Node};
16use crate::error::{GraphError, Result};
17
18#[derive(Debug, Clone)]
20pub struct Path<N: Node, E: EdgeWeight> {
21 pub nodes: Vec<N>,
23 pub total_weight: E,
25}
26
27#[derive(Debug, Clone)]
29pub struct AStarResult<N: Node, E: EdgeWeight> {
30 pub path: Vec<N>,
32 pub cost: E,
34}
35
36#[derive(Clone)]
38struct AStarState<N: Node, E: EdgeWeight> {
39 node: N,
40 cost: E,
41 heuristic: E,
42 path: Vec<N>,
43}
44
45impl<N: Node, E: EdgeWeight> PartialEq for AStarState<N, E> {
46 fn eq(&self, other: &Self) -> bool {
47 self.node == other.node
48 }
49}
50
51impl<N: Node, E: EdgeWeight> Eq for AStarState<N, E> {}
52
53impl<N: Node, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd> Ord
54 for AStarState<N, E>
55{
56 fn cmp(&self, other: &Self) -> Ordering {
57 let self_total = self.cost + self.heuristic;
59 let other_total = other.cost + other.heuristic;
60 other_total
61 .partial_cmp(&self_total)
62 .unwrap_or(Ordering::Equal)
63 .then_with(|| {
64 other
65 .cost
66 .partial_cmp(&self.cost)
67 .unwrap_or(Ordering::Equal)
68 })
69 }
70}
71
72impl<N: Node, E: EdgeWeight + std::ops::Add<Output = E> + Copy + PartialOrd> PartialOrd
73 for AStarState<N, E>
74{
75 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
76 Some(self.cmp(other))
77 }
78}
79
80pub fn shortest_path<N, E, Ix>(
92 graph: &Graph<N, E, Ix>,
93 source: &N,
94 target: &N,
95) -> Result<Option<Path<N, E>>>
96where
97 N: Node + std::fmt::Debug,
98 E: EdgeWeight
99 + num_traits::Zero
100 + num_traits::One
101 + std::ops::Add<Output = E>
102 + PartialOrd
103 + std::marker::Copy
104 + std::fmt::Debug
105 + std::default::Default,
106 Ix: petgraph::graph::IndexType,
107{
108 if !graph.has_node(source) {
110 return Err(GraphError::InvalidGraph(format!(
111 "Source node {:?} not found",
112 source
113 )));
114 }
115 if !graph.has_node(target) {
116 return Err(GraphError::InvalidGraph(format!(
117 "Target node {:?} not found",
118 target
119 )));
120 }
121
122 let source_idx = graph
123 .inner()
124 .node_indices()
125 .find(|&idx| graph.inner()[idx] == *source)
126 .unwrap();
127 let target_idx = graph
128 .inner()
129 .node_indices()
130 .find(|&idx| graph.inner()[idx] == *target)
131 .unwrap();
132
133 let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
135
136 if !results.contains_key(&target_idx) {
138 return Ok(None);
139 }
140
141 let total_weight = results[&target_idx];
142
143 let mut path = Vec::new();
145 let mut current = target_idx;
146
147 path.push(graph.inner()[current].clone());
148
149 while current != source_idx {
151 let min_prev = graph
152 .inner()
153 .edges_directed(current, petgraph::Direction::Incoming)
154 .filter_map(|e| {
155 let from = e.source();
156 let edge_weight = *e.weight();
157
158 if let Some(from_dist) = results.get(&from) {
160 if *from_dist + edge_weight == results[¤t] {
162 return Some((from, *from_dist));
163 }
164 }
165 None
166 })
167 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
168
169 if let Some((prev, _)) = min_prev {
170 current = prev;
171 path.push(graph.inner()[current].clone());
172 } else {
173 return Err(GraphError::AlgorithmError(
175 "Failed to reconstruct path".to_string(),
176 ));
177 }
178 }
179
180 path.reverse();
182
183 Ok(Some(Path {
184 nodes: path,
185 total_weight,
186 }))
187}
188
189pub fn shortest_path_digraph<N, E, Ix>(
191 graph: &DiGraph<N, E, Ix>,
192 source: &N,
193 target: &N,
194) -> Result<Option<Path<N, E>>>
195where
196 N: Node + std::fmt::Debug,
197 E: EdgeWeight
198 + num_traits::Zero
199 + num_traits::One
200 + std::ops::Add<Output = E>
201 + PartialOrd
202 + std::marker::Copy
203 + std::fmt::Debug
204 + std::default::Default,
205 Ix: petgraph::graph::IndexType,
206{
207 if !graph.has_node(source) {
210 return Err(GraphError::InvalidGraph(format!(
211 "Source node {:?} not found",
212 source
213 )));
214 }
215 if !graph.has_node(target) {
216 return Err(GraphError::InvalidGraph(format!(
217 "Target node {:?} not found",
218 target
219 )));
220 }
221
222 let source_idx = graph
223 .inner()
224 .node_indices()
225 .find(|&idx| graph.inner()[idx] == *source)
226 .unwrap();
227 let target_idx = graph
228 .inner()
229 .node_indices()
230 .find(|&idx| graph.inner()[idx] == *target)
231 .unwrap();
232
233 let results = dijkstra(graph.inner(), source_idx, Some(target_idx), |e| *e.weight());
235
236 if !results.contains_key(&target_idx) {
238 return Ok(None);
239 }
240
241 let total_weight = results[&target_idx];
242
243 let mut path = Vec::new();
245 let mut current = target_idx;
246
247 path.push(graph.inner()[current].clone());
248
249 while current != source_idx {
251 let min_prev = graph
252 .inner()
253 .edges_directed(current, petgraph::Direction::Incoming)
254 .filter_map(|e| {
255 let from = e.source();
256 let edge_weight = *e.weight();
257
258 if let Some(from_dist) = results.get(&from) {
260 if *from_dist + edge_weight == results[¤t] {
262 return Some((from, *from_dist));
263 }
264 }
265 None
266 })
267 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
268
269 if let Some((prev, _)) = min_prev {
270 current = prev;
271 path.push(graph.inner()[current].clone());
272 } else {
273 return Err(GraphError::AlgorithmError(
275 "Failed to reconstruct path".to_string(),
276 ));
277 }
278 }
279
280 path.reverse();
282
283 Ok(Some(Path {
284 nodes: path,
285 total_weight,
286 }))
287}
288
289pub fn floyd_warshall<N, E, Ix>(graph: &Graph<N, E, Ix>) -> Result<ndarray::Array2<f64>>
300where
301 N: Node,
302 E: EdgeWeight + Into<f64> + num_traits::Zero + Copy,
303 Ix: petgraph::graph::IndexType,
304{
305 let n = graph.node_count();
306
307 if n == 0 {
308 return Ok(ndarray::Array2::zeros((0, 0)));
309 }
310
311 let mut dist = ndarray::Array2::from_elem((n, n), f64::INFINITY);
313
314 for i in 0..n {
316 dist[[i, i]] = 0.0;
317 }
318
319 for edge in graph.inner().edge_references() {
321 let i = edge.source().index();
322 let j = edge.target().index();
323 let weight: f64 = (*edge.weight()).into();
324
325 dist[[i, j]] = weight;
326 dist[[j, i]] = weight;
328 }
329
330 for k in 0..n {
332 for i in 0..n {
333 for j in 0..n {
334 let alt = dist[[i, k]] + dist[[k, j]];
335 if alt < dist[[i, j]] {
336 dist[[i, j]] = alt;
337 }
338 }
339 }
340 }
341
342 Ok(dist)
343}
344
345pub fn floyd_warshall_digraph<N, E, Ix>(graph: &DiGraph<N, E, Ix>) -> Result<ndarray::Array2<f64>>
347where
348 N: Node,
349 E: EdgeWeight + Into<f64> + num_traits::Zero + Copy,
350 Ix: petgraph::graph::IndexType,
351{
352 let n = graph.node_count();
353
354 if n == 0 {
355 return Ok(ndarray::Array2::zeros((0, 0)));
356 }
357
358 let mut dist = ndarray::Array2::from_elem((n, n), f64::INFINITY);
360
361 for i in 0..n {
363 dist[[i, i]] = 0.0;
364 }
365
366 for edge in graph.inner().edge_references() {
368 let i = edge.source().index();
369 let j = edge.target().index();
370 let weight: f64 = (*edge.weight()).into();
371
372 dist[[i, j]] = weight;
373 }
374
375 for k in 0..n {
377 for i in 0..n {
378 for j in 0..n {
379 let alt = dist[[i, k]] + dist[[k, j]];
380 if alt < dist[[i, j]] {
381 dist[[i, j]] = alt;
382 }
383 }
384 }
385 }
386
387 Ok(dist)
388}
389
390pub fn astar_search<N, E, Ix, H>(
401 graph: &Graph<N, E, Ix>,
402 start: &N,
403 goal: &N,
404 heuristic: H,
405) -> Result<AStarResult<N, E>>
406where
407 N: Node + Clone + Hash + Eq,
408 E: EdgeWeight + Clone + std::ops::Add<Output = E> + num_traits::Zero + PartialOrd + Copy,
409 Ix: petgraph::graph::IndexType,
410 H: Fn(&N) -> E,
411{
412 if !graph.contains_node(start) || !graph.contains_node(goal) {
413 return Err(GraphError::NodeNotFound);
414 }
415
416 let mut open_set = BinaryHeap::new();
417 let mut g_score: HashMap<N, E> = HashMap::new();
418 let mut came_from: HashMap<N, N> = HashMap::new();
419
420 g_score.insert(start.clone(), E::zero());
421
422 open_set.push(AStarState {
423 node: start.clone(),
424 cost: E::zero(),
425 heuristic: heuristic(start),
426 path: vec![start.clone()],
427 });
428
429 while let Some(current_state) = open_set.pop() {
430 let current = ¤t_state.node;
431
432 if current == goal {
433 return Ok(AStarResult {
434 path: current_state.path,
435 cost: current_state.cost,
436 });
437 }
438
439 let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
440
441 if let Ok(neighbors) = graph.neighbors(current) {
442 for neighbor in neighbors {
443 if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
444 let tentative_g = current_g + edge_weight;
445
446 let current_neighbor_g = g_score.get(&neighbor);
447 if current_neighbor_g.is_none() || tentative_g < *current_neighbor_g.unwrap() {
448 came_from.insert(neighbor.clone(), current.clone());
449 g_score.insert(neighbor.clone(), tentative_g);
450
451 let mut new_path = current_state.path.clone();
452 new_path.push(neighbor.clone());
453
454 open_set.push(AStarState {
455 node: neighbor.clone(),
456 cost: tentative_g,
457 heuristic: heuristic(&neighbor),
458 path: new_path,
459 });
460 }
461 }
462 }
463 }
464 }
465
466 Err(GraphError::NoPath)
467}
468
469pub fn astar_search_digraph<N, E, Ix, H>(
471 graph: &DiGraph<N, E, Ix>,
472 start: &N,
473 goal: &N,
474 heuristic: H,
475) -> Result<AStarResult<N, E>>
476where
477 N: Node + Clone + Hash + Eq,
478 E: EdgeWeight + Clone + std::ops::Add<Output = E> + num_traits::Zero + PartialOrd + Copy,
479 Ix: petgraph::graph::IndexType,
480 H: Fn(&N) -> E,
481{
482 if !graph.contains_node(start) || !graph.contains_node(goal) {
483 return Err(GraphError::NodeNotFound);
484 }
485
486 let mut open_set = BinaryHeap::new();
487 let mut g_score: HashMap<N, E> = HashMap::new();
488 let mut came_from: HashMap<N, N> = HashMap::new();
489
490 g_score.insert(start.clone(), E::zero());
491
492 open_set.push(AStarState {
493 node: start.clone(),
494 cost: E::zero(),
495 heuristic: heuristic(start),
496 path: vec![start.clone()],
497 });
498
499 while let Some(current_state) = open_set.pop() {
500 let current = ¤t_state.node;
501
502 if current == goal {
503 return Ok(AStarResult {
504 path: current_state.path,
505 cost: current_state.cost,
506 });
507 }
508
509 let current_g = g_score.get(current).cloned().unwrap_or_else(E::zero);
510
511 if let Ok(successors) = graph.successors(current) {
512 for neighbor in successors {
513 if let Ok(edge_weight) = graph.edge_weight(current, &neighbor) {
514 let tentative_g = current_g + edge_weight;
515
516 let current_neighbor_g = g_score.get(&neighbor);
517 if current_neighbor_g.is_none() || tentative_g < *current_neighbor_g.unwrap() {
518 came_from.insert(neighbor.clone(), current.clone());
519 g_score.insert(neighbor.clone(), tentative_g);
520
521 let mut new_path = current_state.path.clone();
522 new_path.push(neighbor.clone());
523
524 open_set.push(AStarState {
525 node: neighbor.clone(),
526 cost: tentative_g,
527 heuristic: heuristic(&neighbor),
528 path: new_path,
529 });
530 }
531 }
532 }
533 }
534 }
535
536 Err(GraphError::NoPath)
537}
538
539pub fn k_shortest_paths<N, E, Ix>(
544 graph: &Graph<N, E, Ix>,
545 start: &N,
546 goal: &N,
547 k: usize,
548) -> Result<Vec<(f64, Vec<N>)>>
549where
550 N: Node + Clone + Hash + Eq + Ord + std::fmt::Debug,
551 E: EdgeWeight
552 + Into<f64>
553 + Clone
554 + num_traits::Zero
555 + num_traits::One
556 + std::ops::Add<Output = E>
557 + PartialOrd
558 + std::marker::Copy
559 + std::fmt::Debug
560 + std::default::Default,
561 Ix: petgraph::graph::IndexType,
562{
563 if k == 0 {
564 return Ok(vec![]);
565 }
566
567 if !graph.contains_node(start) || !graph.contains_node(goal) {
569 return Err(GraphError::NodeNotFound);
570 }
571
572 let mut paths = Vec::new();
573 let mut candidates = std::collections::BinaryHeap::new();
574
575 match shortest_path(graph, start, goal) {
577 Ok(Some(path)) => {
578 let weight: f64 = path.total_weight.into();
579 paths.push((weight, path.nodes));
580 }
581 Ok(None) => return Ok(vec![]), Err(e) => return Err(e),
583 }
584
585 for i in 0..k - 1 {
587 if i >= paths.len() {
588 break;
589 }
590
591 let (_, prev_path) = &paths[i];
592
593 for j in 0..prev_path.len() - 1 {
595 let spur_node = &prev_path[j];
596 let root_path = &prev_path[..=j];
597
598 let mut removed_edges = Vec::new();
600
601 for (_, path) in &paths {
603 if path.len() > j && &path[..=j] == root_path && j + 1 < path.len() {
604 removed_edges.push((path[j].clone(), path[j + 1].clone()));
605 }
606 }
607
608 if let Ok((spur_weight, spur_path)) =
610 shortest_path_avoiding_edges(graph, spur_node, goal, &removed_edges, root_path)
611 {
612 let mut total_weight = spur_weight;
614 for idx in 0..j {
615 if let Ok(edge_weight) = graph.edge_weight(&prev_path[idx], &prev_path[idx + 1])
616 {
617 let weight: f64 = edge_weight.into();
618 total_weight += weight;
619 }
620 }
621
622 let mut complete_path = root_path[..j].to_vec();
624 complete_path.extend(spur_path);
625
626 candidates.push((
628 std::cmp::Reverse(ordered_float::OrderedFloat(total_weight)),
629 complete_path.clone(),
630 ));
631 }
632 }
633 }
634
635 while paths.len() < k && !candidates.is_empty() {
637 let (std::cmp::Reverse(ordered_float::OrderedFloat(weight)), path) =
638 candidates.pop().unwrap();
639
640 let is_duplicate = paths.iter().any(|(_, p)| p == &path);
642 if !is_duplicate {
643 paths.push((weight, path));
644 }
645 }
646
647 Ok(paths)
648}
649
650fn shortest_path_avoiding_edges<N, E, Ix>(
652 graph: &Graph<N, E, Ix>,
653 start: &N,
654 goal: &N,
655 avoided_edges: &[(N, N)],
656 excluded_nodes: &[N],
657) -> Result<(f64, Vec<N>)>
658where
659 N: Node + Clone + Hash + Eq + Ord,
660 E: EdgeWeight + Into<f64>,
661 Ix: petgraph::graph::IndexType,
662{
663 use std::cmp::Reverse;
664
665 let mut distances: HashMap<N, f64> = HashMap::new();
666 let mut previous: HashMap<N, N> = HashMap::new();
667 let mut heap = BinaryHeap::new();
668
669 distances.insert(start.clone(), 0.0);
670 heap.push((Reverse(ordered_float::OrderedFloat(0.0)), start.clone()));
671
672 while let Some((Reverse(ordered_float::OrderedFloat(dist)), node)) = heap.pop() {
673 if &node == goal {
674 let mut path = vec![goal.clone()];
676 let mut current = goal.clone();
677
678 while let Some(prev) = previous.get(¤t) {
679 path.push(prev.clone());
680 current = prev.clone();
681 }
682
683 path.reverse();
684 return Ok((dist, path));
685 }
686
687 if distances.get(&node).is_none_or(|&d| dist > d) {
688 continue;
689 }
690
691 if let Ok(neighbors) = graph.neighbors(&node) {
692 for neighbor in neighbors {
693 if avoided_edges.contains(&(node.clone(), neighbor.clone())) {
695 continue;
696 }
697
698 if &neighbor != start && &neighbor != goal && excluded_nodes.contains(&neighbor) {
700 continue;
701 }
702
703 if let Ok(edge_weight) = graph.edge_weight(&node, &neighbor) {
704 let weight: f64 = edge_weight.into();
705 let new_dist = dist + weight;
706
707 if new_dist < *distances.get(&neighbor).unwrap_or(&f64::INFINITY) {
708 distances.insert(neighbor.clone(), new_dist);
709 previous.insert(neighbor.clone(), node.clone());
710 heap.push((Reverse(ordered_float::OrderedFloat(new_dist)), neighbor));
711 }
712 }
713 }
714 }
715 }
716
717 Err(GraphError::NoPath)
718}
719
720#[cfg(test)]
721mod tests {
722 use super::*;
723
724 #[test]
725 fn test_shortest_path() {
726 let mut graph: Graph<i32, f64> = Graph::new();
727
728 graph.add_edge(1, 2, 4.0).unwrap();
730 graph.add_edge(1, 3, 2.0).unwrap();
731 graph.add_edge(2, 3, 1.0).unwrap();
732 graph.add_edge(2, 4, 5.0).unwrap();
733 graph.add_edge(3, 4, 8.0).unwrap();
734
735 let path = shortest_path(&graph, &1, &4).unwrap().unwrap();
736
737 assert_eq!(path.total_weight, 8.0);
739 assert_eq!(path.nodes, vec![1, 3, 2, 4]);
740 }
741
742 #[test]
743 fn test_floyd_warshall() {
744 let mut graph: Graph<i32, f64> = Graph::new();
745
746 graph.add_edge(0, 1, 1.0).unwrap();
748 graph.add_edge(1, 2, 2.0).unwrap();
749 graph.add_edge(2, 0, 3.0).unwrap();
750
751 let distances = floyd_warshall(&graph).unwrap();
752
753 assert_eq!(distances[[0, 0]], 0.0);
755 assert_eq!(distances[[0, 1]], 1.0);
756 assert_eq!(distances[[0, 2]], 3.0); assert_eq!(distances[[1, 0]], 1.0); }
759
760 #[test]
761 fn test_astar_search() {
762 let mut graph: Graph<(i32, i32), f64> = Graph::new();
763
764 graph.add_edge((0, 0), (0, 1), 1.0).unwrap();
766 graph.add_edge((0, 1), (1, 1), 1.0).unwrap();
767 graph.add_edge((1, 1), (1, 0), 1.0).unwrap();
768 graph.add_edge((1, 0), (0, 0), 1.0).unwrap();
769
770 let heuristic = |&(x, y): &(i32, i32)| -> f64 { ((1 - x).abs() + (1 - y).abs()) as f64 };
772
773 let result = astar_search(&graph, &(0, 0), &(1, 1), heuristic);
774 let result = result.unwrap();
775 assert_eq!(result.cost, 2.0);
776 assert_eq!(result.path.len(), 3); }
778
779 #[test]
780 fn test_k_shortest_paths() {
781 let mut graph: Graph<char, f64> = Graph::new();
782
783 graph.add_edge('A', 'B', 2.0).unwrap();
785 graph.add_edge('B', 'D', 2.0).unwrap();
786 graph.add_edge('A', 'C', 1.0).unwrap();
787 graph.add_edge('C', 'D', 4.0).unwrap();
788 graph.add_edge('B', 'C', 1.0).unwrap();
789
790 let paths = k_shortest_paths(&graph, &'A', &'D', 3).unwrap();
791
792 assert!(paths.len() >= 2);
793 assert_eq!(paths[0].0, 4.0); assert_eq!(paths[0].1, vec!['A', 'B', 'D']);
795 }
796}