sqlite_knowledge_graph/algorithms/
connected.rs1use crate::error::Result;
2use std::cmp::Reverse;
4
5use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8pub fn connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
12 let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
14 let mut all_nodes: HashSet<i64> = HashSet::new();
15
16 let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
17
18 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
19
20 for row in rows {
21 let (from, to) = row?;
22 all_nodes.insert(from);
23 all_nodes.insert(to);
24 graph.entry(from).or_default().push(to);
25 graph.entry(to).or_default().push(from);
26 }
27
28 let mut stmt = conn.prepare("SELECT id FROM entities")?;
30 let entity_rows = stmt.query_map([], |row| row.get::<_, i64>(0))?;
31 for row in entity_rows {
32 let id = row?;
33 all_nodes.insert(id);
34 graph.entry(id).or_default();
35 }
36
37 let mut visited = HashSet::new();
38 let mut components = Vec::new();
39
40 for &start in &all_nodes {
41 if visited.contains(&start) {
42 continue;
43 }
44
45 let mut component = Vec::new();
46 let mut queue = VecDeque::new();
47 queue.push_back(start);
48 visited.insert(start);
49
50 while let Some(node) = queue.pop_front() {
51 component.push(node);
52
53 if let Some(neighbors) = graph.get(&node) {
54 for &neighbor in neighbors {
55 if !visited.contains(&neighbor) {
56 visited.insert(neighbor);
57 queue.push_back(neighbor);
58 }
59 }
60 }
61 }
62
63 components.push(component);
64 }
65
66 components.sort_by_key(|b| Reverse(b.len()));
68
69 Ok(components)
70}
71
72pub fn strongly_connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
76 let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
78 let mut reverse_graph: HashMap<i64, Vec<i64>> = HashMap::new();
79 let mut all_nodes: HashSet<i64> = HashSet::new();
80
81 let mut stmt = conn.prepare("SELECT from_id, to_id FROM relations")?;
82 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
83
84 for row in rows {
85 let (from, to) = row?;
86 all_nodes.insert(from);
87 all_nodes.insert(to);
88 graph.entry(from).or_default().push(to);
89 reverse_graph.entry(to).or_default().push(from);
90 graph.entry(to).or_default();
91 reverse_graph.entry(from).or_default();
92 }
93
94 let mut visited = HashSet::new();
96 let mut finish_order = Vec::new();
97
98 fn dfs1(
99 node: i64,
100 graph: &HashMap<i64, Vec<i64>>,
101 visited: &mut HashSet<i64>,
102 finish_order: &mut Vec<i64>,
103 ) {
104 visited.insert(node);
105 if let Some(neighbors) = graph.get(&node) {
106 for &neighbor in neighbors {
107 if !visited.contains(&neighbor) {
108 dfs1(neighbor, graph, visited, finish_order);
109 }
110 }
111 }
112 finish_order.push(node);
113 }
114
115 for &node in &all_nodes {
116 if !visited.contains(&node) {
117 dfs1(node, &graph, &mut visited, &mut finish_order);
118 }
119 }
120
121 let mut visited = HashSet::new();
123 let mut components = Vec::new();
124
125 fn dfs2(
126 node: i64,
127 reverse_graph: &HashMap<i64, Vec<i64>>,
128 visited: &mut HashSet<i64>,
129 component: &mut Vec<i64>,
130 ) {
131 visited.insert(node);
132 component.push(node);
133 if let Some(neighbors) = reverse_graph.get(&node) {
134 for &neighbor in neighbors {
135 if !visited.contains(&neighbor) {
136 dfs2(neighbor, reverse_graph, visited, component);
137 }
138 }
139 }
140 }
141
142 for &node in finish_order.iter().rev() {
143 if !visited.contains(&node) {
144 let mut component = Vec::new();
145 dfs2(node, &reverse_graph, &mut visited, &mut component);
146 components.push(component);
147 }
148 }
149
150 components.sort_by_key(|b| Reverse(b.len()));
152
153 Ok(components)
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 fn setup_test_db() -> Connection {
161 let conn = Connection::open_in_memory().unwrap();
162 conn.execute_batch(
163 "CREATE TABLE entities (id INTEGER PRIMARY KEY);
164 CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);"
165 ).unwrap();
166
167 conn.execute(
169 "INSERT INTO entities (id) VALUES (1), (2), (3), (4), (5)",
170 [],
171 )
172 .unwrap();
173 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
174 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
175 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (4, 5, 'link', 1.0)", []).unwrap();
176
177 conn
178 }
179
180 #[test]
181 fn test_connected_components() {
182 let conn = setup_test_db();
183 let components = connected_components(&conn).unwrap();
184
185 assert_eq!(components.len(), 2);
186 assert_eq!(components[0].len(), 3); assert_eq!(components[1].len(), 2);
188 }
189
190 #[test]
191 fn test_strongly_connected_components() {
192 let conn = Connection::open_in_memory().unwrap();
193 conn.execute_batch(
194 "CREATE TABLE entities (id INTEGER PRIMARY KEY);
195 CREATE TABLE relations (id INTEGER PRIMARY KEY, from_id INTEGER, to_id INTEGER, relation_type TEXT, weight REAL);"
196 ).unwrap();
197
198 conn.execute("INSERT INTO entities (id) VALUES (1), (2), (3)", [])
200 .unwrap();
201 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'link', 1.0)", []).unwrap();
202 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'link', 1.0)", []).unwrap();
203 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (3, 1, 'link', 1.0)", []).unwrap();
204
205 let components = strongly_connected_components(&conn).unwrap();
206
207 assert_eq!(components.len(), 1);
209 assert_eq!(components[0].len(), 3);
210 }
211}