1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
9use scirs2_core::numeric::{Float, FromPrimitive, Zero};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::fmt::Debug;
12
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ClusteringError, Result};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Graph<F> {
20 pub n_nodes: usize,
22 pub adjacency: Vec<Vec<(usize, F)>>,
24 pub node_features: Option<Array2<F>>,
26}
27
28impl<F: Copy + PartialOrd + Zero + 'static> Graph<F> {
31 pub fn new(_nnodes: usize) -> Self {
33 Self {
34 n_nodes: _nnodes,
35 adjacency: vec![Vec::new(); _nnodes],
36 node_features: None,
37 }
38 }
39
40 pub fn from_adjacencymatrix(_adjacencymatrix: ArrayView2<F>) -> Result<Self> {
42 let n_nodes = _adjacencymatrix.shape()[0];
43 if _adjacencymatrix.shape()[1] != n_nodes {
44 return Err(ClusteringError::InvalidInput(
45 "Adjacency _matrix must be square".to_string(),
46 ));
47 }
48
49 let mut graph = Self::new(n_nodes);
50
51 for i in 0..n_nodes {
52 for j in (i + 1)..n_nodes {
53 let weight = _adjacencymatrix[[i, j]];
54 if weight > F::zero() {
55 graph.add_edge(i, j, weight)?;
56 }
57 }
58 }
59
60 Ok(graph)
61 }
62
63 pub fn add_edge(&mut self, node1: usize, node2: usize, weight: F) -> Result<()> {
65 if node1 >= self.n_nodes || node2 >= self.n_nodes {
66 return Err(ClusteringError::InvalidInput(
67 "Node index out of bounds".to_string(),
68 ));
69 }
70
71 if node1 != node2 {
72 self.adjacency[node1].push((node2, weight));
73 self.adjacency[node2].push((node1, weight)); }
75
76 Ok(())
77 }
78
79 pub fn degree(&self, node: usize) -> usize {
81 if node < self.n_nodes {
82 self.adjacency[node].len()
83 } else {
84 0
85 }
86 }
87
88 pub fn neighbor_s(&self, node: usize) -> &[(usize, F)] {
90 if node < self.n_nodes {
91 &self.adjacency[node]
92 } else {
93 &[]
94 }
95 }
96}
97
98impl<F: Float + FromPrimitive + Debug + ScalarOperand + std::iter::Sum + 'static> Graph<F> {
100 pub fn from_knngraph(data: ArrayView2<F>, k: usize) -> Result<Self> {
102 let n_samples = data.shape()[0];
103 let mut graph = Self::new(n_samples);
104 graph.node_features = Some(data.to_owned());
105
106 for i in 0..n_samples {
108 let mut distances: Vec<(usize, F)> = Vec::new();
109
110 for j in 0..n_samples {
111 if i != j {
112 let dist = euclidean_distance(data.row(i), data.row(j));
113 distances.push((j, dist));
114 }
115 }
116
117 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
119
120 for &(neighbor_idx, distance) in distances.iter().take(k) {
121 let similarity = F::one() / (F::one() + distance);
123 graph.add_edge(i, neighbor_idx, similarity)?;
124 }
125 }
126
127 Ok(graph)
128 }
129
130 pub fn weighted_degree(&self, node: usize) -> F {
132 if node < self.n_nodes {
133 self.adjacency[node].iter().map(|(_, weight)| *weight).sum()
134 } else {
135 F::zero()
136 }
137 }
138
139 pub fn modularity(&self, communities: &[usize]) -> F {
141 let total_weight = self.total_edge_weight();
142 if total_weight == F::zero() {
143 return F::zero();
144 }
145
146 let mut modularity = F::zero();
147
148 for i in 0..self.n_nodes {
149 for j in 0..self.n_nodes {
150 if communities[i] == communities[j] {
151 let edge_weight = self.get_edge_weight(i, j);
152 let degree_i = self.weighted_degree(i);
153 let degree_j = self.weighted_degree(j);
154
155 let expected = degree_i * degree_j
156 / (F::from(2.0).expect("Failed to convert constant to float")
157 * total_weight);
158 modularity = modularity + edge_weight - expected;
159 }
160 }
161 }
162
163 modularity / (F::from(2.0).expect("Failed to convert constant to float") * total_weight)
164 }
165
166 fn get_edge_weight(&self, node1: usize, node2: usize) -> F {
168 if node1 < self.n_nodes {
169 for &(neighbor_, weight) in &self.adjacency[node1] {
170 if neighbor_ == node2 {
171 return weight;
172 }
173 }
174 }
175 F::zero()
176 }
177
178 fn total_edge_weight(&self) -> F {
180 let mut total = F::zero();
181 for node in 0..self.n_nodes {
182 for &(_, weight) in &self.adjacency[node] {
183 total = total + weight;
184 }
185 }
186 total / F::from(2.0).expect("Failed to convert constant to float") }
188}
189
190#[allow(dead_code)]
227pub fn louvain<F>(graph: &Graph<F>, resolution: f64, max_iterations: usize) -> Result<Array1<usize>>
228where
229 F: Float
230 + FromPrimitive
231 + Debug
232 + ScalarOperand
233 + std::iter::Sum
234 + std::cmp::Eq
235 + std::hash::Hash
236 + 'static,
237 f64: From<F>,
238{
239 let n_nodes = graph.n_nodes;
240 let mut communities: Array1<usize> = Array1::from_iter(0..n_nodes);
241 let mut improved = true;
242 let mut iteration = 0;
243
244 while improved && iteration < max_iterations {
245 improved = false;
246 iteration += 1;
247
248 for node in 0..n_nodes {
250 let current_community = communities[node];
251 let mut best_community = current_community;
252 let mut best_gain = F::zero();
253
254 let mut candidate_communities = HashSet::new();
256 candidate_communities.insert(current_community);
257
258 for &(neighbor_id, _weight) in graph.neighbor_s(node) {
259 candidate_communities.insert(communities[neighbor_id]);
260 }
261
262 for &candidate_community in &candidate_communities {
263 if candidate_community != current_community {
264 let gain = modularity_gain(
266 graph,
267 &communities,
268 node,
269 current_community,
270 candidate_community,
271 resolution,
272 );
273
274 if gain > best_gain {
275 best_gain = gain;
276 best_community = candidate_community;
277 }
278 }
279 }
280
281 if best_community != current_community && best_gain > F::zero() {
283 communities[node] = best_community;
284 improved = true;
285 }
286 }
287 }
288
289 Ok(communities)
290}
291
292#[allow(dead_code)]
294fn modularity_gain<F>(
295 graph: &Graph<F>,
296 communities: &Array1<usize>,
297 node: usize,
298 from_community: usize,
299 to_community: usize,
300 resolution: f64,
301) -> F
302where
303 F: Float
304 + FromPrimitive
305 + Debug
306 + ScalarOperand
307 + std::iter::Sum
308 + std::cmp::Eq
309 + std::hash::Hash
310 + 'static,
311 f64: From<F>,
312{
313 let total_weight = graph.total_edge_weight();
314 if total_weight == F::zero() {
315 return F::zero();
316 }
317
318 let node_degree = graph.weighted_degree(node);
319 let resolution_f = F::from(resolution).expect("Failed to convert to float");
320
321 let mut edges_to_target = F::zero();
323 let mut edges_from_source = F::zero();
324
325 for &(neighbor_, weight) in graph.neighbor_s(node) {
326 if communities[neighbor_] == to_community {
327 edges_to_target = edges_to_target + weight;
328 }
329 if communities[neighbor_] == from_community && neighbor_ != node {
330 edges_from_source = edges_from_source + weight;
331 }
332 }
333
334 let target_community_weight = calculate_community_weight(graph, communities, to_community);
336 let source_community_weight = calculate_community_weight(graph, communities, from_community);
337
338 let gain_to = edges_to_target
340 - resolution_f * node_degree * target_community_weight
341 / (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
342 let loss_from = edges_from_source
343 - resolution_f * node_degree * (source_community_weight - node_degree)
344 / (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
345
346 gain_to - loss_from
347}
348
349#[allow(dead_code)]
351fn calculate_community_weight<F>(
352 graph: &Graph<F>,
353 communities: &Array1<usize>,
354 community: usize,
355) -> F
356where
357 F: Float
358 + FromPrimitive
359 + Debug
360 + ScalarOperand
361 + std::iter::Sum
362 + std::cmp::Eq
363 + std::hash::Hash
364 + 'static,
365{
366 let mut weight = F::zero();
367 for node in 0..graph.n_nodes {
368 if communities[node] == community {
369 weight = weight + graph.weighted_degree(node);
370 }
371 }
372 weight
373}
374
375#[allow(dead_code)]
390pub fn label_propagation<F>(
391 graph: &Graph<F>,
392 max_iterations: usize,
393 tolerance: f64,
394) -> Result<Array1<usize>>
395where
396 F: Float
397 + FromPrimitive
398 + Debug
399 + ScalarOperand
400 + std::iter::Sum
401 + std::cmp::Eq
402 + std::hash::Hash
403 + 'static,
404 f64: From<F>,
405{
406 let n_nodes = graph.n_nodes;
407 let mut labels: Array1<usize> = Array1::from_iter(0..n_nodes);
408 let tolerance_f = F::from(tolerance).expect("Failed to convert to float");
409
410 for _iteration in 0..max_iterations {
411 let mut new_labels = labels.clone();
412 let mut changed_nodes = 0;
413
414 let mut node_order: Vec<usize> = (0..n_nodes).collect();
416 node_order.sort_by_key(|&i| i * 17 % n_nodes);
418
419 for &node in &node_order {
420 let mut label_weights: HashMap<usize, F> = HashMap::new();
422
423 for &(neighbor_, weight) in graph.neighbor_s(node) {
424 let label = labels[neighbor_];
425 let entry = label_weights.entry(label).or_insert(F::zero());
426 *entry = *entry + weight;
427 }
428
429 if let Some((&best_label_, _)) = label_weights
431 .iter()
432 .max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
433 {
434 if best_label_ != labels[node] {
435 new_labels[node] = best_label_;
436 changed_nodes += 1;
437 }
438 }
439 }
440
441 labels = new_labels;
442
443 let change_ratio = changed_nodes as f64 / n_nodes as f64;
445 if change_ratio < tolerance {
446 break;
447 }
448 }
449
450 let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
452 let label_mapping: HashMap<usize, usize> = unique_labels
453 .into_iter()
454 .enumerate()
455 .map(|(new_label, old_label)| (old_label, new_label))
456 .collect();
457
458 for label in labels.iter_mut() {
459 *label = label_mapping[label];
460 }
461
462 Ok(labels)
463}
464
465#[allow(dead_code)]
480pub fn girvan_newman<F>(graph: &Graph<F>, ncommunities: usize) -> Result<Array1<usize>>
481where
482 F: Float
483 + FromPrimitive
484 + Debug
485 + ScalarOperand
486 + std::iter::Sum
487 + std::cmp::Eq
488 + std::hash::Hash
489 + 'static,
490{
491 if ncommunities > graph.n_nodes {
492 return Err(ClusteringError::InvalidInput(
493 "Number of _communities cannot exceed number of nodes".to_string(),
494 ));
495 }
496
497 let mut workinggraph = graph.clone();
498 let mut _communities = find_connected_components(&workinggraph);
499
500 while count_communities(&_communities) < ncommunities && has_edges(&workinggraph) {
501 let edge_betweenness = calculate_edge_betweenness(&workinggraph)?;
503
504 if let Some((max_edge_, _)) = edge_betweenness
506 .iter()
507 .max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
508 {
509 remove_edge(&mut workinggraph, max_edge_.0, max_edge_.1);
511
512 _communities = find_connected_components(&workinggraph);
514 } else {
515 break; }
517 }
518
519 Ok(Array1::from_vec(_communities))
520}
521
522#[allow(dead_code)]
524fn calculate_edge_betweenness<F>(graph: &Graph<F>) -> Result<HashMap<(usize, usize), f64>>
525where
526 F: Float
527 + FromPrimitive
528 + Debug
529 + ScalarOperand
530 + std::iter::Sum
531 + std::cmp::Eq
532 + std::hash::Hash
533 + 'static,
534{
535 let mut edge_betweenness = HashMap::new();
536
537 for node in 0..graph.n_nodes {
539 for &(neighbor_, _) in graph.neighbor_s(node) {
540 if node < neighbor_ {
541 edge_betweenness.insert((node, neighbor_), 0.0);
543 }
544 }
545 }
546
547 for source in 0..graph.n_nodes {
549 for target in (source + 1)..graph.n_nodes {
550 let paths = find_all_shortest_paths(graph, source, target);
551
552 if !paths.is_empty() {
553 let contribution = 1.0 / paths.len() as f64;
554
555 for path in paths {
556 for i in 0..(path.len() - 1) {
557 let (u, v) = if path[i] < path[i + 1] {
558 (path[i], path[i + 1])
559 } else {
560 (path[i + 1], path[i])
561 };
562
563 *edge_betweenness.entry((u, v)).or_insert(0.0) += contribution;
564 }
565 }
566 }
567 }
568 }
569
570 Ok(edge_betweenness)
571}
572
573#[allow(dead_code)]
575fn find_all_shortest_paths<F>(graph: &Graph<F>, source: usize, target: usize) -> Vec<Vec<usize>>
576where
577 F: Float
578 + FromPrimitive
579 + Debug
580 + ScalarOperand
581 + std::iter::Sum
582 + std::cmp::Eq
583 + std::hash::Hash
584 + 'static,
585{
586 let mut distances = vec![None; graph.n_nodes];
587 let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); graph.n_nodes];
588 let mut queue = VecDeque::new();
589
590 distances[source] = Some(0);
591 queue.push_back(source);
592
593 while let Some(current) = queue.pop_front() {
594 let current_dist = distances[current].expect("Operation failed");
595
596 for &(neighbor_, _) in graph.neighbor_s(current) {
597 if distances[neighbor_].is_none() {
598 distances[neighbor_] = Some(current_dist + 1);
600 predecessors[neighbor_].push(current);
601 queue.push_back(neighbor_);
602 } else if distances[neighbor_] == Some(current_dist + 1) {
603 predecessors[neighbor_].push(current);
605 }
606 }
607 }
608
609 if distances[target].is_none() {
611 return Vec::new(); }
613
614 let mut paths = Vec::new();
615 let mut current_paths = vec![vec![target]];
616
617 while !current_paths.is_empty() {
618 let mut next_paths = Vec::new();
619
620 for path in current_paths {
621 let last_node = path[path.len() - 1];
622
623 if last_node == source {
624 let mut complete_path = path.clone();
625 complete_path.reverse();
626 paths.push(complete_path);
627 } else {
628 for &pred in &predecessors[last_node] {
629 let mut new_path = path.clone();
630 new_path.push(pred);
631 next_paths.push(new_path);
632 }
633 }
634 }
635
636 current_paths = next_paths;
637 }
638
639 paths
640}
641
642#[allow(dead_code)]
644fn remove_edge<F>(graph: &mut Graph<F>, node1: usize, node2: usize)
645where
646 F: Float
647 + FromPrimitive
648 + Debug
649 + ScalarOperand
650 + std::iter::Sum
651 + std::cmp::Eq
652 + std::hash::Hash
653 + 'static,
654{
655 graph.adjacency[node1].retain(|(neighbor_, _)| *neighbor_ != node2);
656 graph.adjacency[node2].retain(|(neighbor_, _)| *neighbor_ != node1);
657}
658
659#[allow(dead_code)]
661fn has_edges<F>(graph: &Graph<F>) -> bool
662where
663 F: Float
664 + FromPrimitive
665 + Debug
666 + ScalarOperand
667 + std::iter::Sum
668 + std::cmp::Eq
669 + std::hash::Hash
670 + 'static,
671{
672 graph
673 .adjacency
674 .iter()
675 .any(|neighbor_s| !neighbor_s.is_empty())
676}
677
678#[allow(dead_code)]
680fn find_connected_components<F>(graph: &Graph<F>) -> Vec<usize>
681where
682 F: Float
683 + FromPrimitive
684 + Debug
685 + ScalarOperand
686 + std::iter::Sum
687 + std::cmp::Eq
688 + std::hash::Hash
689 + 'static,
690{
691 let mut visited = vec![false; graph.n_nodes];
692 let mut components = vec![0; graph.n_nodes];
693 let mut component_id = 0;
694
695 for node in 0..graph.n_nodes {
696 if !visited[node] {
697 dfs_component(graph, node, component_id, &mut visited, &mut components);
698 component_id += 1;
699 }
700 }
701
702 components
703}
704
705#[allow(dead_code)]
707fn dfs_component<F>(
708 graph: &Graph<F>,
709 node: usize,
710 component_id: usize,
711 visited: &mut [bool],
712 components: &mut [usize],
713) where
714 F: Float
715 + FromPrimitive
716 + Debug
717 + ScalarOperand
718 + std::iter::Sum
719 + std::cmp::Eq
720 + std::hash::Hash
721 + 'static,
722{
723 visited[node] = true;
724 components[node] = component_id;
725
726 for &(neighbor_, _) in graph.neighbor_s(node) {
727 if !visited[neighbor_] {
728 dfs_component(graph, neighbor_, component_id, visited, components);
729 }
730 }
731}
732
733#[allow(dead_code)]
735fn count_communities(communities: &[usize]) -> usize {
736 let mut unique: HashSet<usize> = HashSet::new();
737 for &community in communities {
738 unique.insert(community);
739 }
740 unique.len()
741}
742
743#[allow(dead_code)]
745fn euclidean_distance<F>(a: ArrayView1<F>, b: ArrayView1<F>) -> F
746where
747 F: Float + std::iter::Sum + 'static,
748{
749 let diff = &a.to_owned() - &b.to_owned();
750 diff.dot(&diff).sqrt()
751}
752
753#[derive(Debug, Clone, Serialize, Deserialize)]
755pub struct GraphClusteringConfig {
756 pub algorithm: GraphClusteringAlgorithm,
758 pub max_iterations: usize,
760 pub tolerance: f64,
762 pub resolution: f64,
764 pub ncommunities: Option<usize>,
766}
767
768#[derive(Debug, Clone, Serialize, Deserialize)]
770pub enum GraphClusteringAlgorithm {
771 Louvain,
773 LabelPropagation,
775 GirvanNewman,
777}
778
779impl Default for GraphClusteringConfig {
780 fn default() -> Self {
781 Self {
782 algorithm: GraphClusteringAlgorithm::Louvain,
783 max_iterations: 100,
784 tolerance: 1e-6,
785 resolution: 1.0,
786 ncommunities: None,
787 }
788 }
789}
790
791#[allow(dead_code)]
802pub fn graph_clustering<F>(
803 graph: &Graph<F>,
804 config: &GraphClusteringConfig,
805) -> Result<Array1<usize>>
806where
807 F: Float
808 + FromPrimitive
809 + Debug
810 + ScalarOperand
811 + std::iter::Sum
812 + std::cmp::Eq
813 + std::hash::Hash
814 + 'static,
815 f64: From<F>,
816{
817 match config.algorithm {
818 GraphClusteringAlgorithm::Louvain => {
819 louvain(graph, config.resolution, config.max_iterations)
820 }
821 GraphClusteringAlgorithm::LabelPropagation => {
822 label_propagation(graph, config.max_iterations, config.tolerance)
823 }
824 GraphClusteringAlgorithm::GirvanNewman => {
825 let ncommunities = config.ncommunities.unwrap_or(2);
826 girvan_newman(graph, ncommunities)
827 }
828 }
829}
830
831#[cfg(test)]
832mod tests {
833 use super::*;
834 use scirs2_core::ndarray::Array2;
835
836 #[test]
837 fn testgraph_creation() {
838 let graph = Graph::<i32>::new(5);
839 assert_eq!(graph.n_nodes, 5);
840 assert_eq!(graph.adjacency.len(), 5);
841 }
842
843 #[test]
844 fn testgraph_from_adjacencymatrix() {
845 let adjacency = Array2::from_shape_vec((3, 3), vec![0, 1, 0, 1, 0, 1, 0, 1, 0])
846 .expect("Operation failed");
847
848 let graph = Graph::from_adjacencymatrix(adjacency.view()).expect("Operation failed");
849 assert_eq!(graph.n_nodes, 3);
850 assert_eq!(graph.degree(0), 1);
851 assert_eq!(graph.degree(1), 2);
852 assert_eq!(graph.degree(2), 1);
853 }
854
855 }