1use super::{num_vertices, to_adjacency_list, validate_graph};
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::numeric::Float;
11use std::collections::VecDeque;
12use std::fmt::Debug;
13
14#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum TraversalOrder {
17 BreadthFirst,
19 DepthFirst,
21}
22
23impl TraversalOrder {
24 #[allow(clippy::should_implement_trait)]
25 pub fn from_str(s: &str) -> SparseResult<Self> {
26 match s.to_lowercase().as_str() {
27 "breadth_first" | "bfs" | "breadth-first" => Ok(Self::BreadthFirst),
28 "depth_first" | "dfs" | "depth-first" => Ok(Self::DepthFirst),
29 _ => Err(SparseError::ValueError(format!(
30 "Unknown traversal order: {s}"
31 ))),
32 }
33 }
34}
35
36#[allow(dead_code)]
68pub fn traversegraph<T, S>(
69 graph: &S,
70 start: usize,
71 directed: bool,
72 order: &str,
73 return_predecessors: bool,
74) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
75where
76 T: Float + Debug + Copy + 'static,
77 S: SparseArray<T>,
78{
79 validate_graph(graph, directed)?;
80 let n = num_vertices(graph);
81
82 if start >= n {
83 return Err(SparseError::ValueError(format!(
84 "Start vertex {start} out of bounds for graph with {n} vertices"
85 )));
86 }
87
88 let traversal_order = TraversalOrder::from_str(order)?;
89
90 match traversal_order {
91 TraversalOrder::BreadthFirst => {
92 breadth_first_search(graph, start, directed, return_predecessors)
93 }
94 TraversalOrder::DepthFirst => {
95 depth_first_search(graph, start, directed, return_predecessors)
96 }
97 }
98}
99
100#[allow(dead_code)]
102pub fn breadth_first_search<T, S>(
103 graph: &S,
104 start: usize,
105 directed: bool,
106 return_predecessors: bool,
107) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
108where
109 T: Float + Debug + Copy + 'static,
110 S: SparseArray<T>,
111{
112 let n = num_vertices(graph);
113 let adj_list = to_adjacency_list(graph, directed)?;
114
115 let mut visited = vec![false; n];
116 let mut queue = VecDeque::new();
117 let mut traversal_order = Vec::new();
118 let mut predecessors = if return_predecessors {
119 Some(Array1::from_elem(n, -1isize))
120 } else {
121 None
122 };
123
124 queue.push_back(start);
126 visited[start] = true;
127
128 while let Some(current) = queue.pop_front() {
129 traversal_order.push(current);
130
131 for &(neighbor, _) in &adj_list[current] {
133 if !visited[neighbor] {
134 visited[neighbor] = true;
135 queue.push_back(neighbor);
136
137 if let Some(ref mut preds) = predecessors {
138 preds[neighbor] = current as isize;
139 }
140 }
141 }
142 }
143
144 Ok((traversal_order, predecessors))
145}
146
147#[allow(dead_code)]
149pub fn depth_first_search<T, S>(
150 graph: &S,
151 start: usize,
152 directed: bool,
153 return_predecessors: bool,
154) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
155where
156 T: Float + Debug + Copy + 'static,
157 S: SparseArray<T>,
158{
159 let n = num_vertices(graph);
160 let adj_list = to_adjacency_list(graph, directed)?;
161
162 let mut visited = vec![false; n];
163 let mut stack = Vec::new();
164 let mut traversal_order = Vec::new();
165 let mut predecessors = if return_predecessors {
166 Some(Array1::from_elem(n, -1isize))
167 } else {
168 None
169 };
170
171 stack.push(start);
173
174 while let Some(current) = stack.pop() {
175 if visited[current] {
176 continue;
177 }
178
179 visited[current] = true;
180 traversal_order.push(current);
181
182 let mut neighbor_s: Vec<_> = adj_list[current]
184 .iter()
185 .filter(|&(neighbor_, _)| !visited[*neighbor_])
186 .collect();
187 neighbor_s.reverse(); for &(neighbor_, _) in neighbor_s {
190 if !visited[neighbor_] {
191 stack.push(neighbor_);
192
193 if let Some(ref mut preds) = predecessors {
194 if preds[neighbor_] == -1 {
195 preds[neighbor_] = current as isize;
196 }
197 }
198 }
199 }
200 }
201
202 Ok((traversal_order, predecessors))
203}
204
205#[allow(dead_code)]
207pub fn depth_first_search_recursive<T, S>(
208 graph: &S,
209 start: usize,
210 directed: bool,
211 return_predecessors: bool,
212) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
213where
214 T: Float + Debug + Copy + 'static,
215 S: SparseArray<T>,
216{
217 let n = num_vertices(graph);
218 let adj_list = to_adjacency_list(graph, directed)?;
219
220 let mut visited = vec![false; n];
221 let mut traversal_order = Vec::new();
222 let mut predecessors = if return_predecessors {
223 Some(Array1::from_elem(n, -1isize))
224 } else {
225 None
226 };
227
228 dfs_recursive_helper::<T>(
229 start,
230 &adj_list,
231 &mut visited,
232 &mut traversal_order,
233 &mut predecessors,
234 );
235
236 Ok((traversal_order, predecessors))
237}
238
239#[allow(dead_code)]
241fn dfs_recursive_helper<T>(
242 node: usize,
243 adj_list: &[Vec<(usize, T)>],
244 visited: &mut [bool],
245 traversal_order: &mut Vec<usize>,
246 predecessors: &mut Option<Array1<isize>>,
247) where
248 T: Float + Debug + Copy + 'static,
249{
250 visited[node] = true;
251 traversal_order.push(node);
252
253 for &(neighbor_, _) in &adj_list[node] {
254 if !visited[neighbor_] {
255 if let Some(ref mut preds) = predecessors {
256 preds[neighbor_] = node as isize;
257 }
258 dfs_recursive_helper(neighbor_, adj_list, visited, traversal_order, predecessors);
259 }
260 }
261}
262
263#[allow(dead_code)]
289pub fn bfs_distances<T, S>(graph: &S, start: usize, directed: bool) -> SparseResult<Array1<isize>>
290where
291 T: Float + Debug + Copy + 'static,
292 S: SparseArray<T>,
293{
294 let n = num_vertices(graph);
295 let adj_list = to_adjacency_list(graph, directed)?;
296
297 if start >= n {
298 return Err(SparseError::ValueError(format!(
299 "Start vertex {start} out of bounds for graph with {n} vertices"
300 )));
301 }
302
303 let mut distances = Array1::from_elem(n, -1isize);
304 let mut queue = VecDeque::new();
305
306 distances[start] = 0;
308 queue.push_back(start);
309
310 while let Some(current) = queue.pop_front() {
311 let current_distance = distances[current];
312
313 for &(neighbor_, _) in &adj_list[current] {
314 if distances[neighbor_] == -1 {
315 distances[neighbor_] = current_distance + 1;
316 queue.push_back(neighbor_);
317 }
318 }
319 }
320
321 Ok(distances)
322}
323
324#[allow(dead_code)]
351pub fn has_path<T, S>(graph: &S, source: usize, target: usize, directed: bool) -> SparseResult<bool>
352where
353 T: Float + Debug + Copy + 'static,
354 S: SparseArray<T>,
355{
356 let n = num_vertices(graph);
357
358 if source >= n || target >= n {
359 return Err(SparseError::ValueError(format!(
360 "Vertex index out of bounds for graph with {n} vertices"
361 )));
362 }
363
364 if source == target {
365 return Ok(true);
366 }
367
368 let (traversal_order, _) = breadth_first_search(graph, source, directed, false)?;
369 Ok(traversal_order.contains(&target))
370}
371
372#[allow(dead_code)]
398pub fn reachable_vertices<T, S>(
399 graph: &S,
400 source: usize,
401 directed: bool,
402) -> SparseResult<Vec<usize>>
403where
404 T: Float + Debug + Copy + 'static,
405 S: SparseArray<T>,
406{
407 let (traversal_order, _) = breadth_first_search(graph, source, directed, false)?;
408 Ok(traversal_order)
409}
410
411#[allow(dead_code)]
436pub fn topological_sort<T, S>(graph: &S) -> SparseResult<Vec<usize>>
437where
438 T: Float + Debug + Copy + 'static,
439 S: SparseArray<T>,
440{
441 let n = num_vertices(graph);
442 let adj_list = to_adjacency_list(graph, true)?; let mut in_degree = vec![0; n];
446 for adj in &adj_list {
447 for &(neighbor_, _) in adj {
448 in_degree[neighbor_] += 1;
449 }
450 }
451
452 let mut queue = VecDeque::new();
454 for (vertex, °ree) in in_degree.iter().enumerate() {
455 if degree == 0 {
456 queue.push_back(vertex);
457 }
458 }
459
460 let mut topo_order = Vec::new();
461
462 while let Some(vertex) = queue.pop_front() {
463 topo_order.push(vertex);
464
465 for &(neighbor_, _) in &adj_list[vertex] {
467 in_degree[neighbor_] -= 1;
468 if in_degree[neighbor_] == 0 {
469 queue.push_back(neighbor_);
470 }
471 }
472 }
473
474 if topo_order.len() != n {
476 return Err(SparseError::ValueError(
477 "Graph contains cycles - topological sort not possible".to_string(),
478 ));
479 }
480
481 Ok(topo_order)
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::csr_array::CsrArray;
488
489 fn create_testgraph() -> CsrArray<f64> {
490 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
495 let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
496 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
497
498 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
499 }
500
501 fn create_dag() -> CsrArray<f64> {
502 let rows = vec![0, 0, 1, 2];
504 let cols = vec![1, 2, 3, 3];
505 let data = vec![1.0, 1.0, 1.0, 1.0];
506
507 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
508 }
509
510 #[test]
511 fn test_bfs() {
512 let graph = create_testgraph();
513 let (order, predecessors) = breadth_first_search(&graph, 0, false, true).unwrap();
514
515 assert_eq!(order.len(), 4);
517 assert!(order.contains(&0));
518 assert!(order.contains(&1));
519 assert!(order.contains(&2));
520 assert!(order.contains(&3));
521
522 assert_eq!(order[0], 0);
524
525 let preds = predecessors.unwrap();
527 assert_eq!(preds[0], -1); assert!(preds[1] == 0); assert!(preds[2] == 0); }
531
532 #[test]
533 fn test_dfs() {
534 let graph = create_testgraph();
535 let (order, _) = depth_first_search(&graph, 0, false, false).unwrap();
536
537 assert_eq!(order.len(), 4);
539 assert!(order.contains(&0));
540 assert!(order.contains(&1));
541 assert!(order.contains(&2));
542 assert!(order.contains(&3));
543
544 assert_eq!(order[0], 0);
546 }
547
548 #[test]
549 fn test_dfs_recursive() {
550 let graph = create_testgraph();
551 let (order, _) = depth_first_search_recursive(&graph, 0, false, false).unwrap();
552
553 assert_eq!(order.len(), 4);
555 assert!(order.contains(&0));
556 assert!(order.contains(&1));
557 assert!(order.contains(&2));
558 assert!(order.contains(&3));
559
560 assert_eq!(order[0], 0);
562 }
563
564 #[test]
565 fn test_traversegraph_api() {
566 let graph = create_testgraph();
567
568 let (bfs_order, _) = traversegraph(&graph, 0, false, "bfs", false).unwrap();
570 assert_eq!(bfs_order[0], 0);
571 assert_eq!(bfs_order.len(), 4);
572
573 let (dfs_order, _) = traversegraph(&graph, 0, false, "dfs", false).unwrap();
575 assert_eq!(dfs_order[0], 0);
576 assert_eq!(dfs_order.len(), 4);
577 }
578
579 #[test]
580 fn test_bfs_distances() {
581 let graph = create_testgraph();
582 let distances = bfs_distances(&graph, 0, false).unwrap();
583
584 assert_eq!(distances[0], 0); assert_eq!(distances[1], 1); assert_eq!(distances[2], 1); assert_eq!(distances[3], 2); }
589
590 #[test]
591 fn test_has_path() {
592 let graph = create_testgraph();
593
594 assert!(has_path(&graph, 0, 3, false).unwrap());
596 assert!(has_path(&graph, 1, 2, false).unwrap());
597 assert!(has_path(&graph, 0, 0, false).unwrap()); let rows = vec![0, 2];
601 let cols = vec![1, 3];
602 let data = vec![1.0, 1.0];
603 let disconnected = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
604
605 assert!(has_path(&disconnected, 0, 1, false).unwrap());
606 assert!(!has_path(&disconnected, 0, 2, false).unwrap());
607 }
608
609 #[test]
610 fn test_reachable_vertices() {
611 let graph = create_testgraph();
612 let reachable = reachable_vertices(&graph, 0, false).unwrap();
613
614 assert_eq!(reachable.len(), 4);
616 assert!(reachable.contains(&0));
617 assert!(reachable.contains(&1));
618 assert!(reachable.contains(&2));
619 assert!(reachable.contains(&3));
620 }
621
622 #[test]
623 fn test_topological_sort() {
624 let dag = create_dag();
625 let topo_order = topological_sort(&dag).unwrap();
626
627 assert_eq!(topo_order.len(), 4);
628
629 let pos_0 = topo_order.iter().position(|&x| x == 0).unwrap();
631 let pos_1 = topo_order.iter().position(|&x| x == 1).unwrap();
632 let pos_2 = topo_order.iter().position(|&x| x == 2).unwrap();
633 let pos_3 = topo_order.iter().position(|&x| x == 3).unwrap();
634
635 assert!(pos_0 < pos_1);
636 assert!(pos_0 < pos_2);
637 assert!(pos_1 < pos_3);
638 assert!(pos_2 < pos_3);
639 }
640
641 #[test]
642 fn test_topological_sort_cycle() {
643 let rows = vec![0, 1, 2];
645 let cols = vec![1, 2, 0];
646 let data = vec![1.0, 1.0, 1.0];
647 let cyclic = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
648
649 assert!(topological_sort(&cyclic).is_err());
651 }
652
653 #[test]
654 fn test_invalid_start_vertex() {
655 let graph = create_testgraph();
656
657 assert!(traversegraph(&graph, 10, false, "bfs", false).is_err());
659 assert!(bfs_distances(&graph, 10, false).is_err());
660 }
661}