Skip to main content

sqlite_knowledge_graph/algorithms/
connected.rs

1use crate::error::Result;
2/// Connected components algorithms
3use std::cmp::Reverse;
4
5use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Find weakly connected components
9///
10/// Returns a list of components, each being a list of entity IDs.
11pub fn connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
12    // Build undirected adjacency list
13    let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
14    let mut all_nodes: HashSet<i64> = HashSet::new();
15
16    let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
17
18    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
19
20    for row in rows {
21        let (from, to) = row?;
22        all_nodes.insert(from);
23        all_nodes.insert(to);
24        graph.entry(from).or_default().push(to);
25        graph.entry(to).or_default().push(from);
26    }
27
28    // Add isolated nodes
29    let mut stmt = conn.prepare("SELECT id FROM entities")?;
30    let entity_rows = stmt.query_map([], |row| row.get::<_, i64>(0))?;
31    for row in entity_rows {
32        let id = row?;
33        all_nodes.insert(id);
34        graph.entry(id).or_default();
35    }
36
37    let mut visited = HashSet::new();
38    let mut components = Vec::new();
39
40    for &start in &all_nodes {
41        if visited.contains(&start) {
42            continue;
43        }
44
45        let mut component = Vec::new();
46        let mut queue = VecDeque::new();
47        queue.push_back(start);
48        visited.insert(start);
49
50        while let Some(node) = queue.pop_front() {
51            component.push(node);
52
53            if let Some(neighbors) = graph.get(&node) {
54                for &neighbor in neighbors {
55                    if !visited.contains(&neighbor) {
56                        visited.insert(neighbor);
57                        queue.push_back(neighbor);
58                    }
59                }
60            }
61        }
62
63        components.push(component);
64    }
65
66    // Sort by size descending
67    components.sort_by_key(|b| Reverse(b.len()));
68
69    Ok(components)
70}
71
72/// Find strongly connected components using Kosaraju's algorithm
73///
74/// Returns a list of strongly connected components.
75pub fn strongly_connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
76    // Build directed adjacency list
77    let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
78    let mut reverse_graph: HashMap<i64, Vec<i64>> = HashMap::new();
79    let mut all_nodes: HashSet<i64> = HashSet::new();
80
81    let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
82    let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
83
84    for row in rows {
85        let (from, to) = row?;
86        all_nodes.insert(from);
87        all_nodes.insert(to);
88        graph.entry(from).or_default().push(to);
89        reverse_graph.entry(to).or_default().push(from);
90        graph.entry(to).or_default();
91        reverse_graph.entry(from).or_default();
92    }
93
94    // First pass: order by finish time
95    let mut visited = HashSet::new();
96    let mut finish_order = Vec::new();
97
98    fn dfs1(
99        node: i64,
100        graph: &HashMap<i64, Vec<i64>>,
101        visited: &mut HashSet<i64>,
102        finish_order: &mut Vec<i64>,
103    ) {
104        visited.insert(node);
105        if let Some(neighbors) = graph.get(&node) {
106            for &neighbor in neighbors {
107                if !visited.contains(&neighbor) {
108                    dfs1(neighbor, graph, visited, finish_order);
109                }
110            }
111        }
112        finish_order.push(node);
113    }
114
115    for &node in &all_nodes {
116        if !visited.contains(&node) {
117            dfs1(node, &graph, &mut visited, &mut finish_order);
118        }
119    }
120
121    // Second pass: collect SCCs
122    let mut visited = HashSet::new();
123    let mut components = Vec::new();
124
125    fn dfs2(
126        node: i64,
127        reverse_graph: &HashMap<i64, Vec<i64>>,
128        visited: &mut HashSet<i64>,
129        component: &mut Vec<i64>,
130    ) {
131        visited.insert(node);
132        component.push(node);
133        if let Some(neighbors) = reverse_graph.get(&node) {
134            for &neighbor in neighbors {
135                if !visited.contains(&neighbor) {
136                    dfs2(neighbor, reverse_graph, visited, component);
137                }
138            }
139        }
140    }
141
142    for &node in finish_order.iter().rev() {
143        if !visited.contains(&node) {
144            let mut component = Vec::new();
145            dfs2(node, &reverse_graph, &mut visited, &mut component);
146            components.push(component);
147        }
148    }
149
150    // Sort by size descending
151    components.sort_by_key(|b| Reverse(b.len()));
152
153    Ok(components)
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    fn setup_test_db() -> Connection {
161        let conn = Connection::open_in_memory().unwrap();
162        conn.execute_batch(
163            "CREATE TABLE entities (id INTEGER PRIMARY KEY);
164             CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);"
165        ).unwrap();
166
167        // Create two disconnected components: 1-2-3 and 4-5
168        conn.execute(
169            "INSERT INTO entities (id) VALUES (1), (2), (3), (4), (5)",
170            [],
171        )
172        .unwrap();
173        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
174        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
175        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (4, 5, 'link', 1.0)", []).unwrap();
176
177        conn
178    }
179
180    #[test]
181    fn test_connected_components() {
182        let conn = setup_test_db();
183        let components = connected_components(&conn).unwrap();
184
185        assert_eq!(components.len(), 2);
186        assert_eq!(components[0].len(), 3); // Largest component
187        assert_eq!(components[1].len(), 2);
188    }
189
190    #[test]
191    fn test_strongly_connected_components() {
192        let conn = Connection::open_in_memory().unwrap();
193        conn.execute_batch(
194            "CREATE TABLE entities (id INTEGER PRIMARY KEY);
195             CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);"
196        ).unwrap();
197
198        // Create a cycle: 1 -> 2 -> 3 -> 1
199        conn.execute("INSERT INTO entities (id) VALUES (1), (2), (3)", [])
200            .unwrap();
201        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
202        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
203        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (3, 1, 'link', 1.0)", []).unwrap();
204
205        let components = strongly_connected_components(&conn).unwrap();
206
207        // All three nodes should be in one SCC
208        assert_eq!(components.len(), 1);
209        assert_eq!(components[0].len(), 3);
210    }
211}