1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
22use scirs2_core::random::{thread_rng, CoreRandom, Rng};
23use sklears_core::types::Float;
24use std::collections::{HashMap, HashSet, VecDeque};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum CommunityAlgorithm {
29 Louvain,
31 Leiden,
33 SpectralClustering,
35 LabelPropagation,
37 GirvanNewman,
39 FastGreedy,
41}
42
43#[derive(Debug, Clone)]
45pub struct CommunityDetectionConfig {
46 pub algorithm: CommunityAlgorithm,
48 pub resolution: Float,
50 pub max_iterations: usize,
52 pub tolerance: Float,
54 pub random_seed: Option<u64>,
56 pub min_community_size: usize,
58}
59
60impl Default for CommunityDetectionConfig {
61 fn default() -> Self {
62 Self {
63 algorithm: CommunityAlgorithm::Louvain,
64 resolution: 1.0,
65 max_iterations: 100,
66 tolerance: 1e-6,
67 random_seed: None,
68 min_community_size: 1,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct CommunityStructure {
76 pub assignments: Array1<usize>,
78 pub num_communities: usize,
80 pub modularity: Float,
82 pub community_sizes: Vec<usize>,
84 pub dendrogram: Option<Vec<(usize, usize, Float)>>,
86}
87
88pub struct CommunityDetector {
90 config: CommunityDetectionConfig,
92}
93
94impl CommunityDetector {
95 pub fn new(config: CommunityDetectionConfig) -> Self {
97 Self { config }
98 }
99
100 pub fn detect(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
108 match self.config.algorithm {
109 CommunityAlgorithm::Louvain => self.louvain(adjacency),
110 CommunityAlgorithm::Leiden => self.leiden(adjacency),
111 CommunityAlgorithm::SpectralClustering => self.spectral_clustering(adjacency),
112 CommunityAlgorithm::LabelPropagation => self.label_propagation(adjacency),
113 CommunityAlgorithm::GirvanNewman => self.girvan_newman(adjacency),
114 CommunityAlgorithm::FastGreedy => self.fast_greedy(adjacency),
115 }
116 }
117
118 fn louvain(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
120 let n = adjacency.nrows();
121 let mut assignments = Array1::from_iter(0..n);
122
123 let m = adjacency.sum() / 2.0; let mut best_modularity = 0.0;
125
126 for iteration in 0..self.config.max_iterations {
127 let mut improved = false;
128
129 for node in 0..n {
131 let current_community = assignments[node];
132
133 let mut best_delta = 0.0;
135 let mut best_community = current_community;
136
137 for neighbor in 0..n {
138 if adjacency[[node, neighbor]] > 0.0 {
139 let neighbor_community = assignments[neighbor];
140 if neighbor_community != current_community {
141 let delta = self.compute_modularity_delta(
142 &adjacency,
143 &assignments,
144 node,
145 current_community,
146 neighbor_community,
147 m,
148 );
149
150 if delta > best_delta {
151 best_delta = delta;
152 best_community = neighbor_community;
153 }
154 }
155 }
156 }
157
158 if best_delta > self.config.tolerance {
159 assignments[node] = best_community;
160 improved = true;
161 }
162 }
163
164 let current_modularity = self.compute_modularity(&adjacency, &assignments);
166
167 if current_modularity > best_modularity {
168 best_modularity = current_modularity;
169 }
170
171 if !improved {
172 break;
173 }
174 }
175
176 let (assignments, num_communities) = self.relabel_communities(assignments);
178
179 let community_sizes = self.compute_community_sizes(&assignments, num_communities);
181
182 CommunityStructure {
183 assignments,
184 num_communities,
185 modularity: best_modularity,
186 community_sizes,
187 dendrogram: None,
188 }
189 }
190
191 fn leiden(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
193 let mut louvain_result = self.louvain(adjacency);
195
196 let n = adjacency.nrows();
198 let mut refined_assignments = louvain_result.assignments.clone();
199
200 for community_id in 0..louvain_result.num_communities {
201 let community_nodes: Vec<usize> = (0..n)
203 .filter(|&i| refined_assignments[i] == community_id)
204 .collect();
205
206 if community_nodes.len() > 2 {
207 let split_assignments = self.try_split_community(&adjacency, &community_nodes);
209 if let Some(splits) = split_assignments {
210 for (idx, &node) in community_nodes.iter().enumerate() {
211 if splits[idx] == 1 {
212 refined_assignments[node] = louvain_result.num_communities;
213 }
214 }
215 louvain_result.num_communities += 1;
216 }
217 }
218 }
219
220 let (assignments, num_communities) = self.relabel_communities(refined_assignments);
221 let modularity = self.compute_modularity(&adjacency, &assignments);
222 let community_sizes = self.compute_community_sizes(&assignments, num_communities);
223
224 CommunityStructure {
225 assignments,
226 num_communities,
227 modularity,
228 community_sizes,
229 dendrogram: None,
230 }
231 }
232
233 fn spectral_clustering(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
235 let n = adjacency.nrows();
236
237 let degrees = adjacency.sum_axis(Axis(1));
239 let mut laplacian = adjacency.to_owned();
240
241 for i in 0..n {
242 laplacian[[i, i]] = degrees[i] - laplacian[[i, i]];
243 for j in 0..n {
244 if i != j {
245 laplacian[[i, j]] = -laplacian[[i, j]];
246 }
247 }
248 }
249
250 let k = (n as Float).sqrt().ceil() as usize;
253 let mut assignments = Array1::zeros(n);
254
255 let sorted_indices: Vec<usize> = {
257 let mut indices: Vec<usize> = (0..n).collect();
258 indices.sort_by(|&i, &j| degrees[i].partial_cmp(°rees[j]).unwrap());
259 indices
260 };
261
262 for (idx, &node) in sorted_indices.iter().enumerate() {
263 assignments[node] = idx * k / n;
264 }
265
266 let num_communities = k;
267 let modularity = self.compute_modularity(&adjacency, &assignments);
268 let community_sizes = self.compute_community_sizes(&assignments, num_communities);
269
270 CommunityStructure {
271 assignments,
272 num_communities,
273 modularity,
274 community_sizes,
275 dendrogram: None,
276 }
277 }
278
279 fn label_propagation(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
281 let n = adjacency.nrows();
282 let mut assignments = Array1::from_iter(0..n); let mut rng = if let Some(seed) = self.config.random_seed {
285 thread_rng() } else {
287 thread_rng()
288 };
289
290 for _ in 0..self.config.max_iterations {
291 let mut changed = false;
292
293 let mut node_order: Vec<usize> = (0..n).collect();
295 for i in 0..n {
296 let j = rng.random_range(i..n);
297 node_order.swap(i, j);
298 }
299
300 for &node in &node_order {
302 let mut label_counts: HashMap<usize, Float> = HashMap::new();
303
304 for neighbor in 0..n {
306 if adjacency[[node, neighbor]] > 0.0 {
307 let label = assignments[neighbor];
308 *label_counts.entry(label).or_insert(0.0) += adjacency[[node, neighbor]];
309 }
310 }
311
312 if let Some((&best_label, _)) = label_counts
314 .iter()
315 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
316 {
317 if best_label != assignments[node] {
318 assignments[node] = best_label;
319 changed = true;
320 }
321 }
322 }
323
324 if !changed {
325 break;
326 }
327 }
328
329 let (assignments, num_communities) = self.relabel_communities(assignments);
330 let modularity = self.compute_modularity(&adjacency, &assignments);
331 let community_sizes = self.compute_community_sizes(&assignments, num_communities);
332
333 CommunityStructure {
334 assignments,
335 num_communities,
336 modularity,
337 community_sizes,
338 dendrogram: None,
339 }
340 }
341
342 fn girvan_newman(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
344 let n = adjacency.nrows();
346 let mut graph = adjacency.to_owned();
347 let mut dendrogram = Vec::new();
348
349 let mut assignments = Array1::zeros(n);
351 let mut num_communities = 1;
352
353 for iteration in 0..self.config.max_iterations {
355 let mut max_weight = 0.0;
357 let mut max_edge = (0, 0);
358
359 for i in 0..n {
360 for j in (i + 1)..n {
361 if graph[[i, j]] > max_weight {
362 max_weight = graph[[i, j]];
363 max_edge = (i, j);
364 }
365 }
366 }
367
368 if max_weight == 0.0 {
369 break;
370 }
371
372 let (i, j) = max_edge;
374 graph[[i, j]] = 0.0;
375 graph[[j, i]] = 0.0;
376
377 dendrogram.push((i, j, max_weight));
378
379 let components = self.find_connected_components(&graph);
381 if components.len() > num_communities {
382 num_communities = components.len();
383
384 for (comp_id, component) in components.iter().enumerate() {
386 for &node in component {
387 assignments[node] = comp_id;
388 }
389 }
390 }
391
392 if num_communities >= n / self.config.min_community_size {
393 break;
394 }
395 }
396
397 let modularity = self.compute_modularity(&adjacency, &assignments);
398 let community_sizes = self.compute_community_sizes(&assignments, num_communities);
399
400 CommunityStructure {
401 assignments,
402 num_communities,
403 modularity,
404 community_sizes,
405 dendrogram: Some(dendrogram),
406 }
407 }
408
409 fn fast_greedy(&self, adjacency: ArrayView2<Float>) -> CommunityStructure {
411 let n = adjacency.nrows();
412 let mut assignments = Array1::from_iter(0..n);
413
414 let m = adjacency.sum() / 2.0;
415 let mut best_modularity = 0.0;
416 let mut best_assignments = assignments.clone();
417
418 for _ in 0..n - 1 {
420 let mut best_delta = Float::NEG_INFINITY;
421 let mut best_merge = (0, 0);
422
423 for i in 0..n {
425 for j in (i + 1)..n {
426 if assignments[i] != assignments[j] {
427 let delta = self.compute_modularity_delta(
428 &adjacency,
429 &assignments,
430 i,
431 assignments[i],
432 assignments[j],
433 m,
434 );
435
436 if delta > best_delta {
437 best_delta = delta;
438 best_merge = (i, j);
439 }
440 }
441 }
442 }
443
444 if best_delta <= 0.0 {
445 break;
446 }
447
448 let (i, j) = best_merge;
450 let comm_i = assignments[i];
451 let comm_j = assignments[j];
452
453 for k in 0..n {
454 if assignments[k] == comm_j {
455 assignments[k] = comm_i;
456 }
457 }
458
459 let modularity = self.compute_modularity(&adjacency, &assignments);
460 if modularity > best_modularity {
461 best_modularity = modularity;
462 best_assignments = assignments.clone();
463 }
464 }
465
466 let (assignments, num_communities) = self.relabel_communities(best_assignments);
467 let community_sizes = self.compute_community_sizes(&assignments, num_communities);
468
469 CommunityStructure {
470 assignments,
471 num_communities,
472 modularity: best_modularity,
473 community_sizes,
474 dendrogram: None,
475 }
476 }
477
478 fn compute_modularity(
480 &self,
481 adjacency: &ArrayView2<Float>,
482 assignments: &Array1<usize>,
483 ) -> Float {
484 let n = adjacency.nrows();
485 let m = adjacency.sum() / 2.0;
486
487 if m == 0.0 {
488 return 0.0;
489 }
490
491 let mut modularity = 0.0;
492
493 for i in 0..n {
494 for j in 0..n {
495 if assignments[i] == assignments[j] {
496 let k_i = adjacency.row(i).sum();
497 let k_j = adjacency.row(j).sum();
498 modularity += adjacency[[i, j]] - (k_i * k_j) / (2.0 * m);
499 }
500 }
501 }
502
503 modularity / (2.0 * m)
504 }
505
506 fn compute_modularity_delta(
508 &self,
509 adjacency: &ArrayView2<Float>,
510 assignments: &Array1<usize>,
511 node: usize,
512 from_comm: usize,
513 to_comm: usize,
514 m: Float,
515 ) -> Float {
516 let n = adjacency.nrows();
517
518 let mut k_i_in_from = 0.0;
520 let mut k_i_in_to = 0.0;
521
522 for j in 0..n {
523 if assignments[j] == from_comm {
524 k_i_in_from += adjacency[[node, j]];
525 }
526 if assignments[j] == to_comm {
527 k_i_in_to += adjacency[[node, j]];
528 }
529 }
530
531 let k_i = adjacency.row(node).sum();
532
533 (k_i_in_to - k_i_in_from) / m
535 - self.config.resolution * k_i * (k_i_in_to - k_i_in_from) / (m * m)
536 }
537
538 fn relabel_communities(&self, assignments: Array1<usize>) -> (Array1<usize>, usize) {
540 let unique_communities: HashSet<usize> = assignments.iter().copied().collect();
541 let community_map: HashMap<usize, usize> = unique_communities
542 .iter()
543 .enumerate()
544 .map(|(new_id, &old_id)| (old_id, new_id))
545 .collect();
546
547 let new_assignments = assignments.mapv(|c| community_map[&c]);
548 let num_communities = unique_communities.len();
549
550 (new_assignments, num_communities)
551 }
552
553 fn compute_community_sizes(
555 &self,
556 assignments: &Array1<usize>,
557 num_communities: usize,
558 ) -> Vec<usize> {
559 let mut sizes = vec![0; num_communities];
560
561 for &community in assignments.iter() {
562 if community < num_communities {
563 sizes[community] += 1;
564 }
565 }
566
567 sizes
568 }
569
570 fn find_connected_components(&self, adjacency: &Array2<Float>) -> Vec<Vec<usize>> {
572 let n = adjacency.nrows();
573 let mut visited = vec![false; n];
574 let mut components = Vec::new();
575
576 for start in 0..n {
577 if !visited[start] {
578 let mut component = Vec::new();
579 let mut queue = VecDeque::new();
580 queue.push_back(start);
581 visited[start] = true;
582
583 while let Some(node) = queue.pop_front() {
584 component.push(node);
585
586 for neighbor in 0..n {
587 if !visited[neighbor] && adjacency[[node, neighbor]] > 0.0 {
588 visited[neighbor] = true;
589 queue.push_back(neighbor);
590 }
591 }
592 }
593
594 components.push(component);
595 }
596 }
597
598 components
599 }
600
601 fn try_split_community(
603 &self,
604 adjacency: &ArrayView2<Float>,
605 nodes: &[usize],
606 ) -> Option<Vec<usize>> {
607 if nodes.len() < 2 {
608 return None;
609 }
610
611 let mut internal_edges: HashMap<(usize, usize), Float> = HashMap::new();
613
614 for (i, &node_i) in nodes.iter().enumerate() {
615 for (j, &node_j) in nodes.iter().enumerate().skip(i + 1) {
616 let weight = adjacency[[node_i, node_j]];
617 if weight > 0.0 {
618 internal_edges.insert((i, j), weight);
619 }
620 }
621 }
622
623 if internal_edges.len() > nodes.len() {
625 return None;
626 }
627
628 let mut assignments = vec![0; nodes.len()];
630 for i in (nodes.len() / 2)..nodes.len() {
631 assignments[i] = 1;
632 }
633
634 Some(assignments)
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use scirs2_core::ndarray::array;
642
643 #[test]
644 fn test_community_detector_creation() {
645 let config = CommunityDetectionConfig::default();
646 let detector = CommunityDetector::new(config);
647 assert_eq!(detector.config.algorithm, CommunityAlgorithm::Louvain);
648 }
649
650 #[test]
651 fn test_louvain_simple_graph() {
652 let config = CommunityDetectionConfig {
653 algorithm: CommunityAlgorithm::Louvain,
654 ..Default::default()
655 };
656 let detector = CommunityDetector::new(config);
657
658 let adjacency = array![
660 [0.0, 1.0, 1.0, 0.0, 0.0],
661 [1.0, 0.0, 1.0, 0.0, 0.0],
662 [1.0, 1.0, 0.0, 0.5, 0.0],
663 [0.0, 0.0, 0.5, 0.0, 1.0],
664 [0.0, 0.0, 0.0, 1.0, 0.0],
665 ];
666
667 let result = detector.detect(adjacency.view());
668
669 assert!(result.num_communities >= 2);
670 assert!(result.num_communities <= 5);
671 assert_eq!(result.assignments.len(), 5);
672 }
673
674 #[test]
675 fn test_label_propagation() {
676 let config = CommunityDetectionConfig {
677 algorithm: CommunityAlgorithm::LabelPropagation,
678 max_iterations: 10,
679 ..Default::default()
680 };
681 let detector = CommunityDetector::new(config);
682
683 let adjacency = array![
684 [0.0, 1.0, 1.0, 0.0],
685 [1.0, 0.0, 1.0, 0.0],
686 [1.0, 1.0, 0.0, 1.0],
687 [0.0, 0.0, 1.0, 0.0],
688 ];
689
690 let result = detector.detect(adjacency.view());
691
692 assert!(result.num_communities >= 1);
693 assert!(result.num_communities <= 4);
694 }
695
696 #[test]
697 fn test_spectral_clustering() {
698 let config = CommunityDetectionConfig {
699 algorithm: CommunityAlgorithm::SpectralClustering,
700 ..Default::default()
701 };
702 let detector = CommunityDetector::new(config);
703
704 let adjacency = array![[0.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0],];
705
706 let result = detector.detect(adjacency.view());
707
708 assert_eq!(result.assignments.len(), 3);
709 assert!(result.num_communities >= 1);
710 }
711
712 #[test]
713 fn test_modularity_computation() {
714 let config = CommunityDetectionConfig::default();
715 let detector = CommunityDetector::new(config);
716
717 let adjacency = array![[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0],];
718
719 let assignments = array![0, 0, 1];
720
721 let modularity = detector.compute_modularity(&adjacency.view(), &assignments);
722
723 assert!(modularity >= -0.5);
725 assert!(modularity <= 1.0);
726 }
727
728 #[test]
729 fn test_relabel_communities() {
730 let config = CommunityDetectionConfig::default();
731 let detector = CommunityDetector::new(config);
732
733 let assignments = array![5, 5, 10, 10, 2];
734 let (relabeled, num_communities) = detector.relabel_communities(assignments);
735
736 assert_eq!(num_communities, 3);
737 assert_eq!(relabeled.len(), 5);
738
739 for &comm in relabeled.iter() {
741 assert!(comm < num_communities);
742 }
743 }
744
745 #[test]
746 fn test_community_sizes() {
747 let config = CommunityDetectionConfig::default();
748 let detector = CommunityDetector::new(config);
749
750 let assignments = array![0, 0, 1, 1, 1, 2];
751 let sizes = detector.compute_community_sizes(&assignments, 3);
752
753 assert_eq!(sizes, vec![2, 3, 1]);
754 }
755
756 #[test]
757 fn test_connected_components() {
758 let config = CommunityDetectionConfig::default();
759 let detector = CommunityDetector::new(config);
760
761 let adjacency = array![
763 [0.0, 1.0, 0.0, 0.0],
764 [1.0, 0.0, 0.0, 0.0],
765 [0.0, 0.0, 0.0, 1.0],
766 [0.0, 0.0, 1.0, 0.0],
767 ];
768
769 let components = detector.find_connected_components(&adjacency);
770
771 assert_eq!(components.len(), 2);
772 assert_eq!(components[0].len() + components[1].len(), 4);
773 }
774
775 #[test]
776 fn test_girvan_newman() {
777 let config = CommunityDetectionConfig {
778 algorithm: CommunityAlgorithm::GirvanNewman,
779 max_iterations: 5,
780 ..Default::default()
781 };
782 let detector = CommunityDetector::new(config);
783
784 let adjacency = array![
785 [0.0, 1.0, 1.0, 0.0],
786 [1.0, 0.0, 1.0, 0.5],
787 [1.0, 1.0, 0.0, 0.5],
788 [0.0, 0.5, 0.5, 0.0],
789 ];
790
791 let result = detector.detect(adjacency.view());
792
793 assert!(result.dendrogram.is_some());
794 assert!(result.num_communities >= 1);
795 }
796
797 #[test]
798 fn test_fast_greedy() {
799 let config = CommunityDetectionConfig {
800 algorithm: CommunityAlgorithm::FastGreedy,
801 ..Default::default()
802 };
803 let detector = CommunityDetector::new(config);
804
805 let adjacency = array![
806 [0.0, 2.0, 2.0, 0.0, 0.0],
807 [2.0, 0.0, 2.0, 0.0, 0.0],
808 [2.0, 2.0, 0.0, 1.0, 0.0],
809 [0.0, 0.0, 1.0, 0.0, 3.0],
810 [0.0, 0.0, 0.0, 3.0, 0.0],
811 ];
812
813 let result = detector.detect(adjacency.view());
814
815 assert!(result.num_communities >= 2);
816 assert!(result.modularity >= 0.0);
817 }
818}