Skip to main content

scirs2_cluster/
graph.rs

1//! Graph clustering and community detection algorithms
2//!
3//! This module provides implementations of various graph clustering algorithms for
4//! detecting communities and clusters in network data. These algorithms work with
5//! graph representations where nodes represent data points and edges represent
6//! similarities or connections between them.
7
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
9use scirs2_core::numeric::{Float, FromPrimitive, Zero};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::fmt::Debug;
12
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ClusteringError, Result};
16
17/// Graph representation for clustering algorithms
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Graph<F> {
20    /// Number of nodes in the graph
21    pub n_nodes: usize,
22    /// Adjacency list representation: node_id -> [(neighbor_id, weight), ...]
23    pub adjacency: Vec<Vec<(usize, F)>>,
24    /// Optional node labels/features
25    pub node_features: Option<Array2<F>>,
26}
27
28/// Basic construction and traversal methods — no Float arithmetic required.
29/// Works with any `F` that is `Copy + PartialOrd + Zero + 'static`.
30impl<F: Copy + PartialOrd + Zero + 'static> Graph<F> {
31    /// Create a new empty graph with specified number of nodes
32    pub fn new(_nnodes: usize) -> Self {
33        Self {
34            n_nodes: _nnodes,
35            adjacency: vec![Vec::new(); _nnodes],
36            node_features: None,
37        }
38    }
39
40    /// Create a graph from an adjacency matrix
41    pub fn from_adjacencymatrix(_adjacencymatrix: ArrayView2<F>) -> Result<Self> {
42        let n_nodes = _adjacencymatrix.shape()[0];
43        if _adjacencymatrix.shape()[1] != n_nodes {
44            return Err(ClusteringError::InvalidInput(
45                "Adjacency _matrix must be square".to_string(),
46            ));
47        }
48
49        let mut graph = Self::new(n_nodes);
50
51        for i in 0..n_nodes {
52            for j in (i + 1)..n_nodes {
53                let weight = _adjacencymatrix[[i, j]];
54                if weight > F::zero() {
55                    graph.add_edge(i, j, weight)?;
56                }
57            }
58        }
59
60        Ok(graph)
61    }
62
63    /// Add an edge between two nodes
64    pub fn add_edge(&mut self, node1: usize, node2: usize, weight: F) -> Result<()> {
65        if node1 >= self.n_nodes || node2 >= self.n_nodes {
66            return Err(ClusteringError::InvalidInput(
67                "Node index out of bounds".to_string(),
68            ));
69        }
70
71        if node1 != node2 {
72            self.adjacency[node1].push((node2, weight));
73            self.adjacency[node2].push((node1, weight)); // Undirected graph
74        }
75
76        Ok(())
77    }
78
79    /// Get the degree of a node (number of neighbor_s)
80    pub fn degree(&self, node: usize) -> usize {
81        if node < self.n_nodes {
82            self.adjacency[node].len()
83        } else {
84            0
85        }
86    }
87
88    /// Get all neighbor_s of a node
89    pub fn neighbor_s(&self, node: usize) -> &[(usize, F)] {
90        if node < self.n_nodes {
91            &self.adjacency[node]
92        } else {
93            &[]
94        }
95    }
96}
97
98/// Float-arithmetic methods — modularity calculations, KNN graph construction, etc.
99impl<F: Float + FromPrimitive + Debug + ScalarOperand + std::iter::Sum + 'static> Graph<F> {
100    /// Create a k-nearest neighbor_ graph from data points
101    pub fn from_knngraph(data: ArrayView2<F>, k: usize) -> Result<Self> {
102        let n_samples = data.shape()[0];
103        let mut graph = Self::new(n_samples);
104        graph.node_features = Some(data.to_owned());
105
106        // For each point, find k nearest neighbor_s
107        for i in 0..n_samples {
108            let mut distances: Vec<(usize, F)> = Vec::new();
109
110            for j in 0..n_samples {
111                if i != j {
112                    let dist = euclidean_distance(data.row(i), data.row(j));
113                    distances.push((j, dist));
114                }
115            }
116
117            // Sort by distance and take k nearest
118            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("Operation failed"));
119
120            for &(neighbor_idx, distance) in distances.iter().take(k) {
121                // Use similarity (inverse of distance) as edge weight
122                let similarity = F::one() / (F::one() + distance);
123                graph.add_edge(i, neighbor_idx, similarity)?;
124            }
125        }
126
127        Ok(graph)
128    }
129
130    /// Get the weighted degree of a node (sum of edge weights)
131    pub fn weighted_degree(&self, node: usize) -> F {
132        if node < self.n_nodes {
133            self.adjacency[node].iter().map(|(_, weight)| *weight).sum()
134        } else {
135            F::zero()
136        }
137    }
138
139    /// Calculate modularity of a given community assignment
140    pub fn modularity(&self, communities: &[usize]) -> F {
141        let total_weight = self.total_edge_weight();
142        if total_weight == F::zero() {
143            return F::zero();
144        }
145
146        let mut modularity = F::zero();
147
148        for i in 0..self.n_nodes {
149            for j in 0..self.n_nodes {
150                if communities[i] == communities[j] {
151                    let edge_weight = self.get_edge_weight(i, j);
152                    let degree_i = self.weighted_degree(i);
153                    let degree_j = self.weighted_degree(j);
154
155                    let expected = degree_i * degree_j
156                        / (F::from(2.0).expect("Failed to convert constant to float")
157                            * total_weight);
158                    modularity = modularity + edge_weight - expected;
159                }
160            }
161        }
162
163        modularity / (F::from(2.0).expect("Failed to convert constant to float") * total_weight)
164    }
165
166    /// Get edge weight between two nodes
167    fn get_edge_weight(&self, node1: usize, node2: usize) -> F {
168        if node1 < self.n_nodes {
169            for &(neighbor_, weight) in &self.adjacency[node1] {
170                if neighbor_ == node2 {
171                    return weight;
172                }
173            }
174        }
175        F::zero()
176    }
177
178    /// Calculate total weight of all edges in the graph
179    fn total_edge_weight(&self) -> F {
180        let mut total = F::zero();
181        for node in 0..self.n_nodes {
182            for &(_, weight) in &self.adjacency[node] {
183                total = total + weight;
184            }
185        }
186        total / F::from(2.0).expect("Failed to convert constant to float") // Divide by 2 because each edge is counted twice
187    }
188}
189
190/// Louvain community detection algorithm
191///
192/// The Louvain algorithm is a greedy optimization method that attempts to optimize
193/// the modularity of a partition of the network. It produces high quality communities
194/// and has excellent performance on large networks.
195///
196/// # Arguments
197///
198/// * `graph` - Input graph
199/// * `resolution` - Resolution parameter (higher values lead to smaller communities)
200/// * `max_iterations` - Maximum number of iterations
201///
202/// # Returns
203///
204/// Community assignments for each node
205///
206/// # Example
207///
208/// ```no_run
209/// // Doctest disabled due to incompatible trait constraints (Float vs Eq+Hash)
210/// use scirs2_core::ndarray::Array2;
211/// use scirs2_cluster::graph::{Graph, louvain};
212///
213/// // Note: Graph requires F: Float + Eq + Hash, which is impossible for standard float types
214/// // This is a design issue that needs to be addressed
215/// let adjacency = Array2::from_shape_vec((4, 4), vec![
216///     0.0, 1.0, 1.0, 0.0,
217///     1.0, 0.0, 0.0, 0.0,
218///     1.0, 0.0, 0.0, 1.0,
219///     0.0, 0.0, 1.0, 0.0,
220/// ]).expect("Operation failed");
221///
222/// // This would fail to compile due to trait constraint conflicts
223/// // let graph = Graph::from_adjacencymatrix(adjacency.view()).expect("Operation failed");
224/// // let communities = louvain(&graph, 1.0, 100).expect("Operation failed");
225/// ```
226#[allow(dead_code)]
227pub fn louvain<F>(graph: &Graph<F>, resolution: f64, max_iterations: usize) -> Result<Array1<usize>>
228where
229    F: Float
230        + FromPrimitive
231        + Debug
232        + ScalarOperand
233        + std::iter::Sum
234        + std::cmp::Eq
235        + std::hash::Hash
236        + 'static,
237    f64: From<F>,
238{
239    let n_nodes = graph.n_nodes;
240    let mut communities: Array1<usize> = Array1::from_iter(0..n_nodes);
241    let mut improved = true;
242    let mut iteration = 0;
243
244    while improved && iteration < max_iterations {
245        improved = false;
246        iteration += 1;
247
248        // Phase 1: Optimize modularity by moving nodes
249        for node in 0..n_nodes {
250            let current_community = communities[node];
251            let mut best_community = current_community;
252            let mut best_gain = F::zero();
253
254            // Try moving node to each neighbor_'s community
255            let mut candidate_communities = HashSet::new();
256            candidate_communities.insert(current_community);
257
258            for &(neighbor_id, _weight) in graph.neighbor_s(node) {
259                candidate_communities.insert(communities[neighbor_id]);
260            }
261
262            for &candidate_community in &candidate_communities {
263                if candidate_community != current_community {
264                    // Calculate modularity gain from moving to this community
265                    let gain = modularity_gain(
266                        graph,
267                        &communities,
268                        node,
269                        current_community,
270                        candidate_community,
271                        resolution,
272                    );
273
274                    if gain > best_gain {
275                        best_gain = gain;
276                        best_community = candidate_community;
277                    }
278                }
279            }
280
281            // Move node to best community if improvement found
282            if best_community != current_community && best_gain > F::zero() {
283                communities[node] = best_community;
284                improved = true;
285            }
286        }
287    }
288
289    Ok(communities)
290}
291
292/// Calculate modularity gain from moving a node to a different community
293#[allow(dead_code)]
294fn modularity_gain<F>(
295    graph: &Graph<F>,
296    communities: &Array1<usize>,
297    node: usize,
298    from_community: usize,
299    to_community: usize,
300    resolution: f64,
301) -> F
302where
303    F: Float
304        + FromPrimitive
305        + Debug
306        + ScalarOperand
307        + std::iter::Sum
308        + std::cmp::Eq
309        + std::hash::Hash
310        + 'static,
311    f64: From<F>,
312{
313    let total_weight = graph.total_edge_weight();
314    if total_weight == F::zero() {
315        return F::zero();
316    }
317
318    let node_degree = graph.weighted_degree(node);
319    let resolution_f = F::from(resolution).expect("Failed to convert to float");
320
321    // Calculate connections within target _community
322    let mut edges_to_target = F::zero();
323    let mut edges_from_source = F::zero();
324
325    for &(neighbor_, weight) in graph.neighbor_s(node) {
326        if communities[neighbor_] == to_community {
327            edges_to_target = edges_to_target + weight;
328        }
329        if communities[neighbor_] == from_community && neighbor_ != node {
330            edges_from_source = edges_from_source + weight;
331        }
332    }
333
334    // Calculate _community weights
335    let target_community_weight = calculate_community_weight(graph, communities, to_community);
336    let source_community_weight = calculate_community_weight(graph, communities, from_community);
337
338    // Calculate modularity gain
339    let gain_to = edges_to_target
340        - resolution_f * node_degree * target_community_weight
341            / (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
342    let loss_from = edges_from_source
343        - resolution_f * node_degree * (source_community_weight - node_degree)
344            / (F::from(2.0).expect("Failed to convert constant to float") * total_weight);
345
346    gain_to - loss_from
347}
348
349/// Calculate total weight of a community
350#[allow(dead_code)]
351fn calculate_community_weight<F>(
352    graph: &Graph<F>,
353    communities: &Array1<usize>,
354    community: usize,
355) -> F
356where
357    F: Float
358        + FromPrimitive
359        + Debug
360        + ScalarOperand
361        + std::iter::Sum
362        + std::cmp::Eq
363        + std::hash::Hash
364        + 'static,
365{
366    let mut weight = F::zero();
367    for node in 0..graph.n_nodes {
368        if communities[node] == community {
369            weight = weight + graph.weighted_degree(node);
370        }
371    }
372    weight
373}
374
375/// Label propagation algorithm for community detection
376///
377/// A fast algorithm where each node adopts the label that most of its neighbor_s have.
378/// This process continues iteratively until convergence.
379///
380/// # Arguments
381///
382/// * `graph` - Input graph
383/// * `max_iterations` - Maximum number of iterations
384/// * `tolerance` - Convergence tolerance
385///
386/// # Returns
387///
388/// Community assignments for each node
389#[allow(dead_code)]
390pub fn label_propagation<F>(
391    graph: &Graph<F>,
392    max_iterations: usize,
393    tolerance: f64,
394) -> Result<Array1<usize>>
395where
396    F: Float
397        + FromPrimitive
398        + Debug
399        + ScalarOperand
400        + std::iter::Sum
401        + std::cmp::Eq
402        + std::hash::Hash
403        + 'static,
404    f64: From<F>,
405{
406    let n_nodes = graph.n_nodes;
407    let mut labels: Array1<usize> = Array1::from_iter(0..n_nodes);
408    let tolerance_f = F::from(tolerance).expect("Failed to convert to float");
409
410    for _iteration in 0..max_iterations {
411        let mut new_labels = labels.clone();
412        let mut changed_nodes = 0;
413
414        // Process nodes in random order
415        let mut node_order: Vec<usize> = (0..n_nodes).collect();
416        // For deterministic results, we'll use a simple shuffle based on node index
417        node_order.sort_by_key(|&i| i * 17 % n_nodes);
418
419        for &node in &node_order {
420            // Count label frequencies among neighbor_s
421            let mut label_weights: HashMap<usize, F> = HashMap::new();
422
423            for &(neighbor_, weight) in graph.neighbor_s(node) {
424                let label = labels[neighbor_];
425                let entry = label_weights.entry(label).or_insert(F::zero());
426                *entry = *entry + weight;
427            }
428
429            // Choose label with highest weight
430            if let Some((&best_label_, _)) = label_weights
431                .iter()
432                .max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
433            {
434                if best_label_ != labels[node] {
435                    new_labels[node] = best_label_;
436                    changed_nodes += 1;
437                }
438            }
439        }
440
441        labels = new_labels;
442
443        // Check convergence
444        let change_ratio = changed_nodes as f64 / n_nodes as f64;
445        if change_ratio < tolerance {
446            break;
447        }
448    }
449
450    // Relabel communities to be consecutive integers starting from 0
451    let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
452    let label_mapping: HashMap<usize, usize> = unique_labels
453        .into_iter()
454        .enumerate()
455        .map(|(new_label, old_label)| (old_label, new_label))
456        .collect();
457
458    for label in labels.iter_mut() {
459        *label = label_mapping[label];
460    }
461
462    Ok(labels)
463}
464
465/// Girvan-Newman algorithm for community detection
466///
467/// This algorithm removes edges with highest betweenness centrality iteratively
468/// to reveal community structure. It's more computationally expensive but can
469/// produce hierarchical community structures.
470///
471/// # Arguments
472///
473/// * `graph` - Input graph
474/// * `ncommunities` - Desired number of communities (algorithm stops when reached)
475///
476/// # Returns
477///
478/// Community assignments for each node
479#[allow(dead_code)]
480pub fn girvan_newman<F>(graph: &Graph<F>, ncommunities: usize) -> Result<Array1<usize>>
481where
482    F: Float
483        + FromPrimitive
484        + Debug
485        + ScalarOperand
486        + std::iter::Sum
487        + std::cmp::Eq
488        + std::hash::Hash
489        + 'static,
490{
491    if ncommunities > graph.n_nodes {
492        return Err(ClusteringError::InvalidInput(
493            "Number of _communities cannot exceed number of nodes".to_string(),
494        ));
495    }
496
497    let mut workinggraph = graph.clone();
498    let mut _communities = find_connected_components(&workinggraph);
499
500    while count_communities(&_communities) < ncommunities && has_edges(&workinggraph) {
501        // Calculate edge betweenness centrality
502        let edge_betweenness = calculate_edge_betweenness(&workinggraph)?;
503
504        // Find edge with highest betweenness
505        if let Some((max_edge_, _)) = edge_betweenness
506            .iter()
507            .max_by(|a, b| a.1.partial_cmp(b.1).expect("Operation failed"))
508        {
509            // Remove the edge with highest betweenness
510            remove_edge(&mut workinggraph, max_edge_.0, max_edge_.1);
511
512            // Recalculate connected components
513            _communities = find_connected_components(&workinggraph);
514        } else {
515            break; // No more edges to remove
516        }
517    }
518
519    Ok(Array1::from_vec(_communities))
520}
521
522/// Calculate edge betweenness centrality for all edges
523#[allow(dead_code)]
524fn calculate_edge_betweenness<F>(graph: &Graph<F>) -> Result<HashMap<(usize, usize), f64>>
525where
526    F: Float
527        + FromPrimitive
528        + Debug
529        + ScalarOperand
530        + std::iter::Sum
531        + std::cmp::Eq
532        + std::hash::Hash
533        + 'static,
534{
535    let mut edge_betweenness = HashMap::new();
536
537    // Initialize all edges with zero betweenness
538    for node in 0..graph.n_nodes {
539        for &(neighbor_, _) in graph.neighbor_s(node) {
540            if node < neighbor_ {
541                // Count each edge only once
542                edge_betweenness.insert((node, neighbor_), 0.0);
543            }
544        }
545    }
546
547    // For each pair of nodes, calculate shortest paths and update edge betweenness
548    for source in 0..graph.n_nodes {
549        for target in (source + 1)..graph.n_nodes {
550            let paths = find_all_shortest_paths(graph, source, target);
551
552            if !paths.is_empty() {
553                let contribution = 1.0 / paths.len() as f64;
554
555                for path in paths {
556                    for i in 0..(path.len() - 1) {
557                        let (u, v) = if path[i] < path[i + 1] {
558                            (path[i], path[i + 1])
559                        } else {
560                            (path[i + 1], path[i])
561                        };
562
563                        *edge_betweenness.entry((u, v)).or_insert(0.0) += contribution;
564                    }
565                }
566            }
567        }
568    }
569
570    Ok(edge_betweenness)
571}
572
573/// Find all shortest paths between two nodes using BFS
574#[allow(dead_code)]
575fn find_all_shortest_paths<F>(graph: &Graph<F>, source: usize, target: usize) -> Vec<Vec<usize>>
576where
577    F: Float
578        + FromPrimitive
579        + Debug
580        + ScalarOperand
581        + std::iter::Sum
582        + std::cmp::Eq
583        + std::hash::Hash
584        + 'static,
585{
586    let mut distances = vec![None; graph.n_nodes];
587    let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); graph.n_nodes];
588    let mut queue = VecDeque::new();
589
590    distances[source] = Some(0);
591    queue.push_back(source);
592
593    while let Some(current) = queue.pop_front() {
594        let current_dist = distances[current].expect("Operation failed");
595
596        for &(neighbor_, _) in graph.neighbor_s(current) {
597            if distances[neighbor_].is_none() {
598                // First time visiting this node
599                distances[neighbor_] = Some(current_dist + 1);
600                predecessors[neighbor_].push(current);
601                queue.push_back(neighbor_);
602            } else if distances[neighbor_] == Some(current_dist + 1) {
603                // Another shortest path found
604                predecessors[neighbor_].push(current);
605            }
606        }
607    }
608
609    // Reconstruct all shortest paths
610    if distances[target].is_none() {
611        return Vec::new(); // No path exists
612    }
613
614    let mut paths = Vec::new();
615    let mut current_paths = vec![vec![target]];
616
617    while !current_paths.is_empty() {
618        let mut next_paths = Vec::new();
619
620        for path in current_paths {
621            let last_node = path[path.len() - 1];
622
623            if last_node == source {
624                let mut complete_path = path.clone();
625                complete_path.reverse();
626                paths.push(complete_path);
627            } else {
628                for &pred in &predecessors[last_node] {
629                    let mut new_path = path.clone();
630                    new_path.push(pred);
631                    next_paths.push(new_path);
632                }
633            }
634        }
635
636        current_paths = next_paths;
637    }
638
639    paths
640}
641
642/// Remove an edge from the graph
643#[allow(dead_code)]
644fn remove_edge<F>(graph: &mut Graph<F>, node1: usize, node2: usize)
645where
646    F: Float
647        + FromPrimitive
648        + Debug
649        + ScalarOperand
650        + std::iter::Sum
651        + std::cmp::Eq
652        + std::hash::Hash
653        + 'static,
654{
655    graph.adjacency[node1].retain(|(neighbor_, _)| *neighbor_ != node2);
656    graph.adjacency[node2].retain(|(neighbor_, _)| *neighbor_ != node1);
657}
658
659/// Check if the graph has any edges
660#[allow(dead_code)]
661fn has_edges<F>(graph: &Graph<F>) -> bool
662where
663    F: Float
664        + FromPrimitive
665        + Debug
666        + ScalarOperand
667        + std::iter::Sum
668        + std::cmp::Eq
669        + std::hash::Hash
670        + 'static,
671{
672    graph
673        .adjacency
674        .iter()
675        .any(|neighbor_s| !neighbor_s.is_empty())
676}
677
678/// Find connected components in the graph
679#[allow(dead_code)]
680fn find_connected_components<F>(graph: &Graph<F>) -> Vec<usize>
681where
682    F: Float
683        + FromPrimitive
684        + Debug
685        + ScalarOperand
686        + std::iter::Sum
687        + std::cmp::Eq
688        + std::hash::Hash
689        + 'static,
690{
691    let mut visited = vec![false; graph.n_nodes];
692    let mut components = vec![0; graph.n_nodes];
693    let mut component_id = 0;
694
695    for node in 0..graph.n_nodes {
696        if !visited[node] {
697            dfs_component(graph, node, component_id, &mut visited, &mut components);
698            component_id += 1;
699        }
700    }
701
702    components
703}
704
705/// Depth-first search to mark connected component
706#[allow(dead_code)]
707fn dfs_component<F>(
708    graph: &Graph<F>,
709    node: usize,
710    component_id: usize,
711    visited: &mut [bool],
712    components: &mut [usize],
713) where
714    F: Float
715        + FromPrimitive
716        + Debug
717        + ScalarOperand
718        + std::iter::Sum
719        + std::cmp::Eq
720        + std::hash::Hash
721        + 'static,
722{
723    visited[node] = true;
724    components[node] = component_id;
725
726    for &(neighbor_, _) in graph.neighbor_s(node) {
727        if !visited[neighbor_] {
728            dfs_component(graph, neighbor_, component_id, visited, components);
729        }
730    }
731}
732
733/// Count the number of unique communities
734#[allow(dead_code)]
735fn count_communities(communities: &[usize]) -> usize {
736    let mut unique: HashSet<usize> = HashSet::new();
737    for &community in communities {
738        unique.insert(community);
739    }
740    unique.len()
741}
742
743/// Helper function to calculate Euclidean distance between two points
744#[allow(dead_code)]
745fn euclidean_distance<F>(a: ArrayView1<F>, b: ArrayView1<F>) -> F
746where
747    F: Float + std::iter::Sum + 'static,
748{
749    let diff = &a.to_owned() - &b.to_owned();
750    diff.dot(&diff).sqrt()
751}
752
753/// Configuration for graph clustering algorithms
754#[derive(Debug, Clone, Serialize, Deserialize)]
755pub struct GraphClusteringConfig {
756    /// Algorithm to use for clustering
757    pub algorithm: GraphClusteringAlgorithm,
758    /// Maximum number of iterations (for iterative algorithms)
759    pub max_iterations: usize,
760    /// Convergence tolerance
761    pub tolerance: f64,
762    /// Resolution parameter (for modularity-based algorithms)
763    pub resolution: f64,
764    /// Target number of communities (for hierarchical algorithms)
765    pub ncommunities: Option<usize>,
766}
767
768/// Available graph clustering algorithms
769#[derive(Debug, Clone, Serialize, Deserialize)]
770pub enum GraphClusteringAlgorithm {
771    /// Louvain community detection
772    Louvain,
773    /// Label propagation algorithm
774    LabelPropagation,
775    /// Girvan-Newman algorithm
776    GirvanNewman,
777}
778
779impl Default for GraphClusteringConfig {
780    fn default() -> Self {
781        Self {
782            algorithm: GraphClusteringAlgorithm::Louvain,
783            max_iterations: 100,
784            tolerance: 1e-6,
785            resolution: 1.0,
786            ncommunities: None,
787        }
788    }
789}
790
791/// Perform graph clustering using the specified configuration
792///
793/// # Arguments
794///
795/// * `graph` - Input graph
796/// * `config` - Clustering configuration
797///
798/// # Returns
799///
800/// Community assignments for each node
801#[allow(dead_code)]
802pub fn graph_clustering<F>(
803    graph: &Graph<F>,
804    config: &GraphClusteringConfig,
805) -> Result<Array1<usize>>
806where
807    F: Float
808        + FromPrimitive
809        + Debug
810        + ScalarOperand
811        + std::iter::Sum
812        + std::cmp::Eq
813        + std::hash::Hash
814        + 'static,
815    f64: From<F>,
816{
817    match config.algorithm {
818        GraphClusteringAlgorithm::Louvain => {
819            louvain(graph, config.resolution, config.max_iterations)
820        }
821        GraphClusteringAlgorithm::LabelPropagation => {
822            label_propagation(graph, config.max_iterations, config.tolerance)
823        }
824        GraphClusteringAlgorithm::GirvanNewman => {
825            let ncommunities = config.ncommunities.unwrap_or(2);
826            girvan_newman(graph, ncommunities)
827        }
828    }
829}
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834    use scirs2_core::ndarray::Array2;
835
836    #[test]
837    fn testgraph_creation() {
838        let graph = Graph::<i32>::new(5);
839        assert_eq!(graph.n_nodes, 5);
840        assert_eq!(graph.adjacency.len(), 5);
841    }
842
843    #[test]
844    fn testgraph_from_adjacencymatrix() {
845        let adjacency = Array2::from_shape_vec((3, 3), vec![0, 1, 0, 1, 0, 1, 0, 1, 0])
846            .expect("Operation failed");
847
848        let graph = Graph::from_adjacencymatrix(adjacency.view()).expect("Operation failed");
849        assert_eq!(graph.n_nodes, 3);
850        assert_eq!(graph.degree(0), 1);
851        assert_eq!(graph.degree(1), 2);
852        assert_eq!(graph.degree(2), 1);
853    }
854
855    /*
856    #[test]
857    fn test_louvain_clustering() {
858        // Create a simple graph with two obvious communities
859        let adjacency = Array2::from_shape_vec(
860            (4, 4),
861            vec![
862                0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0,
863            ],
864        )
865        .expect("Operation failed");
866
867        let graph = Graph::from_adjacencymatrix(adjacency.view()).expect("Operation failed");
868        let communities = louvain(&graph, 1.0, 100).expect("Operation failed");
869
870        // Nodes 0,1 should be in one community and nodes 2,3 in another
871        assert_eq!(communities.len(), 4);
872        assert_eq!(communities[0], communities[1]);
873        assert_eq!(communities[2], communities[3]);
874        assert_ne!(communities[0], communities[2]);
875    }
876    */
877
878    /*
879    #[test]
880    fn test_label_propagation() {
881        let adjacency = Array2::from_shape_vec(
882            (4, 4),
883            vec![
884                0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0,
885            ],
886        )
887        .expect("Operation failed");
888
889        let graph = Graph::from_adjacencymatrix(adjacency.view()).expect("Operation failed");
890        let communities = label_propagation(&graph, 100, 1e-6).expect("Operation failed");
891
892        assert_eq!(communities.len(), 4);
893        // Should detect two communities
894        let unique_communities: HashSet<usize> = communities.iter().cloned().collect();
895        assert_eq!(unique_communities.len(), 2);
896    }
897    */
898
899    /*
900    #[test]
901    fn test_knngraph_creation() {
902        let data =
903            Array2::from_shape_vec((4, 2), vec![0, 0, 1, 1, 5, 5, 6, 6]).expect("Operation failed");
904
905        let graph = Graph::from_knngraph(data.view(), 2).expect("Operation failed");
906        assert_eq!(graph.n_nodes, 4);
907
908        // Each node should have exactly 2 neighbor_s
909        for node in 0..4 {
910            assert_eq!(graph.degree(node), 2);
911        }
912    }
913    */
914}