Skip to main content

sqlite_knowledge_graph/algorithms/
louvain.rs

1use crate::error::Result;
2/// Louvain community detection algorithm
3use rusqlite::Connection;
4use std::collections::HashMap;
5
6/// Community detection result
7#[derive(Debug, Clone)]
8pub struct CommunityResult {
9    /// Entity to community mapping
10    pub memberships: Vec<(i64, i32)>,
11    /// Number of communities
12    pub num_communities: i32,
13    /// Modularity score
14    pub modularity: f64,
15}
16
17/// Compute communities using a simplified single-phase Louvain algorithm.
18///
19/// This implements Phase 1 of the Louvain method (local node moves) only.
20/// Phase 2 (community aggregation into super-nodes) is not implemented.
21/// For graphs with deep hierarchical community structure, results may be
22/// sub-optimal compared to the full two-phase algorithm.
23///
24/// Returns community memberships and modularity score.
25pub fn louvain_communities(conn: &Connection) -> Result<CommunityResult> {
26    // Build adjacency list with weights
27    let mut graph: HashMap<i64, HashMap<i64, f64>> = HashMap::new();
28    let mut total_weight = 0.0;
29
30    let mut stmt = conn.prepare("SELECT source_id, target_id, weight FROM kg_relations")?;
31
32    let rows = stmt.query_map([], |row| {
33        Ok((
34            row.get::<_, i64>(0)?,
35            row.get::<_, i64>(1)?,
36            row.get::<_, f64>(2)?,
37        ))
38    })?;
39
40    for row in rows {
41        let (from, to, weight) = row?;
42        *graph.entry(from).or_default().entry(to).or_default() += weight;
43        graph.entry(to).or_default(); // Ensure target node exists
44        total_weight += weight;
45    }
46
47    if graph.is_empty() {
48        return Ok(CommunityResult {
49            memberships: Vec::new(),
50            num_communities: 0,
51            modularity: 0.0,
52        });
53    }
54
55    let nodes: Vec<i64> = graph.keys().copied().collect();
56    let _n = nodes.len();
57
58    // Initialize: each node in its own community
59    let mut community: HashMap<i64, i32> = nodes
60        .iter()
61        .enumerate()
62        .map(|(i, &id)| (id, i as i32))
63        .collect();
64    let mut improved = true;
65    let mut iteration = 0;
66
67    while improved && iteration < 100 {
68        improved = false;
69        iteration += 1;
70
71        for &node in &nodes {
72            let current_community = community[&node];
73
74            // Find neighboring communities
75            let neighbors: Vec<i64> = graph
76                .get(&node)
77                .map(|edges| edges.keys().copied().collect())
78                .unwrap_or_default();
79
80            let mut best_community = current_community;
81            let mut best_gain = 0.0;
82
83            for &neighbor in &neighbors {
84                let neighbor_community = community[&neighbor];
85                if neighbor_community == current_community {
86                    continue;
87                }
88
89                // Calculate modularity gain (simplified)
90                let gain = calculate_modularity_gain(
91                    &graph,
92                    node,
93                    neighbor_community,
94                    &community,
95                    total_weight,
96                );
97
98                if gain > best_gain {
99                    best_gain = gain;
100                    best_community = neighbor_community;
101                }
102            }
103
104            if best_community != current_community {
105                community.insert(node, best_community);
106                improved = true;
107            }
108        }
109    }
110
111    // Renumber communities consecutively
112    let mut community_map: HashMap<i32, i32> = HashMap::new();
113    let mut next_id = 0i32;
114
115    for &comm in community.values() {
116        if let std::collections::hash_map::Entry::Vacant(e) = community_map.entry(comm) {
117            e.insert(next_id);
118            next_id += 1;
119        }
120    }
121
122    let memberships: Vec<(i64, i32)> = nodes
123        .iter()
124        .map(|&id| (id, community_map[&community[&id]]))
125        .collect();
126
127    // Calculate final modularity
128    let modularity = calculate_modularity(&graph, &community, total_weight);
129
130    Ok(CommunityResult {
131        memberships,
132        num_communities: next_id,
133        modularity,
134    })
135}
136
137fn calculate_modularity_gain(
138    graph: &HashMap<i64, HashMap<i64, f64>>,
139    node: i64,
140    target_community: i32,
141    community: &HashMap<i64, i32>,
142    total_weight: f64,
143) -> f64 {
144    if total_weight == 0.0 {
145        return 0.0;
146    }
147
148    let m = total_weight;
149
150    // k_i: degree (sum of weights) of the node being moved
151    let k_i: f64 = graph
152        .get(&node)
153        .map(|edges| edges.values().sum())
154        .unwrap_or(0.0);
155
156    // k_i_in: sum of weights from node to nodes already in target_community
157    let k_i_in: f64 = graph
158        .get(&node)
159        .map(|edges| {
160            edges
161                .iter()
162                .filter(|(&nbr, _)| community.get(&nbr) == Some(&target_community))
163                .map(|(_, &w)| w)
164                .sum()
165        })
166        .unwrap_or(0.0);
167
168    // k_tot: sum of degrees of all nodes in target_community
169    let k_tot: f64 = graph
170        .iter()
171        .filter(|(&id, _)| id != node && community.get(&id) == Some(&target_community))
172        .map(|(_, edges)| edges.values().sum::<f64>())
173        .sum();
174
175    // Standard Louvain ΔQ = k_i_in / m  -  k_tot * k_i / (2 * m²)
176    k_i_in / m - k_tot * k_i / (2.0 * m * m)
177}
178
179fn calculate_modularity(
180    graph: &HashMap<i64, HashMap<i64, f64>>,
181    community: &HashMap<i64, i32>,
182    total_weight: f64,
183) -> f64 {
184    if total_weight == 0.0 {
185        return 0.0;
186    }
187
188    let mut modularity = 0.0;
189
190    for (&from, edges) in graph {
191        for (&to, &weight) in edges {
192            if community.get(&from) == community.get(&to) {
193                modularity += weight / total_weight;
194            }
195        }
196    }
197
198    modularity
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    fn setup_test_db() -> Connection {
206        let conn = Connection::open_in_memory().unwrap();
207        crate::schema::create_schema(&conn).unwrap();
208
209        // Create two communities: 1-2-3 and 4-5-6, with weak link between
210        use crate::graph::entity::{insert_entity, Entity};
211        use crate::graph::relation::{insert_relation, Relation};
212        let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
213        let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
214        let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
215        let id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
216        let id5 = insert_entity(&conn, &Entity::new("node", "Node 5")).unwrap();
217        let id6 = insert_entity(&conn, &Entity::new("node", "Node 6")).unwrap();
218        insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
219        insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
220        insert_relation(&conn, &Relation::new(id4, id5, "link", 1.0).unwrap()).unwrap();
221        insert_relation(&conn, &Relation::new(id5, id6, "link", 1.0).unwrap()).unwrap();
222        insert_relation(&conn, &Relation::new(id3, id4, "link", 0.1).unwrap()).unwrap();
223
224        conn
225    }
226
227    #[test]
228    fn test_louvain() {
229        let conn = setup_test_db();
230        let result = louvain_communities(&conn).unwrap();
231
232        assert!(result.num_communities >= 1);
233        assert!(result.memberships.len() == 6);
234    }
235
236    #[test]
237    fn test_empty_graph() {
238        let conn = Connection::open_in_memory().unwrap();
239        crate::schema::create_schema(&conn).unwrap();
240
241        let result = louvain_communities(&conn).unwrap();
242        assert_eq!(result.num_communities, 0);
243    }
244}