1use crate::error::{SparseError, SparseResult};
50use crate::sparray::SparseArray;
51use scirs2_core::numeric::Float;
52use std::cmp::Ordering;
53use std::fmt::Debug;
54
55pub mod connected_components;
56pub mod laplacian;
57pub mod minimum_spanning_tree;
58pub mod shortest_path;
59pub mod traversal;
60
61pub use connected_components::*;
62pub use laplacian::*;
63pub use minimum_spanning_tree::*;
64pub use shortest_path::*;
65pub use traversal::*;
66
67#[derive(Debug, Clone, Copy, PartialEq)]
69pub enum GraphMode {
70 Directed,
72 Undirected,
74}
75
76#[derive(Debug, Clone)]
78struct PriorityQueueNode<T>
79where
80 T: Float + PartialOrd,
81{
82 distance: T,
83 node: usize,
84}
85
86impl<T> PartialEq for PriorityQueueNode<T>
87where
88 T: Float + PartialOrd,
89{
90 fn eq(&self, other: &Self) -> bool {
91 self.distance == other.distance && self.node == other.node
92 }
93}
94
95impl<T> Eq for PriorityQueueNode<T> where T: Float + PartialOrd {}
96
97impl<T> PartialOrd for PriorityQueueNode<T>
98where
99 T: Float + PartialOrd,
100{
101 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
102 Some(self.cmp(other))
103 }
104}
105
106impl<T> Ord for PriorityQueueNode<T>
107where
108 T: Float + PartialOrd,
109{
110 fn cmp(&self, other: &Self) -> Ordering {
111 other
113 .distance
114 .partial_cmp(&self.distance)
115 .unwrap_or(Ordering::Equal)
116 }
117}
118
119#[allow(dead_code)]
130pub fn validate_graph<T, S>(matrix: &S, directed: bool) -> SparseResult<()>
131where
132 T: Float + Debug + Copy + 'static,
133 S: SparseArray<T>,
134{
135 let (rows, cols) = matrix.shape();
136
137 if rows != cols {
139 return Err(SparseError::ValueError(
140 "Graph _matrix must be square".to_string(),
141 ));
142 }
143
144 let (row_indices, col_indices, values) = matrix.find();
146 for &value in values.iter() {
147 if value < T::zero() {
148 return Err(SparseError::ValueError(
149 "Negative edge weights not supported".to_string(),
150 ));
151 }
152 }
153
154 if !directed {
156 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
157 if row != col {
158 let weight = values[i];
159 let reverse_weight = matrix.get(col, row);
160
161 if (weight - reverse_weight).abs() > T::from(1e-10).unwrap() {
162 return Err(SparseError::ValueError(
163 "Undirected graph _matrix must be symmetric".to_string(),
164 ));
165 }
166 }
167 }
168 }
169
170 Ok(())
171}
172
173#[allow(dead_code)]
184pub fn to_adjacency_list<T, S>(matrix: &S, directed: bool) -> SparseResult<Vec<Vec<(usize, T)>>>
185where
186 T: Float + Debug + Copy + 'static,
187 S: SparseArray<T>,
188{
189 let (n_, _) = matrix.shape();
190 let mut adj_list = vec![Vec::new(); n_];
191
192 let (row_indices, col_indices, values) = matrix.find();
193
194 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
195 let weight = values[i];
196
197 if !weight.is_zero() {
198 adj_list[row].push((col, weight));
199
200 if !directed && row != col {
202 let reverse_exists = row_indices
204 .iter()
205 .zip(col_indices.iter())
206 .any(|(r, c)| *r == col && *c == row);
207
208 if !reverse_exists {
209 adj_list[col].push((row, weight));
210 }
211 }
212 }
213 }
214
215 Ok(adj_list)
216}
217
218#[allow(dead_code)]
220pub fn num_vertices<T, S>(matrix: &S) -> usize
221where
222 T: Float + Debug + Copy + 'static,
223 S: SparseArray<T>,
224{
225 matrix.shape().0
226}
227
228#[allow(dead_code)]
230pub fn num_edges<T, S>(matrix: &S, directed: bool) -> SparseResult<usize>
231where
232 T: Float + Debug + Copy + 'static,
233 S: SparseArray<T>,
234{
235 let nnz = matrix.nnz();
236
237 if directed {
238 Ok(nnz)
239 } else {
240 let (row_indices, col_indices_, _) = matrix.find();
242 let mut diagonal_count = 0;
243 let mut off_diagonal_count = 0;
244
245 for (&row, &col) in row_indices.iter().zip(col_indices_.iter()) {
246 if row == col {
247 diagonal_count += 1;
248 } else {
249 off_diagonal_count += 1;
250 }
251 }
252
253 Ok(diagonal_count + off_diagonal_count / 2)
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::csr_array::CsrArray;
262
263 fn create_test_graph() -> CsrArray<f64> {
264 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
269 let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
270 let data = vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 3.0, 1.0];
271
272 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
273 }
274
275 #[test]
276 fn test_validate_graph_symmetric() {
277 let graph = create_test_graph();
278
279 assert!(validate_graph(&graph, false).is_ok());
281
282 assert!(validate_graph(&graph, true).is_ok());
284 }
285
286 #[test]
287 fn test_validate_graph_asymmetric() {
288 let rows = vec![0, 1];
290 let cols = vec![1, 0];
291 let data = vec![1.0, 2.0]; let graph = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
294
295 assert!(validate_graph(&graph, true).is_ok());
297
298 assert!(validate_graph(&graph, false).is_err());
300 }
301
302 #[test]
303 fn test_validate_graph_negative_weights() {
304 let rows = vec![0, 1];
305 let cols = vec![1, 0];
306 let data = vec![-1.0, 1.0]; let graph = CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).unwrap();
309
310 assert!(validate_graph(&graph, true).is_err());
312 assert!(validate_graph(&graph, false).is_err());
313 }
314
315 #[test]
316 fn test_to_adjacency_list() {
317 let graph = create_test_graph();
318 let adj_list = to_adjacency_list(&graph, false).unwrap();
319
320 assert_eq!(adj_list.len(), 4);
321
322 assert_eq!(adj_list[0].len(), 2);
324 assert!(adj_list[0].contains(&(1, 1.0)));
325 assert!(adj_list[0].contains(&(2, 2.0)));
326
327 assert_eq!(adj_list[1].len(), 2);
329 assert!(adj_list[1].contains(&(0, 1.0)));
330 assert!(adj_list[1].contains(&(3, 3.0)));
331 }
332
333 #[test]
334 fn test_num_vertices() {
335 let graph = create_test_graph();
336 assert_eq!(num_vertices(&graph), 4);
337 }
338
339 #[test]
340 fn test_num_edges() {
341 let graph = create_test_graph();
342
343 assert_eq!(num_edges(&graph, true).unwrap(), 8);
345
346 assert_eq!(num_edges(&graph, false).unwrap(), 4);
348 }
349}