1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::fmt::Debug;
12
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ClusteringError, Result};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Graph<F: Float> {
20 pub n_nodes: usize,
22 pub adjacency: Vec<Vec<(usize, F)>>,
24 pub node_features: Option<Array2<F>>,
26}
27
28impl<
29 F: Float
30 + FromPrimitive
31 + Debug
32 + ScalarOperand
33 + std::iter::Sum
34 + std::cmp::Eq
35 + std::hash::Hash
36 + 'static,
37 > Graph<F>
38{
39 pub fn new(_nnodes: usize) -> Self {
41 Self {
42 n_nodes: _nnodes,
43 adjacency: vec![Vec::new(); _nnodes],
44 node_features: None,
45 }
46 }
47
48 pub fn from_adjacencymatrix(_adjacencymatrix: ArrayView2<F>) -> Result<Self> {
50 let n_nodes = _adjacencymatrix.shape()[0];
51 if _adjacencymatrix.shape()[1] != n_nodes {
52 return Err(ClusteringError::InvalidInput(
53 "Adjacency _matrix must be square".to_string(),
54 ));
55 }
56
57 let mut graph = Self::new(n_nodes);
58
59 for i in 0..n_nodes {
60 for j in 0..n_nodes {
61 let weight = _adjacencymatrix[[i, j]];
62 if weight > F::zero() && i != j {
63 graph.add_edge(i, j, weight)?;
64 }
65 }
66 }
67
68 Ok(graph)
69 }
70
71 pub fn from_knngraph(data: ArrayView2<F>, k: usize) -> Result<Self> {
73 let n_samples = data.shape()[0];
74 let mut graph = Self::new(n_samples);
75 graph.node_features = Some(data.to_owned());
76
77 for i in 0..n_samples {
79 let mut distances: Vec<(usize, F)> = Vec::new();
80
81 for j in 0..n_samples {
82 if i != j {
83 let dist = euclidean_distance(data.row(i), data.row(j));
84 distances.push((j, dist));
85 }
86 }
87
88 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
90
91 for &(neighbor_idx, distance) in distances.iter().take(k) {
92 let similarity = F::one() / (F::one() + distance);
94 graph.add_edge(i, neighbor_idx, similarity)?;
95 }
96 }
97
98 Ok(graph)
99 }
100
101 pub fn add_edge(&mut self, node1: usize, node2: usize, weight: F) -> Result<()> {
103 if node1 >= self.n_nodes || node2 >= self.n_nodes {
104 return Err(ClusteringError::InvalidInput(
105 "Node index out of bounds".to_string(),
106 ));
107 }
108
109 if node1 != node2 {
110 self.adjacency[node1].push((node2, weight));
111 self.adjacency[node2].push((node1, weight)); }
113
114 Ok(())
115 }
116
117 pub fn degree(&self, node: usize) -> usize {
119 if node < self.n_nodes {
120 self.adjacency[node].len()
121 } else {
122 0
123 }
124 }
125
126 pub fn weighted_degree(&self, node: usize) -> F {
128 if node < self.n_nodes {
129 self.adjacency[node].iter().map(|(_, weight)| *weight).sum()
130 } else {
131 F::zero()
132 }
133 }
134
135 pub fn neighbor_s(&self, node: usize) -> &[(usize, F)] {
137 if node < self.n_nodes {
138 &self.adjacency[node]
139 } else {
140 &[]
141 }
142 }
143
144 pub fn modularity(&self, communities: &[usize]) -> F {
146 let total_weight = self.total_edge_weight();
147 if total_weight == F::zero() {
148 return F::zero();
149 }
150
151 let mut modularity = F::zero();
152
153 for i in 0..self.n_nodes {
154 for j in 0..self.n_nodes {
155 if communities[i] == communities[j] {
156 let edge_weight = self.get_edge_weight(i, j);
157 let degree_i = self.weighted_degree(i);
158 let degree_j = self.weighted_degree(j);
159
160 let expected = degree_i * degree_j / (F::from(2.0).unwrap() * total_weight);
161 modularity = modularity + edge_weight - expected;
162 }
163 }
164 }
165
166 modularity / (F::from(2.0).unwrap() * total_weight)
167 }
168
169 fn get_edge_weight(&self, node1: usize, node2: usize) -> F {
171 if node1 < self.n_nodes {
172 for &(neighbor_, weight) in &self.adjacency[node1] {
173 if neighbor_ == node2 {
174 return weight;
175 }
176 }
177 }
178 F::zero()
179 }
180
181 fn total_edge_weight(&self) -> F {
183 let mut total = F::zero();
184 for node in 0..self.n_nodes {
185 for &(_, weight) in &self.adjacency[node] {
186 total = total + weight;
187 }
188 }
189 total / F::from(2.0).unwrap() }
191}
192
193#[allow(dead_code)]
230pub fn louvain<F>(graph: &Graph<F>, resolution: f64, max_iterations: usize) -> Result<Array1<usize>>
231where
232 F: Float
233 + FromPrimitive
234 + Debug
235 + ScalarOperand
236 + std::iter::Sum
237 + std::cmp::Eq
238 + std::hash::Hash
239 + 'static,
240 f64: From<F>,
241{
242 let n_nodes = graph.n_nodes;
243 let mut communities: Array1<usize> = Array1::from_iter(0..n_nodes);
244 let mut improved = true;
245 let mut iteration = 0;
246
247 while improved && iteration < max_iterations {
248 improved = false;
249 iteration += 1;
250
251 for node in 0..n_nodes {
253 let current_community = communities[node];
254 let mut best_community = current_community;
255 let mut best_gain = F::zero();
256
257 let mut candidate_communities = HashSet::new();
259 candidate_communities.insert(current_community);
260
261 for &(neighbor_id, _weight) in graph.neighbor_s(node) {
262 candidate_communities.insert(communities[neighbor_id]);
263 }
264
265 for &candidate_community in &candidate_communities {
266 if candidate_community != current_community {
267 let gain = modularity_gain(
269 graph,
270 &communities,
271 node,
272 current_community,
273 candidate_community,
274 resolution,
275 );
276
277 if gain > best_gain {
278 best_gain = gain;
279 best_community = candidate_community;
280 }
281 }
282 }
283
284 if best_community != current_community && best_gain > F::zero() {
286 communities[node] = best_community;
287 improved = true;
288 }
289 }
290 }
291
292 Ok(communities)
293}
294
295#[allow(dead_code)]
297fn modularity_gain<F>(
298 graph: &Graph<F>,
299 communities: &Array1<usize>,
300 node: usize,
301 from_community: usize,
302 to_community: usize,
303 resolution: f64,
304) -> F
305where
306 F: Float
307 + FromPrimitive
308 + Debug
309 + ScalarOperand
310 + std::iter::Sum
311 + std::cmp::Eq
312 + std::hash::Hash
313 + 'static,
314 f64: From<F>,
315{
316 let total_weight = graph.total_edge_weight();
317 if total_weight == F::zero() {
318 return F::zero();
319 }
320
321 let node_degree = graph.weighted_degree(node);
322 let resolution_f = F::from(resolution).unwrap();
323
324 let mut edges_to_target = F::zero();
326 let mut edges_from_source = F::zero();
327
328 for &(neighbor_, weight) in graph.neighbor_s(node) {
329 if communities[neighbor_] == to_community {
330 edges_to_target = edges_to_target + weight;
331 }
332 if communities[neighbor_] == from_community && neighbor_ != node {
333 edges_from_source = edges_from_source + weight;
334 }
335 }
336
337 let target_community_weight = calculate_community_weight(graph, communities, to_community);
339 let source_community_weight = calculate_community_weight(graph, communities, from_community);
340
341 let gain_to = edges_to_target
343 - resolution_f * node_degree * target_community_weight
344 / (F::from(2.0).unwrap() * total_weight);
345 let loss_from = edges_from_source
346 - resolution_f * node_degree * (source_community_weight - node_degree)
347 / (F::from(2.0).unwrap() * total_weight);
348
349 gain_to - loss_from
350}
351
352#[allow(dead_code)]
354fn calculate_community_weight<F>(
355 graph: &Graph<F>,
356 communities: &Array1<usize>,
357 community: usize,
358) -> F
359where
360 F: Float
361 + FromPrimitive
362 + Debug
363 + ScalarOperand
364 + std::iter::Sum
365 + std::cmp::Eq
366 + std::hash::Hash
367 + 'static,
368{
369 let mut weight = F::zero();
370 for node in 0..graph.n_nodes {
371 if communities[node] == community {
372 weight = weight + graph.weighted_degree(node);
373 }
374 }
375 weight
376}
377
378#[allow(dead_code)]
393pub fn label_propagation<F>(
394 graph: &Graph<F>,
395 max_iterations: usize,
396 tolerance: f64,
397) -> Result<Array1<usize>>
398where
399 F: Float
400 + FromPrimitive
401 + Debug
402 + ScalarOperand
403 + std::iter::Sum
404 + std::cmp::Eq
405 + std::hash::Hash
406 + 'static,
407 f64: From<F>,
408{
409 let n_nodes = graph.n_nodes;
410 let mut labels: Array1<usize> = Array1::from_iter(0..n_nodes);
411 let tolerance_f = F::from(tolerance).unwrap();
412
413 for _iteration in 0..max_iterations {
414 let mut new_labels = labels.clone();
415 let mut changed_nodes = 0;
416
417 let mut node_order: Vec<usize> = (0..n_nodes).collect();
419 node_order.sort_by_key(|&i| i * 17 % n_nodes);
421
422 for &node in &node_order {
423 let mut label_weights: HashMap<usize, F> = HashMap::new();
425
426 for &(neighbor_, weight) in graph.neighbor_s(node) {
427 let label = labels[neighbor_];
428 let entry = label_weights.entry(label).or_insert(F::zero());
429 *entry = *entry + weight;
430 }
431
432 if let Some((&best_label_, _)) = label_weights
434 .iter()
435 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
436 {
437 if best_label_ != labels[node] {
438 new_labels[node] = best_label_;
439 changed_nodes += 1;
440 }
441 }
442 }
443
444 labels = new_labels;
445
446 let change_ratio = changed_nodes as f64 / n_nodes as f64;
448 if change_ratio < tolerance {
449 break;
450 }
451 }
452
453 let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
455 let label_mapping: HashMap<usize, usize> = unique_labels
456 .into_iter()
457 .enumerate()
458 .map(|(new_label, old_label)| (old_label, new_label))
459 .collect();
460
461 for label in labels.iter_mut() {
462 *label = label_mapping[label];
463 }
464
465 Ok(labels)
466}
467
468#[allow(dead_code)]
483pub fn girvan_newman<F>(graph: &Graph<F>, ncommunities: usize) -> Result<Array1<usize>>
484where
485 F: Float
486 + FromPrimitive
487 + Debug
488 + ScalarOperand
489 + std::iter::Sum
490 + std::cmp::Eq
491 + std::hash::Hash
492 + 'static,
493{
494 if ncommunities > graph.n_nodes {
495 return Err(ClusteringError::InvalidInput(
496 "Number of _communities cannot exceed number of nodes".to_string(),
497 ));
498 }
499
500 let mut workinggraph = graph.clone();
501 let mut _communities = find_connected_components(&workinggraph);
502
503 while count_communities(&_communities) < ncommunities && has_edges(&workinggraph) {
504 let edge_betweenness = calculate_edge_betweenness(&workinggraph)?;
506
507 if let Some((max_edge_, _)) = edge_betweenness
509 .iter()
510 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
511 {
512 remove_edge(&mut workinggraph, max_edge_.0, max_edge_.1);
514
515 _communities = find_connected_components(&workinggraph);
517 } else {
518 break; }
520 }
521
522 Ok(Array1::from_vec(_communities))
523}
524
525#[allow(dead_code)]
527fn calculate_edge_betweenness<F>(graph: &Graph<F>) -> Result<HashMap<(usize, usize), f64>>
528where
529 F: Float
530 + FromPrimitive
531 + Debug
532 + ScalarOperand
533 + std::iter::Sum
534 + std::cmp::Eq
535 + std::hash::Hash
536 + 'static,
537{
538 let mut edge_betweenness = HashMap::new();
539
540 for node in 0..graph.n_nodes {
542 for &(neighbor_, _) in graph.neighbor_s(node) {
543 if node < neighbor_ {
544 edge_betweenness.insert((node, neighbor_), 0.0);
546 }
547 }
548 }
549
550 for source in 0..graph.n_nodes {
552 for target in (source + 1)..graph.n_nodes {
553 let paths = find_all_shortest_paths(graph, source, target);
554
555 if !paths.is_empty() {
556 let contribution = 1.0 / paths.len() as f64;
557
558 for path in paths {
559 for i in 0..(path.len() - 1) {
560 let (u, v) = if path[i] < path[i + 1] {
561 (path[i], path[i + 1])
562 } else {
563 (path[i + 1], path[i])
564 };
565
566 *edge_betweenness.entry((u, v)).or_insert(0.0) += contribution;
567 }
568 }
569 }
570 }
571 }
572
573 Ok(edge_betweenness)
574}
575
576#[allow(dead_code)]
578fn find_all_shortest_paths<F>(graph: &Graph<F>, source: usize, target: usize) -> Vec<Vec<usize>>
579where
580 F: Float
581 + FromPrimitive
582 + Debug
583 + ScalarOperand
584 + std::iter::Sum
585 + std::cmp::Eq
586 + std::hash::Hash
587 + 'static,
588{
589 let mut distances = vec![None; graph.n_nodes];
590 let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); graph.n_nodes];
591 let mut queue = VecDeque::new();
592
593 distances[source] = Some(0);
594 queue.push_back(source);
595
596 while let Some(current) = queue.pop_front() {
597 let current_dist = distances[current].unwrap();
598
599 for &(neighbor_, _) in graph.neighbor_s(current) {
600 if distances[neighbor_].is_none() {
601 distances[neighbor_] = Some(current_dist + 1);
603 predecessors[neighbor_].push(current);
604 queue.push_back(neighbor_);
605 } else if distances[neighbor_] == Some(current_dist + 1) {
606 predecessors[neighbor_].push(current);
608 }
609 }
610 }
611
612 if distances[target].is_none() {
614 return Vec::new(); }
616
617 let mut paths = Vec::new();
618 let mut current_paths = vec![vec![target]];
619
620 while !current_paths.is_empty() {
621 let mut next_paths = Vec::new();
622
623 for path in current_paths {
624 let last_node = path[path.len() - 1];
625
626 if last_node == source {
627 let mut complete_path = path.clone();
628 complete_path.reverse();
629 paths.push(complete_path);
630 } else {
631 for &pred in &predecessors[last_node] {
632 let mut new_path = path.clone();
633 new_path.push(pred);
634 next_paths.push(new_path);
635 }
636 }
637 }
638
639 current_paths = next_paths;
640 }
641
642 paths
643}
644
645#[allow(dead_code)]
647fn remove_edge<F>(graph: &mut Graph<F>, node1: usize, node2: usize)
648where
649 F: Float
650 + FromPrimitive
651 + Debug
652 + ScalarOperand
653 + std::iter::Sum
654 + std::cmp::Eq
655 + std::hash::Hash
656 + 'static,
657{
658 graph.adjacency[node1].retain(|(neighbor_, _)| *neighbor_ != node2);
659 graph.adjacency[node2].retain(|(neighbor_, _)| *neighbor_ != node1);
660}
661
662#[allow(dead_code)]
664fn has_edges<F>(graph: &Graph<F>) -> bool
665where
666 F: Float
667 + FromPrimitive
668 + Debug
669 + ScalarOperand
670 + std::iter::Sum
671 + std::cmp::Eq
672 + std::hash::Hash
673 + 'static,
674{
675 graph
676 .adjacency
677 .iter()
678 .any(|neighbor_s| !neighbor_s.is_empty())
679}
680
681#[allow(dead_code)]
683fn find_connected_components<F>(graph: &Graph<F>) -> Vec<usize>
684where
685 F: Float
686 + FromPrimitive
687 + Debug
688 + ScalarOperand
689 + std::iter::Sum
690 + std::cmp::Eq
691 + std::hash::Hash
692 + 'static,
693{
694 let mut visited = vec![false; graph.n_nodes];
695 let mut components = vec![0; graph.n_nodes];
696 let mut component_id = 0;
697
698 for node in 0..graph.n_nodes {
699 if !visited[node] {
700 dfs_component(graph, node, component_id, &mut visited, &mut components);
701 component_id += 1;
702 }
703 }
704
705 components
706}
707
708#[allow(dead_code)]
710fn dfs_component<F>(
711 graph: &Graph<F>,
712 node: usize,
713 component_id: usize,
714 visited: &mut [bool],
715 components: &mut [usize],
716) where
717 F: Float
718 + FromPrimitive
719 + Debug
720 + ScalarOperand
721 + std::iter::Sum
722 + std::cmp::Eq
723 + std::hash::Hash
724 + 'static,
725{
726 visited[node] = true;
727 components[node] = component_id;
728
729 for &(neighbor_, _) in graph.neighbor_s(node) {
730 if !visited[neighbor_] {
731 dfs_component(graph, neighbor_, component_id, visited, components);
732 }
733 }
734}
735
736#[allow(dead_code)]
738fn count_communities(communities: &[usize]) -> usize {
739 let mut unique: HashSet<usize> = HashSet::new();
740 for &community in communities {
741 unique.insert(community);
742 }
743 unique.len()
744}
745
746#[allow(dead_code)]
748fn euclidean_distance<F>(a: ArrayView1<F>, b: ArrayView1<F>) -> F
749where
750 F: Float + std::iter::Sum + 'static,
751{
752 let diff = &a.to_owned() - &b.to_owned();
753 diff.dot(&diff).sqrt()
754}
755
756#[derive(Debug, Clone, Serialize, Deserialize)]
758pub struct GraphClusteringConfig {
759 pub algorithm: GraphClusteringAlgorithm,
761 pub max_iterations: usize,
763 pub tolerance: f64,
765 pub resolution: f64,
767 pub ncommunities: Option<usize>,
769}
770
771#[derive(Debug, Clone, Serialize, Deserialize)]
773pub enum GraphClusteringAlgorithm {
774 Louvain,
776 LabelPropagation,
778 GirvanNewman,
780}
781
782impl Default for GraphClusteringConfig {
783 fn default() -> Self {
784 Self {
785 algorithm: GraphClusteringAlgorithm::Louvain,
786 max_iterations: 100,
787 tolerance: 1e-6,
788 resolution: 1.0,
789 ncommunities: None,
790 }
791 }
792}
793
794#[allow(dead_code)]
805pub fn graph_clustering<F>(
806 graph: &Graph<F>,
807 config: &GraphClusteringConfig,
808) -> Result<Array1<usize>>
809where
810 F: Float
811 + FromPrimitive
812 + Debug
813 + ScalarOperand
814 + std::iter::Sum
815 + std::cmp::Eq
816 + std::hash::Hash
817 + 'static,
818 f64: From<F>,
819{
820 match config.algorithm {
821 GraphClusteringAlgorithm::Louvain => {
822 louvain(graph, config.resolution, config.max_iterations)
823 }
824 GraphClusteringAlgorithm::LabelPropagation => {
825 label_propagation(graph, config.max_iterations, config.tolerance)
826 }
827 GraphClusteringAlgorithm::GirvanNewman => {
828 let ncommunities = config.ncommunities.unwrap_or(2);
829 girvan_newman(graph, ncommunities)
830 }
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837 use scirs2_core::ndarray::Array2;
838
839 }