1use crate::error::{SparseError, SparseResult};
9use crate::sparray::SparseArray;
10use scirs2_core::numeric::{Float, SparseElement};
11use std::collections::{HashMap, HashSet};
12use std::fmt::Debug;
13
14pub fn louvain_communities<T, S>(
28 graph: &S,
29 resolution: T,
30 max_iter: usize,
31) -> SparseResult<(usize, Vec<usize>)>
32where
33 T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
34 S: SparseArray<T>,
35{
36 let n = graph.shape().0;
37
38 if graph.shape().0 != graph.shape().1 {
39 return Err(SparseError::ValueError(
40 "Graph matrix must be square".to_string(),
41 ));
42 }
43
44 let mut communities = (0..n).collect::<Vec<_>>();
46
47 let mut degrees = vec![T::sparse_zero(); n];
49 let mut sum_all_weights = T::sparse_zero();
51
52 for i in 0..n {
53 for j in 0..n {
54 let weight = graph.get(i, j);
55 if !scirs2_core::SparseElement::is_zero(&weight) {
56 degrees[i] = degrees[i] + weight;
57 sum_all_weights = sum_all_weights + weight;
58 }
59 }
60 }
61
62 let two = T::from(2.0)
64 .ok_or_else(|| SparseError::ComputationError("Cannot convert 2.0".to_string()))?;
65 let m = sum_all_weights / two;
66
67 if scirs2_core::SparseElement::is_zero(&m) {
68 return Ok((n, communities));
69 }
70
71 let mut improvement = true;
72 let mut iteration = 0;
73
74 while improvement && iteration < max_iter {
75 improvement = false;
76 iteration += 1;
77
78 for node in 0..n {
80 let current_community = communities[node];
81
82 let mut neighbor_community_set = HashSet::new();
84
85 for neighbor in 0..n {
86 let weight = graph.get(node, neighbor);
87 if !scirs2_core::SparseElement::is_zero(&weight) && neighbor != node {
88 neighbor_community_set.insert(communities[neighbor]);
89 }
90 }
91 neighbor_community_set.insert(current_community);
93
94 let mut neighbor_communities: Vec<usize> = neighbor_community_set.into_iter().collect();
96 neighbor_communities.sort();
97
98 let mut weight_to_current = T::sparse_zero();
101 let mut sigma_current = T::sparse_zero(); for i in 0..n {
103 if i != node && communities[i] == current_community {
104 let w = graph.get(node, i);
105 weight_to_current = weight_to_current + w;
106 sigma_current = sigma_current + degrees[i];
107 }
108 }
109
110 let k_i = degrees[node];
111
112 let remove_cost = weight_to_current - resolution * k_i * sigma_current / (two * m);
115
116 let mut best_community = current_community;
117 let mut best_delta = T::sparse_zero();
118
119 for &community in &neighbor_communities {
120 if community == current_community {
121 continue;
123 }
124
125 let mut weight_to_target = T::sparse_zero();
127 let mut sigma_target = T::sparse_zero();
128 for i in 0..n {
129 if communities[i] == community {
130 let w = graph.get(node, i);
131 weight_to_target = weight_to_target + w;
132 sigma_target = sigma_target + degrees[i];
133 }
134 }
135
136 let add_gain = weight_to_target - resolution * k_i * sigma_target / (two * m);
138
139 let delta = add_gain - remove_cost;
141
142 if delta > best_delta {
143 best_delta = delta;
144 best_community = community;
145 }
146 }
147
148 if best_community != current_community {
149 communities[node] = best_community;
150 improvement = true;
151 }
152 }
153 }
154
155 let community_map = renumber_communities(&communities);
157 let final_communities: Vec<usize> = communities.iter().map(|&c| community_map[&c]).collect();
158
159 let num_communities = community_map.len();
160
161 Ok((num_communities, final_communities))
162}
163
164fn renumber_communities(communities: &[usize]) -> HashMap<usize, usize> {
166 let unique_communities: HashSet<usize> = communities.iter().copied().collect();
167 let mut community_map = HashMap::new();
168
169 for (new_id, &old_id) in unique_communities.iter().enumerate() {
170 community_map.insert(old_id, new_id);
171 }
172
173 community_map
174}
175
176pub fn label_propagation<T, S>(graph: &S, max_iter: usize) -> SparseResult<(usize, Vec<usize>)>
190where
191 T: Float + SparseElement + Debug + Copy + 'static,
192 S: SparseArray<T>,
193{
194 let n = graph.shape().0;
195
196 if graph.shape().0 != graph.shape().1 {
197 return Err(SparseError::ValueError(
198 "Graph matrix must be square".to_string(),
199 ));
200 }
201
202 let mut labels = (0..n).collect::<Vec<_>>();
204
205 let mut changed = true;
206 let mut iteration = 0;
207
208 while changed && iteration < max_iter {
209 changed = false;
210 iteration += 1;
211
212 for node in 0..n {
214 let mut label_counts: HashMap<usize, T> = HashMap::new();
216
217 for neighbor in 0..n {
218 let weight = graph.get(node, neighbor);
219 if !scirs2_core::SparseElement::is_zero(&weight) && neighbor != node {
220 let neighbor_label = labels[neighbor];
221 let count = label_counts
222 .entry(neighbor_label)
223 .or_insert(T::sparse_zero());
224 *count = *count + weight;
225 }
226 }
227
228 if label_counts.is_empty() {
229 continue;
230 }
231
232 let mut best_label = labels[node];
234 let mut best_count = T::sparse_zero();
235
236 for (&label, &count) in &label_counts {
237 if count > best_count {
238 best_count = count;
239 best_label = label;
240 }
241 }
242
243 if best_label != labels[node] {
244 labels[node] = best_label;
245 changed = true;
246 }
247 }
248 }
249
250 let community_map = renumber_communities(&labels);
252 let final_communities: Vec<usize> = labels.iter().map(|&c| community_map[&c]).collect();
253 let num_communities = community_map.len();
254
255 Ok((num_communities, final_communities))
256}
257
258pub fn modularity<T, S>(graph: &S, communities: &[usize]) -> SparseResult<T>
271where
272 T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
273 S: SparseArray<T>,
274{
275 let n = graph.shape().0;
276
277 if graph.shape().0 != graph.shape().1 {
278 return Err(SparseError::ValueError(
279 "Graph matrix must be square".to_string(),
280 ));
281 }
282
283 if communities.len() != n {
284 return Err(SparseError::ValueError(
285 "Communities vector must match graph size".to_string(),
286 ));
287 }
288
289 let two = T::from(2.0)
290 .ok_or_else(|| SparseError::ComputationError("Cannot convert 2.0".to_string()))?;
291
292 let mut sum_all_weights = T::sparse_zero();
294 let mut degrees = vec![T::sparse_zero(); n];
295 for i in 0..n {
296 for j in 0..n {
297 let weight = graph.get(i, j);
298 if !scirs2_core::SparseElement::is_zero(&weight) {
299 degrees[i] = degrees[i] + weight;
300 sum_all_weights = sum_all_weights + weight;
301 }
302 }
303 }
304
305 let m = sum_all_weights / two;
307
308 if scirs2_core::SparseElement::is_zero(&m) {
309 return Ok(T::sparse_zero());
310 }
311
312 let two_m = two * m;
313
314 let mut q = T::sparse_zero();
316 for i in 0..n {
317 for j in 0..n {
318 if communities[i] == communities[j] {
319 let aij = graph.get(i, j);
320 let kikj = degrees[i] * degrees[j];
321 q = q + (aij - kikj / two_m);
322 }
323 }
324 }
325
326 q = q / two_m;
327
328 Ok(q)
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::csr_array::CsrArray;
335
336 fn create_two_community_graph() -> CsrArray<f64> {
337 let rows = vec![
341 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 2, 3, ];
345 let cols = vec![
346 1, 2, 0, 2, 0, 1, 4, 5, 3, 5, 3, 4, 3, 2, ];
350 let data = vec![1.0; 14];
351
352 CsrArray::from_triplets(&rows, &cols, &data, (6, 6), false).expect("Failed to create")
353 }
354
355 #[test]
356 fn test_louvain_communities() {
357 let graph = create_two_community_graph();
358 let (num_communities, communities) = louvain_communities(&graph, 1.0, 10).expect("Failed");
359
360 assert!(num_communities >= 2);
362 assert!(num_communities <= 3);
363 assert_eq!(communities.len(), 6);
364 }
365
366 #[test]
367 fn test_label_propagation() {
368 let graph = create_two_community_graph();
369 let (num_communities, communities) = label_propagation(&graph, 10).expect("Failed");
370
371 assert!(num_communities >= 1);
373 assert_eq!(communities.len(), 6);
374 }
375
376 #[test]
377 fn test_modularity() {
378 let graph = create_two_community_graph();
379
380 let communities = vec![0, 0, 0, 1, 1, 1];
382 let q = modularity(&graph, &communities).expect("Failed");
383
384 assert!(q > 0.0);
386
387 let random_communities = vec![0, 1, 0, 1, 0, 1];
389 let q_random = modularity(&graph, &random_communities).expect("Failed");
390
391 assert!(q > q_random);
392 }
393
394 #[test]
395 fn test_single_node_communities() {
396 let graph = create_two_community_graph();
397
398 let communities = vec![0, 1, 2, 3, 4, 5];
400 let q = modularity(&graph, &communities).expect("Failed");
401
402 assert!(q < 0.3);
404 }
405
406 #[test]
407 fn test_all_same_community() {
408 let graph = create_two_community_graph();
409
410 let communities = vec![0, 0, 0, 0, 0, 0];
412 let q = modularity(&graph, &communities).expect("Failed");
413
414 assert!(q.abs() < 0.1);
416 }
417}