1use crate::error::Result;
2use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8#[derive(Debug, Clone)]
10pub struct TraversalNode {
11 pub entity_id: i64,
12 pub entity_type: String,
13 pub depth: u32,
14}
15
16#[derive(Debug, Clone)]
18pub struct PathStep {
19 pub from_id: i64,
20 pub to_id: i64,
21 pub relation_type: String,
22 pub weight: f64,
23}
24
25#[derive(Debug, Clone)]
27pub struct TraversalPath {
28 pub start_id: i64,
29 pub end_id: i64,
30 pub steps: Vec<PathStep>,
31 pub total_weight: f64,
32}
33
34#[derive(Debug, Clone)]
36pub struct GraphStats {
37 pub total_entities: i64,
38 pub total_relations: i64,
39 pub avg_degree: f64,
40 pub max_degree: i64,
41 pub density: f64,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum Direction {
47 Outgoing,
48 Incoming,
49 Both,
50}
51
52#[derive(Debug, Clone)]
54pub struct TraversalQuery {
55 pub direction: Direction,
56 pub rel_types: Option<Vec<String>>,
57 pub min_weight: Option<f64>,
58 pub max_depth: u32,
59}
60
61impl Default for TraversalQuery {
62 fn default() -> Self {
63 Self {
64 direction: Direction::Both,
65 rel_types: None,
66 min_weight: None,
67 max_depth: 3,
68 }
69 }
70}
71
72pub fn bfs_traversal(
76 conn: &Connection,
77 start_id: i64,
78 query: TraversalQuery,
79) -> Result<Vec<TraversalNode>> {
80 let mut result = Vec::new();
81 let mut visited = HashSet::new();
82 let mut queue = VecDeque::new();
83
84 let start_type: String = conn.query_row(
86 "SELECT entity_type FROM kg_entities WHERE id = ?1",
87 [start_id],
88 |row| row.get(0),
89 )?;
90
91 queue.push_back((start_id, start_type, 0u32));
92 visited.insert(start_id);
93
94 while let Some((entity_id, _entity_type, depth)) = queue.pop_front() {
95 if depth > query.max_depth {
96 continue;
97 }
98
99 result.push(TraversalNode {
100 entity_id,
101 entity_type: _entity_type.clone(),
102 depth,
103 });
104
105 if depth == query.max_depth {
106 continue;
107 }
108
109 let neighbors = get_neighbors(conn, entity_id, &query)?;
111
112 for (neighbor_id, neighbor_type) in neighbors {
113 if !visited.contains(&neighbor_id) {
114 visited.insert(neighbor_id);
115 queue.push_back((neighbor_id, neighbor_type, depth + 1));
116 }
117 }
118 }
119
120 Ok(result)
121}
122
123pub fn dfs_traversal(
127 conn: &Connection,
128 start_id: i64,
129 query: TraversalQuery,
130) -> Result<Vec<TraversalNode>> {
131 let mut result = Vec::new();
132 let mut visited = HashSet::new();
133
134 let start_type: String = conn.query_row(
136 "SELECT entity_type FROM kg_entities WHERE id = ?1",
137 [start_id],
138 |row| row.get(0),
139 )?;
140
141 dfs_visit(
142 conn,
143 start_id,
144 start_type,
145 0,
146 &query,
147 &mut visited,
148 &mut result,
149 )?;
150
151 Ok(result)
152}
153
154fn dfs_visit(
155 conn: &Connection,
156 entity_id: i64,
157 entity_type: String,
158 depth: u32,
159 query: &TraversalQuery,
160 visited: &mut HashSet<i64>,
161 result: &mut Vec<TraversalNode>,
162) -> Result<()> {
163 if visited.contains(&entity_id) || depth > query.max_depth {
164 return Ok(());
165 }
166
167 visited.insert(entity_id);
168 result.push(TraversalNode {
169 entity_id,
170 entity_type: entity_type.clone(),
171 depth,
172 });
173
174 if depth == query.max_depth {
175 return Ok(());
176 }
177
178 let neighbors = get_neighbors(conn, entity_id, query)?;
179
180 for (neighbor_id, neighbor_type) in neighbors {
181 dfs_visit(
182 conn,
183 neighbor_id,
184 neighbor_type,
185 depth + 1,
186 query,
187 visited,
188 result,
189 )?;
190 }
191
192 Ok(())
193}
194
195pub fn find_shortest_path(
199 conn: &Connection,
200 from_id: i64,
201 to_id: i64,
202 max_depth: u32,
203) -> Result<Option<TraversalPath>> {
204 if from_id == to_id {
205 return Ok(Some(TraversalPath {
206 start_id: from_id,
207 end_id: to_id,
208 steps: Vec::new(),
209 total_weight: 0.0,
210 }));
211 }
212
213 let mut visited = HashMap::new(); let mut queue: VecDeque<(i64, u32)> = VecDeque::new(); queue.push_back((from_id, 0));
217 visited.insert(from_id, None);
218
219 while let Some((current_id, current_depth)) = queue.pop_front() {
220 if current_depth >= max_depth {
221 continue;
222 }
223
224 let relations = get_outgoing_relations(conn, current_id)?;
226
227 for (target_id, rel_type, weight) in relations {
228 if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(target_id) {
229 e.insert(Some((current_id, rel_type.clone(), weight)));
230
231 if target_id == to_id {
232 return Ok(Some(reconstruct_path(from_id, to_id, &visited)?));
234 }
235
236 queue.push_back((target_id, current_depth + 1));
237 }
238 }
239 }
240
241 Ok(None)
242}
243
244pub fn compute_graph_stats(conn: &Connection) -> Result<GraphStats> {
246 let total_entities: i64 =
247 conn.query_row("SELECT COUNT(*) FROM kg_entities", [], |row| row.get(0))?;
248
249 let total_relations: i64 =
250 conn.query_row("SELECT COUNT(*) FROM kg_relations", [], |row| row.get(0))?;
251
252 let max_degree: i64 = conn.query_row(
253 "SELECT COALESCE(MAX(cnt), 0) FROM (
254 SELECT source_id as id, COUNT(*) as cnt FROM kg_relations GROUP BY source_id
255 UNION ALL
256 SELECT target_id as id, COUNT(*) as cnt FROM kg_relations GROUP BY target_id
257 )",
258 [],
259 |row| row.get(0),
260 )?;
261
262 let avg_degree = if total_entities > 0 {
263 (total_relations as f64 * 2.0) / (total_entities as f64)
264 } else {
265 0.0
266 };
267
268 let density = if total_entities > 1 {
269 let possible_edges = total_entities * (total_entities - 1);
270 total_relations as f64 / possible_edges as f64
271 } else {
272 0.0
273 };
274
275 Ok(GraphStats {
276 total_entities,
277 total_relations,
278 avg_degree,
279 max_degree,
280 density,
281 })
282}
283
284fn get_neighbors(
287 conn: &Connection,
288 entity_id: i64,
289 query: &TraversalQuery,
290) -> Result<Vec<(i64, String)>> {
291 let mut neighbors = Vec::new();
292
293 let sql = match query.direction {
294 Direction::Outgoing => {
295 "SELECT r.target_id, e.entity_type FROM kg_relations r
296 JOIN kg_entities e ON r.target_id = e.id
297 WHERE r.source_id = ?1"
298 }
299 Direction::Incoming => {
300 "SELECT r.source_id, e.entity_type FROM kg_relations r
301 JOIN kg_entities e ON r.source_id = e.id
302 WHERE r.target_id = ?1"
303 }
304 Direction::Both => {
305 "SELECT r.target_id, e.entity_type FROM kg_relations r
306 JOIN kg_entities e ON r.target_id = e.id
307 WHERE r.source_id = ?1
308 UNION
309 SELECT r.source_id, e.entity_type FROM kg_relations r
310 JOIN kg_entities e ON r.source_id = e.id
311 WHERE r.target_id = ?1"
312 }
313 };
314
315 let mut stmt = conn.prepare(sql)?;
316
317 let rows = stmt.query_map([entity_id], |row| {
318 Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?))
319 })?;
320
321 for row in rows {
322 let (id, entity_type) = row?;
323 neighbors.push((id, entity_type));
324 }
325
326 Ok(neighbors)
327}
328
329fn get_outgoing_relations(conn: &Connection, entity_id: i64) -> Result<Vec<(i64, String, f64)>> {
330 let mut relations = Vec::new();
331
332 let mut stmt =
333 conn.prepare("SELECT target_id, rel_type, weight FROM kg_relations WHERE source_id = ?1")?;
334
335 let rows = stmt.query_map([entity_id], |row| {
336 Ok((
337 row.get::<_, i64>(0)?,
338 row.get::<_, String>(1)?,
339 row.get::<_, f64>(2)?,
340 ))
341 })?;
342
343 for row in rows {
344 relations.push(row?);
345 }
346
347 Ok(relations)
348}
349
350fn reconstruct_path(
351 from_id: i64,
352 to_id: i64,
353 visited: &HashMap<i64, Option<(i64, String, f64)>>,
354) -> Result<TraversalPath> {
355 let mut steps = Vec::new();
356 let mut current = to_id;
357 let mut total_weight = 0.0;
358
359 while current != from_id {
360 if let Some(Some((from, rel_type, weight))) = visited.get(¤t) {
361 steps.push(PathStep {
362 from_id: *from,
363 to_id: current,
364 relation_type: rel_type.clone(),
365 weight: *weight,
366 });
367 total_weight += weight;
368 current = *from;
369 } else {
370 break;
371 }
372 }
373
374 steps.reverse();
375
376 Ok(TraversalPath {
377 start_id: from_id,
378 end_id: to_id,
379 steps,
380 total_weight,
381 })
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use rusqlite::Connection;
388
389 fn setup_test_db() -> Connection {
390 let conn = Connection::open_in_memory().unwrap();
391 crate::schema::create_schema(&conn).unwrap();
392
393 use crate::graph::entity::{insert_entity, Entity};
394 use crate::graph::relation::{insert_relation, Relation};
395
396 let id_a = insert_entity(&conn, &Entity::new("paper", "A")).unwrap();
398 let id_b = insert_entity(&conn, &Entity::new("paper", "B")).unwrap();
399 let id_c = insert_entity(&conn, &Entity::new("paper", "C")).unwrap();
400 let id_d = insert_entity(&conn, &Entity::new("paper", "D")).unwrap();
401
402 insert_relation(&conn, &Relation::new(id_a, id_b, "cites", 1.0).unwrap()).unwrap();
404 insert_relation(&conn, &Relation::new(id_b, id_c, "cites", 1.0).unwrap()).unwrap();
405 insert_relation(&conn, &Relation::new(id_a, id_d, "cites", 0.5).unwrap()).unwrap();
406
407 conn
408 }
409
410 #[test]
411 fn test_bfs_traversal() {
412 let conn = setup_test_db();
413 let query = TraversalQuery {
414 direction: Direction::Outgoing,
415 max_depth: 2,
416 ..Default::default()
417 };
418
419 let result = bfs_traversal(&conn, 1, query).unwrap();
420
421 assert_eq!(result.len(), 4); assert!(result.iter().any(|n| n.entity_id == 1 && n.depth == 0));
423 assert!(result.iter().any(|n| n.entity_id == 2 && n.depth == 1));
424 assert!(result.iter().any(|n| n.entity_id == 3 && n.depth == 2));
425 assert!(result.iter().any(|n| n.entity_id == 4 && n.depth == 1));
426 }
427
428 #[test]
429 fn test_dfs_traversal() {
430 let conn = setup_test_db();
431 let query = TraversalQuery {
432 direction: Direction::Outgoing,
433 max_depth: 2,
434 ..Default::default()
435 };
436
437 let result = dfs_traversal(&conn, 1, query).unwrap();
438
439 assert_eq!(result.len(), 4);
440 assert_eq!(result[0].entity_id, 1); }
442
443 #[test]
444 fn test_shortest_path() {
445 let conn = setup_test_db();
446
447 let path = find_shortest_path(&conn, 1, 3, 5).unwrap();
449 assert!(path.is_some());
450
451 let path = path.unwrap();
452 assert_eq!(path.start_id, 1);
453 assert_eq!(path.end_id, 3);
454 assert_eq!(path.steps.len(), 2); let path = find_shortest_path(&conn, 1, 4, 5).unwrap();
458 assert!(path.is_some());
459 let path = path.unwrap();
460 assert_eq!(path.steps.len(), 1);
461 }
462
463 #[test]
464 fn test_no_path() {
465 let conn = setup_test_db();
466
467 let path = find_shortest_path(&conn, 4, 1, 5).unwrap();
469 assert!(path.is_none());
470 }
471
472 #[test]
473 fn test_graph_stats() {
474 let conn = setup_test_db();
475
476 let stats = compute_graph_stats(&conn).unwrap();
477
478 assert_eq!(stats.total_entities, 4);
479 assert_eq!(stats.total_relations, 3);
480 assert_eq!(stats.max_degree, 2); }
482}