sqlite_knowledge_graph/algorithms/
louvain.rs1use crate::error::Result;
2use rusqlite::Connection;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct CommunityResult {
9 pub memberships: Vec<(i64, i32)>,
11 pub num_communities: i32,
13 pub modularity: f64,
15}
16
17pub fn louvain_communities(conn: &Connection) -> Result<CommunityResult> {
26 let mut graph: HashMap<i64, HashMap<i64, f64>> = HashMap::new();
28 let mut total_weight = 0.0;
29
30 let mut stmt = conn.prepare("SELECT source_id, target_id, weight FROM kg_relations")?;
31
32 let rows = stmt.query_map([], |row| {
33 Ok((
34 row.get::<_, i64>(0)?,
35 row.get::<_, i64>(1)?,
36 row.get::<_, f64>(2)?,
37 ))
38 })?;
39
40 for row in rows {
41 let (from, to, weight) = row?;
42 *graph.entry(from).or_default().entry(to).or_default() += weight;
43 graph.entry(to).or_default(); total_weight += weight;
45 }
46
47 if graph.is_empty() {
48 return Ok(CommunityResult {
49 memberships: Vec::new(),
50 num_communities: 0,
51 modularity: 0.0,
52 });
53 }
54
55 let nodes: Vec<i64> = graph.keys().copied().collect();
56 let _n = nodes.len();
57
58 let mut community: HashMap<i64, i32> = nodes
60 .iter()
61 .enumerate()
62 .map(|(i, &id)| (id, i as i32))
63 .collect();
64 let mut improved = true;
65 let mut iteration = 0;
66
67 while improved && iteration < 100 {
68 improved = false;
69 iteration += 1;
70
71 for &node in &nodes {
72 let current_community = community[&node];
73
74 let neighbors: Vec<i64> = graph
76 .get(&node)
77 .map(|edges| edges.keys().copied().collect())
78 .unwrap_or_default();
79
80 let mut best_community = current_community;
81 let mut best_gain = 0.0;
82
83 for &neighbor in &neighbors {
84 let neighbor_community = community[&neighbor];
85 if neighbor_community == current_community {
86 continue;
87 }
88
89 let gain = calculate_modularity_gain(
91 &graph,
92 node,
93 neighbor_community,
94 &community,
95 total_weight,
96 );
97
98 if gain > best_gain {
99 best_gain = gain;
100 best_community = neighbor_community;
101 }
102 }
103
104 if best_community != current_community {
105 community.insert(node, best_community);
106 improved = true;
107 }
108 }
109 }
110
111 let mut community_map: HashMap<i32, i32> = HashMap::new();
113 let mut next_id = 0i32;
114
115 for &comm in community.values() {
116 if let std::collections::hash_map::Entry::Vacant(e) = community_map.entry(comm) {
117 e.insert(next_id);
118 next_id += 1;
119 }
120 }
121
122 let memberships: Vec<(i64, i32)> = nodes
123 .iter()
124 .map(|&id| (id, community_map[&community[&id]]))
125 .collect();
126
127 let modularity = calculate_modularity(&graph, &community, total_weight);
129
130 Ok(CommunityResult {
131 memberships,
132 num_communities: next_id,
133 modularity,
134 })
135}
136
137fn calculate_modularity_gain(
138 graph: &HashMap<i64, HashMap<i64, f64>>,
139 node: i64,
140 target_community: i32,
141 community: &HashMap<i64, i32>,
142 total_weight: f64,
143) -> f64 {
144 if total_weight == 0.0 {
145 return 0.0;
146 }
147
148 let m = total_weight;
149
150 let k_i: f64 = graph
152 .get(&node)
153 .map(|edges| edges.values().sum())
154 .unwrap_or(0.0);
155
156 let k_i_in: f64 = graph
158 .get(&node)
159 .map(|edges| {
160 edges
161 .iter()
162 .filter(|(&nbr, _)| community.get(&nbr) == Some(&target_community))
163 .map(|(_, &w)| w)
164 .sum()
165 })
166 .unwrap_or(0.0);
167
168 let k_tot: f64 = graph
170 .iter()
171 .filter(|(&id, _)| id != node && community.get(&id) == Some(&target_community))
172 .map(|(_, edges)| edges.values().sum::<f64>())
173 .sum();
174
175 k_i_in / m - k_tot * k_i / (2.0 * m * m)
177}
178
179fn calculate_modularity(
180 graph: &HashMap<i64, HashMap<i64, f64>>,
181 community: &HashMap<i64, i32>,
182 total_weight: f64,
183) -> f64 {
184 if total_weight == 0.0 {
185 return 0.0;
186 }
187
188 let mut modularity = 0.0;
189
190 for (&from, edges) in graph {
191 for (&to, &weight) in edges {
192 if community.get(&from) == community.get(&to) {
193 modularity += weight / total_weight;
194 }
195 }
196 }
197
198 modularity
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 fn setup_test_db() -> Connection {
206 let conn = Connection::open_in_memory().unwrap();
207 crate::schema::create_schema(&conn).unwrap();
208
209 use crate::graph::entity::{insert_entity, Entity};
211 use crate::graph::relation::{insert_relation, Relation};
212 let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
213 let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
214 let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
215 let id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
216 let id5 = insert_entity(&conn, &Entity::new("node", "Node 5")).unwrap();
217 let id6 = insert_entity(&conn, &Entity::new("node", "Node 6")).unwrap();
218 insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
219 insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
220 insert_relation(&conn, &Relation::new(id4, id5, "link", 1.0).unwrap()).unwrap();
221 insert_relation(&conn, &Relation::new(id5, id6, "link", 1.0).unwrap()).unwrap();
222 insert_relation(&conn, &Relation::new(id3, id4, "link", 0.1).unwrap()).unwrap();
223
224 conn
225 }
226
227 #[test]
228 fn test_louvain() {
229 let conn = setup_test_db();
230 let result = louvain_communities(&conn).unwrap();
231
232 assert!(result.num_communities >= 1);
233 assert!(result.memberships.len() == 6);
234 }
235
236 #[test]
237 fn test_empty_graph() {
238 let conn = Connection::open_in_memory().unwrap();
239 crate::schema::create_schema(&conn).unwrap();
240
241 let result = louvain_communities(&conn).unwrap();
242 assert_eq!(result.num_communities, 0);
243 }
244}