sqlite_knowledge_graph/algorithms/
connected.rs1use crate::error::Result;
2use std::cmp::Reverse;
4
5use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8pub fn connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
12 let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
14 let mut all_nodes: HashSet<i64> = HashSet::new();
15
16 let mut stmt = conn.prepare("SELECT source_id, target_id FROM kg_relations")?;
17
18 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
19
20 for row in rows {
21 let (from, to) = row?;
22 all_nodes.insert(from);
23 all_nodes.insert(to);
24 graph.entry(from).or_default().push(to);
25 graph.entry(to).or_default().push(from);
26 }
27
28 let mut stmt = conn.prepare("SELECT id FROM kg_entities")?;
30 let entity_rows = stmt.query_map([], |row| row.get::<_, i64>(0))?;
31 for row in entity_rows {
32 let id = row?;
33 all_nodes.insert(id);
34 graph.entry(id).or_default();
35 }
36
37 let mut visited = HashSet::new();
38 let mut components = Vec::new();
39
40 for &start in &all_nodes {
41 if visited.contains(&start) {
42 continue;
43 }
44
45 let mut component = Vec::new();
46 let mut queue = VecDeque::new();
47 queue.push_back(start);
48 visited.insert(start);
49
50 while let Some(node) = queue.pop_front() {
51 component.push(node);
52
53 if let Some(neighbors) = graph.get(&node) {
54 for &neighbor in neighbors {
55 if !visited.contains(&neighbor) {
56 visited.insert(neighbor);
57 queue.push_back(neighbor);
58 }
59 }
60 }
61 }
62
63 components.push(component);
64 }
65
66 components.sort_by_key(|b| Reverse(b.len()));
68
69 Ok(components)
70}
71
72pub fn strongly_connected_components(conn: &Connection) -> Result<Vec<Vec<i64>>> {
76 let mut graph: HashMap<i64, Vec<i64>> = HashMap::new();
78 let mut reverse_graph: HashMap<i64, Vec<i64>> = HashMap::new();
79 let mut all_nodes: HashSet<i64> = HashSet::new();
80
81 let mut stmt = conn.prepare("SELECT source_id, target_id FROM kg_relations")?;
82 let rows = stmt.query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))?;
83
84 for row in rows {
85 let (from, to) = row?;
86 all_nodes.insert(from);
87 all_nodes.insert(to);
88 graph.entry(from).or_default().push(to);
89 reverse_graph.entry(to).or_default().push(from);
90 graph.entry(to).or_default();
91 reverse_graph.entry(from).or_default();
92 }
93
94 let mut visited = HashSet::new();
96 let mut finish_order = Vec::new();
97
98 for &start in &all_nodes {
99 if visited.contains(&start) {
100 continue;
101 }
102 let mut stack: Vec<(i64, usize)> = vec![(start, 0)];
104 visited.insert(start);
105 while let Some((node, idx)) = stack.last_mut() {
106 let node = *node;
107 let neighbors = graph.get(&node).map(|v| v.as_slice()).unwrap_or(&[]);
108 if *idx < neighbors.len() {
109 let neighbor = neighbors[*idx];
110 *idx += 1;
111 if !visited.contains(&neighbor) {
112 visited.insert(neighbor);
113 stack.push((neighbor, 0));
114 }
115 } else {
116 finish_order.push(node);
117 stack.pop();
118 }
119 }
120 }
121
122 let mut visited = HashSet::new();
124 let mut components = Vec::new();
125
126 for &start in finish_order.iter().rev() {
127 if visited.contains(&start) {
128 continue;
129 }
130 let mut component = Vec::new();
131 let mut stack = vec![start];
132 visited.insert(start);
133 while let Some(node) = stack.pop() {
134 component.push(node);
135 if let Some(neighbors) = reverse_graph.get(&node) {
136 for &neighbor in neighbors {
137 if !visited.contains(&neighbor) {
138 visited.insert(neighbor);
139 stack.push(neighbor);
140 }
141 }
142 }
143 }
144 components.push(component);
145 }
146
147 components.sort_by_key(|b| Reverse(b.len()));
149
150 Ok(components)
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 fn setup_test_db() -> Connection {
158 let conn = Connection::open_in_memory().unwrap();
159 crate::schema::create_schema(&conn).unwrap();
160
161 use crate::graph::entity::{insert_entity, Entity};
163 use crate::graph::relation::{insert_relation, Relation};
164 let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
165 let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
166 let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
167 let id4 = insert_entity(&conn, &Entity::new("node", "Node 4")).unwrap();
168 let id5 = insert_entity(&conn, &Entity::new("node", "Node 5")).unwrap();
169 insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
170 insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
171 insert_relation(&conn, &Relation::new(id4, id5, "link", 1.0).unwrap()).unwrap();
172
173 conn
174 }
175
176 #[test]
177 fn test_connected_components() {
178 let conn = setup_test_db();
179 let components = connected_components(&conn).unwrap();
180
181 assert_eq!(components.len(), 2);
182 assert_eq!(components[0].len(), 3); assert_eq!(components[1].len(), 2);
184 }
185
186 #[test]
187 fn test_strongly_connected_components() {
188 let conn = Connection::open_in_memory().unwrap();
189 crate::schema::create_schema(&conn).unwrap();
190
191 use crate::graph::entity::{insert_entity, Entity};
193 use crate::graph::relation::{insert_relation, Relation};
194 let id1 = insert_entity(&conn, &Entity::new("node", "Node 1")).unwrap();
195 let id2 = insert_entity(&conn, &Entity::new("node", "Node 2")).unwrap();
196 let id3 = insert_entity(&conn, &Entity::new("node", "Node 3")).unwrap();
197 insert_relation(&conn, &Relation::new(id1, id2, "link", 1.0).unwrap()).unwrap();
198 insert_relation(&conn, &Relation::new(id2, id3, "link", 1.0).unwrap()).unwrap();
199 insert_relation(&conn, &Relation::new(id3, id1, "link", 1.0).unwrap()).unwrap();
200
201 let components = strongly_connected_components(&conn).unwrap();
202
203 assert_eq!(components.len(), 1);
205 assert_eq!(components[0].len(), 3);
206 }
207}