scirs2_sparse/csgraph/
traversal.rs

1//! Graph traversal algorithms for sparse graphs
2//!
3//! This module provides breadth-first search (BFS) and depth-first search (DFS)
4//! algorithms for sparse matrices representing graphs.
5
6use super::{num_vertices, to_adjacency_list, validate_graph};
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::numeric::Float;
11use std::collections::VecDeque;
12use std::fmt::Debug;
13
14/// Traversal order types
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum TraversalOrder {
17    /// Breadth-first search
18    BreadthFirst,
19    /// Depth-first search
20    DepthFirst,
21}
22
23impl TraversalOrder {
24    #[allow(clippy::should_implement_trait)]
25    pub fn from_str(s: &str) -> SparseResult<Self> {
26        match s.to_lowercase().as_str() {
27            "breadth_first" | "bfs" | "breadth-first" => Ok(Self::BreadthFirst),
28            "depth_first" | "dfs" | "depth-first" => Ok(Self::DepthFirst),
29            _ => Err(SparseError::ValueError(format!(
30                "Unknown traversal order: {s}"
31            ))),
32        }
33    }
34}
35
36/// Perform graph traversal from a starting vertex
37///
38/// # Arguments
39///
40/// * `graph` - The graph as a sparse matrix
41/// * `start` - Starting vertex
42/// * `directed` - Whether the graph is directed
43/// * `order` - Traversal order (BFS or DFS)
44/// * `return_predecessors` - Whether to return predecessor information
45///
46/// # Returns
47///
48/// A tuple containing:
49/// - Traversal order as a vector of vertex indices
50/// - Optional predecessor array
51///
52/// # Examples
53///
54/// ```
55/// use scirs2_sparse::csgraph::traversegraph;
56/// use scirs2_sparse::csr_array::CsrArray;
57///
58/// // Create a simple graph
59/// let rows = vec![0, 1, 1, 2];
60/// let cols = vec![1, 0, 2, 1];
61/// let data = vec![1.0, 1.0, 1.0, 1.0];
62/// let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
63///
64/// // Perform BFS from vertex 0
65/// let (order, _) = traversegraph(&graph, 0, false, "bfs", false).unwrap();
66/// ```
67#[allow(dead_code)]
68pub fn traversegraph<T, S>(
69    graph: &S,
70    start: usize,
71    directed: bool,
72    order: &str,
73    return_predecessors: bool,
74) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
75where
76    T: Float + Debug + Copy + 'static,
77    S: SparseArray<T>,
78{
79    validate_graph(graph, directed)?;
80    let n = num_vertices(graph);
81
82    if start >= n {
83        return Err(SparseError::ValueError(format!(
84            "Start vertex {start} out of bounds for graph with {n} vertices"
85        )));
86    }
87
88    let traversal_order = TraversalOrder::from_str(order)?;
89
90    match traversal_order {
91        TraversalOrder::BreadthFirst => {
92            breadth_first_search(graph, start, directed, return_predecessors)
93        }
94        TraversalOrder::DepthFirst => {
95            depth_first_search(graph, start, directed, return_predecessors)
96        }
97    }
98}
99
100/// Breadth-first search traversal
101#[allow(dead_code)]
102pub fn breadth_first_search<T, S>(
103    graph: &S,
104    start: usize,
105    directed: bool,
106    return_predecessors: bool,
107) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
108where
109    T: Float + Debug + Copy + 'static,
110    S: SparseArray<T>,
111{
112    let n = num_vertices(graph);
113    let adj_list = to_adjacency_list(graph, directed)?;
114
115    let mut visited = vec![false; n];
116    let mut queue = VecDeque::new();
117    let mut traversal_order = Vec::new();
118    let mut predecessors = if return_predecessors {
119        Some(Array1::from_elem(n, -1isize))
120    } else {
121        None
122    };
123
124    // Start BFS from the given vertex
125    queue.push_back(start);
126    visited[start] = true;
127
128    while let Some(current) = queue.pop_front() {
129        traversal_order.push(current);
130
131        // Visit all unvisited neighbors
132        for &(neighbor, _) in &adj_list[current] {
133            if !visited[neighbor] {
134                visited[neighbor] = true;
135                queue.push_back(neighbor);
136
137                if let Some(ref mut preds) = predecessors {
138                    preds[neighbor] = current as isize;
139                }
140            }
141        }
142    }
143
144    Ok((traversal_order, predecessors))
145}
146
147/// Depth-first search traversal
148#[allow(dead_code)]
149pub fn depth_first_search<T, S>(
150    graph: &S,
151    start: usize,
152    directed: bool,
153    return_predecessors: bool,
154) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
155where
156    T: Float + Debug + Copy + 'static,
157    S: SparseArray<T>,
158{
159    let n = num_vertices(graph);
160    let adj_list = to_adjacency_list(graph, directed)?;
161
162    let mut visited = vec![false; n];
163    let mut stack = Vec::new();
164    let mut traversal_order = Vec::new();
165    let mut predecessors = if return_predecessors {
166        Some(Array1::from_elem(n, -1isize))
167    } else {
168        None
169    };
170
171    // Start DFS from the given vertex
172    stack.push(start);
173
174    while let Some(current) = stack.pop() {
175        if visited[current] {
176            continue;
177        }
178
179        visited[current] = true;
180        traversal_order.push(current);
181
182        // Add all unvisited neighbor_s to the stack (in reverse order for consistent ordering)
183        let mut neighbor_s: Vec<_> = adj_list[current]
184            .iter()
185            .filter(|&(neighbor_, _)| !visited[*neighbor_])
186            .collect();
187        neighbor_s.reverse(); // Reverse to maintain left-to-right order when popping
188
189        for &(neighbor_, _) in neighbor_s {
190            if !visited[neighbor_] {
191                stack.push(neighbor_);
192
193                if let Some(ref mut preds) = predecessors {
194                    if preds[neighbor_] == -1 {
195                        preds[neighbor_] = current as isize;
196                    }
197                }
198            }
199        }
200    }
201
202    Ok((traversal_order, predecessors))
203}
204
205/// Recursive depth-first search traversal
206#[allow(dead_code)]
207pub fn depth_first_search_recursive<T, S>(
208    graph: &S,
209    start: usize,
210    directed: bool,
211    return_predecessors: bool,
212) -> SparseResult<(Vec<usize>, Option<Array1<isize>>)>
213where
214    T: Float + Debug + Copy + 'static,
215    S: SparseArray<T>,
216{
217    let n = num_vertices(graph);
218    let adj_list = to_adjacency_list(graph, directed)?;
219
220    let mut visited = vec![false; n];
221    let mut traversal_order = Vec::new();
222    let mut predecessors = if return_predecessors {
223        Some(Array1::from_elem(n, -1isize))
224    } else {
225        None
226    };
227
228    dfs_recursive_helper::<T>(
229        start,
230        &adj_list,
231        &mut visited,
232        &mut traversal_order,
233        &mut predecessors,
234    );
235
236    Ok((traversal_order, predecessors))
237}
238
239/// Helper function for recursive DFS
240#[allow(dead_code)]
241fn dfs_recursive_helper<T>(
242    node: usize,
243    adj_list: &[Vec<(usize, T)>],
244    visited: &mut [bool],
245    traversal_order: &mut Vec<usize>,
246    predecessors: &mut Option<Array1<isize>>,
247) where
248    T: Float + Debug + Copy + 'static,
249{
250    visited[node] = true;
251    traversal_order.push(node);
252
253    for &(neighbor_, _) in &adj_list[node] {
254        if !visited[neighbor_] {
255            if let Some(ref mut preds) = predecessors {
256                preds[neighbor_] = node as isize;
257            }
258            dfs_recursive_helper(neighbor_, adj_list, visited, traversal_order, predecessors);
259        }
260    }
261}
262
263/// Compute distances from a source vertex using BFS
264///
265/// # Arguments
266///
267/// * `graph` - The graph as a sparse matrix (unweighted)
268/// * `start` - Starting vertex
269/// * `directed` - Whether the graph is directed
270///
271/// # Returns
272///
273/// Array of distances from the start vertex to all other vertices
274///
275/// # Examples
276///
277/// ```
278/// use scirs2_sparse::csgraph::bfs_distances;
279/// use scirs2_sparse::csr_array::CsrArray;
280///
281/// let rows = vec![0, 1, 1, 2];
282/// let cols = vec![1, 0, 2, 1];
283/// let data = vec![1.0, 1.0, 1.0, 1.0];
284/// let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
285///
286/// let distances = bfs_distances(&graph, 0, false).unwrap();
287/// ```
288#[allow(dead_code)]
289pub fn bfs_distances<T, S>(graph: &S, start: usize, directed: bool) -> SparseResult<Array1<isize>>
290where
291    T: Float + Debug + Copy + 'static,
292    S: SparseArray<T>,
293{
294    let n = num_vertices(graph);
295    let adj_list = to_adjacency_list(graph, directed)?;
296
297    if start >= n {
298        return Err(SparseError::ValueError(format!(
299            "Start vertex {start} out of bounds for graph with {n} vertices"
300        )));
301    }
302
303    let mut distances = Array1::from_elem(n, -1isize);
304    let mut queue = VecDeque::new();
305
306    // Start BFS
307    distances[start] = 0;
308    queue.push_back(start);
309
310    while let Some(current) = queue.pop_front() {
311        let current_distance = distances[current];
312
313        for &(neighbor_, _) in &adj_list[current] {
314            if distances[neighbor_] == -1 {
315                distances[neighbor_] = current_distance + 1;
316                queue.push_back(neighbor_);
317            }
318        }
319    }
320
321    Ok(distances)
322}
323
324/// Check if there is a path between two vertices
325///
326/// # Arguments
327///
328/// * `graph` - The graph as a sparse matrix
329/// * `source` - Source vertex
330/// * `target` - Target vertex
331/// * `directed` - Whether the graph is directed
332///
333/// # Returns
334///
335/// True if there is a path from source to target, false otherwise
336///
337/// # Examples
338///
339/// ```
340/// use scirs2_sparse::csgraph::has_path;
341/// use scirs2_sparse::csr_array::CsrArray;
342///
343/// let rows = vec![0, 1, 1, 2];
344/// let cols = vec![1, 0, 2, 1];
345/// let data = vec![1.0, 1.0, 1.0, 1.0];
346/// let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
347///
348/// assert!(has_path(&graph, 0, 2, false).unwrap());
349/// ```
350#[allow(dead_code)]
351pub fn has_path<T, S>(graph: &S, source: usize, target: usize, directed: bool) -> SparseResult<bool>
352where
353    T: Float + Debug + Copy + 'static,
354    S: SparseArray<T>,
355{
356    let n = num_vertices(graph);
357
358    if source >= n || target >= n {
359        return Err(SparseError::ValueError(format!(
360            "Vertex index out of bounds for graph with {n} vertices"
361        )));
362    }
363
364    if source == target {
365        return Ok(true);
366    }
367
368    let (traversal_order, _) = breadth_first_search(graph, source, directed, false)?;
369    Ok(traversal_order.contains(&target))
370}
371
372/// Find all vertices reachable from a source vertex
373///
374/// # Arguments
375///
376/// * `graph` - The graph as a sparse matrix
377/// * `source` - Source vertex
378/// * `directed` - Whether the graph is directed
379///
380/// # Returns
381///
382/// Vector of all vertices reachable from the source
383///
384/// # Examples
385///
386/// ```
387/// use scirs2_sparse::csgraph::reachable_vertices;
388/// use scirs2_sparse::csr_array::CsrArray;
389///
390/// let rows = vec![0, 1, 1, 2];
391/// let cols = vec![1, 0, 2, 1];
392/// let data = vec![1.0, 1.0, 1.0, 1.0];
393/// let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
394///
395/// let reachable = reachable_vertices(&graph, 0, false).unwrap();
396/// ```
397#[allow(dead_code)]
398pub fn reachable_vertices<T, S>(
399    graph: &S,
400    source: usize,
401    directed: bool,
402) -> SparseResult<Vec<usize>>
403where
404    T: Float + Debug + Copy + 'static,
405    S: SparseArray<T>,
406{
407    let (traversal_order, _) = breadth_first_search(graph, source, directed, false)?;
408    Ok(traversal_order)
409}
410
411/// Topological sort of a directed acyclic graph (DAG)
412///
413/// # Arguments
414///
415/// * `graph` - The directed graph as a sparse matrix
416///
417/// # Returns
418///
419/// Topologically sorted order of vertices, or error if the graph has cycles
420///
421/// # Examples
422///
423/// ```
424/// use scirs2_sparse::csgraph::topological_sort;
425/// use scirs2_sparse::csr_array::CsrArray;
426///
427/// // Create a DAG: 0 -> 1 -> 2
428/// let rows = vec![0, 1];
429/// let cols = vec![1, 2];
430/// let data = vec![1.0, 1.0];
431/// let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
432///
433/// let topo_order = topological_sort(&graph).unwrap();
434/// ```
435#[allow(dead_code)]
436pub fn topological_sort<T, S>(graph: &S) -> SparseResult<Vec<usize>>
437where
438    T: Float + Debug + Copy + 'static,
439    S: SparseArray<T>,
440{
441    let n = num_vertices(graph);
442    let adj_list = to_adjacency_list(graph, true)?; // Must be directed
443
444    // Compute in-degrees
445    let mut in_degree = vec![0; n];
446    for adj in &adj_list {
447        for &(neighbor_, _) in adj {
448            in_degree[neighbor_] += 1;
449        }
450    }
451
452    // Initialize queue with vertices having in-degree 0
453    let mut queue = VecDeque::new();
454    for (vertex, &degree) in in_degree.iter().enumerate() {
455        if degree == 0 {
456            queue.push_back(vertex);
457        }
458    }
459
460    let mut topo_order = Vec::new();
461
462    while let Some(vertex) = queue.pop_front() {
463        topo_order.push(vertex);
464
465        // Remove this vertex and update in-degrees of its neighbor_s
466        for &(neighbor_, _) in &adj_list[vertex] {
467            in_degree[neighbor_] -= 1;
468            if in_degree[neighbor_] == 0 {
469                queue.push_back(neighbor_);
470            }
471        }
472    }
473
474    // Check if all vertices were processed (no cycles)
475    if topo_order.len() != n {
476        return Err(SparseError::ValueError(
477            "Graph contains cycles - topological sort not possible".to_string(),
478        ));
479    }
480
481    Ok(topo_order)
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use crate::csr_array::CsrArray;
488
489    fn create_testgraph() -> CsrArray<f64> {
490        // Create a simple connected graph:
491        //   0 -- 1
492        //   |    |
493        //   2 -- 3
494        let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
495        let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
496        let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
497
498        CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
499    }
500
501    fn create_dag() -> CsrArray<f64> {
502        // Create a DAG: 0 -> 1 -> 3, 0 -> 2 -> 3
503        let rows = vec![0, 0, 1, 2];
504        let cols = vec![1, 2, 3, 3];
505        let data = vec![1.0, 1.0, 1.0, 1.0];
506
507        CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
508    }
509
510    #[test]
511    fn test_bfs() {
512        let graph = create_testgraph();
513        let (order, predecessors) = breadth_first_search(&graph, 0, false, true).unwrap();
514
515        // Should visit all vertices
516        assert_eq!(order.len(), 4);
517        assert!(order.contains(&0));
518        assert!(order.contains(&1));
519        assert!(order.contains(&2));
520        assert!(order.contains(&3));
521
522        // First vertex should be the start
523        assert_eq!(order[0], 0);
524
525        // Check predecessors
526        let preds = predecessors.unwrap();
527        assert_eq!(preds[0], -1); // Start vertex has no predecessor
528        assert!(preds[1] == 0); // 1's predecessor should be 0
529        assert!(preds[2] == 0); // 2's predecessor should be 0
530    }
531
532    #[test]
533    fn test_dfs() {
534        let graph = create_testgraph();
535        let (order, _) = depth_first_search(&graph, 0, false, false).unwrap();
536
537        // Should visit all vertices
538        assert_eq!(order.len(), 4);
539        assert!(order.contains(&0));
540        assert!(order.contains(&1));
541        assert!(order.contains(&2));
542        assert!(order.contains(&3));
543
544        // First vertex should be the start
545        assert_eq!(order[0], 0);
546    }
547
548    #[test]
549    fn test_dfs_recursive() {
550        let graph = create_testgraph();
551        let (order, _) = depth_first_search_recursive(&graph, 0, false, false).unwrap();
552
553        // Should visit all vertices
554        assert_eq!(order.len(), 4);
555        assert!(order.contains(&0));
556        assert!(order.contains(&1));
557        assert!(order.contains(&2));
558        assert!(order.contains(&3));
559
560        // First vertex should be the start
561        assert_eq!(order[0], 0);
562    }
563
564    #[test]
565    fn test_traversegraph_api() {
566        let graph = create_testgraph();
567
568        // Test BFS
569        let (bfs_order, _) = traversegraph(&graph, 0, false, "bfs", false).unwrap();
570        assert_eq!(bfs_order[0], 0);
571        assert_eq!(bfs_order.len(), 4);
572
573        // Test DFS
574        let (dfs_order, _) = traversegraph(&graph, 0, false, "dfs", false).unwrap();
575        assert_eq!(dfs_order[0], 0);
576        assert_eq!(dfs_order.len(), 4);
577    }
578
579    #[test]
580    fn test_bfs_distances() {
581        let graph = create_testgraph();
582        let distances = bfs_distances(&graph, 0, false).unwrap();
583
584        assert_eq!(distances[0], 0); // Distance to self is 0
585        assert_eq!(distances[1], 1); // Direct neighbor_
586        assert_eq!(distances[2], 1); // Direct neighbor_
587        assert_eq!(distances[3], 2); // Via 1 or 2
588    }
589
590    #[test]
591    fn test_has_path() {
592        let graph = create_testgraph();
593
594        // All vertices are connected
595        assert!(has_path(&graph, 0, 3, false).unwrap());
596        assert!(has_path(&graph, 1, 2, false).unwrap());
597        assert!(has_path(&graph, 0, 0, false).unwrap()); // Self path
598
599        // Test disconnected graph
600        let rows = vec![0, 2];
601        let cols = vec![1, 3];
602        let data = vec![1.0, 1.0];
603        let disconnected = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
604
605        assert!(has_path(&disconnected, 0, 1, false).unwrap());
606        assert!(!has_path(&disconnected, 0, 2, false).unwrap());
607    }
608
609    #[test]
610    fn test_reachable_vertices() {
611        let graph = create_testgraph();
612        let reachable = reachable_vertices(&graph, 0, false).unwrap();
613
614        // All vertices should be reachable
615        assert_eq!(reachable.len(), 4);
616        assert!(reachable.contains(&0));
617        assert!(reachable.contains(&1));
618        assert!(reachable.contains(&2));
619        assert!(reachable.contains(&3));
620    }
621
622    #[test]
623    fn test_topological_sort() {
624        let dag = create_dag();
625        let topo_order = topological_sort(&dag).unwrap();
626
627        assert_eq!(topo_order.len(), 4);
628
629        // 0 should come before 1 and 2
630        let pos_0 = topo_order.iter().position(|&x| x == 0).unwrap();
631        let pos_1 = topo_order.iter().position(|&x| x == 1).unwrap();
632        let pos_2 = topo_order.iter().position(|&x| x == 2).unwrap();
633        let pos_3 = topo_order.iter().position(|&x| x == 3).unwrap();
634
635        assert!(pos_0 < pos_1);
636        assert!(pos_0 < pos_2);
637        assert!(pos_1 < pos_3);
638        assert!(pos_2 < pos_3);
639    }
640
641    #[test]
642    fn test_topological_sort_cycle() {
643        // Create a graph with a cycle: 0 -> 1 -> 2 -> 0
644        let rows = vec![0, 1, 2];
645        let cols = vec![1, 2, 0];
646        let data = vec![1.0, 1.0, 1.0];
647        let cyclic = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
648
649        // Should fail due to cycle
650        assert!(topological_sort(&cyclic).is_err());
651    }
652
653    #[test]
654    fn test_invalid_start_vertex() {
655        let graph = create_testgraph();
656
657        // Test out of bounds start vertex
658        assert!(traversegraph(&graph, 10, false, "bfs", false).is_err());
659        assert!(bfs_distances(&graph, 10, false).is_err());
660    }
661}