scirs2_graph/algorithms/community/
label_propagation.rs

1//! Label propagation algorithm for community detection
2
3use super::types::CommunityResult;
4use crate::base::{EdgeWeight, Graph, IndexType, Node};
5use scirs2_core::random::seq::SliceRandom;
6use std::collections::HashMap;
7use std::hash::Hash;
8
9/// Internal implementation of label propagation algorithm
10#[allow(dead_code)]
11fn label_propagation_internal<N, E, Ix>(
12    graph: &Graph<N, E, Ix>,
13    max_iterations: usize,
14) -> HashMap<N, usize>
15where
16    N: Node + Clone + Hash + Eq + std::fmt::Debug,
17    E: EdgeWeight,
18    Ix: IndexType,
19{
20    let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
21    let n = nodes.len();
22
23    if n == 0 {
24        return HashMap::new();
25    }
26
27    // Initialize each node with its own label
28    let mut labels: Vec<usize> = (0..n).collect();
29    let node_to_idx: HashMap<N, usize> = nodes
30        .iter()
31        .enumerate()
32        .map(|(i, n)| (n.clone(), i))
33        .collect();
34
35    let mut rng = scirs2_core::random::rng();
36    let mut changed = true;
37    let mut _iterations = 0;
38
39    while changed && _iterations < max_iterations {
40        changed = false;
41        _iterations += 1;
42
43        // Process nodes in random order
44        let mut order: Vec<usize> = (0..n).collect();
45        order.shuffle(&mut rng);
46
47        for &i in &order {
48            let node = &nodes[i];
49
50            // Count labels of neighbors
51            let mut label_counts: HashMap<usize, usize> = HashMap::new();
52
53            if let Ok(neighbors) = graph.neighbors(node) {
54                for neighbor in neighbors {
55                    if let Some(&neighbor_idx) = node_to_idx.get(&neighbor) {
56                        let neighbor_label = labels[neighbor_idx];
57                        *label_counts.entry(neighbor_label).or_insert(0) += 1;
58                    }
59                }
60            }
61
62            if label_counts.is_empty() {
63                continue;
64            }
65
66            // Find most frequent label(s)
67            let max_count = *label_counts.values().max().unwrap();
68            let best_labels: Vec<usize> = label_counts
69                .into_iter()
70                .filter(|(_, count)| *count == max_count)
71                .map(|(label, _)| label)
72                .collect();
73
74            // Choose randomly among ties
75            use scirs2_core::random::Rng;
76            let new_label = best_labels[rng.gen_range(0..best_labels.len())];
77
78            if labels[i] != new_label {
79                labels[i] = new_label;
80                changed = true;
81            }
82        }
83    }
84
85    // Convert to final result
86    nodes
87        .into_iter()
88        .enumerate()
89        .map(|(i, node)| (node, labels[i]))
90        .collect()
91}
92
93/// Label propagation algorithm for community detection (legacy API)
94///
95/// **Note**: This function is deprecated in favor of `label_propagation_result`.
96/// It will be removed in version 2.0.
97///
98/// Each node adopts the label that most of its neighbors have, with ties broken randomly.
99/// Returns a mapping from nodes to community labels.
100///
101/// # Time Complexity
102/// O(k * m) where k is the number of iterations (typically small) and m is
103/// the number of edges. The algorithm often converges in 5-10 iterations.
104///
105/// # Space Complexity
106/// O(n) for storing labels and temporary data structures.
107#[deprecated(
108    since = "0.1.0-beta.2",
109    note = "Use `label_propagation_result` instead"
110)]
111#[allow(dead_code)]
112pub fn label_propagation<N, E, Ix>(
113    graph: &Graph<N, E, Ix>,
114    max_iterations: usize,
115) -> HashMap<N, usize>
116where
117    N: Node + Clone + Hash + Eq + std::fmt::Debug,
118    E: EdgeWeight,
119    Ix: IndexType,
120{
121    label_propagation_internal(graph, max_iterations)
122}
123
124/// Label propagation algorithm with standardized CommunityResult return type
125///
126/// This function provides the same functionality as `label_propagation` but returns
127/// a standardized `CommunityResult` type that provides multiple ways to access
128/// the community structure.
129///
130/// # Arguments
131/// * `graph` - The graph to analyze
132/// * `max_iterations` - Maximum number of iterations (default: 100)
133///
134/// # Returns
135/// * A `CommunityResult` with comprehensive community information
136///
137/// # Time Complexity
138/// O(k * m) where k is the number of iterations and m is the number of edges.
139/// In practice, the algorithm usually converges in a few iterations.
140///
141/// # Space Complexity
142/// O(n) for storing node labels and community assignments.
143///
144/// # Example
145/// ```rust
146/// use scirs2_graph::{Graph, label_propagation_result};
147///
148/// let mut graph: Graph<i32, f64> = Graph::new();
149/// // ... add nodes and edges ...
150/// let result = label_propagation_result(&graph, 100);
151///
152/// println!("Found {} communities", result.num_communities);
153/// for (i, community) in result.communities.iter().enumerate() {
154///     println!("Community {}: {} members", i, community.len());
155/// }
156/// ```
157#[allow(dead_code)]
158pub fn label_propagation_result<N, E, Ix>(
159    graph: &Graph<N, E, Ix>,
160    max_iterations: usize,
161) -> CommunityResult<N>
162where
163    N: Node + Clone + Hash + Eq + std::fmt::Debug,
164    E: EdgeWeight,
165    Ix: IndexType,
166{
167    let node_communities = label_propagation_internal(graph, max_iterations);
168    CommunityResult::from_node_map(node_communities)
169}