sqlite_knowledge_graph/algorithms/
louvain.rs1use crate::error::Result;
2use rusqlite::Connection;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct CommunityResult {
9 pub memberships: Vec<(i64, i32)>,
11 pub num_communities: i32,
13 pub modularity: f64,
15}
16
17pub fn louvain_communities(conn: &Connection) -> Result<CommunityResult> {
21 let mut graph: HashMap<i64, HashMap<i64, f64>> = HashMap::new();
23 let mut total_weight = 0.0;
24
25 let mut stmt = conn.prepare("SELECT from_id, to_id, weight FROM relations")?;
26
27 let rows = stmt.query_map([], |row| {
28 Ok((
29 row.get::<_, i64>(0)?,
30 row.get::<_, i64>(1)?,
31 row.get::<_, f64>(2)?,
32 ))
33 })?;
34
35 for row in rows {
36 let (from, to, weight) = row?;
37 *graph.entry(from).or_default().entry(to).or_default() += weight;
38 graph.entry(to).or_default(); total_weight += weight;
40 }
41
42 if graph.is_empty() {
43 return Ok(CommunityResult {
44 memberships: Vec::new(),
45 num_communities: 0,
46 modularity: 0.0,
47 });
48 }
49
50 let nodes: Vec<i64> = graph.keys().copied().collect();
51 let _n = nodes.len();
52
53 let mut community: HashMap<i64, i32> = nodes
55 .iter()
56 .enumerate()
57 .map(|(i, &id)| (id, i as i32))
58 .collect();
59 let mut improved = true;
60 let mut iteration = 0;
61
62 while improved && iteration < 100 {
63 improved = false;
64 iteration += 1;
65
66 for &node in &nodes {
67 let current_community = community[&node];
68
69 let neighbors: Vec<i64> = graph
71 .get(&node)
72 .map(|edges| edges.keys().copied().collect())
73 .unwrap_or_default();
74
75 let mut best_community = current_community;
76 let mut best_gain = 0.0;
77
78 for &neighbor in &neighbors {
79 let neighbor_community = community[&neighbor];
80 if neighbor_community == current_community {
81 continue;
82 }
83
84 let gain = calculate_modularity_gain(
86 &graph,
87 node,
88 neighbor_community,
89 &community,
90 total_weight,
91 );
92
93 if gain > best_gain {
94 best_gain = gain;
95 best_community = neighbor_community;
96 }
97 }
98
99 if best_community != current_community {
100 community.insert(node, best_community);
101 improved = true;
102 }
103 }
104 }
105
106 let mut community_map: HashMap<i32, i32> = HashMap::new();
108 let mut next_id = 0i32;
109
110 for &comm in community.values() {
111 if let std::collections::hash_map::Entry::Vacant(e) = community_map.entry(comm) {
112 e.insert(next_id);
113 next_id += 1;
114 }
115 }
116
117 let memberships: Vec<(i64, i32)> = nodes
118 .iter()
119 .map(|&id| (id, community_map[&community[&id]]))
120 .collect();
121
122 let modularity = calculate_modularity(&graph, &community, total_weight);
124
125 Ok(CommunityResult {
126 memberships,
127 num_communities: next_id,
128 modularity,
129 })
130}
131
132fn calculate_modularity_gain(
133 graph: &HashMap<i64, HashMap<i64, f64>>,
134 node: i64,
135 target_community: i32,
136 community: &HashMap<i64, i32>,
137 total_weight: f64,
138) -> f64 {
139 let mut gain = 0.0;
140
141 if let Some(neighbors) = graph.get(&node) {
142 for (&neighbor, &weight) in neighbors {
143 if community.get(&neighbor) == Some(&target_community) {
144 gain += weight / total_weight;
145 }
146 }
147 }
148
149 gain
150}
151
152fn calculate_modularity(
153 graph: &HashMap<i64, HashMap<i64, f64>>,
154 community: &HashMap<i64, i32>,
155 total_weight: f64,
156) -> f64 {
157 if total_weight == 0.0 {
158 return 0.0;
159 }
160
161 let mut modularity = 0.0;
162
163 for (&from, edges) in graph {
164 for (&to, &weight) in edges {
165 if community.get(&from) == community.get(&to) {
166 modularity += weight / total_weight;
167 }
168 }
169 }
170
171 modularity
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 fn setup_test_db() -> Connection {
179 let conn = Connection::open_in_memory().unwrap();
180 conn.execute_batch(
181 "CREATE TABLE entities (id INTEGER PRIMARY KEY);
182 CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);"
183 ).unwrap();
184
185 conn.execute(
187 "INSERT INTO entities (id) VALUES (1), (2), (3), (4), (5), (6)",
188 [],
189 )
190 .unwrap();
191 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
192 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
193 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (4, 5, 'link', 1.0)", []).unwrap();
194 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (5, 6, 'link', 1.0)", []).unwrap();
195 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (3, 4, 'link', 0.1)", []).unwrap();
196
197 conn
198 }
199
200 #[test]
201 fn test_louvain() {
202 let conn = setup_test_db();
203 let result = louvain_communities(&conn).unwrap();
204
205 assert!(result.num_communities >= 1);
206 assert!(result.memberships.len() == 6);
207 }
208
209 #[test]
210 fn test_empty_graph() {
211 let conn = Connection::open_in_memory().unwrap();
212 conn.execute_batch("CREATE TABLE entities (id INTEGER PRIMARY KEY); CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);").unwrap();
213
214 let result = louvain_communities(&conn).unwrap();
215 assert_eq!(result.num_communities, 0);
216 }
217}