Skip to main content

sqlite_knowledge_graph/algorithms/
louvain.rs

1use crate::error::Result;
2/// Louvain community detection algorithm
3use rusqlite::Connection;
4use std::collections::HashMap;
5
6/// Community detection result
7#[derive(Debug, Clone)]
8pub struct CommunityResult {
9    /// Entity to community mapping
10    pub memberships: Vec<(i64, i32)>,
11    /// Number of communities
12    pub num_communities: i32,
13    /// Modularity score
14    pub modularity: f64,
15}
16
17/// Compute communities using Louvain algorithm
18///
19/// Returns community memberships and modularity score.
20pub fn louvain_communities(conn: &Connection) -> Result<CommunityResult> {
21    // Build adjacency list with weights
22    let mut graph: HashMap<i64, HashMap<i64, f64>> = HashMap::new();
23    let mut total_weight = 0.0;
24
25    let mut stmt = conn.prepare("SELECT from_id, to_id, weight FROM relations")?;
26
27    let rows = stmt.query_map([], |row| {
28        Ok((
29            row.get::<_, i64>(0)?,
30            row.get::<_, i64>(1)?,
31            row.get::<_, f64>(2)?,
32        ))
33    })?;
34
35    for row in rows {
36        let (from, to, weight) = row?;
37        *graph.entry(from).or_default().entry(to).or_default() += weight;
38        graph.entry(to).or_default(); // Ensure target node exists
39        total_weight += weight;
40    }
41
42    if graph.is_empty() {
43        return Ok(CommunityResult {
44            memberships: Vec::new(),
45            num_communities: 0,
46            modularity: 0.0,
47        });
48    }
49
50    let nodes: Vec<i64> = graph.keys().copied().collect();
51    let _n = nodes.len();
52
53    // Initialize: each node in its own community
54    let mut community: HashMap<i64, i32> = nodes
55        .iter()
56        .enumerate()
57        .map(|(i, &id)| (id, i as i32))
58        .collect();
59    let mut improved = true;
60    let mut iteration = 0;
61
62    while improved && iteration < 100 {
63        improved = false;
64        iteration += 1;
65
66        for &node in &nodes {
67            let current_community = community[&node];
68
69            // Find neighboring communities
70            let neighbors: Vec<i64> = graph
71                .get(&node)
72                .map(|edges| edges.keys().copied().collect())
73                .unwrap_or_default();
74
75            let mut best_community = current_community;
76            let mut best_gain = 0.0;
77
78            for &neighbor in &neighbors {
79                let neighbor_community = community[&neighbor];
80                if neighbor_community == current_community {
81                    continue;
82                }
83
84                // Calculate modularity gain (simplified)
85                let gain = calculate_modularity_gain(
86                    &graph,
87                    node,
88                    neighbor_community,
89                    &community,
90                    total_weight,
91                );
92
93                if gain > best_gain {
94                    best_gain = gain;
95                    best_community = neighbor_community;
96                }
97            }
98
99            if best_community != current_community {
100                community.insert(node, best_community);
101                improved = true;
102            }
103        }
104    }
105
106    // Renumber communities consecutively
107    let mut community_map: HashMap<i32, i32> = HashMap::new();
108    let mut next_id = 0i32;
109
110    for &comm in community.values() {
111        if let std::collections::hash_map::Entry::Vacant(e) = community_map.entry(comm) {
112            e.insert(next_id);
113            next_id += 1;
114        }
115    }
116
117    let memberships: Vec<(i64, i32)> = nodes
118        .iter()
119        .map(|&id| (id, community_map[&community[&id]]))
120        .collect();
121
122    // Calculate final modularity
123    let modularity = calculate_modularity(&graph, &community, total_weight);
124
125    Ok(CommunityResult {
126        memberships,
127        num_communities: next_id,
128        modularity,
129    })
130}
131
132fn calculate_modularity_gain(
133    graph: &HashMap<i64, HashMap<i64, f64>>,
134    node: i64,
135    target_community: i32,
136    community: &HashMap<i64, i32>,
137    total_weight: f64,
138) -> f64 {
139    let mut gain = 0.0;
140
141    if let Some(neighbors) = graph.get(&node) {
142        for (&neighbor, &weight) in neighbors {
143            if community.get(&neighbor) == Some(&target_community) {
144                gain += weight / total_weight;
145            }
146        }
147    }
148
149    gain
150}
151
152fn calculate_modularity(
153    graph: &HashMap<i64, HashMap<i64, f64>>,
154    community: &HashMap<i64, i32>,
155    total_weight: f64,
156) -> f64 {
157    if total_weight == 0.0 {
158        return 0.0;
159    }
160
161    let mut modularity = 0.0;
162
163    for (&from, edges) in graph {
164        for (&to, &weight) in edges {
165            if community.get(&from) == community.get(&to) {
166                modularity += weight / total_weight;
167            }
168        }
169    }
170
171    modularity
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    fn setup_test_db() -> Connection {
179        let conn = Connection::open_in_memory().unwrap();
180        conn.execute_batch(
181            "CREATE TABLE entities (id INTEGER PRIMARY KEY);
182             CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);"
183        ).unwrap();
184
185        // Create two communities: 1-2-3 and 4-5-6, with weak link between
186        conn.execute(
187            "INSERT INTO entities (id) VALUES (1), (2), (3), (4), (5), (6)",
188            [],
189        )
190        .unwrap();
191        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
192        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
193        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (4, 5, 'link', 1.0)", []).unwrap();
194        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (5, 6, 'link', 1.0)", []).unwrap();
195        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (3, 4, 'link', 0.1)", []).unwrap();
196
197        conn
198    }
199
200    #[test]
201    fn test_louvain() {
202        let conn = setup_test_db();
203        let result = louvain_communities(&conn).unwrap();
204
205        assert!(result.num_communities >= 1);
206        assert!(result.memberships.len() == 6);
207    }
208
209    #[test]
210    fn test_empty_graph() {
211        let conn = Connection::open_in_memory().unwrap();
212        conn.execute_batch("CREATE TABLE entities (id INTEGER PRIMARY KEY); CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);").unwrap();
213
214        let result = louvain_communities(&conn).unwrap();
215        assert_eq!(result.num_communities, 0);
216    }
217}