1use crate::error::Result;
2use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8#[derive(Debug, Clone)]
10pub struct TraversalNode {
11 pub entity_id: i64,
12 pub entity_type: String,
13 pub depth: u32,
14}
15
16#[derive(Debug, Clone)]
18pub struct PathStep {
19 pub from_id: i64,
20 pub to_id: i64,
21 pub relation_type: String,
22 pub weight: f64,
23}
24
25#[derive(Debug, Clone)]
27pub struct TraversalPath {
28 pub start_id: i64,
29 pub end_id: i64,
30 pub steps: Vec<PathStep>,
31 pub total_weight: f64,
32}
33
34#[derive(Debug, Clone)]
36pub struct GraphStats {
37 pub total_entities: i64,
38 pub total_relations: i64,
39 pub avg_degree: f64,
40 pub max_degree: i64,
41 pub density: f64,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum Direction {
47 Outgoing,
48 Incoming,
49 Both,
50}
51
52#[derive(Debug, Clone)]
54pub struct TraversalQuery {
55 pub direction: Direction,
56 pub rel_types: Option<Vec<String>>,
57 pub min_weight: Option<f64>,
58 pub max_depth: u32,
59}
60
61impl Default for TraversalQuery {
62 fn default() -> Self {
63 Self {
64 direction: Direction::Both,
65 rel_types: None,
66 min_weight: None,
67 max_depth: 3,
68 }
69 }
70}
71
72pub fn bfs_traversal(
76 conn: &Connection,
77 start_id: i64,
78 query: TraversalQuery,
79) -> Result<Vec<TraversalNode>> {
80 let mut result = Vec::new();
81 let mut visited = HashSet::new();
82 let mut queue = VecDeque::new();
83
84 let start_type: String = conn.query_row(
86 "SELECT entity_type FROM entities WHERE id = ?1",
87 [start_id],
88 |row| row.get(0),
89 )?;
90
91 queue.push_back((start_id, start_type, 0u32));
92 visited.insert(start_id);
93
94 while let Some((entity_id, _entity_type, depth)) = queue.pop_front() {
95 if depth > query.max_depth {
96 continue;
97 }
98
99 result.push(TraversalNode {
100 entity_id,
101 entity_type: _entity_type.clone(),
102 depth,
103 });
104
105 if depth == query.max_depth {
106 continue;
107 }
108
109 let neighbors = get_neighbors(conn, entity_id, &query)?;
111
112 for (neighbor_id, neighbor_type) in neighbors {
113 if !visited.contains(&neighbor_id) {
114 visited.insert(neighbor_id);
115 queue.push_back((neighbor_id, neighbor_type, depth + 1));
116 }
117 }
118 }
119
120 Ok(result)
121}
122
123pub fn dfs_traversal(
127 conn: &Connection,
128 start_id: i64,
129 query: TraversalQuery,
130) -> Result<Vec<TraversalNode>> {
131 let mut result = Vec::new();
132 let mut visited = HashSet::new();
133
134 let start_type: String = conn.query_row(
136 "SELECT entity_type FROM entities WHERE id = ?1",
137 [start_id],
138 |row| row.get(0),
139 )?;
140
141 dfs_visit(
142 conn,
143 start_id,
144 start_type,
145 0,
146 &query,
147 &mut visited,
148 &mut result,
149 )?;
150
151 Ok(result)
152}
153
154fn dfs_visit(
155 conn: &Connection,
156 entity_id: i64,
157 entity_type: String,
158 depth: u32,
159 query: &TraversalQuery,
160 visited: &mut HashSet<i64>,
161 result: &mut Vec<TraversalNode>,
162) -> Result<()> {
163 if visited.contains(&entity_id) || depth > query.max_depth {
164 return Ok(());
165 }
166
167 visited.insert(entity_id);
168 result.push(TraversalNode {
169 entity_id,
170 entity_type: entity_type.clone(),
171 depth,
172 });
173
174 if depth == query.max_depth {
175 return Ok(());
176 }
177
178 let neighbors = get_neighbors(conn, entity_id, query)?;
179
180 for (neighbor_id, neighbor_type) in neighbors {
181 dfs_visit(
182 conn,
183 neighbor_id,
184 neighbor_type,
185 depth + 1,
186 query,
187 visited,
188 result,
189 )?;
190 }
191
192 Ok(())
193}
194
195pub fn find_shortest_path(
199 conn: &Connection,
200 from_id: i64,
201 to_id: i64,
202 max_depth: u32,
203) -> Result<Option<TraversalPath>> {
204 if from_id == to_id {
205 return Ok(Some(TraversalPath {
206 start_id: from_id,
207 end_id: to_id,
208 steps: Vec::new(),
209 total_weight: 0.0,
210 }));
211 }
212
213 let mut visited = HashMap::new(); let mut queue = VecDeque::new();
215
216 queue.push_back(from_id);
217 visited.insert(from_id, None);
218
219 while let Some(current_id) = queue.pop_front() {
220 let current_depth = count_depth(&visited, current_id);
221
222 if current_depth >= max_depth {
223 continue;
224 }
225
226 let relations = get_outgoing_relations(conn, current_id)?;
228
229 for (target_id, rel_type, weight) in relations {
230 if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(target_id) {
231 e.insert(Some((current_id, rel_type.clone(), weight)));
232
233 if target_id == to_id {
234 return Ok(Some(reconstruct_path(from_id, to_id, &visited)?));
236 }
237
238 queue.push_back(target_id);
239 }
240 }
241 }
242
243 Ok(None)
244}
245
246pub fn compute_graph_stats(conn: &Connection) -> Result<GraphStats> {
248 let total_entities: i64 =
249 conn.query_row("SELECT COUNT(*) FROM entities", [], |row| row.get(0))?;
250
251 let total_relations: i64 =
252 conn.query_row("SELECT COUNT(*) FROM relations", [], |row| row.get(0))?;
253
254 let max_degree: i64 = conn.query_row(
255 "SELECT COALESCE(MAX(cnt), 0) FROM (
256 SELECT from_id as id, COUNT(*) as cnt FROM relations GROUP BY from_id
257 UNION ALL
258 SELECT to_id as id, COUNT(*) as cnt FROM relations GROUP BY to_id
259 )",
260 [],
261 |row| row.get(0),
262 )?;
263
264 let avg_degree = if total_entities > 0 {
265 (total_relations as f64 * 2.0) / (total_entities as f64)
266 } else {
267 0.0
268 };
269
270 let density = if total_entities > 1 {
271 let possible_edges = total_entities * (total_entities - 1);
272 total_relations as f64 / possible_edges as f64
273 } else {
274 0.0
275 };
276
277 Ok(GraphStats {
278 total_entities,
279 total_relations,
280 avg_degree,
281 max_degree,
282 density,
283 })
284}
285
286fn get_neighbors(
289 conn: &Connection,
290 entity_id: i64,
291 query: &TraversalQuery,
292) -> Result<Vec<(i64, String)>> {
293 let mut neighbors = Vec::new();
294
295 let sql = match query.direction {
296 Direction::Outgoing => {
297 "SELECT r.to_id, e.entity_type FROM relations r
298 JOIN entities e ON r.to_id = e.id
299 WHERE r.from_id = ?1"
300 }
301 Direction::Incoming => {
302 "SELECT r.from_id, e.entity_type FROM relations r
303 JOIN entities e ON r.from_id = e.id
304 WHERE r.to_id = ?1"
305 }
306 Direction::Both => {
307 "SELECT r.to_id, e.entity_type FROM relations r
308 JOIN entities e ON r.to_id = e.id
309 WHERE r.from_id = ?1
310 UNION
311 SELECT r.from_id, e.entity_type FROM relations r
312 JOIN entities e ON r.from_id = e.id
313 WHERE r.to_id = ?1"
314 }
315 };
316
317 let mut stmt = conn.prepare(sql)?;
318
319 let rows = stmt.query_map([entity_id], |row| {
320 Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?))
321 })?;
322
323 for row in rows {
324 let (id, entity_type) = row?;
325 neighbors.push((id, entity_type));
326 }
327
328 Ok(neighbors)
329}
330
331fn get_outgoing_relations(conn: &Connection, entity_id: i64) -> Result<Vec<(i64, String, f64)>> {
332 let mut relations = Vec::new();
333
334 let mut stmt =
335 conn.prepare("SELECT to_id, relation_type, weight FROM relations WHERE from_id = ?1")?;
336
337 let rows = stmt.query_map([entity_id], |row| {
338 Ok((
339 row.get::<_, i64>(0)?,
340 row.get::<_, String>(1)?,
341 row.get::<_, f64>(2)?,
342 ))
343 })?;
344
345 for row in rows {
346 relations.push(row?);
347 }
348
349 Ok(relations)
350}
351
352fn count_depth(visited: &HashMap<i64, Option<(i64, String, f64)>>, entity_id: i64) -> u32 {
353 let mut depth = 0u32;
354 let mut current = entity_id;
355
356 while let Some(Some((from_id, _, _))) = visited.get(¤t) {
357 depth += 1;
358 current = *from_id;
359 if depth > 100 {
360 break;
361 }
362 }
363
364 depth
365}
366
367fn reconstruct_path(
368 from_id: i64,
369 to_id: i64,
370 visited: &HashMap<i64, Option<(i64, String, f64)>>,
371) -> Result<TraversalPath> {
372 let mut steps = Vec::new();
373 let mut current = to_id;
374 let mut total_weight = 0.0;
375
376 while current != from_id {
377 if let Some(Some((from, rel_type, weight))) = visited.get(¤t) {
378 steps.push(PathStep {
379 from_id: *from,
380 to_id: current,
381 relation_type: rel_type.clone(),
382 weight: *weight,
383 });
384 total_weight += weight;
385 current = *from;
386 } else {
387 break;
388 }
389 }
390
391 steps.reverse();
392
393 Ok(TraversalPath {
394 start_id: from_id,
395 end_id: to_id,
396 steps,
397 total_weight,
398 })
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use rusqlite::Connection;
405
406 fn setup_test_db() -> Connection {
407 let conn = Connection::open_in_memory().unwrap();
408
409 conn.execute_batch(
410 "CREATE TABLE entities (
411 id INTEGER PRIMARY KEY,
412 entity_type TEXT NOT NULL,
413 name TEXT,
414 metadata TEXT
415 );
416 CREATE TABLE relations (
417 id INTEGER PRIMARY KEY,
418 from_id INTEGER NOT NULL,
419 to_id INTEGER NOT NULL,
420 relation_type TEXT NOT NULL,
421 weight REAL DEFAULT 1.0,
422 confidence REAL DEFAULT 1.0
423 );
424 ",
425 )
426 .unwrap();
427
428 conn.execute(
430 "INSERT INTO entities (id, entity_type, name) VALUES (1, 'paper', 'A')",
431 [],
432 )
433 .unwrap();
434 conn.execute(
435 "INSERT INTO entities (id, entity_type, name) VALUES (2, 'paper', 'B')",
436 [],
437 )
438 .unwrap();
439 conn.execute(
440 "INSERT INTO entities (id, entity_type, name) VALUES (3, 'paper', 'C')",
441 [],
442 )
443 .unwrap();
444 conn.execute(
445 "INSERT INTO entities (id, entity_type, name) VALUES (4, 'paper', 'D')",
446 [],
447 )
448 .unwrap();
449
450 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'cites', 1.0)", []).unwrap();
452 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'cites', 1.0)", []).unwrap();
453 conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 4, 'cites', 0.5)", []).unwrap();
454
455 conn
456 }
457
458 #[test]
459 fn test_bfs_traversal() {
460 let conn = setup_test_db();
461 let query = TraversalQuery {
462 direction: Direction::Outgoing,
463 max_depth: 2,
464 ..Default::default()
465 };
466
467 let result = bfs_traversal(&conn, 1, query).unwrap();
468
469 assert_eq!(result.len(), 4); assert!(result.iter().any(|n| n.entity_id == 1 && n.depth == 0));
471 assert!(result.iter().any(|n| n.entity_id == 2 && n.depth == 1));
472 assert!(result.iter().any(|n| n.entity_id == 3 && n.depth == 2));
473 assert!(result.iter().any(|n| n.entity_id == 4 && n.depth == 1));
474 }
475
476 #[test]
477 fn test_dfs_traversal() {
478 let conn = setup_test_db();
479 let query = TraversalQuery {
480 direction: Direction::Outgoing,
481 max_depth: 2,
482 ..Default::default()
483 };
484
485 let result = dfs_traversal(&conn, 1, query).unwrap();
486
487 assert_eq!(result.len(), 4);
488 assert_eq!(result[0].entity_id, 1); }
490
491 #[test]
492 fn test_shortest_path() {
493 let conn = setup_test_db();
494
495 let path = find_shortest_path(&conn, 1, 3, 5).unwrap();
497 assert!(path.is_some());
498
499 let path = path.unwrap();
500 assert_eq!(path.start_id, 1);
501 assert_eq!(path.end_id, 3);
502 assert_eq!(path.steps.len(), 2); let path = find_shortest_path(&conn, 1, 4, 5).unwrap();
506 assert!(path.is_some());
507 let path = path.unwrap();
508 assert_eq!(path.steps.len(), 1);
509 }
510
511 #[test]
512 fn test_no_path() {
513 let conn = setup_test_db();
514
515 let path = find_shortest_path(&conn, 4, 1, 5).unwrap();
517 assert!(path.is_none());
518 }
519
520 #[test]
521 fn test_graph_stats() {
522 let conn = setup_test_db();
523
524 let stats = compute_graph_stats(&conn).unwrap();
525
526 assert_eq!(stats.total_entities, 4);
527 assert_eq!(stats.total_relations, 3);
528 assert_eq!(stats.max_degree, 2); }
530}