scirs2_graph/algorithms/community/
infomap.rs1use super::modularity::modularity;
4use super::types::CommunityResult;
5use crate::base::{EdgeWeight, Graph, IndexType, Node};
6use std::collections::HashMap;
7use std::hash::Hash;
8
9#[derive(Debug, Clone)]
11pub struct InfomapResult<N: Node> {
12 pub node_communities: HashMap<N, usize>,
14 pub code_length: f64,
16 pub modularity: f64,
18}
19
20#[allow(dead_code)]
44pub fn infomap_communities<N, E, Ix>(
45 graph: &Graph<N, E, Ix>,
46 max_iterations: usize,
47 tolerance: f64,
48) -> InfomapResult<N>
49where
50 N: Node + Clone + Hash + Eq + std::fmt::Debug,
51 E: EdgeWeight + Into<f64> + Copy,
52 Ix: IndexType,
53{
54 let nodes: Vec<N> = graph.nodes().into_iter().cloned().collect();
55 let n = nodes.len();
56
57 if n == 0 {
58 return InfomapResult {
59 node_communities: HashMap::new(),
60 code_length: 0.0,
61 modularity: 0.0,
62 };
63 }
64
65 let (transition_matrix, node_weights) = build_transition_matrix(graph, &nodes);
67 let stationary_probs = compute_stationary_distribution(&transition_matrix, &node_weights);
68
69 let mut communities: HashMap<N, usize> = nodes
71 .iter()
72 .enumerate()
73 .map(|(i, node)| (node.clone(), i))
74 .collect();
75
76 let mut current_code_length = calculate_map_equation(
77 graph,
78 &communities,
79 &transition_matrix,
80 &stationary_probs,
81 &nodes,
82 );
83 let mut best_communities = communities.clone();
84 let mut best_code_length = current_code_length;
85
86 let mut rng = scirs2_core::random::rng();
87
88 for _iteration in 0..max_iterations {
89 let mut improved = false;
90
91 for node in &nodes {
93 let current_community = communities[node];
94 let mut best_community = current_community;
95 let mut best_local_code_length = current_code_length;
96
97 let mut neighboring_communities = std::collections::HashSet::new();
99 if let Ok(neighbors) = graph.neighbors(node) {
100 for neighbor in neighbors {
101 if let Some(&comm) = communities.get(&neighbor) {
102 neighboring_communities.insert(comm);
103 }
104 }
105 }
106
107 for &candidate_community in &neighboring_communities {
109 if candidate_community != current_community {
110 communities.insert(node.clone(), candidate_community);
111 let new_code_length = calculate_map_equation(
112 graph,
113 &communities,
114 &transition_matrix,
115 &stationary_probs,
116 &nodes,
117 );
118
119 if new_code_length < best_local_code_length {
120 best_local_code_length = new_code_length;
121 best_community = candidate_community;
122 }
123 }
124 }
125
126 if best_community != current_community {
128 communities.insert(node.clone(), best_community);
129 current_code_length = best_local_code_length;
130 improved = true;
131
132 if current_code_length < best_code_length {
133 best_code_length = current_code_length;
134 best_communities = communities.clone();
135 }
136 } else {
137 communities.insert(node.clone(), current_community);
139 }
140 }
141
142 if !improved || (best_code_length - current_code_length).abs() < tolerance {
144 break;
145 }
146 }
147
148 let mut community_map: HashMap<usize, usize> = HashMap::new();
150 let mut next_id = 0;
151 for &comm in best_communities.values() {
152 if let std::collections::hash_map::Entry::Vacant(e) = community_map.entry(comm) {
153 e.insert(next_id);
154 next_id += 1;
155 }
156 }
157
158 for (_, comm) in best_communities.iter_mut() {
160 *comm = community_map[comm];
161 }
162
163 let final_modularity = modularity(graph, &best_communities);
164
165 InfomapResult {
166 node_communities: best_communities,
167 code_length: best_code_length,
168 modularity: final_modularity,
169 }
170}
171
172#[allow(dead_code)]
174fn build_transition_matrix<N, E, Ix>(
175 graph: &Graph<N, E, Ix>,
176 nodes: &[N],
177) -> (Vec<Vec<f64>>, Vec<f64>)
178where
179 N: Node + std::fmt::Debug,
180 E: EdgeWeight + Into<f64> + Copy,
181 Ix: IndexType,
182{
183 let n = nodes.len();
184 let mut transition_matrix = vec![vec![0.0; n]; n];
185 let mut node_weights = vec![0.0; n];
186
187 let node_to_idx: HashMap<&N, usize> = nodes.iter().enumerate().map(|(i, n)| (n, i)).collect();
189
190 for (i, node) in nodes.iter().enumerate() {
192 let mut total_weight = 0.0;
193
194 if let Ok(neighbors) = graph.neighbors(node) {
196 for neighbor in neighbors {
197 if let Ok(weight) = graph.edge_weight(node, &neighbor) {
198 total_weight += weight.into();
199 }
200 }
201 }
202
203 node_weights[i] = total_weight;
204
205 if total_weight > 0.0 {
207 if let Ok(neighbors) = graph.neighbors(node) {
208 for neighbor in neighbors {
209 if let Some(&j) = node_to_idx.get(&neighbor) {
210 if let Ok(weight) = graph.edge_weight(node, &neighbor) {
211 transition_matrix[i][j] = weight.into() / total_weight;
212 }
213 }
214 }
215 }
216 } else {
217 for j in 0..n {
219 transition_matrix[i][j] = 1.0 / n as f64;
220 }
221 }
222 }
223
224 (transition_matrix, node_weights)
225}
226
227#[allow(dead_code)]
229fn compute_stationary_distribution(
230 transition_matrix: &[Vec<f64>],
231 node_weights: &[f64],
232) -> Vec<f64> {
233 let n = transition_matrix.len();
234 if n == 0 {
235 return vec![];
236 }
237
238 let total_weight: f64 = node_weights.iter().sum();
240 let mut pi = if total_weight > 0.0 {
241 node_weights.iter().map(|&w| w / total_weight).collect()
242 } else {
243 vec![1.0 / n as f64; n]
244 };
245
246 for _ in 0..1000 {
248 let mut new_pi = vec![0.0; n];
249
250 for (i, new_pi_item) in new_pi.iter_mut().enumerate().take(n) {
251 for j in 0..n {
252 *new_pi_item += pi[j] * transition_matrix[j][i];
253 }
254 }
255
256 let sum: f64 = new_pi.iter().sum();
258 if sum > 0.0 {
259 for p in new_pi.iter_mut() {
260 *p /= sum;
261 }
262 }
263
264 let diff: f64 = pi
266 .iter()
267 .zip(&new_pi)
268 .map(|(old, new)| (old - new).abs())
269 .sum();
270
271 pi = new_pi;
272
273 if diff < 1e-10 {
274 break;
275 }
276 }
277
278 pi
279}
280
281#[allow(dead_code)]
283fn calculate_map_equation<N, E, Ix>(
284 graph: &Graph<N, E, Ix>,
285 communities: &HashMap<N, usize>,
286 transition_matrix: &[Vec<f64>],
287 stationary_probs: &[f64],
288 nodes: &[N],
289) -> f64
290where
291 N: Node + std::fmt::Debug,
292 E: EdgeWeight + Into<f64> + Copy,
293 Ix: IndexType,
294{
295 let n = nodes.len();
296 if n == 0 {
297 return 0.0;
298 }
299
300 let node_to_idx: HashMap<&N, usize> = nodes.iter().enumerate().map(|(i, n)| (n, i)).collect();
302
303 let mut community_exit_prob: HashMap<usize, f64> = HashMap::new();
305 let mut community_flow: HashMap<usize, f64> = HashMap::new();
306
307 for &comm in communities.values() {
309 community_exit_prob.insert(comm, 0.0);
310 community_flow.insert(comm, 0.0);
311 }
312
313 for (node, &comm) in communities {
315 if let Some(&i) = node_to_idx.get(node) {
316 let pi_i = stationary_probs[i];
317 *community_flow.get_mut(&comm).unwrap() += pi_i;
318
319 if let Ok(neighbors) = graph.neighbors(node) {
321 for neighbor in neighbors {
322 if let Some(&neighbor_comm) = communities.get(&neighbor) {
323 if neighbor_comm != comm {
324 if let Some(&j) = node_to_idx.get(&neighbor) {
325 *community_exit_prob.get_mut(&comm).unwrap() +=
326 pi_i * transition_matrix[i][j];
327 }
328 }
329 }
330 }
331 }
332 }
333 }
334
335 let mut code_length = 0.0;
337
338 let total_exit_flow: f64 = community_exit_prob.values().sum();
340 if total_exit_flow > 0.0 {
341 for &q_alpha in community_exit_prob.values() {
342 if q_alpha > 0.0 {
343 code_length -= q_alpha * (q_alpha / total_exit_flow).ln();
344 }
345 }
346 }
347
348 for (&comm, &q_alpha) in &community_exit_prob {
350 let p_alpha = community_flow[&comm];
351 let total_alpha = q_alpha + p_alpha;
352
353 if total_alpha > 0.0 {
354 let mut h_alpha = 0.0;
356
357 if q_alpha > 0.0 {
359 h_alpha -= (q_alpha / total_alpha) * (q_alpha / total_alpha).ln();
360 }
361
362 for (node, &node_comm) in communities {
364 if node_comm == comm {
365 if let Some(&i) = node_to_idx.get(node) {
366 let pi_i = stationary_probs[i];
367 if pi_i > 0.0 {
368 let prob_in_module = pi_i / total_alpha;
369 h_alpha -= prob_in_module * prob_in_module.ln();
370 }
371 }
372 }
373 }
374
375 code_length += total_alpha * h_alpha;
376 }
377 }
378
379 code_length
380}