Skip to main content

sqlite_knowledge_graph/algorithms/
connected.rs

1use crate::error::Result;
2/// Connected components algorithms
3use std::cmp::Reverse;
4
5use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Find weakly connected components
9///
10/// Returns a list of components, each being a list of entity IDs.
11pub fn connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
12    // Build undirected adjacency list
13    let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
14    let mut all_nodes: HashSet<i64> = HashSet::new();
15
16    let mut stmt = conn.prepare("SELECT source_id, target_id FROM kg_relations")?;
17
18    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
19
20    for row in rows {
21        let (from, to) = row?;
22        all_nodes.insert(from);
23        all_nodes.insert(to);
24        graph.entry(from).or_default().push(to);
25        graph.entry(to).or_default().push(from);
26    }
27
28    // Add isolated nodes
29    let mut stmt = conn.prepare("SELECT id FROM kg_entities")?;
30    let entity_rows = stmt.query_map([], |row| row.get::<_, i64>(0))?;
31    for row in entity_rows {
32        let id = row?;
33        all_nodes.insert(id);
34        graph.entry(id).or_default();
35    }
36
37    let mut visited = HashSet::new();
38    let mut components = Vec::new();
39
40    for &start in &all_nodes {
41        if visited.contains(&start) {
42            continue;
43        }
44
45        let mut component = Vec::new();
46        let mut queue = VecDeque::new();
47        queue.push_back(start);
48        visited.insert(start);
49
50        while let Some(node) = queue.pop_front() {
51            component.push(node);
52
53            if let Some(neighbors) = graph.get(&node) {
54                for &neighbor in neighbors {
55                    if !visited.contains(&neighbor) {
56                        visited.insert(neighbor);
57                        queue.push_back(neighbor);
58                    }
59                }
60            }
61        }
62
63        components.push(component);
64    }
65
66    // Sort by size descending
67    components.sort_by_key(|b| Reverse(b.len()));
68
69    Ok(components)
70}
71
72/// Find strongly connected components using Kosaraju's algorithm
73///
74/// Returns a list of strongly connected components.
75pub fn strongly_connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
76    // Build directed adjacency list
77    let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
78    let mut reverse_graph: HashMap<i64, Vec<i64>> = HashMap::new();
79    let mut all_nodes: HashSet<i64> = HashSet::new();
80
81    let mut stmt = conn.prepare("SELECT source_id, target_id FROM kg_relations")?;
82    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
83
84    for row in rows {
85        let (from, to) = row?;
86        all_nodes.insert(from);
87        all_nodes.insert(to);
88        graph.entry(from).or_default().push(to);
89        reverse_graph.entry(to).or_default().push(from);
90        graph.entry(to).or_default();
91        reverse_graph.entry(from).or_default();
92    }
93
94    // First pass: compute finish order iteratively (avoids stack overflow on large chains)
95    let mut visited = HashSet::new();
96    let mut finish_order = Vec::new();
97
98    for &start in &all_nodes {
99        if visited.contains(&start) {
100            continue;
101        }
102        // Iterative DFS with explicit stack; each entry is (node, iterator_index)
103        let mut stack: Vec<(i64, usize)> = vec![(start, 0)];
104        visited.insert(start);
105        while let Some((node, idx)) = stack.last_mut() {
106            let node = *node;
107            let neighbors = graph.get(&node).map(|v| v.as_slice()).unwrap_or(&[]);
108            if *idx < neighbors.len() {
109                let neighbor = neighbors[*idx];
110                *idx += 1;
111                if !visited.contains(&neighbor) {
112                    visited.insert(neighbor);
113                    stack.push((neighbor, 0));
114                }
115            } else {
116                finish_order.push(node);
117                stack.pop();
118            }
119        }
120    }
121
122    // Second pass: collect SCCs iteratively (reverse graph BFS/DFS)
123    let mut visited = HashSet::new();
124    let mut components = Vec::new();
125
126    for &start in finish_order.iter().rev() {
127        if visited.contains(&start) {
128            continue;
129        }
130        let mut component = Vec::new();
131        let mut stack = vec![start];
132        visited.insert(start);
133        while let Some(node) = stack.pop() {
134            component.push(node);
135            if let Some(neighbors) = reverse_graph.get(&node) {
136                for &neighbor in neighbors {
137                    if !visited.contains(&neighbor) {
138                        visited.insert(neighbor);
139                        stack.push(neighbor);
140                    }
141                }
142            }
143        }
144        components.push(component);
145    }
146
147    // Sort by size descending
148    components.sort_by_key(|b| Reverse(b.len()));
149
150    Ok(components)
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    fn setup_test_db() -> Connection {
158        let conn = Connection::open_in_memory().unwrap();
159        crate::schema::create_schema(&conn).unwrap();
160
161        // Create two disconnected components: 1-2-3 and 4-5
162        use crate::graph::entity::{insert_entity, Entity};
163        use crate::graph::relation::{insert_relation, Relation};
164        let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
165        let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
166        let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
167        let id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
168        let id5 = insert_entity(&conn, &Entity::new("node", "Node 5")).unwrap();
169        insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
170        insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
171        insert_relation(&conn, &Relation::new(id4, id5, "link", 1.0).unwrap()).unwrap();
172
173        conn
174    }
175
176    #[test]
177    fn test_connected_components() {
178        let conn = setup_test_db();
179        let components = connected_components(&conn).unwrap();
180
181        assert_eq!(components.len(), 2);
182        assert_eq!(components[0].len(), 3); // Largest component
183        assert_eq!(components[1].len(), 2);
184    }
185
186    #[test]
187    fn test_strongly_connected_components() {
188        let conn = Connection::open_in_memory().unwrap();
189        crate::schema::create_schema(&conn).unwrap();
190
191        // Create a cycle: 1 -> 2 -> 3 -> 1
192        use crate::graph::entity::{insert_entity, Entity};
193        use crate::graph::relation::{insert_relation, Relation};
194        let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
195        let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
196        let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
197        insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
198        insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
199        insert_relation(&conn, &Relation::new(id3, id1, "link", 1.0).unwrap()).unwrap();
200
201        let components = strongly_connected_components(&conn).unwrap();
202
203        // All three nodes should be in one SCC
204        assert_eq!(components.len(), 1);
205        assert_eq!(components[0].len(), 3);
206    }
207}