scirs2_cluster/
graph.rs

1//! Graph clustering and community detection algorithms
2//!
3//! This module provides implementations of various graph clustering algorithms for
4//! detecting communities and clusters in network data. These algorithms work with
5//! graph representations where nodes represent data points and edges represent
6//! similarities or connections between them.
7
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::fmt::Debug;
12
13use serde::{Deserialize, Serialize};
14
15use crate::error::{ClusteringError, Result};
16
17/// Graph representation for clustering algorithms
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Graph<F: Float> {
20    /// Number of nodes in the graph
21    pub n_nodes: usize,
22    /// Adjacency list representation: node_id -> [(neighbor_id, weight), ...]
23    pub adjacency: Vec<Vec<(usize, F)>>,
24    /// Optional node labels/features
25    pub node_features: Option<Array2<F>>,
26}
27
28impl<
29        F: Float
30            + FromPrimitive
31            + Debug
32            + ScalarOperand
33            + std::iter::Sum
34            + std::cmp::Eq
35            + std::hash::Hash
36            + 'static,
37    > Graph<F>
38{
39    /// Create a new empty graph with specified number of nodes
40    pub fn new(_nnodes: usize) -> Self {
41        Self {
42            n_nodes: _nnodes,
43            adjacency: vec![Vec::new(); _nnodes],
44            node_features: None,
45        }
46    }
47
48    /// Create a graph from an adjacency matrix
49    pub fn from_adjacencymatrix(_adjacencymatrix: ArrayView2<F>) -> Result<Self> {
50        let n_nodes = _adjacencymatrix.shape()[0];
51        if _adjacencymatrix.shape()[1] != n_nodes {
52            return Err(ClusteringError::InvalidInput(
53                "Adjacency _matrix must be square".to_string(),
54            ));
55        }
56
57        let mut graph = Self::new(n_nodes);
58
59        for i in 0..n_nodes {
60            for j in 0..n_nodes {
61                let weight = _adjacencymatrix[[i, j]];
62                if weight > F::zero() && i != j {
63                    graph.add_edge(i, j, weight)?;
64                }
65            }
66        }
67
68        Ok(graph)
69    }
70
71    /// Create a k-nearest neighbor_ graph from data points
72    pub fn from_knngraph(data: ArrayView2<F>, k: usize) -> Result<Self> {
73        let n_samples = data.shape()[0];
74        let mut graph = Self::new(n_samples);
75        graph.node_features = Some(data.to_owned());
76
77        // For each point, find k nearest neighbor_s
78        for i in 0..n_samples {
79            let mut distances: Vec<(usize, F)> = Vec::new();
80
81            for j in 0..n_samples {
82                if i != j {
83                    let dist = euclidean_distance(data.row(i), data.row(j));
84                    distances.push((j, dist));
85                }
86            }
87
88            // Sort by distance and take k nearest
89            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
90
91            for &(neighbor_idx, distance) in distances.iter().take(k) {
92                // Use similarity (inverse of distance) as edge weight
93                let similarity = F::one() / (F::one() + distance);
94                graph.add_edge(i, neighbor_idx, similarity)?;
95            }
96        }
97
98        Ok(graph)
99    }
100
101    /// Add an edge between two nodes
102    pub fn add_edge(&mut self, node1: usize, node2: usize, weight: F) -> Result<()> {
103        if node1 >= self.n_nodes || node2 >= self.n_nodes {
104            return Err(ClusteringError::InvalidInput(
105                "Node index out of bounds".to_string(),
106            ));
107        }
108
109        if node1 != node2 {
110            self.adjacency[node1].push((node2, weight));
111            self.adjacency[node2].push((node1, weight)); // Undirected graph
112        }
113
114        Ok(())
115    }
116
117    /// Get the degree of a node (number of neighbor_s)
118    pub fn degree(&self, node: usize) -> usize {
119        if node < self.n_nodes {
120            self.adjacency[node].len()
121        } else {
122            0
123        }
124    }
125
126    /// Get the weighted degree of a node (sum of edge weights)
127    pub fn weighted_degree(&self, node: usize) -> F {
128        if node < self.n_nodes {
129            self.adjacency[node].iter().map(|(_, weight)| *weight).sum()
130        } else {
131            F::zero()
132        }
133    }
134
135    /// Get all neighbor_s of a node
136    pub fn neighbor_s(&self, node: usize) -> &[(usize, F)] {
137        if node < self.n_nodes {
138            &self.adjacency[node]
139        } else {
140            &[]
141        }
142    }
143
144    /// Calculate modularity of a given community assignment
145    pub fn modularity(&self, communities: &[usize]) -> F {
146        let total_weight = self.total_edge_weight();
147        if total_weight == F::zero() {
148            return F::zero();
149        }
150
151        let mut modularity = F::zero();
152
153        for i in 0..self.n_nodes {
154            for j in 0..self.n_nodes {
155                if communities[i] == communities[j] {
156                    let edge_weight = self.get_edge_weight(i, j);
157                    let degree_i = self.weighted_degree(i);
158                    let degree_j = self.weighted_degree(j);
159
160                    let expected = degree_i * degree_j / (F::from(2.0).unwrap() * total_weight);
161                    modularity = modularity + edge_weight - expected;
162                }
163            }
164        }
165
166        modularity / (F::from(2.0).unwrap() * total_weight)
167    }
168
169    /// Get edge weight between two nodes
170    fn get_edge_weight(&self, node1: usize, node2: usize) -> F {
171        if node1 < self.n_nodes {
172            for &(neighbor_, weight) in &self.adjacency[node1] {
173                if neighbor_ == node2 {
174                    return weight;
175                }
176            }
177        }
178        F::zero()
179    }
180
181    /// Calculate total weight of all edges in the graph
182    fn total_edge_weight(&self) -> F {
183        let mut total = F::zero();
184        for node in 0..self.n_nodes {
185            for &(_, weight) in &self.adjacency[node] {
186                total = total + weight;
187            }
188        }
189        total / F::from(2.0).unwrap() // Divide by 2 because each edge is counted twice
190    }
191}
192
193/// Louvain community detection algorithm
194///
195/// The Louvain algorithm is a greedy optimization method that attempts to optimize
196/// the modularity of a partition of the network. It produces high quality communities
197/// and has excellent performance on large networks.
198///
199/// # Arguments
200///
201/// * `graph` - Input graph
202/// * `resolution` - Resolution parameter (higher values lead to smaller communities)
203/// * `max_iterations` - Maximum number of iterations
204///
205/// # Returns
206///
207/// Community assignments for each node
208///
209/// # Example
210///
211/// ```no_run
212/// // Doctest disabled due to incompatible trait constraints (Float vs Eq+Hash)
213/// use scirs2_core::ndarray::Array2;
214/// use scirs2_cluster::graph::{Graph, louvain};
215///
216/// // Note: Graph requires F: Float + Eq + Hash, which is impossible for standard float types
217/// // This is a design issue that needs to be addressed
218/// let adjacency = Array2::from_shape_vec((4, 4), vec![
219///     0.0, 1.0, 1.0, 0.0,
220///     1.0, 0.0, 0.0, 0.0,
221///     1.0, 0.0, 0.0, 1.0,
222///     0.0, 0.0, 1.0, 0.0,
223/// ]).unwrap();
224///
225/// // This would fail to compile due to trait constraint conflicts
226/// // let graph = Graph::from_adjacencymatrix(adjacency.view()).unwrap();
227/// // let communities = louvain(&graph, 1.0, 100).unwrap();
228/// ```
229#[allow(dead_code)]
230pub fn louvain<F>(graph: &Graph<F>, resolution: f64, max_iterations: usize) -> Result<Array1<usize>>
231where
232    F: Float
233        + FromPrimitive
234        + Debug
235        + ScalarOperand
236        + std::iter::Sum
237        + std::cmp::Eq
238        + std::hash::Hash
239        + 'static,
240    f64: From<F>,
241{
242    let n_nodes = graph.n_nodes;
243    let mut communities: Array1<usize> = Array1::from_iter(0..n_nodes);
244    let mut improved = true;
245    let mut iteration = 0;
246
247    while improved && iteration < max_iterations {
248        improved = false;
249        iteration += 1;
250
251        // Phase 1: Optimize modularity by moving nodes
252        for node in 0..n_nodes {
253            let current_community = communities[node];
254            let mut best_community = current_community;
255            let mut best_gain = F::zero();
256
257            // Try moving node to each neighbor_'s community
258            let mut candidate_communities = HashSet::new();
259            candidate_communities.insert(current_community);
260
261            for &(neighbor_id, _weight) in graph.neighbor_s(node) {
262                candidate_communities.insert(communities[neighbor_id]);
263            }
264
265            for &candidate_community in &candidate_communities {
266                if candidate_community != current_community {
267                    // Calculate modularity gain from moving to this community
268                    let gain = modularity_gain(
269                        graph,
270                        &communities,
271                        node,
272                        current_community,
273                        candidate_community,
274                        resolution,
275                    );
276
277                    if gain > best_gain {
278                        best_gain = gain;
279                        best_community = candidate_community;
280                    }
281                }
282            }
283
284            // Move node to best community if improvement found
285            if best_community != current_community && best_gain > F::zero() {
286                communities[node] = best_community;
287                improved = true;
288            }
289        }
290    }
291
292    Ok(communities)
293}
294
295/// Calculate modularity gain from moving a node to a different community
296#[allow(dead_code)]
297fn modularity_gain<F>(
298    graph: &Graph<F>,
299    communities: &Array1<usize>,
300    node: usize,
301    from_community: usize,
302    to_community: usize,
303    resolution: f64,
304) -> F
305where
306    F: Float
307        + FromPrimitive
308        + Debug
309        + ScalarOperand
310        + std::iter::Sum
311        + std::cmp::Eq
312        + std::hash::Hash
313        + 'static,
314    f64: From<F>,
315{
316    let total_weight = graph.total_edge_weight();
317    if total_weight == F::zero() {
318        return F::zero();
319    }
320
321    let node_degree = graph.weighted_degree(node);
322    let resolution_f = F::from(resolution).unwrap();
323
324    // Calculate connections within target _community
325    let mut edges_to_target = F::zero();
326    let mut edges_from_source = F::zero();
327
328    for &(neighbor_, weight) in graph.neighbor_s(node) {
329        if communities[neighbor_] == to_community {
330            edges_to_target = edges_to_target + weight;
331        }
332        if communities[neighbor_] == from_community && neighbor_ != node {
333            edges_from_source = edges_from_source + weight;
334        }
335    }
336
337    // Calculate _community weights
338    let target_community_weight = calculate_community_weight(graph, communities, to_community);
339    let source_community_weight = calculate_community_weight(graph, communities, from_community);
340
341    // Calculate modularity gain
342    let gain_to = edges_to_target
343        - resolution_f * node_degree * target_community_weight
344            / (F::from(2.0).unwrap() * total_weight);
345    let loss_from = edges_from_source
346        - resolution_f * node_degree * (source_community_weight - node_degree)
347            / (F::from(2.0).unwrap() * total_weight);
348
349    gain_to - loss_from
350}
351
352/// Calculate total weight of a community
353#[allow(dead_code)]
354fn calculate_community_weight<F>(
355    graph: &Graph<F>,
356    communities: &Array1<usize>,
357    community: usize,
358) -> F
359where
360    F: Float
361        + FromPrimitive
362        + Debug
363        + ScalarOperand
364        + std::iter::Sum
365        + std::cmp::Eq
366        + std::hash::Hash
367        + 'static,
368{
369    let mut weight = F::zero();
370    for node in 0..graph.n_nodes {
371        if communities[node] == community {
372            weight = weight + graph.weighted_degree(node);
373        }
374    }
375    weight
376}
377
378/// Label propagation algorithm for community detection
379///
380/// A fast algorithm where each node adopts the label that most of its neighbor_s have.
381/// This process continues iteratively until convergence.
382///
383/// # Arguments
384///
385/// * `graph` - Input graph
386/// * `max_iterations` - Maximum number of iterations
387/// * `tolerance` - Convergence tolerance
388///
389/// # Returns
390///
391/// Community assignments for each node
392#[allow(dead_code)]
393pub fn label_propagation<F>(
394    graph: &Graph<F>,
395    max_iterations: usize,
396    tolerance: f64,
397) -> Result<Array1<usize>>
398where
399    F: Float
400        + FromPrimitive
401        + Debug
402        + ScalarOperand
403        + std::iter::Sum
404        + std::cmp::Eq
405        + std::hash::Hash
406        + 'static,
407    f64: From<F>,
408{
409    let n_nodes = graph.n_nodes;
410    let mut labels: Array1<usize> = Array1::from_iter(0..n_nodes);
411    let tolerance_f = F::from(tolerance).unwrap();
412
413    for _iteration in 0..max_iterations {
414        let mut new_labels = labels.clone();
415        let mut changed_nodes = 0;
416
417        // Process nodes in random order
418        let mut node_order: Vec<usize> = (0..n_nodes).collect();
419        // For deterministic results, we'll use a simple shuffle based on node index
420        node_order.sort_by_key(|&i| i * 17 % n_nodes);
421
422        for &node in &node_order {
423            // Count label frequencies among neighbor_s
424            let mut label_weights: HashMap<usize, F> = HashMap::new();
425
426            for &(neighbor_, weight) in graph.neighbor_s(node) {
427                let label = labels[neighbor_];
428                let entry = label_weights.entry(label).or_insert(F::zero());
429                *entry = *entry + weight;
430            }
431
432            // Choose label with highest weight
433            if let Some((&best_label_, _)) = label_weights
434                .iter()
435                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
436            {
437                if best_label_ != labels[node] {
438                    new_labels[node] = best_label_;
439                    changed_nodes += 1;
440                }
441            }
442        }
443
444        labels = new_labels;
445
446        // Check convergence
447        let change_ratio = changed_nodes as f64 / n_nodes as f64;
448        if change_ratio < tolerance {
449            break;
450        }
451    }
452
453    // Relabel communities to be consecutive integers starting from 0
454    let unique_labels: HashSet<usize> = labels.iter().cloned().collect();
455    let label_mapping: HashMap<usize, usize> = unique_labels
456        .into_iter()
457        .enumerate()
458        .map(|(new_label, old_label)| (old_label, new_label))
459        .collect();
460
461    for label in labels.iter_mut() {
462        *label = label_mapping[label];
463    }
464
465    Ok(labels)
466}
467
468/// Girvan-Newman algorithm for community detection
469///
470/// This algorithm removes edges with highest betweenness centrality iteratively
471/// to reveal community structure. It's more computationally expensive but can
472/// produce hierarchical community structures.
473///
474/// # Arguments
475///
476/// * `graph` - Input graph
477/// * `ncommunities` - Desired number of communities (algorithm stops when reached)
478///
479/// # Returns
480///
481/// Community assignments for each node
482#[allow(dead_code)]
483pub fn girvan_newman<F>(graph: &Graph<F>, ncommunities: usize) -> Result<Array1<usize>>
484where
485    F: Float
486        + FromPrimitive
487        + Debug
488        + ScalarOperand
489        + std::iter::Sum
490        + std::cmp::Eq
491        + std::hash::Hash
492        + 'static,
493{
494    if ncommunities > graph.n_nodes {
495        return Err(ClusteringError::InvalidInput(
496            "Number of _communities cannot exceed number of nodes".to_string(),
497        ));
498    }
499
500    let mut workinggraph = graph.clone();
501    let mut _communities = find_connected_components(&workinggraph);
502
503    while count_communities(&_communities) < ncommunities && has_edges(&workinggraph) {
504        // Calculate edge betweenness centrality
505        let edge_betweenness = calculate_edge_betweenness(&workinggraph)?;
506
507        // Find edge with highest betweenness
508        if let Some((max_edge_, _)) = edge_betweenness
509            .iter()
510            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
511        {
512            // Remove the edge with highest betweenness
513            remove_edge(&mut workinggraph, max_edge_.0, max_edge_.1);
514
515            // Recalculate connected components
516            _communities = find_connected_components(&workinggraph);
517        } else {
518            break; // No more edges to remove
519        }
520    }
521
522    Ok(Array1::from_vec(_communities))
523}
524
525/// Calculate edge betweenness centrality for all edges
526#[allow(dead_code)]
527fn calculate_edge_betweenness<F>(graph: &Graph<F>) -> Result<HashMap<(usize, usize), f64>>
528where
529    F: Float
530        + FromPrimitive
531        + Debug
532        + ScalarOperand
533        + std::iter::Sum
534        + std::cmp::Eq
535        + std::hash::Hash
536        + 'static,
537{
538    let mut edge_betweenness = HashMap::new();
539
540    // Initialize all edges with zero betweenness
541    for node in 0..graph.n_nodes {
542        for &(neighbor_, _) in graph.neighbor_s(node) {
543            if node < neighbor_ {
544                // Count each edge only once
545                edge_betweenness.insert((node, neighbor_), 0.0);
546            }
547        }
548    }
549
550    // For each pair of nodes, calculate shortest paths and update edge betweenness
551    for source in 0..graph.n_nodes {
552        for target in (source + 1)..graph.n_nodes {
553            let paths = find_all_shortest_paths(graph, source, target);
554
555            if !paths.is_empty() {
556                let contribution = 1.0 / paths.len() as f64;
557
558                for path in paths {
559                    for i in 0..(path.len() - 1) {
560                        let (u, v) = if path[i] < path[i + 1] {
561                            (path[i], path[i + 1])
562                        } else {
563                            (path[i + 1], path[i])
564                        };
565
566                        *edge_betweenness.entry((u, v)).or_insert(0.0) += contribution;
567                    }
568                }
569            }
570        }
571    }
572
573    Ok(edge_betweenness)
574}
575
576/// Find all shortest paths between two nodes using BFS
577#[allow(dead_code)]
578fn find_all_shortest_paths<F>(graph: &Graph<F>, source: usize, target: usize) -> Vec<Vec<usize>>
579where
580    F: Float
581        + FromPrimitive
582        + Debug
583        + ScalarOperand
584        + std::iter::Sum
585        + std::cmp::Eq
586        + std::hash::Hash
587        + 'static,
588{
589    let mut distances = vec![None; graph.n_nodes];
590    let mut predecessors: Vec<Vec<usize>> = vec![Vec::new(); graph.n_nodes];
591    let mut queue = VecDeque::new();
592
593    distances[source] = Some(0);
594    queue.push_back(source);
595
596    while let Some(current) = queue.pop_front() {
597        let current_dist = distances[current].unwrap();
598
599        for &(neighbor_, _) in graph.neighbor_s(current) {
600            if distances[neighbor_].is_none() {
601                // First time visiting this node
602                distances[neighbor_] = Some(current_dist + 1);
603                predecessors[neighbor_].push(current);
604                queue.push_back(neighbor_);
605            } else if distances[neighbor_] == Some(current_dist + 1) {
606                // Another shortest path found
607                predecessors[neighbor_].push(current);
608            }
609        }
610    }
611
612    // Reconstruct all shortest paths
613    if distances[target].is_none() {
614        return Vec::new(); // No path exists
615    }
616
617    let mut paths = Vec::new();
618    let mut current_paths = vec![vec![target]];
619
620    while !current_paths.is_empty() {
621        let mut next_paths = Vec::new();
622
623        for path in current_paths {
624            let last_node = path[path.len() - 1];
625
626            if last_node == source {
627                let mut complete_path = path.clone();
628                complete_path.reverse();
629                paths.push(complete_path);
630            } else {
631                for &pred in &predecessors[last_node] {
632                    let mut new_path = path.clone();
633                    new_path.push(pred);
634                    next_paths.push(new_path);
635                }
636            }
637        }
638
639        current_paths = next_paths;
640    }
641
642    paths
643}
644
645/// Remove an edge from the graph
646#[allow(dead_code)]
647fn remove_edge<F>(graph: &mut Graph<F>, node1: usize, node2: usize)
648where
649    F: Float
650        + FromPrimitive
651        + Debug
652        + ScalarOperand
653        + std::iter::Sum
654        + std::cmp::Eq
655        + std::hash::Hash
656        + 'static,
657{
658    graph.adjacency[node1].retain(|(neighbor_, _)| *neighbor_ != node2);
659    graph.adjacency[node2].retain(|(neighbor_, _)| *neighbor_ != node1);
660}
661
662/// Check if the graph has any edges
663#[allow(dead_code)]
664fn has_edges<F>(graph: &Graph<F>) -> bool
665where
666    F: Float
667        + FromPrimitive
668        + Debug
669        + ScalarOperand
670        + std::iter::Sum
671        + std::cmp::Eq
672        + std::hash::Hash
673        + 'static,
674{
675    graph
676        .adjacency
677        .iter()
678        .any(|neighbor_s| !neighbor_s.is_empty())
679}
680
681/// Find connected components in the graph
682#[allow(dead_code)]
683fn find_connected_components<F>(graph: &Graph<F>) -> Vec<usize>
684where
685    F: Float
686        + FromPrimitive
687        + Debug
688        + ScalarOperand
689        + std::iter::Sum
690        + std::cmp::Eq
691        + std::hash::Hash
692        + 'static,
693{
694    let mut visited = vec![false; graph.n_nodes];
695    let mut components = vec![0; graph.n_nodes];
696    let mut component_id = 0;
697
698    for node in 0..graph.n_nodes {
699        if !visited[node] {
700            dfs_component(graph, node, component_id, &mut visited, &mut components);
701            component_id += 1;
702        }
703    }
704
705    components
706}
707
708/// Depth-first search to mark connected component
709#[allow(dead_code)]
710fn dfs_component<F>(
711    graph: &Graph<F>,
712    node: usize,
713    component_id: usize,
714    visited: &mut [bool],
715    components: &mut [usize],
716) where
717    F: Float
718        + FromPrimitive
719        + Debug
720        + ScalarOperand
721        + std::iter::Sum
722        + std::cmp::Eq
723        + std::hash::Hash
724        + 'static,
725{
726    visited[node] = true;
727    components[node] = component_id;
728
729    for &(neighbor_, _) in graph.neighbor_s(node) {
730        if !visited[neighbor_] {
731            dfs_component(graph, neighbor_, component_id, visited, components);
732        }
733    }
734}
735
736/// Count the number of unique communities
737#[allow(dead_code)]
738fn count_communities(communities: &[usize]) -> usize {
739    let mut unique: HashSet<usize> = HashSet::new();
740    for &community in communities {
741        unique.insert(community);
742    }
743    unique.len()
744}
745
746/// Helper function to calculate Euclidean distance between two points
747#[allow(dead_code)]
748fn euclidean_distance<F>(a: ArrayView1<F>, b: ArrayView1<F>) -> F
749where
750    F: Float + std::iter::Sum + 'static,
751{
752    let diff = &a.to_owned() - &b.to_owned();
753    diff.dot(&diff).sqrt()
754}
755
756/// Configuration for graph clustering algorithms
757#[derive(Debug, Clone, Serialize, Deserialize)]
758pub struct GraphClusteringConfig {
759    /// Algorithm to use for clustering
760    pub algorithm: GraphClusteringAlgorithm,
761    /// Maximum number of iterations (for iterative algorithms)
762    pub max_iterations: usize,
763    /// Convergence tolerance
764    pub tolerance: f64,
765    /// Resolution parameter (for modularity-based algorithms)
766    pub resolution: f64,
767    /// Target number of communities (for hierarchical algorithms)
768    pub ncommunities: Option<usize>,
769}
770
771/// Available graph clustering algorithms
772#[derive(Debug, Clone, Serialize, Deserialize)]
773pub enum GraphClusteringAlgorithm {
774    /// Louvain community detection
775    Louvain,
776    /// Label propagation algorithm
777    LabelPropagation,
778    /// Girvan-Newman algorithm
779    GirvanNewman,
780}
781
782impl Default for GraphClusteringConfig {
783    fn default() -> Self {
784        Self {
785            algorithm: GraphClusteringAlgorithm::Louvain,
786            max_iterations: 100,
787            tolerance: 1e-6,
788            resolution: 1.0,
789            ncommunities: None,
790        }
791    }
792}
793
794/// Perform graph clustering using the specified configuration
795///
796/// # Arguments
797///
798/// * `graph` - Input graph
799/// * `config` - Clustering configuration
800///
801/// # Returns
802///
803/// Community assignments for each node
804#[allow(dead_code)]
805pub fn graph_clustering<F>(
806    graph: &Graph<F>,
807    config: &GraphClusteringConfig,
808) -> Result<Array1<usize>>
809where
810    F: Float
811        + FromPrimitive
812        + Debug
813        + ScalarOperand
814        + std::iter::Sum
815        + std::cmp::Eq
816        + std::hash::Hash
817        + 'static,
818    f64: From<F>,
819{
820    match config.algorithm {
821        GraphClusteringAlgorithm::Louvain => {
822            louvain(graph, config.resolution, config.max_iterations)
823        }
824        GraphClusteringAlgorithm::LabelPropagation => {
825            label_propagation(graph, config.max_iterations, config.tolerance)
826        }
827        GraphClusteringAlgorithm::GirvanNewman => {
828            let ncommunities = config.ncommunities.unwrap_or(2);
829            girvan_newman(graph, ncommunities)
830        }
831    }
832}
833
834#[cfg(test)]
835mod tests {
836    use super::*;
837    use scirs2_core::ndarray::Array2;
838
839    // TODO: Graph tests disabled due to trait bound conflicts
840    // Float types like f64 don't implement Eq + Hash required by Graph
841    /*
842    #[test]
843    fn testgraph_creation() {
844        let graph = Graph::<i32>::new(5);
845        assert_eq!(graph.n_nodes, 5);
846        assert_eq!(graph.adjacency.len(), 5);
847    }
848
849    #[test]
850    fn testgraph_from_adjacencymatrix() {
851        let adjacency =
852            Array2::from_shape_vec((3, 3), vec![0, 1, 0, 1, 0, 1, 0, 1, 0])
853                .unwrap();
854
855        let graph = Graph::from_adjacencymatrix(adjacency.view()).unwrap();
856        assert_eq!(graph.n_nodes, 3);
857        assert_eq!(graph.degree(0), 1);
858        assert_eq!(graph.degree(1), 2);
859        assert_eq!(graph.degree(2), 1);
860    }
861    */
862
863    /*
864    #[test]
865    fn test_louvain_clustering() {
866        // Create a simple graph with two obvious communities
867        let adjacency = Array2::from_shape_vec(
868            (4, 4),
869            vec![
870                0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0,
871            ],
872        )
873        .unwrap();
874
875        let graph = Graph::from_adjacencymatrix(adjacency.view()).unwrap();
876        let communities = louvain(&graph, 1.0, 100).unwrap();
877
878        // Nodes 0,1 should be in one community and nodes 2,3 in another
879        assert_eq!(communities.len(), 4);
880        assert_eq!(communities[0], communities[1]);
881        assert_eq!(communities[2], communities[3]);
882        assert_ne!(communities[0], communities[2]);
883    }
884    */
885
886    /*
887    #[test]
888    fn test_label_propagation() {
889        let adjacency = Array2::from_shape_vec(
890            (4, 4),
891            vec![
892                0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0,
893            ],
894        )
895        .unwrap();
896
897        let graph = Graph::from_adjacencymatrix(adjacency.view()).unwrap();
898        let communities = label_propagation(&graph, 100, 1e-6).unwrap();
899
900        assert_eq!(communities.len(), 4);
901        // Should detect two communities
902        let unique_communities: HashSet<usize> = communities.iter().cloned().collect();
903        assert_eq!(unique_communities.len(), 2);
904    }
905    */
906
907    /*
908    #[test]
909    fn test_knngraph_creation() {
910        let data =
911            Array2::from_shape_vec((4, 2), vec![0, 0, 1, 1, 5, 5, 6, 6]).unwrap();
912
913        let graph = Graph::from_knngraph(data.view(), 2).unwrap();
914        assert_eq!(graph.n_nodes, 4);
915
916        // Each node should have exactly 2 neighbor_s
917        for node in 0..4 {
918            assert_eq!(graph.degree(node), 2);
919        }
920    }
921    */
922}