1use crate::error::{SparseError, SparseResult};
50use crate::sparray::SparseArray;
51use scirs2_core::numeric::{Float, SparseElement};
52use std::cmp::Ordering;
53use std::fmt::Debug;
54
55pub mod connected_components;
56pub mod laplacian;
57pub mod minimum_spanning_tree;
58pub mod shortest_path;
59pub mod traversal;
60
61pub mod centrality;
63pub mod community_detection;
64pub mod max_flow;
65
66pub use centrality::*;
67pub use community_detection::*;
68pub use connected_components::*;
69pub use laplacian::*;
70pub use max_flow::*;
71pub use minimum_spanning_tree::*;
72pub use shortest_path::*;
73pub use traversal::*;
74
75#[derive(Debug, Clone, Copy, PartialEq)]
77pub enum GraphMode {
78 Directed,
80 Undirected,
82}
83
84#[derive(Debug, Clone)]
86struct PriorityQueueNode<T>
87where
88 T: Float + PartialOrd,
89{
90 distance: T,
91 node: usize,
92}
93
94impl<T> PartialEq for PriorityQueueNode<T>
95where
96 T: Float + PartialOrd,
97{
98 fn eq(&self, other: &Self) -> bool {
99 self.distance == other.distance && self.node == other.node
100 }
101}
102
103impl<T> Eq for PriorityQueueNode<T> where T: Float + PartialOrd {}
104
105impl<T> PartialOrd for PriorityQueueNode<T>
106where
107 T: Float + PartialOrd,
108{
109 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
110 Some(self.cmp(other))
111 }
112}
113
114impl<T> Ord for PriorityQueueNode<T>
115where
116 T: Float + PartialOrd,
117{
118 fn cmp(&self, other: &Self) -> Ordering {
119 other
121 .distance
122 .partial_cmp(&self.distance)
123 .unwrap_or(Ordering::Equal)
124 }
125}
126
127#[allow(dead_code)]
138pub fn validate_graph<T, S>(matrix: &S, directed: bool) -> SparseResult<()>
139where
140 T: Float + SparseElement + Debug + Copy + 'static,
141 S: SparseArray<T>,
142{
143 let (rows, cols) = matrix.shape();
144
145 if rows != cols {
147 return Err(SparseError::ValueError(
148 "Graph _matrix must be square".to_string(),
149 ));
150 }
151
152 let (row_indices, col_indices, values) = matrix.find();
154 for &value in values.iter() {
155 if value < T::sparse_zero() {
156 return Err(SparseError::ValueError(
157 "Negative edge weights not supported".to_string(),
158 ));
159 }
160 }
161
162 if !directed {
164 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
165 if row != col {
166 let weight = values[i];
167 let reverse_weight = matrix.get(col, row);
168
169 if (weight - reverse_weight).abs() > T::from(1e-10).expect("Operation failed") {
170 return Err(SparseError::ValueError(
171 "Undirected graph _matrix must be symmetric".to_string(),
172 ));
173 }
174 }
175 }
176 }
177
178 Ok(())
179}
180
181#[allow(dead_code)]
192pub fn to_adjacency_list<T, S>(matrix: &S, directed: bool) -> SparseResult<Vec<Vec<(usize, T)>>>
193where
194 T: Float + SparseElement + Debug + Copy + 'static,
195 S: SparseArray<T>,
196{
197 let (n_, _) = matrix.shape();
198 let mut adj_list = vec![Vec::new(); n_];
199
200 let (row_indices, col_indices, values) = matrix.find();
201
202 for (i, (&row, &col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
203 let weight = values[i];
204
205 if !SparseElement::is_zero(&weight) {
206 adj_list[row].push((col, weight));
207
208 if !directed && row != col {
210 let reverse_exists = row_indices
212 .iter()
213 .zip(col_indices.iter())
214 .any(|(r, c)| *r == col && *c == row);
215
216 if !reverse_exists {
217 adj_list[col].push((row, weight));
218 }
219 }
220 }
221 }
222
223 Ok(adj_list)
224}
225
226#[allow(dead_code)]
228pub fn num_vertices<T, S>(matrix: &S) -> usize
229where
230 T: Float + SparseElement + Debug + Copy + 'static,
231 S: SparseArray<T>,
232{
233 matrix.shape().0
234}
235
236#[allow(dead_code)]
238pub fn num_edges<T, S>(matrix: &S, directed: bool) -> SparseResult<usize>
239where
240 T: Float + SparseElement + Debug + Copy + 'static,
241 S: SparseArray<T>,
242{
243 let nnz = matrix.nnz();
244
245 if directed {
246 Ok(nnz)
247 } else {
248 let (row_indices, col_indices_, _) = matrix.find();
250 let mut diagonal_count = 0;
251 let mut off_diagonal_count = 0;
252
253 for (&row, &col) in row_indices.iter().zip(col_indices_.iter()) {
254 if row == col {
255 diagonal_count += 1;
256 } else {
257 off_diagonal_count += 1;
258 }
259 }
260
261 Ok(diagonal_count + off_diagonal_count / 2)
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::csr_array::CsrArray;
270
271 fn create_test_graph() -> CsrArray<f64> {
272 let rows = vec![0, 0, 1, 1, 2, 2, 3, 3];
277 let cols = vec![1, 2, 0, 3, 0, 3, 1, 2];
278 let data = vec![1.0, 2.0, 1.0, 3.0, 2.0, 1.0, 3.0, 1.0];
279
280 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).expect("Operation failed")
281 }
282
283 #[test]
284 fn test_validate_graph_symmetric() {
285 let graph = create_test_graph();
286
287 assert!(validate_graph(&graph, false).is_ok());
289
290 assert!(validate_graph(&graph, true).is_ok());
292 }
293
294 #[test]
295 fn test_validate_graph_asymmetric() {
296 let rows = vec![0, 1];
298 let cols = vec![1, 0];
299 let data = vec![1.0, 2.0]; let graph =
302 CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).expect("Operation failed");
303
304 assert!(validate_graph(&graph, true).is_ok());
306
307 assert!(validate_graph(&graph, false).is_err());
309 }
310
311 #[test]
312 fn test_validate_graph_negative_weights() {
313 let rows = vec![0, 1];
314 let cols = vec![1, 0];
315 let data = vec![-1.0, 1.0]; let graph =
318 CsrArray::from_triplets(&rows, &cols, &data, (2, 2), false).expect("Operation failed");
319
320 assert!(validate_graph(&graph, true).is_err());
322 assert!(validate_graph(&graph, false).is_err());
323 }
324
325 #[test]
326 fn test_to_adjacency_list() {
327 let graph = create_test_graph();
328 let adj_list = to_adjacency_list(&graph, false).expect("Operation failed");
329
330 assert_eq!(adj_list.len(), 4);
331
332 assert_eq!(adj_list[0].len(), 2);
334 assert!(adj_list[0].contains(&(1, 1.0)));
335 assert!(adj_list[0].contains(&(2, 2.0)));
336
337 assert_eq!(adj_list[1].len(), 2);
339 assert!(adj_list[1].contains(&(0, 1.0)));
340 assert!(adj_list[1].contains(&(3, 3.0)));
341 }
342
343 #[test]
344 fn test_num_vertices() {
345 let graph = create_test_graph();
346 assert_eq!(num_vertices(&graph), 4);
347 }
348
349 #[test]
350 fn test_num_edges() {
351 let graph = create_test_graph();
352
353 assert_eq!(num_edges(&graph, true).expect("Operation failed"), 8);
355
356 assert_eq!(num_edges(&graph, false).expect("Operation failed"), 4);
358 }
359}