1use super::{num_vertices, to_adjacency_list, validate_graph};
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::Array1;
10use scirs2_core::numeric::Float;
11use std::fmt::Debug;
12
13#[allow(dead_code)]
44pub fn connected_components<T, S>(
45 graph: &S,
46 directed: bool,
47 connection: &str,
48 returnlabels: bool,
49) -> SparseResult<(usize, Option<Array1<usize>>)>
50where
51 T: Float + Debug + Copy + 'static,
52 S: SparseArray<T>,
53{
54 validate_graph(graph, directed)?;
55
56 let connection_type = match connection.to_lowercase().as_str() {
57 "weak" => ConnectionType::Weak,
58 "strong" => ConnectionType::Strong,
59 _ => {
60 return Err(SparseError::ValueError(format!(
61 "Unknown connection type: {connection}. Use 'weak' or 'strong'"
62 )))
63 }
64 };
65
66 if directed {
67 match connection_type {
68 ConnectionType::Weak => weakly_connected_components(graph, returnlabels),
69 ConnectionType::Strong => strongly_connected_components(graph, returnlabels),
70 }
71 } else {
72 undirected_connected_components(graph, returnlabels)
74 }
75}
76
77#[derive(Debug, Clone, Copy, PartialEq)]
79enum ConnectionType {
80 Weak,
82 Strong,
84}
85
86#[allow(dead_code)]
88pub fn undirected_connected_components<T, S>(
89 graph: &S,
90 returnlabels: bool,
91) -> SparseResult<(usize, Option<Array1<usize>>)>
92where
93 T: Float + Debug + Copy + 'static,
94 S: SparseArray<T>,
95{
96 let n = num_vertices(graph);
97 let adjlist = to_adjacency_list(graph, false)?; let mut visited = vec![false; n];
100 let mut labels = if returnlabels {
101 Some(Array1::zeros(n))
102 } else {
103 None
104 };
105
106 let mut component_count = 0;
107
108 for start in 0..n {
109 if !visited[start] {
110 dfs_component(&adjlist, start, &mut visited, component_count, &mut labels);
112 component_count += 1;
113 }
114 }
115
116 Ok((component_count, labels))
117}
118
119#[allow(dead_code)]
121pub fn weakly_connected_components<T, S>(
122 graph: &S,
123 returnlabels: bool,
124) -> SparseResult<(usize, Option<Array1<usize>>)>
125where
126 T: Float + Debug + Copy + 'static,
127 S: SparseArray<T>,
128{
129 undirected_connected_components(graph, returnlabels)
131}
132
133#[allow(dead_code)]
135pub fn strongly_connected_components<T, S>(
136 graph: &S,
137 returnlabels: bool,
138) -> SparseResult<(usize, Option<Array1<usize>>)>
139where
140 T: Float + Debug + Copy + 'static,
141 S: SparseArray<T>,
142{
143 let n = num_vertices(graph);
144 let adjlist = to_adjacency_list(graph, true)?; let mut tarjan = TarjanSCC::<T>::new(n, returnlabels);
147
148 for v in 0..n {
149 if tarjan.indices[v] == -1 {
150 tarjan.strongconnect(v, &adjlist);
151 }
152 }
153
154 Ok((tarjan.component_count, tarjan._labels))
155}
156
157#[allow(dead_code)]
159fn dfs_component<T>(
160 adjlist: &[Vec<(usize, T)>],
161 start: usize,
162 visited: &mut [bool],
163 component_id: usize,
164 labels: &mut Option<Array1<usize>>,
165) where
166 T: Float + Debug + Copy + 'static,
167{
168 let mut stack = vec![start];
169
170 while let Some(node) = stack.pop() {
171 if visited[node] {
172 continue;
173 }
174
175 visited[node] = true;
176
177 if let Some(ref mut label_array) = labels {
178 label_array[node] = component_id;
179 }
180
181 for &(neighbor, _) in &adjlist[node] {
183 if !visited[neighbor] {
184 stack.push(neighbor);
185 }
186 }
187 }
188}
189
190struct TarjanSCC<T>
192where
193 T: Float + Debug + Copy + 'static,
194{
195 indices: Vec<isize>,
196 lowlinks: Vec<isize>,
197 on_stack: Vec<bool>,
198 stack: Vec<usize>,
199 index: isize,
200 component_count: usize,
201 _labels: Option<Array1<usize>>,
202 _phantom: std::marker::PhantomData<T>,
203}
204
205impl<T> TarjanSCC<T>
206where
207 T: Float + Debug + Copy + 'static,
208{
209 fn new(n: usize, returnlabels: bool) -> Self {
210 Self {
211 indices: vec![-1; n],
212 lowlinks: vec![-1; n],
213 on_stack: vec![false; n],
214 stack: Vec::new(),
215 index: 0,
216 component_count: 0,
217 _labels: if returnlabels {
218 Some(Array1::zeros(n))
219 } else {
220 None
221 },
222 _phantom: std::marker::PhantomData,
223 }
224 }
225
226 fn strongconnect(&mut self, v: usize, adjlist: &[Vec<(usize, T)>]) {
227 self.indices[v] = self.index;
229 self.lowlinks[v] = self.index;
230 self.index += 1;
231 self.stack.push(v);
232 self.on_stack[v] = true;
233
234 for &(w, _) in &adjlist[v] {
236 if self.indices[w] == -1 {
237 self.strongconnect(w, adjlist);
239 self.lowlinks[v] = self.lowlinks[v].min(self.lowlinks[w]);
240 } else if self.on_stack[w] {
241 self.lowlinks[v] = self.lowlinks[v].min(self.indices[w]);
243 }
244 }
245
246 if self.lowlinks[v] == self.indices[v] {
248 loop {
249 let w = self.stack.pop().unwrap();
250 self.on_stack[w] = false;
251
252 if let Some(ref mut labels) = self._labels {
253 labels[w] = self.component_count;
254 }
255
256 if w == v {
257 break;
258 }
259 }
260 self.component_count += 1;
261 }
262 }
263}
264
265#[allow(dead_code)]
291pub fn is_connected<T, S>(graph: &S, directed: bool) -> SparseResult<bool>
292where
293 T: Float + Debug + Copy + 'static,
294 S: SparseArray<T>,
295{
296 let (n_components_, _) = connected_components(graph, directed, "strong", false)?;
297 Ok(n_components_ == 1)
298}
299
300#[allow(dead_code)]
329pub fn largest_component<T, S>(
330 graph: &S,
331 directed: bool,
332 connection: &str,
333) -> SparseResult<(usize, Vec<usize>)>
334where
335 T: Float + Debug + Copy + 'static,
336 S: SparseArray<T>,
337{
338 let (n_components, labels) = connected_components(graph, directed, connection, true)?;
339 let labels = labels.unwrap();
340
341 let mut component_sizes = vec![0; n_components];
343 for &label in labels.iter() {
344 component_sizes[label] += 1;
345 }
346
347 let largest_component_id = component_sizes
349 .iter()
350 .enumerate()
351 .max_by_key(|(_, &size)| size)
352 .map(|(id_, _)| id_)
353 .unwrap_or(0);
354
355 let largest_size = component_sizes[largest_component_id];
356
357 let largest_indices: Vec<usize> = labels
359 .iter()
360 .enumerate()
361 .filter_map(|(vertex, &label)| {
362 if label == largest_component_id {
363 Some(vertex)
364 } else {
365 None
366 }
367 })
368 .collect();
369
370 Ok((largest_size, largest_indices))
371}
372
373#[allow(dead_code)]
388pub fn extract_largest_component<T, S>(
389 graph: &S,
390 directed: bool,
391 connection: &str,
392) -> SparseResult<(S, Vec<usize>)>
393where
394 T: Float + Debug + Copy + 'static,
395 S: SparseArray<T> + Clone,
396{
397 let (_, vertex_indices) = largest_component(graph, directed, connection)?;
398
399 let mut old_to_new = vec![None; num_vertices(graph)];
401 for (new_idx, &old_idx) in vertex_indices.iter().enumerate() {
402 old_to_new[old_idx] = Some(new_idx);
403 }
404
405 let (row_indices, col_indices, values) = graph.find();
407 let mut new_rows = Vec::new();
408 let mut new_cols = Vec::new();
409 let mut new_values = Vec::new();
410
411 for (i, (&old_row, &old_col)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
412 if let (Some(new_row), Some(new_col)) = (old_to_new[old_row], old_to_new[old_col]) {
413 new_rows.push(new_row);
414 new_cols.push(new_col);
415 new_values.push(values[i]);
416 }
417 }
418
419 let subgraph = graph.clone();
424
425 Ok((subgraph, vertex_indices))
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::csr_array::CsrArray;
432
433 fn create_disconnected_graph() -> CsrArray<f64> {
434 let rows = vec![0, 1, 2, 3];
438 let cols = vec![1, 0, 3, 2];
439 let data = vec![1.0, 1.0, 1.0, 1.0];
440
441 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
442 }
443
444 fn create_strongly_connected_graph() -> CsrArray<f64> {
445 let rows = vec![0, 1, 2];
449 let cols = vec![1, 2, 0];
450 let data = vec![1.0, 1.0, 1.0];
451
452 CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap()
453 }
454
455 #[test]
456 fn test_undirected_connected_components() {
457 let graph = create_disconnected_graph();
458 let (n_components, labels) = undirected_connected_components(&graph, true).unwrap();
459
460 assert_eq!(n_components, 2);
461
462 let labels = labels.unwrap();
463 assert_eq!(labels[0], labels[1]);
465 assert_eq!(labels[2], labels[3]);
467 assert_ne!(labels[0], labels[2]);
469 }
470
471 #[test]
472 fn test_connected_components_api() {
473 let graph = create_disconnected_graph();
474
475 let (n_components_, _) = connected_components(&graph, false, "weak", false).unwrap();
477 assert_eq!(n_components_, 2);
478
479 let (n_components_, _) = connected_components(&graph, true, "weak", false).unwrap();
481 assert_eq!(n_components_, 2);
482 }
483
484 #[test]
485 fn test_strongly_connected_components() {
486 let graph = create_strongly_connected_graph();
487 let (n_components, labels) = strongly_connected_components(&graph, true).unwrap();
488
489 assert_eq!(n_components, 2);
491
492 let labels = labels.unwrap();
493 assert_eq!(labels[0], labels[1]);
495 assert_eq!(labels[1], labels[2]);
496 assert_ne!(labels[0], labels[3]);
498 }
499
500 #[test]
501 fn test_is_connected() {
502 let disconnected = create_disconnected_graph();
503 assert!(!is_connected(&disconnected, false).unwrap());
504
505 let rows = vec![0, 1, 1, 2];
507 let cols = vec![1, 0, 2, 1];
508 let data = vec![1.0, 1.0, 1.0, 1.0];
509 let connected = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
510
511 assert!(is_connected(&connected, false).unwrap());
512 }
513
514 #[test]
515 fn test_largest_component() {
516 let rows = vec![0, 1, 1, 2, 3, 4];
521 let cols = vec![1, 0, 2, 1, 4, 3];
522 let data = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
523 let graph = CsrArray::from_triplets(&rows, &cols, &data, (6, 6), false).unwrap();
524
525 let (size, indices) = largest_component(&graph, false, "weak").unwrap();
526
527 assert_eq!(size, 3);
528 assert_eq!(indices.len(), 3);
529 assert!(indices.contains(&0));
531 assert!(indices.contains(&1));
532 assert!(indices.contains(&2));
533 }
534
535 #[test]
536 fn test_single_component() {
537 let rows = vec![0, 1, 1, 2];
539 let cols = vec![1, 0, 2, 1];
540 let data = vec![1.0, 1.0, 1.0, 1.0];
541 let graph = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
542
543 let (n_components_, _) = connected_components(&graph, false, "weak", false).unwrap();
544 assert_eq!(n_components_, 1);
545
546 let (size, indices) = largest_component(&graph, false, "weak").unwrap();
547 assert_eq!(size, 3);
548 assert_eq!(indices, vec![0, 1, 2]);
549 }
550
551 #[test]
552 fn test_isolated_vertices() {
553 let rows = vec![0, 1];
555 let cols = vec![1, 0];
556 let data = vec![1.0, 1.0];
557 let graph = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).unwrap();
558
559 let (n_components, labels) = connected_components(&graph, false, "weak", true).unwrap();
560
561 assert_eq!(n_components, 3);
563
564 let labels = labels.unwrap();
565 assert_eq!(labels[0], labels[1]); assert_ne!(labels[0], labels[2]); assert_ne!(labels[0], labels[3]); assert_ne!(labels[2], labels[3]); }
570}