Skip to main content

sqlite_knowledge_graph/algorithms/
louvain.rs

1use crate::error::Result;
2/// Louvain community detection algorithm (full two-phase)
3use rusqlite::Connection;
4use std::collections::HashMap;
5
6/// Community detection result
7#[derive(Debug, Clone)]
8pub struct CommunityResult {
9    /// Entity to community mapping
10    pub memberships: Vec<(i64, i32)>,
11    /// Number of communities
12    pub num_communities: i32,
13    /// Modularity score
14    pub modularity: f64,
15}
16
17/// Compute communities using the full two-phase Louvain algorithm.
18///
19/// - Phase 1: each node greedily moves to the adjacent community that
20///   maximises modularity gain (local optimisation).
21/// - Phase 2: all communities are collapsed into super-nodes; the graph is
22///   rebuilt with aggregated edge weights and Phase 1 is repeated.
23///
24/// The two phases alternate until no further improvement is possible.
25pub fn louvain_communities(conn: &Connection) -> Result<CommunityResult> {
26    // ── Build initial weighted graph from DB ──────────────────────────────
27    let mut init_graph: HashMap<i64, HashMap<i64, f64>> = HashMap::new();
28    let mut total_weight = 0.0;
29
30    let mut stmt = conn.prepare("SELECT source_id, target_id, weight FROM kg_relations")?;
31    let rows = stmt.query_map([], |row| {
32        Ok((
33            row.get::<_, i64>(0)?,
34            row.get::<_, i64>(1)?,
35            row.get::<_, f64>(2)?,
36        ))
37    })?;
38
39    for row in rows {
40        let (from, to, weight) = row?;
41        *init_graph.entry(from).or_default().entry(to).or_default() += weight;
42        init_graph.entry(to).or_default(); // ensure target node exists
43        total_weight += weight;
44    }
45
46    if init_graph.is_empty() {
47        return Ok(CommunityResult {
48            memberships: Vec::new(),
49            num_communities: 0,
50            modularity: 0.0,
51        });
52    }
53
54    // Stable ordering so tests are deterministic
55    let orig_nodes: Vec<i64> = {
56        let mut v: Vec<i64> = init_graph.keys().copied().collect();
57        v.sort_unstable();
58        v
59    };
60    let n = orig_nodes.len();
61    let id_to_idx: HashMap<i64, usize> = orig_nodes
62        .iter()
63        .enumerate()
64        .map(|(i, &id)| (id, i))
65        .collect();
66
67    // Convert to usize-keyed graph (work_graph nodes = 0..n-1 initially)
68    let mut work_graph: HashMap<usize, HashMap<usize, f64>> = HashMap::new();
69    for (&from, edges) in &init_graph {
70        let fi = id_to_idx[&from];
71        work_graph.entry(fi).or_default();
72        for (&to, &w) in edges {
73            let ti = id_to_idx[&to];
74            *work_graph.entry(fi).or_default().entry(ti).or_default() += w;
75        }
76    }
77
78    // orig_community[i] = final community index for original node i
79    let mut orig_community: Vec<usize> = (0..n).collect();
80
81    // sn_members[sn] = original node indices belonging to super-node sn
82    let mut sn_members: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
83
84    // ── Alternating Phase 1 / Phase 2 ─────────────────────────────────────
85    loop {
86        let m = sn_members.len(); // current number of super-nodes / work-nodes
87
88        // Phase 1 ── greedy local moves
89        // Initial assignment: each super-node in its own community
90        let mut community: Vec<usize> = (0..m).collect();
91        let work_nodes: Vec<usize> = (0..m).collect();
92
93        let mut any_improved = false;
94        let mut phase_improved = true;
95        let mut iter = 0;
96
97        while phase_improved && iter < 100 {
98            phase_improved = false;
99            iter += 1;
100
101            for &node in &work_nodes {
102                let cur_comm = community[node];
103
104                let neighbors: Vec<usize> = work_graph
105                    .get(&node)
106                    .map(|e| e.keys().copied().collect())
107                    .unwrap_or_default();
108
109                let mut best_comm = cur_comm;
110                let mut best_gain = 0.0_f64;
111
112                for &nbr in &neighbors {
113                    let nbr_comm = community[nbr];
114                    if nbr_comm == cur_comm {
115                        continue;
116                    }
117
118                    let gain =
119                        modularity_gain(&work_graph, node, nbr_comm, &community, total_weight);
120                    if gain > best_gain {
121                        best_gain = gain;
122                        best_comm = nbr_comm;
123                    }
124                }
125
126                if best_comm != cur_comm {
127                    community[node] = best_comm;
128                    phase_improved = true;
129                    any_improved = true;
130                }
131            }
132        }
133
134        if !any_improved {
135            break; // converged globally
136        }
137
138        // Renumber communities to 0..num_new-1
139        let mut unique_comms: Vec<usize> = community.clone();
140        unique_comms.sort_unstable();
141        unique_comms.dedup();
142        let num_new = unique_comms.len();
143
144        // comm_remap[old_community_id] = new_community_id
145        // All community IDs are in 0..m-1, so a Vec of size m works.
146        let mut comm_remap = vec![0usize; m];
147        for (new_id, &old_comm) in unique_comms.iter().enumerate() {
148            comm_remap[old_comm] = new_id;
149        }
150
151        // Propagate community assignments back to original nodes
152        for (sn, members) in sn_members.iter().enumerate() {
153            let new_comm = comm_remap[community[sn]];
154            for &orig in members {
155                orig_community[orig] = new_comm;
156            }
157        }
158
159        if num_new == m {
160            // Phase 1 didn't reduce the number of communities → done
161            break;
162        }
163
164        // Phase 2 ── aggregate into super-nodes
165        let mut new_sn_members: Vec<Vec<usize>> = vec![Vec::new(); num_new];
166        for (sn, members) in sn_members.iter().enumerate() {
167            let new_sn = comm_remap[community[sn]];
168            new_sn_members[new_sn].extend_from_slice(members);
169        }
170
171        let mut new_graph: HashMap<usize, HashMap<usize, f64>> =
172            (0..num_new).map(|i| (i, HashMap::new())).collect();
173        for (&from_sn, edges) in &work_graph {
174            let from_new = comm_remap[community[from_sn]];
175            for (&to_sn, &weight) in edges {
176                let to_new = comm_remap[community[to_sn]];
177                // Self-loops are included; they affect total degree but not ΔQ.
178                *new_graph
179                    .entry(from_new)
180                    .or_default()
181                    .entry(to_new)
182                    .or_default() += weight;
183            }
184        }
185
186        work_graph = new_graph;
187        sn_members = new_sn_members;
188    }
189
190    // ── Build final result ─────────────────────────────────────────────────
191    // Assign consecutive i32 community IDs for the public API
192    let mut comm_to_final: HashMap<usize, i32> = HashMap::new();
193    let mut next_id = 0i32;
194
195    let memberships: Vec<(i64, i32)> = orig_nodes
196        .iter()
197        .enumerate()
198        .map(|(i, &entity_id)| {
199            let comm = orig_community[i];
200            let final_comm = *comm_to_final.entry(comm).or_insert_with(|| {
201                let id = next_id;
202                next_id += 1;
203                id
204            });
205            (entity_id, final_comm)
206        })
207        .collect();
208
209    let num_communities = next_id;
210
211    // Compute modularity on the original (unaggregated) graph
212    let final_comm_map: HashMap<i64, usize> = orig_nodes
213        .iter()
214        .enumerate()
215        .map(|(i, &id)| (id, orig_community[i]))
216        .collect();
217    let modularity = compute_modularity(&init_graph, &final_comm_map, total_weight);
218
219    Ok(CommunityResult {
220        memberships,
221        num_communities,
222        modularity,
223    })
224}
225
226/// Louvain modularity gain ΔQ for moving `node` into `target_community`.
227///
228/// ΔQ = k_{i,in} / m  −  Σ_tot · k_i / (2m²)
229///
230/// where m = total edge weight, k_i = degree of node,
231/// k_{i,in} = weights from node to nodes already in target_community,
232/// Σ_tot = sum of degrees of all nodes currently in target_community.
233fn modularity_gain(
234    graph: &HashMap<usize, HashMap<usize, f64>>,
235    node: usize,
236    target_community: usize,
237    community: &[usize],
238    total_weight: f64,
239) -> f64 {
240    if total_weight == 0.0 {
241        return 0.0;
242    }
243    let m = total_weight;
244
245    let k_i: f64 = graph
246        .get(&node)
247        .map(|edges| edges.values().sum())
248        .unwrap_or(0.0);
249
250    let k_i_in: f64 = graph
251        .get(&node)
252        .map(|edges| {
253            edges
254                .iter()
255                .filter(|(&nbr, _)| community[nbr] == target_community)
256                .map(|(_, &w)| w)
257                .sum()
258        })
259        .unwrap_or(0.0);
260
261    // Sum of degrees of all nodes *already* in target_community (excluding `node`)
262    let k_tot: f64 = graph
263        .iter()
264        .filter(|(&id, _)| id != node && community[id] == target_community)
265        .map(|(_, edges)| edges.values().sum::<f64>())
266        .sum();
267
268    k_i_in / m - k_tot * k_i / (2.0 * m * m)
269}
270
271/// Compute modularity Q on the original graph given a final community assignment.
272fn compute_modularity(
273    graph: &HashMap<i64, HashMap<i64, f64>>,
274    community: &HashMap<i64, usize>,
275    total_weight: f64,
276) -> f64 {
277    if total_weight == 0.0 {
278        return 0.0;
279    }
280    let mut q = 0.0;
281    for (&from, edges) in graph {
282        for (&to, &weight) in edges {
283            if community.get(&from) == community.get(&to) {
284                q += weight / total_weight;
285            }
286        }
287    }
288    q
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    fn setup_test_db() -> Connection {
296        let conn = Connection::open_in_memory().unwrap();
297        crate::schema::create_schema(&conn).unwrap();
298
299        // Two communities: 1-2-3 and 4-5-6, with a weak cross-link 3→4
300        use crate::graph::entity::{insert_entity, Entity};
301        use crate::graph::relation::{insert_relation, Relation};
302        let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
303        let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
304        let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
305        let id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
306        let id5 = insert_entity(&conn, &Entity::new("node", "Node 5")).unwrap();
307        let id6 = insert_entity(&conn, &Entity::new("node", "Node 6")).unwrap();
308        insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
309        insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
310        insert_relation(&conn, &Relation::new(id4, id5, "link", 1.0).unwrap()).unwrap();
311        insert_relation(&conn, &Relation::new(id5, id6, "link", 1.0).unwrap()).unwrap();
312        insert_relation(&conn, &Relation::new(id3, id4, "link", 0.1).unwrap()).unwrap();
313
314        conn
315    }
316
317    #[test]
318    fn test_louvain() {
319        let conn = setup_test_db();
320        let result = louvain_communities(&conn).unwrap();
321
322        assert!(result.num_communities >= 1);
323        assert_eq!(result.memberships.len(), 6);
324        // With strong intra-cluster edges and a weak cross-link, we expect 2 communities
325        assert!(result.num_communities <= 2);
326    }
327
328    #[test]
329    fn test_empty_graph() {
330        let conn = Connection::open_in_memory().unwrap();
331        crate::schema::create_schema(&conn).unwrap();
332
333        let result = louvain_communities(&conn).unwrap();
334        assert_eq!(result.num_communities, 0);
335    }
336
337    #[test]
338    fn test_single_community() {
339        let conn = Connection::open_in_memory().unwrap();
340        crate::schema::create_schema(&conn).unwrap();
341
342        // Fully connected triangle — should form one community
343        use crate::graph::entity::{insert_entity, Entity};
344        use crate::graph::relation::{insert_relation, Relation};
345        let id1 = insert_entity(&conn, &Entity::new("node", "A")).unwrap();
346        let id2 = insert_entity(&conn, &Entity::new("node", "B")).unwrap();
347        let id3 = insert_entity(&conn, &Entity::new("node", "C")).unwrap();
348        insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
349        insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
350        insert_relation(&conn, &Relation::new(id1, id3, "link", 1.0).unwrap()).unwrap();
351
352        let result = louvain_communities(&conn).unwrap();
353        assert_eq!(result.memberships.len(), 3);
354        assert!(result.num_communities >= 1);
355    }
356}