sqlite_knowledge_graph/algorithms/
louvain.rs1use crate::error::Result;
2use rusqlite::Connection;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct CommunityResult {
9 pub memberships: Vec<(i64, i32)>,
11 pub num_communities: i32,
13 pub modularity: f64,
15}
16
17pub fn louvain_communities(conn: &Connection) -> Result<CommunityResult> {
26 let mut init_graph: HashMap<i64, HashMap<i64, f64>> = HashMap::new();
28 let mut total_weight = 0.0;
29
30 let mut stmt = conn.prepare("SELECT source_id, target_id, weight FROM kg_relations")?;
31 let rows = stmt.query_map([], |row| {
32 Ok((
33 row.get::<_, i64>(0)?,
34 row.get::<_, i64>(1)?,
35 row.get::<_, f64>(2)?,
36 ))
37 })?;
38
39 for row in rows {
40 let (from, to, weight) = row?;
41 *init_graph.entry(from).or_default().entry(to).or_default() += weight;
42 init_graph.entry(to).or_default(); total_weight += weight;
44 }
45
46 if init_graph.is_empty() {
47 return Ok(CommunityResult {
48 memberships: Vec::new(),
49 num_communities: 0,
50 modularity: 0.0,
51 });
52 }
53
54 let orig_nodes: Vec<i64> = {
56 let mut v: Vec<i64> = init_graph.keys().copied().collect();
57 v.sort_unstable();
58 v
59 };
60 let n = orig_nodes.len();
61 let id_to_idx: HashMap<i64, usize> = orig_nodes
62 .iter()
63 .enumerate()
64 .map(|(i, &id)| (id, i))
65 .collect();
66
67 let mut work_graph: HashMap<usize, HashMap<usize, f64>> = HashMap::new();
69 for (&from, edges) in &init_graph {
70 let fi = id_to_idx[&from];
71 work_graph.entry(fi).or_default();
72 for (&to, &w) in edges {
73 let ti = id_to_idx[&to];
74 *work_graph.entry(fi).or_default().entry(ti).or_default() += w;
75 }
76 }
77
78 let mut orig_community: Vec<usize> = (0..n).collect();
80
81 let mut sn_members: Vec<Vec<usize>> = (0..n).map(|i| vec![i]).collect();
83
84 loop {
86 let m = sn_members.len(); let mut community: Vec<usize> = (0..m).collect();
91 let work_nodes: Vec<usize> = (0..m).collect();
92
93 let mut any_improved = false;
94 let mut phase_improved = true;
95 let mut iter = 0;
96
97 while phase_improved && iter < 100 {
98 phase_improved = false;
99 iter += 1;
100
101 for &node in &work_nodes {
102 let cur_comm = community[node];
103
104 let neighbors: Vec<usize> = work_graph
105 .get(&node)
106 .map(|e| e.keys().copied().collect())
107 .unwrap_or_default();
108
109 let mut best_comm = cur_comm;
110 let mut best_gain = 0.0_f64;
111
112 for &nbr in &neighbors {
113 let nbr_comm = community[nbr];
114 if nbr_comm == cur_comm {
115 continue;
116 }
117
118 let gain =
119 modularity_gain(&work_graph, node, nbr_comm, &community, total_weight);
120 if gain > best_gain {
121 best_gain = gain;
122 best_comm = nbr_comm;
123 }
124 }
125
126 if best_comm != cur_comm {
127 community[node] = best_comm;
128 phase_improved = true;
129 any_improved = true;
130 }
131 }
132 }
133
134 if !any_improved {
135 break; }
137
138 let mut unique_comms: Vec<usize> = community.clone();
140 unique_comms.sort_unstable();
141 unique_comms.dedup();
142 let num_new = unique_comms.len();
143
144 let mut comm_remap = vec![0usize; m];
147 for (new_id, &old_comm) in unique_comms.iter().enumerate() {
148 comm_remap[old_comm] = new_id;
149 }
150
151 for (sn, members) in sn_members.iter().enumerate() {
153 let new_comm = comm_remap[community[sn]];
154 for &orig in members {
155 orig_community[orig] = new_comm;
156 }
157 }
158
159 if num_new == m {
160 break;
162 }
163
164 let mut new_sn_members: Vec<Vec<usize>> = vec![Vec::new(); num_new];
166 for (sn, members) in sn_members.iter().enumerate() {
167 let new_sn = comm_remap[community[sn]];
168 new_sn_members[new_sn].extend_from_slice(members);
169 }
170
171 let mut new_graph: HashMap<usize, HashMap<usize, f64>> =
172 (0..num_new).map(|i| (i, HashMap::new())).collect();
173 for (&from_sn, edges) in &work_graph {
174 let from_new = comm_remap[community[from_sn]];
175 for (&to_sn, &weight) in edges {
176 let to_new = comm_remap[community[to_sn]];
177 *new_graph
179 .entry(from_new)
180 .or_default()
181 .entry(to_new)
182 .or_default() += weight;
183 }
184 }
185
186 work_graph = new_graph;
187 sn_members = new_sn_members;
188 }
189
190 let mut comm_to_final: HashMap<usize, i32> = HashMap::new();
193 let mut next_id = 0i32;
194
195 let memberships: Vec<(i64, i32)> = orig_nodes
196 .iter()
197 .enumerate()
198 .map(|(i, &entity_id)| {
199 let comm = orig_community[i];
200 let final_comm = *comm_to_final.entry(comm).or_insert_with(|| {
201 let id = next_id;
202 next_id += 1;
203 id
204 });
205 (entity_id, final_comm)
206 })
207 .collect();
208
209 let num_communities = next_id;
210
211 let final_comm_map: HashMap<i64, usize> = orig_nodes
213 .iter()
214 .enumerate()
215 .map(|(i, &id)| (id, orig_community[i]))
216 .collect();
217 let modularity = compute_modularity(&init_graph, &final_comm_map, total_weight);
218
219 Ok(CommunityResult {
220 memberships,
221 num_communities,
222 modularity,
223 })
224}
225
226fn modularity_gain(
234 graph: &HashMap<usize, HashMap<usize, f64>>,
235 node: usize,
236 target_community: usize,
237 community: &[usize],
238 total_weight: f64,
239) -> f64 {
240 if total_weight == 0.0 {
241 return 0.0;
242 }
243 let m = total_weight;
244
245 let k_i: f64 = graph
246 .get(&node)
247 .map(|edges| edges.values().sum())
248 .unwrap_or(0.0);
249
250 let k_i_in: f64 = graph
251 .get(&node)
252 .map(|edges| {
253 edges
254 .iter()
255 .filter(|(&nbr, _)| community[nbr] == target_community)
256 .map(|(_, &w)| w)
257 .sum()
258 })
259 .unwrap_or(0.0);
260
261 let k_tot: f64 = graph
263 .iter()
264 .filter(|(&id, _)| id != node && community[id] == target_community)
265 .map(|(_, edges)| edges.values().sum::<f64>())
266 .sum();
267
268 k_i_in / m - k_tot * k_i / (2.0 * m * m)
269}
270
271fn compute_modularity(
273 graph: &HashMap<i64, HashMap<i64, f64>>,
274 community: &HashMap<i64, usize>,
275 total_weight: f64,
276) -> f64 {
277 if total_weight == 0.0 {
278 return 0.0;
279 }
280 let mut q = 0.0;
281 for (&from, edges) in graph {
282 for (&to, &weight) in edges {
283 if community.get(&from) == community.get(&to) {
284 q += weight / total_weight;
285 }
286 }
287 }
288 q
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 fn setup_test_db() -> Connection {
296 let conn = Connection::open_in_memory().unwrap();
297 crate::schema::create_schema(&conn).unwrap();
298
299 use crate::graph::entity::{insert_entity, Entity};
301 use crate::graph::relation::{insert_relation, Relation};
302 let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
303 let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
304 let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
305 let id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
306 let id5 = insert_entity(&conn, &Entity::new("node", "Node 5")).unwrap();
307 let id6 = insert_entity(&conn, &Entity::new("node", "Node 6")).unwrap();
308 insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
309 insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
310 insert_relation(&conn, &Relation::new(id4, id5, "link", 1.0).unwrap()).unwrap();
311 insert_relation(&conn, &Relation::new(id5, id6, "link", 1.0).unwrap()).unwrap();
312 insert_relation(&conn, &Relation::new(id3, id4, "link", 0.1).unwrap()).unwrap();
313
314 conn
315 }
316
317 #[test]
318 fn test_louvain() {
319 let conn = setup_test_db();
320 let result = louvain_communities(&conn).unwrap();
321
322 assert!(result.num_communities >= 1);
323 assert_eq!(result.memberships.len(), 6);
324 assert!(result.num_communities <= 2);
326 }
327
328 #[test]
329 fn test_empty_graph() {
330 let conn = Connection::open_in_memory().unwrap();
331 crate::schema::create_schema(&conn).unwrap();
332
333 let result = louvain_communities(&conn).unwrap();
334 assert_eq!(result.num_communities, 0);
335 }
336
337 #[test]
338 fn test_single_community() {
339 let conn = Connection::open_in_memory().unwrap();
340 crate::schema::create_schema(&conn).unwrap();
341
342 use crate::graph::entity::{insert_entity, Entity};
344 use crate::graph::relation::{insert_relation, Relation};
345 let id1 = insert_entity(&conn, &Entity::new("node", "A")).unwrap();
346 let id2 = insert_entity(&conn, &Entity::new("node", "B")).unwrap();
347 let id3 = insert_entity(&conn, &Entity::new("node", "C")).unwrap();
348 insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
349 insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
350 insert_relation(&conn, &Relation::new(id1, id3, "link", 1.0).unwrap()).unwrap();
351
352 let result = louvain_communities(&conn).unwrap();
353 assert_eq!(result.memberships.len(), 3);
354 assert!(result.num_communities >= 1);
355 }
356}