scirs2_sparse/csgraph/
mod.rs

1//! Compressed sparse graph algorithms module
2//!
3//! This module provides graph algorithms optimized for sparse matrices,
4//! similar to SciPy's `sparse.csgraph` module.
5//!
6//! ## Features
7//!
8//! * Shortest path algorithms (Dijkstra, Bellman-Ford)
9//! * Connected components analysis  
10//! * Graph traversal utilities (BFS, DFS)
11//! * Laplacian matrix computation
12//! * Minimum spanning tree algorithms
13//! * Graph connectivity testing
14//!
15//! ## Examples
16//!
17//! ### Shortest Path
18//!
19//! ```
20//! use scirs2_sparse::csgraph::shortest_path;
21//! use scirs2_sparse::csr_array::CsrArray;
22//!
23//! // Create a graph as a sparse matrix
24//! let rows = vec![0, 0, 1, 2];
25//! let cols = vec![1, 2, 2, 0];
26//! let data = vec![1.0, 4.0, 2.0, 3.0];
27//! let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
28//!
29//! // Find shortest paths from vertex 0
30//! let distances = shortest_path(&graph, Some(0), None, "dijkstra", true, false).unwrap();
31//! ```
32//!
33//! ### Connected Components
34//!
35//! ```
36//! use scirs2_sparse::csgraph::connected_components;
37//! use scirs2_sparse::csr_array::CsrArray;
38//!
39//! // Create a graph
40//! let rows = vec![0, 1, 2, 3];
41//! let cols = vec![1, 0, 3, 2];
42//! let data = vec![1.0, 1.0, 1.0, 1.0];
43//! let graph = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
44//!
45//! // Find connected components
46//! let (n_components, labels) = connected_components(&graph, false, "weak", true).unwrap();
47//! ```
48
49use crate::error::{SparseError, SparseResult};
50use crate::sparray::SparseArray;
51use scirs2_core::numeric::Float;
52use std::cmp::Ordering;
53use std::fmt::Debug;
54
55pub mod connected_components;
56pub mod laplacian;
57pub mod minimum_spanning_tree;
58pub mod shortest_path;
59pub mod traversal;
60
61pub use connected_components::*;
62pub use laplacian::*;
63pub use minimum_spanning_tree::*;
64pub use shortest_path::*;
65pub use traversal::*;
66
67/// Graph representation modes
68#[derive(Debug, Clone, Copy, PartialEq)]
69pub enum GraphMode {
70    /// Treat the matrix as a directed graph
71    Directed,
72    /// Treat the matrix as an undirected graph
73    Undirected,
74}
75
76/// Priority queue element for graph algorithms
77#[derive(Debug, Clone)]
78struct PriorityQueueNode<T>
79where
80    T: Float + PartialOrd,
81{
82    distance: T,
83    node: usize,
84}
85
86impl<T> PartialEq for PriorityQueueNode<T>
87where
88    T: Float + PartialOrd,
89{
90    fn eq(&self, other: &Self) -> bool {
91        self.distance == other.distance && self.node == other.node
92    }
93}
94
95impl<T> Eq for PriorityQueueNode<T> where T: Float + PartialOrd {}
96
97impl<T> PartialOrd for PriorityQueueNode<T>
98where
99    T: Float + PartialOrd,
100{
101    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
102        Some(self.cmp(other))
103    }
104}
105
106impl<T> Ord for PriorityQueueNode<T>
107where
108    T: Float + PartialOrd,
109{
110    fn cmp(&self, other: &Self) -> Ordering {
111        // Reverse the ordering for min-heap behavior
112        other
113            .distance
114            .partial_cmp(&self.distance)
115            .unwrap_or(Ordering::Equal)
116    }
117}
118
119/// Check if a sparse matrix represents a valid graph
120///
121/// # Arguments
122///
123/// * `matrix` - The sparse matrix to check
124/// * `directed` - Whether the graph is directed
125///
126/// # Returns
127///
128/// Result indicating if the matrix is a valid graph
129#[allow(dead_code)]
130pub fn validate_graph<T, S>(matrix: &S, directed: bool) -> SparseResult<()>
131where
132    T: Float + Debug + Copy + 'static,
133    S: SparseArray<T>,
134{
135    let (rows, cols) = matrix.shape();
136
137    // Graph matrices must be square
138    if rows != cols {
139        return Err(SparseError::ValueError(
140            "Graph _matrix must be square".to_string(),
141        ));
142    }
143
144    // Check for negative weights (not allowed in some algorithms)
145    let (row_indices, col_indices, values) = matrix.find();
146    for &value in values.iter() {
147        if value < T::zero() {
148            return Err(SparseError::ValueError(
149                "Negative edge weights not supported".to_string(),
150            ));
151        }
152    }
153
154    // For undirected graphs, check symmetry
155    if !directed {
156        for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
157            if row != col {
158                let weight = values[i];
159                let reverse_weight = matrix.get(col, row);
160
161                if (weight - reverse_weight).abs() > T::from(1e-10).unwrap() {
162                    return Err(SparseError::ValueError(
163                        "Undirected graph _matrix must be symmetric".to_string(),
164                    ));
165                }
166            }
167        }
168    }
169
170    Ok(())
171}
172
173/// Convert a sparse matrix to adjacency list representation
174///
175/// # Arguments
176///
177/// * `matrix` - The sparse matrix
178/// * `directed` - Whether the graph is directed
179///
180/// # Returns
181///
182/// Adjacency list as a vector of vectors of (neighbor, weight) pairs
183#[allow(dead_code)]
184pub fn to_adjacency_list<T, S>(matrix: &S, directed: bool) -> SparseResult<Vec<Vec<(usize, T)>>>
185where
186    T: Float + Debug + Copy + 'static,
187    S: SparseArray<T>,
188{
189    let (n_, _) = matrix.shape();
190    let mut adj_list = vec![Vec::new(); n_];
191
192    let (row_indices, col_indices, values) = matrix.find();
193
194    for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
195        let weight = values[i];
196
197        if !weight.is_zero() {
198            adj_list[row].push((col, weight));
199
200            // For undirected graphs, add the reverse edge only if it doesn't already exist
201            if !directed && row != col {
202                // Check if the reverse edge already exists in the _matrix
203                let reverse_exists = row_indices
204                    .iter()
205                    .zip(col_indices.iter())
206                    .any(|(r, c)| *r == col && *c == row);
207
208                if !reverse_exists {
209                    adj_list[col].push((row, weight));
210                }
211            }
212        }
213    }
214
215    Ok(adj_list)
216}
217
218/// Get the number of vertices in a graph matrix
219#[allow(dead_code)]
220pub fn num_vertices<T, S>(matrix: &S) -> usize
221where
222    T: Float + Debug + Copy + 'static,
223    S: SparseArray<T>,
224{
225    matrix.shape().0
226}
227
228/// Get the number of edges in a graph matrix
229#[allow(dead_code)]
230pub fn num_edges<T, S>(matrix: &S, directed: bool) -> SparseResult<usize>
231where
232    T: Float + Debug + Copy + 'static,
233    S: SparseArray<T>,
234{
235    let nnz = matrix.nnz();
236
237    if directed {
238        Ok(nnz)
239    } else {
240        // For undirected graphs, count diagonal elements once and off-diagonal elements half
241        let (row_indices, col_indices_, _) = matrix.find();
242        let mut diagonal_count = 0;
243        let mut off_diagonal_count = 0;
244
245        for (&row, &col) in row_indices.iter().zip(col_indices_.iter()) {
246            if row == col {
247                diagonal_count += 1;
248            } else {
249                off_diagonal_count += 1;
250            }
251        }
252
253        // Off-diagonal edges are counted twice in the _matrix (i,j) and (j,i)
254        Ok(diagonal_count + off_diagonal_count / 2)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::csr_array::CsrArray;
262
263    fn create_test_graph() -> CsrArray<f64> {
264        // Create a simple 4-vertex graph:
265        //   0 -- 1
266        //   |    |
267        //   2 -- 3
268        let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
269        let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
270        let data = vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 3.0, 1.0];
271
272        CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
273    }
274
275    #[test]
276    fn test_validate_graph_symmetric() {
277        let graph = create_test_graph();
278
279        // Should be valid as undirected graph
280        assert!(validate_graph(&graph, false).is_ok());
281
282        // Should also be valid as directed graph
283        assert!(validate_graph(&graph, true).is_ok());
284    }
285
286    #[test]
287    fn test_validate_graph_asymmetric() {
288        // Create an asymmetric graph
289        let rows = vec![0, 1];
290        let cols = vec![1, 0];
291        let data = vec![1.0, 2.0]; // Different weights
292
293        let graph = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
294
295        // Should be valid as directed graph
296        assert!(validate_graph(&graph, true).is_ok());
297
298        // Should fail as undirected graph due to asymmetry
299        assert!(validate_graph(&graph, false).is_err());
300    }
301
302    #[test]
303    fn test_validate_graph_negative_weights() {
304        let rows = vec![0, 1];
305        let cols = vec![1, 0];
306        let data = vec![-1.0, 1.0]; // Negative weight
307
308        let graph = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
309
310        // Should fail due to negative weights
311        assert!(validate_graph(&graph, true).is_err());
312        assert!(validate_graph(&graph, false).is_err());
313    }
314
315    #[test]
316    fn test_to_adjacency_list() {
317        let graph = create_test_graph();
318        let adj_list = to_adjacency_list(&graph, false).unwrap();
319
320        assert_eq!(adj_list.len(), 4);
321
322        // Vertex 0 should be connected to 1 and 2
323        assert_eq!(adj_list[0].len(), 2);
324        assert!(adj_list[0].contains(&(1, 1.0)));
325        assert!(adj_list[0].contains(&(2, 2.0)));
326
327        // Vertex 1 should be connected to 0 and 3
328        assert_eq!(adj_list[1].len(), 2);
329        assert!(adj_list[1].contains(&(0, 1.0)));
330        assert!(adj_list[1].contains(&(3, 3.0)));
331    }
332
333    #[test]
334    fn test_num_vertices() {
335        let graph = create_test_graph();
336        assert_eq!(num_vertices(&graph), 4);
337    }
338
339    #[test]
340    fn test_num_edges() {
341        let graph = create_test_graph();
342
343        // Directed: all 8 edges
344        assert_eq!(num_edges(&graph, true).unwrap(), 8);
345
346        // Undirected: 4 unique edges
347        assert_eq!(num_edges(&graph, false).unwrap(), 4);
348    }
349}