Skip to main content

uni_algo/algo/algorithms/
louvain.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Louvain Community Detection Algorithm.
5
6use crate::algo::GraphProjection;
7use crate::algo::algorithms::Algorithm;
8use std::collections::HashMap;
9use uni_common::core::id::Vid;
10
11pub struct Louvain;
12
13#[derive(Debug, Clone)]
14pub struct LouvainConfig {
15    pub resolution: f64,
16    pub max_iterations: usize,
17    pub min_modularity_gain: f64,
18}
19
20impl Default for LouvainConfig {
21    fn default() -> Self {
22        Self {
23            resolution: 1.0,
24            max_iterations: 10,
25            min_modularity_gain: 1e-4,
26        }
27    }
28}
29
30pub struct LouvainResult {
31    pub communities: Vec<(Vid, u64)>,
32    pub modularity: f64,
33    pub community_count: usize,
34}
35
36impl Algorithm for Louvain {
37    type Config = LouvainConfig;
38    type Result = LouvainResult;
39
40    fn name() -> &'static str {
41        "louvain"
42    }
43
44    fn run(graph: &GraphProjection, config: Self::Config) -> Self::Result {
45        let n = graph.vertex_count();
46        if n == 0 {
47            return LouvainResult {
48                communities: Vec::new(),
49                modularity: 0.0,
50                community_count: 0,
51            };
52        }
53
54        // Initialize: each node in its own community
55        let mut community: Vec<u32> = (0..n as u32).collect();
56
57        // Total edge weight (m)
58        // For unweighted graph, m = edge_count / 2 (since each edge is counted twice if bidirectional)
59        // But Uni GraphProjection might be directed.
60        // Louvain usually works on undirected graphs.
61        // We treat it as undirected by summing all out_degrees.
62        let mut m: f64 = 0.0;
63        let mut node_weights = vec![0.0; n];
64        for v in 0..n as u32 {
65            let mut deg = graph.out_degree(v) as f64;
66            if graph.has_reverse() {
67                deg += graph.in_degree(v) as f64;
68            }
69            m += deg;
70            node_weights[v as usize] = deg;
71        }
72        m /= 2.0;
73
74        if m == 0.0 {
75            return LouvainResult {
76                communities: community
77                    .into_iter()
78                    .enumerate()
79                    .map(|(i, c)| (graph.to_vid(i as u32), c as u64))
80                    .collect(),
81                modularity: 0.0,
82                community_count: n,
83            };
84        }
85
86        // Track community total weights (Sigma_tot)
87        let mut community_weights = node_weights.clone();
88
89        for _ in 0..config.max_iterations {
90            let mut improved = false;
91
92            // Phase 1: Local moves
93            for v in 0..n as u32 {
94                let v_idx = v as usize;
95                let current_comm = community[v_idx];
96                let v_weight = node_weights[v_idx];
97
98                // Find neighbor communities and weights to them (k_i,in)
99                let mut neighbor_comm_weights: HashMap<u32, f64> = HashMap::new();
100                for &u in graph.out_neighbors(v) {
101                    let u_comm = community[u as usize];
102                    *neighbor_comm_weights.entry(u_comm).or_insert(0.0) += 1.0;
103                }
104                if graph.has_reverse() {
105                    for &u in graph.in_neighbors(v) {
106                        let u_comm = community[u as usize];
107                        *neighbor_comm_weights.entry(u_comm).or_insert(0.0) += 1.0;
108                    }
109                }
110
111                let mut best_comm = current_comm;
112                let mut max_gain = 0.0;
113
114                // Remove v from current community
115                community_weights[current_comm as usize] -= v_weight;
116
117                for (&target_comm, &k_i_in) in &neighbor_comm_weights {
118                    let target_comm_weight = community_weights[target_comm as usize];
119
120                    // Modularity gain formula:
121                    // delta_Q = [ (Sigma_tot + k_i,in) / 2m - ((Sigma_tot + k_i)/2m)^2 ] - [ Sigma_tot/2m - (Sigma_tot/2m)^2 - (k_i/2m)^2 ]
122                    // Simplified: delta_Q = (1/2m) * (k_i,in - (Sigma_tot * k_i) / m)
123                    let gain =
124                        k_i_in - (target_comm_weight * v_weight * config.resolution) / (2.0 * m);
125
126                    if gain > max_gain {
127                        max_gain = gain;
128                        best_comm = target_comm;
129                    }
130                }
131
132                if max_gain > config.min_modularity_gain && best_comm != current_comm {
133                    community[v_idx] = best_comm;
134                    improved = true;
135                }
136
137                // Add v to best community
138                community_weights[community[v_idx] as usize] += v_weight;
139            }
140
141            if !improved {
142                break;
143            }
144        }
145
146        // Final modularity calculation
147        let q = compute_modularity(graph, &community, m, config.resolution);
148
149        // Map back to VIDs and renumber communities
150        let mut comm_map: HashMap<u32, u64> = HashMap::new();
151        let mut next_id = 0u64;
152        let mut results = Vec::with_capacity(n);
153        for (i, &comm) in community.iter().enumerate() {
154            let id = *comm_map.entry(comm).or_insert_with(|| {
155                let val = next_id;
156                next_id += 1;
157                val
158            });
159            results.push((graph.to_vid(i as u32), id));
160        }
161
162        LouvainResult {
163            communities: results,
164            modularity: q,
165            community_count: comm_map.len(),
166        }
167    }
168}
169
170fn compute_modularity(graph: &GraphProjection, community: &[u32], m: f64, resolution: f64) -> f64 {
171    let n = graph.vertex_count();
172    let mut q = 0.0;
173
174    // Sum over communities
175    let mut comm_internal_weights: HashMap<u32, f64> = HashMap::new();
176    let mut comm_total_weights: HashMap<u32, f64> = HashMap::new();
177
178    for v in 0..n as u32 {
179        let v_comm = community[v as usize];
180        let v_deg = graph.out_degree(v) as f64;
181        *comm_total_weights.entry(v_comm).or_insert(0.0) += v_deg;
182
183        for &u in graph.out_neighbors(v) {
184            if community[u as usize] == v_comm {
185                *comm_internal_weights.entry(v_comm).or_insert(0.0) += 1.0;
186            }
187        }
188    }
189
190    for (&comm, &internal) in &comm_internal_weights {
191        let total = comm_total_weights[&comm];
192        q += (internal / (2.0 * m)) - resolution * (total / (2.0 * m)).powi(2);
193    }
194
195    q
196}