1use super::{num_vertices, to_adjacency_list, validate_graph, PriorityQueueNode};
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use ndarray::{Array1, Array2};
10use num_traits::Float;
11use std::collections::BinaryHeap;
12use std::fmt::Debug;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum ShortestPathMethod {
17 Dijkstra,
19 BellmanFord,
21 FloydWarshall,
23 Auto,
25}
26
27impl ShortestPathMethod {
28 #[allow(clippy::should_implement_trait)]
29 pub fn from_str(s: &str) -> SparseResult<Self> {
30 match s.to_lowercase().as_str() {
31 "dijkstra" | "dij" => Ok(Self::Dijkstra),
32 "bellman-ford" | "bellman_ford" | "bf" => Ok(Self::BellmanFord),
33 "floyd-warshall" | "floyd_warshall" | "fw" => Ok(Self::FloydWarshall),
34 "auto" => Ok(Self::Auto),
35 _ => Err(SparseError::ValueError(format!(
36 "Unknown shortest path method: {s}"
37 ))),
38 }
39 }
40}
41
42#[allow(dead_code)]
77#[allow(clippy::too_many_arguments)]
78pub fn shortest_path<T, S>(
79 graph: &S,
80 from_vertex: Option<usize>,
81 to_vertex: Option<usize>,
82 method: &str,
83 directed: bool,
84 returnpredecessors: bool,
85) -> SparseResult<(Array2<T>, Option<Array2<isize>>)>
86where
87 T: Float + Debug + Copy + 'static,
88 S: SparseArray<T>,
89{
90 validate_graph(graph, directed)?;
91 let method = ShortestPathMethod::from_str(method)?;
92 let n = num_vertices(graph);
93
94 match (from_vertex, to_vertex) {
95 (None, None) => {
96 all_pairs_shortest_path(graph, method, directed, returnpredecessors)
98 }
99 (Some(source), None) => {
100 let (distances_, predecessors) =
102 single_source_shortest_path(graph, source, method, directed, returnpredecessors)?;
103
104 let mut dist_matrix = Array2::from_elem((n, n), T::infinity());
106 let mut pred_matrix = if returnpredecessors {
107 Some(Array2::from_elem((n, n), -1isize))
108 } else {
109 None
110 };
111
112 for i in 0..n {
113 dist_matrix[[source, i]] = distances_[i];
114 if let Some(ref preds) = predecessors {
115 if let Some(ref mut pred_mat) = pred_matrix {
116 pred_mat[[source, i]] = preds[i];
117 }
118 }
119 }
120
121 Ok((dist_matrix, pred_matrix))
122 }
123 (Some(source), Some(target)) => {
124 let (distances_, predecessors) =
126 single_source_shortest_path(graph, source, method, directed, returnpredecessors)?;
127
128 let dist_matrix = Array2::from_elem((1, 1), distances_[target]);
129 let pred_matrix = if returnpredecessors {
130 let mut pred_mat = Array2::from_elem((1, 1), -1isize);
131 if let Some(ref preds) = predecessors {
132 pred_mat[[0, 0]] = preds[target];
133 }
134 Some(pred_mat)
135 } else {
136 None
137 };
138
139 Ok((dist_matrix, pred_matrix))
140 }
141 (None, Some(_)) => Err(SparseError::ValueError(
142 "Cannot specify target _vertex without source _vertex".to_string(),
143 )),
144 }
145}
146
147#[allow(dead_code)]
149pub fn single_source_shortest_path<T, S>(
150 graph: &S,
151 source: usize,
152 method: ShortestPathMethod,
153 directed: bool,
154 returnpredecessors: bool,
155) -> SparseResult<(Array1<T>, Option<Array1<isize>>)>
156where
157 T: Float + Debug + Copy + 'static,
158 S: SparseArray<T>,
159{
160 let n = num_vertices(graph);
161
162 if source >= n {
163 return Err(SparseError::ValueError(format!(
164 "Source vertex {source} out of bounds for graph with {n} vertices"
165 )));
166 }
167
168 let actual_method = match method {
169 ShortestPathMethod::Auto => {
170 let (_, _, values) = graph.find();
172 if values.iter().any(|&w| w < T::zero()) {
173 ShortestPathMethod::BellmanFord
174 } else {
175 ShortestPathMethod::Dijkstra
176 }
177 }
178 m => m,
179 };
180
181 match actual_method {
182 ShortestPathMethod::Dijkstra => {
183 dijkstra_single_source(graph, source, directed, returnpredecessors)
184 }
185 ShortestPathMethod::BellmanFord => {
186 bellman_ford_single_source(graph, source, directed, returnpredecessors)
187 }
188 _ => Err(SparseError::ValueError(
189 "Method not supported for single source shortest paths".to_string(),
190 )),
191 }
192}
193
194#[allow(dead_code)]
196pub fn all_pairs_shortest_path<T, S>(
197 graph: &S,
198 method: ShortestPathMethod,
199 directed: bool,
200 returnpredecessors: bool,
201) -> SparseResult<(Array2<T>, Option<Array2<isize>>)>
202where
203 T: Float + Debug + Copy + 'static,
204 S: SparseArray<T>,
205{
206 let n = num_vertices(graph);
207
208 let actual_method = match method {
209 ShortestPathMethod::Auto => {
210 if n <= 100 {
211 ShortestPathMethod::FloydWarshall
212 } else {
213 ShortestPathMethod::Dijkstra
214 }
215 }
216 m => m,
217 };
218
219 match actual_method {
220 ShortestPathMethod::FloydWarshall => floyd_warshall(graph, directed, returnpredecessors),
221 ShortestPathMethod::Dijkstra => {
222 let mut distances = Array2::from_elem((n, n), T::infinity());
224 let mut predecessors = if returnpredecessors {
225 Some(Array2::from_elem((n, n), -1isize))
226 } else {
227 None
228 };
229
230 for source in 0..n {
231 let (dist, pred) =
232 dijkstra_single_source(graph, source, directed, returnpredecessors)?;
233
234 for target in 0..n {
235 distances[[source, target]] = dist[target];
236 if let Some(ref pred_vec) = pred {
237 if let Some(ref mut pred_matrix) = predecessors {
238 pred_matrix[[source, target]] = pred_vec[target];
239 }
240 }
241 }
242 }
243
244 Ok((distances, predecessors))
245 }
246 _ => Err(SparseError::ValueError(
247 "Method not supported for all pairs shortest paths".to_string(),
248 )),
249 }
250}
251
252#[allow(dead_code)]
254pub fn dijkstra_single_source<T, S>(
255 graph: &S,
256 source: usize,
257 directed: bool,
258 returnpredecessors: bool,
259) -> SparseResult<(Array1<T>, Option<Array1<isize>>)>
260where
261 T: Float + Debug + Copy + 'static,
262 S: SparseArray<T>,
263{
264 let n = num_vertices(graph);
265 let adj_list = to_adjacency_list(graph, directed)?;
266
267 let mut distances = Array1::from_elem(n, T::infinity());
268 let mut predecessors = if returnpredecessors {
269 Some(Array1::from_elem(n, -1isize))
270 } else {
271 None
272 };
273
274 distances[source] = T::zero();
275
276 let mut heap = BinaryHeap::new();
277 heap.push(PriorityQueueNode {
278 distance: T::zero(),
279 node: source,
280 });
281
282 let mut visited = vec![false; n];
283
284 while let Some(PriorityQueueNode { distance, node }) = heap.pop() {
285 if visited[node] {
286 continue;
287 }
288
289 visited[node] = true;
290
291 if distance == T::infinity() {
293 break;
294 }
295
296 for &(neighbor, weight) in &adj_list[node] {
297 if visited[neighbor] {
298 continue;
299 }
300
301 let new_distance = distance + weight;
302
303 if new_distance < distances[neighbor] {
304 distances[neighbor] = new_distance;
305
306 if let Some(ref mut preds) = predecessors {
307 preds[neighbor] = node as isize;
308 }
309
310 heap.push(PriorityQueueNode {
311 distance: new_distance,
312 node: neighbor,
313 });
314 }
315 }
316 }
317
318 Ok((distances, predecessors))
319}
320
321#[allow(dead_code)]
323pub fn bellman_ford_single_source<T, S>(
324 graph: &S,
325 source: usize,
326 directed: bool,
327 returnpredecessors: bool,
328) -> SparseResult<(Array1<T>, Option<Array1<isize>>)>
329where
330 T: Float + Debug + Copy + 'static,
331 S: SparseArray<T>,
332{
333 let n = num_vertices(graph);
334 let (row_indices, col_indices, values) = graph.find();
335
336 let mut distances = Array1::from_elem(n, T::infinity());
337 let mut predecessors = if returnpredecessors {
338 Some(Array1::from_elem(n, -1isize))
339 } else {
340 None
341 };
342
343 distances[source] = T::zero();
344
345 let mut edges = Vec::new();
347 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
348 let weight = values[i];
349 if !weight.is_zero() {
350 edges.push((row, col, weight));
351
352 if !directed && row != col {
354 edges.push((col, row, weight));
355 }
356 }
357 }
358
359 for _ in 0..(n - 1) {
361 let mut updated = false;
362
363 for &(u, v, weight) in &edges {
364 if distances[u] != T::infinity() {
365 let new_distance = distances[u] + weight;
366
367 if new_distance < distances[v] {
368 distances[v] = new_distance;
369
370 if let Some(ref mut preds) = predecessors {
371 preds[v] = u as isize;
372 }
373
374 updated = true;
375 }
376 }
377 }
378
379 if !updated {
381 break;
382 }
383 }
384
385 for &(u, v, weight) in &edges {
387 if distances[u] != T::infinity() && distances[u] + weight < distances[v] {
388 return Err(SparseError::ValueError(
389 "Graph contains negative cycles".to_string(),
390 ));
391 }
392 }
393
394 Ok((distances, predecessors))
395}
396
397#[allow(dead_code)]
399pub fn floyd_warshall<T, S>(
400 graph: &S,
401 directed: bool,
402 returnpredecessors: bool,
403) -> SparseResult<(Array2<T>, Option<Array2<isize>>)>
404where
405 T: Float + Debug + Copy + 'static,
406 S: SparseArray<T>,
407{
408 let n = num_vertices(graph);
409
410 let mut distances = Array2::from_elem((n, n), T::infinity());
412 let mut predecessors = if returnpredecessors {
413 Some(Array2::from_elem((n, n), -1isize))
414 } else {
415 None
416 };
417
418 for i in 0..n {
420 distances[[i, i]] = T::zero();
421 }
422
423 let (row_indices, col_indices, values) = graph.find();
425 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
426 let weight = values[i];
427 if !weight.is_zero() {
428 distances[[row, col]] = weight;
429
430 if let Some(ref mut preds) = predecessors {
431 if row != col {
432 preds[[row, col]] = row as isize;
433 }
434 }
435
436 if !directed && row != col {
438 distances[[col, row]] = weight;
439
440 if let Some(ref mut preds) = predecessors {
441 preds[[col, row]] = col as isize;
442 }
443 }
444 }
445 }
446
447 for k in 0..n {
449 for i in 0..n {
450 for j in 0..n {
451 let through_k = distances[[i, k]] + distances[[k, j]];
452
453 if through_k < distances[[i, j]] {
454 distances[[i, j]] = through_k;
455
456 if let Some(ref mut preds) = predecessors {
457 preds[[i, j]] = preds[[k, j]];
458 }
459 }
460 }
461 }
462 }
463
464 Ok((distances, predecessors))
465}
466
467#[allow(dead_code)]
469pub fn reconstruct_path(
470 predecessors: &Array1<isize>,
471 source: usize,
472 target: usize,
473) -> SparseResult<Vec<usize>> {
474 let mut path = Vec::new();
475 let mut current = target;
476
477 if predecessors[target] == -1 && source != target {
479 return Ok(path); }
481
482 while current != source {
483 path.push(current);
484
485 let pred = predecessors[current];
486 if pred == -1 {
487 return Err(SparseError::ValueError(
488 "Invalid predecessor information".to_string(),
489 ));
490 }
491
492 current = pred as usize;
493 }
494
495 path.push(source);
496 path.reverse();
497
498 Ok(path)
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504 use crate::csr_array::CsrArray;
505 use approx::assert_relative_eq;
506
507 fn create_test_graph() -> CsrArray<f64> {
508 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
515 let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
516 let data = vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 3.0, 1.0];
517
518 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
519 }
520
521 #[test]
522 fn test_dijkstra_single_source() {
523 let graph = create_test_graph();
524 let (distances_, _) = dijkstra_single_source(&graph, 0, false, false).unwrap();
525
526 assert_relative_eq!(distances_[0], 0.0);
527 assert_relative_eq!(distances_[1], 1.0);
528 assert_relative_eq!(distances_[2], 2.0);
529 assert_relative_eq!(distances_[3], 3.0); }
531
532 #[test]
533 fn test_dijkstra_withpredecessors() {
534 let graph = create_test_graph();
535 let (_distances, predecessors) = dijkstra_single_source(&graph, 0, false, true).unwrap();
536 let preds = predecessors.unwrap();
537
538 assert_eq!(preds[0], -1); assert_eq!(preds[1], 0); assert_eq!(preds[2], 0); assert_eq!(preds[3], 2); }
543
544 #[test]
545 fn test_bellman_ford() {
546 let graph = create_test_graph();
547 let (distances_, _) = bellman_ford_single_source(&graph, 0, false, false).unwrap();
548
549 assert_relative_eq!(distances_[0], 0.0);
550 assert_relative_eq!(distances_[1], 1.0);
551 assert_relative_eq!(distances_[2], 2.0);
552 assert_relative_eq!(distances_[3], 3.0);
553 }
554
555 #[test]
556 fn test_floyd_warshall() {
557 let graph = create_test_graph();
558 let (distances_, _) = floyd_warshall(&graph, false, false).unwrap();
559
560 assert_relative_eq!(distances_[[0, 0]], 0.0);
562 assert_relative_eq!(distances_[[0, 1]], 1.0);
563 assert_relative_eq!(distances_[[0, 2]], 2.0);
564 assert_relative_eq!(distances_[[0, 3]], 3.0);
565
566 assert_relative_eq!(distances_[[1, 0]], 1.0);
568 assert_relative_eq!(distances_[[1, 1]], 0.0);
569 assert_relative_eq!(distances_[[1, 2]], 3.0); assert_relative_eq!(distances_[[1, 3]], 3.0); }
572
573 #[test]
574 fn test_shortest_path_api() {
575 let graph = create_test_graph();
576
577 let (distances_, _) =
579 shortest_path(&graph, Some(0), None, "dijkstra", false, false).unwrap();
580 assert_relative_eq!(distances_[[0, 1]], 1.0);
581 assert_relative_eq!(distances_[[0, 3]], 3.0);
582
583 let (distance, _) = shortest_path(&graph, Some(0), Some(3), "auto", false, false).unwrap();
585 assert_relative_eq!(distance[[0, 0]], 3.0);
586 }
587
588 #[test]
589 fn test_reconstruct_path() {
590 let graph = create_test_graph();
591 let (_, predecessors) = dijkstra_single_source(&graph, 0, false, true).unwrap();
592 let preds = predecessors.unwrap();
593
594 let path = reconstruct_path(&preds, 0, 3).unwrap();
595 assert_eq!(path, vec![0, 2, 3]); }
597
598 #[test]
599 fn test_negative_cycle_detection() {
600 let rows = vec![0, 1, 2];
602 let cols = vec![1, 2, 0];
603 let data = vec![1.0, 1.0, -3.0]; let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
606
607 let result = bellman_ford_single_source(&graph, 0, true, false);
609 assert!(result.is_err());
610 }
611}