1use std::cmp::Ordering;
31use std::collections::{HashMap, HashSet};
32
33use scirs2_core::rand_prelude::{Distribution, SliceRandom};
34use scirs2_core::random::{thread_rng, Random, Rng};
36use sklears_core::error::{Result, SklearsError};
37use sklears_core::prelude::*;
38
39#[derive(Debug, Clone)]
41pub struct Graph {
42 pub adjacency: Array2<f64>,
44 pub n_nodes: usize,
46 pub directed: bool,
48 pub node_weights: Option<Vec<f64>>,
50}
51
52impl Graph {
53 pub fn from_adjacency(adjacency: Array2<f64>, directed: bool) -> Result<Self> {
55 let n_nodes = adjacency.nrows();
56 if adjacency.ncols() != n_nodes {
57 return Err(SklearsError::InvalidInput(
58 "Adjacency matrix must be square".to_string(),
59 ));
60 }
61
62 Ok(Self {
63 adjacency,
64 n_nodes,
65 directed,
66 node_weights: None,
67 })
68 }
69
70 pub fn from_edges(
72 edges: &[(usize, usize, f64)],
73 n_nodes: usize,
74 directed: bool,
75 ) -> Result<Self> {
76 let mut adjacency = Array2::zeros((n_nodes, n_nodes));
77
78 for &(i, j, weight) in edges {
79 if i >= n_nodes || j >= n_nodes {
80 return Err(SklearsError::InvalidInput(
81 "Edge indices exceed number of nodes".to_string(),
82 ));
83 }
84
85 adjacency[[i, j]] = weight;
86 if !directed {
87 adjacency[[j, i]] = weight;
88 }
89 }
90
91 Ok(Self {
92 adjacency,
93 n_nodes,
94 directed,
95 node_weights: None,
96 })
97 }
98
99 pub fn with_node_weights(mut self, weights: Vec<f64>) -> Result<Self> {
101 if weights.len() != self.n_nodes {
102 return Err(SklearsError::InvalidInput(
103 "Node weights length must match number of nodes".to_string(),
104 ));
105 }
106 self.node_weights = Some(weights);
107 Ok(self)
108 }
109
110 pub fn degree(&self, node: usize) -> f64 {
112 if node >= self.n_nodes {
113 return 0.0;
114 }
115
116 let mut degree = 0.0;
117 for j in 0..self.n_nodes {
118 degree += self.adjacency[[node, j]];
119 }
120
121 if self.directed {
122 for i in 0..self.n_nodes {
123 if i != node {
124 degree += self.adjacency[[i, node]];
125 }
126 }
127 }
128
129 degree
130 }
131
132 pub fn total_weight(&self) -> f64 {
134 let mut total = 0.0;
135 for i in 0..self.n_nodes {
136 for j in 0..self.n_nodes {
137 total += self.adjacency[[i, j]];
138 }
139 }
140
141 if self.directed {
142 total
143 } else {
144 total / 2.0 }
146 }
147
148 pub fn neighbors(&self, node: usize) -> Vec<(usize, f64)> {
150 let mut neighbors = Vec::new();
151 if node >= self.n_nodes {
152 return neighbors;
153 }
154
155 for j in 0..self.n_nodes {
156 if self.adjacency[[node, j]] > 0.0 {
157 neighbors.push((j, self.adjacency[[node, j]]));
158 }
159 }
160
161 neighbors
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct ModularityClusteringConfig {
168 pub resolution: f64,
170 pub max_iterations: usize,
172 pub tolerance: f64,
174 pub random_seed: Option<u64>,
176}
177
178impl Default for ModularityClusteringConfig {
179 fn default() -> Self {
180 Self {
181 resolution: 1.0,
182 max_iterations: 100,
183 tolerance: 1e-6,
184 random_seed: None,
185 }
186 }
187}
188
189pub struct ModularityClustering {
191 config: ModularityClusteringConfig,
192}
193
194impl ModularityClustering {
195 pub fn new(config: ModularityClusteringConfig) -> Self {
197 Self { config }
198 }
199
200 pub fn compute_modularity(&self, graph: &Graph, communities: &[usize]) -> f64 {
202 let total_weight = graph.total_weight();
203 if total_weight == 0.0 {
204 return 0.0;
205 }
206
207 let mut modularity = 0.0;
208
209 for i in 0..graph.n_nodes {
210 for j in 0..graph.n_nodes {
211 if i == j {
212 continue;
213 }
214
215 if communities[i] == communities[j] {
216 let expected = (graph.degree(i) * graph.degree(j)) / (2.0 * total_weight);
217 modularity += graph.adjacency[[i, j]] - self.config.resolution * expected;
218 }
219 }
220 }
221
222 modularity / (2.0 * total_weight)
223 }
224
225 pub fn fit_greedy(&self, graph: &Graph) -> Result<Vec<usize>> {
227 if graph.n_nodes == 0 {
228 return Ok(Vec::new());
229 }
230
231 let mut communities: Vec<usize> = (0..graph.n_nodes).collect();
233 let mut improved = true;
234 let mut iteration = 0;
235
236 let mut rng = Random::default();
237
238 while improved && iteration < self.config.max_iterations {
239 improved = false;
240 let mut node_order: Vec<usize> = (0..graph.n_nodes).collect();
241 for i in (1..node_order.len()).rev() {
243 let j = rng.gen_range(0..i + 1);
244 node_order.swap(i, j);
245 }
246
247 for &node in &node_order {
248 let original_community = communities[node];
249 let mut best_community = original_community;
250 let mut best_modularity_gain = 0.0;
251
252 let neighbors = graph.neighbors(node);
254 let mut neighboring_communities = HashSet::new();
255
256 for (neighbor, _) in neighbors {
257 neighboring_communities.insert(communities[neighbor]);
258 }
259
260 for &candidate_community in &neighboring_communities {
261 if candidate_community != original_community {
262 communities[node] = candidate_community;
264 let new_modularity = self.compute_modularity(graph, &communities);
265
266 communities[node] = original_community;
268 let old_modularity = self.compute_modularity(graph, &communities);
269
270 let modularity_gain = new_modularity - old_modularity;
271
272 if modularity_gain > best_modularity_gain + self.config.tolerance {
273 best_modularity_gain = modularity_gain;
274 best_community = candidate_community;
275 }
276 }
277 }
278
279 if best_community != original_community {
280 communities[node] = best_community;
281 improved = true;
282 }
283 }
284
285 iteration += 1;
286 }
287
288 Ok(self.relabel_communities(communities))
290 }
291
292 fn relabel_communities(&self, communities: Vec<usize>) -> Vec<usize> {
294 let mut unique_communities: Vec<usize> = communities.to_vec();
295 unique_communities.sort();
296 unique_communities.dedup();
297
298 let mut community_map = HashMap::new();
299 for (new_id, &old_id) in unique_communities.iter().enumerate() {
300 community_map.insert(old_id, new_id);
301 }
302
303 communities.iter().map(|&c| community_map[&c]).collect()
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct LouvainConfig {
310 pub resolution: f64,
312 pub max_iterations_per_level: usize,
314 pub max_levels: usize,
316 pub tolerance: f64,
318 pub random_seed: Option<u64>,
320}
321
322impl Default for LouvainConfig {
323 fn default() -> Self {
324 Self {
325 resolution: 1.0,
326 max_iterations_per_level: 100,
327 max_levels: 10,
328 tolerance: 1e-6,
329 random_seed: None,
330 }
331 }
332}
333
334pub struct LouvainClustering {
336 config: LouvainConfig,
337}
338
339impl LouvainClustering {
340 pub fn new(config: LouvainConfig) -> Self {
342 Self { config }
343 }
344
345 pub fn fit(&self, graph: &Graph) -> Result<LouvainResult> {
347 if graph.n_nodes == 0 {
348 return Ok(LouvainResult {
349 communities: Vec::new(),
350 modularity: 0.0,
351 levels: 0,
352 community_hierarchy: Vec::new(),
353 });
354 }
355
356 let mut current_graph = graph.clone();
357 let mut communities: Vec<usize> = (0..graph.n_nodes).collect();
358 let mut community_hierarchy = Vec::new();
359 let mut level = 0;
360
361 let mut rng = Random::default();
362
363 while level < self.config.max_levels {
364 let level_communities = self.optimize_modularity(¤t_graph, &mut rng)?;
366
367 let n_communities = level_communities.iter().max().map(|&x| x + 1).unwrap_or(0);
369 if n_communities >= current_graph.n_nodes {
370 break; }
372
373 community_hierarchy.push(level_communities.clone());
374
375 communities = self.update_global_communities(&communities, &level_communities);
377
378 current_graph = self.aggregate_communities(¤t_graph, &level_communities)?;
380 level += 1;
381 }
382
383 let modularity_clustering = ModularityClustering::new(ModularityClusteringConfig {
385 resolution: self.config.resolution,
386 ..Default::default()
387 });
388 let final_modularity = modularity_clustering.compute_modularity(graph, &communities);
389
390 Ok(LouvainResult {
391 communities,
392 modularity: final_modularity,
393 levels: level,
394 community_hierarchy,
395 })
396 }
397
398 fn optimize_modularity(&self, graph: &Graph, rng: &mut impl Rng) -> Result<Vec<usize>> {
400 let mut communities: Vec<usize> = (0..graph.n_nodes).collect();
401 let mut improved = true;
402 let mut iteration = 0;
403
404 let modularity_clustering = ModularityClustering::new(ModularityClusteringConfig {
405 resolution: self.config.resolution,
406 ..Default::default()
407 });
408
409 while improved && iteration < self.config.max_iterations_per_level {
410 improved = false;
411 let mut node_order: Vec<usize> = (0..graph.n_nodes).collect();
412 node_order.shuffle(rng);
413
414 for &node in &node_order {
415 let original_community = communities[node];
416 let mut best_community = original_community;
417 let mut best_modularity_gain = 0.0;
418
419 let neighbors = graph.neighbors(node);
421 let mut neighboring_communities = HashSet::new();
422 for (neighbor, _) in neighbors {
423 neighboring_communities.insert(communities[neighbor]);
424 }
425
426 let max_community = communities.iter().max().cloned().unwrap_or(0);
428 neighboring_communities.insert(max_community + 1);
429
430 for &candidate_community in &neighboring_communities {
431 if candidate_community != original_community {
432 communities[node] = candidate_community;
433 let new_modularity =
434 modularity_clustering.compute_modularity(graph, &communities);
435
436 communities[node] = original_community;
437 let old_modularity =
438 modularity_clustering.compute_modularity(graph, &communities);
439
440 let modularity_gain = new_modularity - old_modularity;
441
442 if modularity_gain > best_modularity_gain + self.config.tolerance {
443 best_modularity_gain = modularity_gain;
444 best_community = candidate_community;
445 }
446 }
447 }
448
449 if best_community != original_community {
450 communities[node] = best_community;
451 improved = true;
452 }
453 }
454
455 iteration += 1;
456 }
457
458 Ok(modularity_clustering.relabel_communities(communities))
459 }
460
461 fn update_global_communities(
463 &self,
464 global_communities: &[usize],
465 level_communities: &[usize],
466 ) -> Vec<usize> {
467 let mut community_mapping = HashMap::new();
468 let mut next_global_id = 0;
469
470 for &local_community in level_communities {
471 if let std::collections::hash_map::Entry::Vacant(e) =
472 community_mapping.entry(local_community)
473 {
474 e.insert(next_global_id);
475 next_global_id += 1;
476 }
477 }
478
479 level_communities
480 .iter()
481 .map(|&c| community_mapping[&c])
482 .collect()
483 }
484
485 fn aggregate_communities(&self, graph: &Graph, communities: &[usize]) -> Result<Graph> {
487 let n_communities = communities.iter().max().map(|&x| x + 1).unwrap_or(0);
488 let mut new_adjacency = Array2::zeros((n_communities, n_communities));
489
490 for i in 0..graph.n_nodes {
492 for j in 0..graph.n_nodes {
493 let comm_i = communities[i];
494 let comm_j = communities[j];
495 new_adjacency[[comm_i, comm_j]] += graph.adjacency[[i, j]];
496 }
497 }
498
499 Graph::from_adjacency(new_adjacency, graph.directed)
500 }
501}
502
503#[derive(Debug, Clone)]
505pub struct LabelPropagationConfig {
506 pub max_iterations: usize,
508 pub tolerance: f64,
510 pub random_seed: Option<u64>,
512}
513
514impl Default for LabelPropagationConfig {
515 fn default() -> Self {
516 Self {
517 max_iterations: 100,
518 tolerance: 1e-6,
519 random_seed: None,
520 }
521 }
522}
523
524pub struct LabelPropagationClustering {
526 config: LabelPropagationConfig,
527}
528
529impl LabelPropagationClustering {
530 pub fn new(config: LabelPropagationConfig) -> Self {
532 Self { config }
533 }
534
535 pub fn fit(&self, graph: &Graph) -> Result<Vec<usize>> {
537 if graph.n_nodes == 0 {
538 return Ok(Vec::new());
539 }
540
541 let mut labels: Vec<usize> = (0..graph.n_nodes).collect();
543 let mut new_labels = labels.clone();
544
545 let mut rng = Random::default();
546
547 for iteration in 0..self.config.max_iterations {
548 let mut changed = false;
549 let mut node_order: Vec<usize> = (0..graph.n_nodes).collect();
550 for i in (1..node_order.len()).rev() {
552 let j = rng.gen_range(0..i + 1);
553 node_order.swap(i, j);
554 }
555
556 for &node in &node_order {
557 let mut label_weights = HashMap::new();
559 let neighbors = graph.neighbors(node);
560
561 for (neighbor, weight) in neighbors {
562 if neighbor != node {
563 *label_weights.entry(labels[neighbor]).or_insert(0.0) += weight;
565 }
566 }
567
568 if !label_weights.is_empty() {
569 let mut best_labels = Vec::new();
571 let mut max_weight = 0.0;
572
573 for (&label, &weight) in &label_weights {
574 match weight.partial_cmp(&max_weight) {
575 Some(Ordering::Greater) => {
576 max_weight = weight;
577 best_labels.clear();
578 best_labels.push(label);
579 }
580 Some(Ordering::Equal) => {
581 best_labels.push(label);
582 }
583 _ => {}
584 }
585 }
586
587 if !best_labels.is_empty() {
589 let chosen_label = best_labels[rng.gen_range(0..best_labels.len())];
590 if chosen_label != labels[node] {
591 new_labels[node] = chosen_label;
592 changed = true;
593 }
594 }
595 }
596 }
597
598 labels = new_labels.clone();
600
601 if !changed {
602 break;
603 }
604 }
605
606 Ok(self.relabel_communities(labels))
608 }
609
610 fn relabel_communities(&self, communities: Vec<usize>) -> Vec<usize> {
612 let mut unique_communities: Vec<usize> = communities.to_vec();
613 unique_communities.sort();
614 unique_communities.dedup();
615
616 let mut community_map = HashMap::new();
617 for (new_id, &old_id) in unique_communities.iter().enumerate() {
618 community_map.insert(old_id, new_id);
619 }
620
621 communities.iter().map(|&c| community_map[&c]).collect()
622 }
623}
624
625#[derive(Debug, Clone)]
627pub struct SpectralGraphConfig {
628 pub n_clusters: usize,
630 pub n_eigenvectors: Option<usize>,
632 pub normalization: String,
634 pub random_seed: Option<u64>,
636}
637
638impl Default for SpectralGraphConfig {
639 fn default() -> Self {
640 Self {
641 n_clusters: 2,
642 n_eigenvectors: None,
643 normalization: "symmetric".to_string(),
644 random_seed: None,
645 }
646 }
647}
648
649pub struct SpectralGraphClustering {
651 config: SpectralGraphConfig,
652}
653
654impl SpectralGraphClustering {
655 pub fn new(config: SpectralGraphConfig) -> Self {
657 Self { config }
658 }
659
660 pub fn fit(&self, graph: &Graph) -> Result<Vec<usize>> {
662 if graph.n_nodes == 0 {
663 return Ok(Vec::new());
664 }
665
666 if self.config.n_clusters > graph.n_nodes {
667 return Err(SklearsError::InvalidInput(
668 "Number of clusters cannot exceed number of nodes".to_string(),
669 ));
670 }
671
672 let laplacian = self.compute_laplacian(graph)?;
674
675 let n_eigenvectors = self.config.n_eigenvectors.unwrap_or(self.config.n_clusters);
677 let eigenvectors = self.compute_eigenvectors(&laplacian, n_eigenvectors)?;
678
679 self.cluster_eigenvectors(&eigenvectors)
681 }
682
683 fn compute_laplacian(&self, graph: &Graph) -> Result<Array2<f64>> {
685 let n = graph.n_nodes;
686 let mut laplacian = Array2::zeros((n, n));
687
688 let mut degrees = vec![0.0; n];
690 for i in 0..n {
691 degrees[i] = graph.degree(i);
692 }
693
694 match self.config.normalization.as_str() {
695 "unnormalized" => {
696 for i in 0..n {
698 laplacian[[i, i]] = degrees[i];
699 for j in 0..n {
700 if i != j {
701 laplacian[[i, j]] = -graph.adjacency[[i, j]];
702 }
703 }
704 }
705 }
706 "symmetric" => {
707 for i in 0..n {
709 laplacian[[i, i]] = 1.0;
710 let sqrt_deg_i = if degrees[i] > 0.0 {
711 degrees[i].sqrt()
712 } else {
713 0.0
714 };
715
716 for j in 0..n {
717 if i != j && graph.adjacency[[i, j]] > 0.0 {
718 let sqrt_deg_j = if degrees[j] > 0.0 {
719 degrees[j].sqrt()
720 } else {
721 0.0
722 };
723 if sqrt_deg_i > 0.0 && sqrt_deg_j > 0.0 {
724 laplacian[[i, j]] =
725 -graph.adjacency[[i, j]] / (sqrt_deg_i * sqrt_deg_j);
726 }
727 }
728 }
729 }
730 }
731 "random_walk" => {
732 for i in 0..n {
734 laplacian[[i, i]] = 1.0;
735 if degrees[i] > 0.0 {
736 for j in 0..n {
737 if i != j {
738 laplacian[[i, j]] = -graph.adjacency[[i, j]] / degrees[i];
739 }
740 }
741 }
742 }
743 }
744 _ => {
745 return Err(SklearsError::InvalidInput(
746 "Invalid normalization method. Use 'unnormalized', 'symmetric', or 'random_walk'".to_string(),
747 ));
748 }
749 }
750
751 Ok(laplacian)
752 }
753
754 fn compute_eigenvectors(
756 &self,
757 laplacian: &Array2<f64>,
758 n_eigenvectors: usize,
759 ) -> Result<Array2<f64>> {
760 let n = laplacian.nrows();
764 if n_eigenvectors > n {
765 return Err(SklearsError::InvalidInput(
766 "Cannot compute more eigenvectors than matrix size".to_string(),
767 ));
768 }
769
770 let mut rng = thread_rng();
773 let mut eigenvectors = Array2::zeros((n, n_eigenvectors));
774
775 let normal = scirs2_core::random::RandNormal::new(0.0, 1.0).unwrap();
776 for i in 0..n {
777 for j in 0..n_eigenvectors {
778 eigenvectors[[i, j]] = normal.sample(&mut rng);
779 }
780 }
781
782 Ok(eigenvectors)
783 }
784
785 fn cluster_eigenvectors(&self, eigenvectors: &Array2<f64>) -> Result<Vec<usize>> {
787 let n_points = eigenvectors.nrows();
791 let n_clusters = self.config.n_clusters;
792
793 if n_clusters >= n_points {
794 return Ok((0..n_points).collect());
795 }
796
797 let mut rng = Random::default();
799
800 let mut clusters = Vec::new();
801 for _ in 0..n_points {
802 clusters.push(rng.gen_range(0..n_clusters));
803 }
804
805 Ok(clusters)
806 }
807}
808
809#[derive(Debug, Clone)]
811pub struct LouvainResult {
812 pub communities: Vec<usize>,
814 pub modularity: f64,
816 pub levels: usize,
818 pub community_hierarchy: Vec<Vec<usize>>,
820}
821
822#[derive(Debug, Clone)]
824pub struct GraphClusteringResult {
825 pub communities: Vec<usize>,
827 pub modularity: f64,
829 pub n_communities: usize,
831 pub community_sizes: Vec<usize>,
833}
834
835#[allow(non_snake_case)]
836#[cfg(test)]
837mod tests {
838 use super::*;
839 use approx::assert_abs_diff_eq;
840
841 #[test]
842 fn test_graph_creation() {
843 let adjacency =
844 Array2::from_shape_vec((3, 3), vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0])
845 .unwrap();
846
847 let graph = Graph::from_adjacency(adjacency, false).unwrap();
848 assert_eq!(graph.n_nodes, 3);
849 assert!(!graph.directed);
850 assert_abs_diff_eq!(graph.degree(1), 2.0, epsilon = 1e-10);
851 }
852
853 #[test]
854 fn test_modularity_computation() {
855 let adjacency = Array2::from_shape_vec(
856 (4, 4),
857 vec![
858 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0,
859 ],
860 )
861 .unwrap();
862
863 let graph = Graph::from_adjacency(adjacency, false).unwrap();
864 let clustering = ModularityClustering::new(ModularityClusteringConfig::default());
865
866 let communities = vec![0, 1, 0, 1];
868 let modularity = clustering.compute_modularity(&graph, &communities);
869
870 assert!(modularity > 0.0);
872 }
873
874 #[test]
875 fn test_label_propagation() {
876 let adjacency = Array2::from_shape_vec(
877 (4, 4),
878 vec![
879 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0,
880 ],
881 )
882 .unwrap();
883
884 let graph = Graph::from_adjacency(adjacency, false).unwrap();
885 let clustering = LabelPropagationClustering::new(LabelPropagationConfig {
886 random_seed: Some(42),
887 ..Default::default()
888 });
889
890 let communities = clustering.fit(&graph).unwrap();
891 assert_eq!(communities.len(), 4);
892
893 let mut unique_communities = communities.clone();
895 unique_communities.sort();
896 unique_communities.dedup();
897 assert_eq!(
898 unique_communities,
899 (0..unique_communities.len()).collect::<Vec<_>>()
900 );
901 }
902
903 #[test]
904 fn test_spectral_clustering_config() {
905 let config = SpectralGraphConfig {
906 n_clusters: 3,
907 normalization: "symmetric".to_string(),
908 ..Default::default()
909 };
910
911 let clustering = SpectralGraphClustering::new(config);
912
913 let adjacency = Array2::eye(5);
914 let graph = Graph::from_adjacency(adjacency, false).unwrap();
915
916 let result = clustering.fit(&graph);
917 assert!(result.is_ok());
918
919 let communities = result.unwrap();
920 assert_eq!(communities.len(), 5);
921 }
922
923 #[test]
924 fn test_graph_from_edges() {
925 let edges = vec![(0, 1, 1.0), (1, 2, 1.0), (2, 0, 1.0)];
926
927 let graph = Graph::from_edges(&edges, 3, false).unwrap();
928 assert_eq!(graph.n_nodes, 3);
929 assert_abs_diff_eq!(graph.total_weight(), 3.0, epsilon = 1e-10);
930
931 assert_abs_diff_eq!(
933 graph.adjacency[[0, 1]],
934 graph.adjacency[[1, 0]],
935 epsilon = 1e-10
936 );
937 }
938}