Skip to main content

sqlite_knowledge_graph/graph/
traversal.rs

1use crate::error::Result;
2/// Graph traversal algorithms for sqlite-knowledge-graph
3///
4/// Provides BFS/DFS traversal, shortest path, and graph statistics.
5use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Node with depth information for traversal results
9#[derive(Debug, Clone)]
10pub struct TraversalNode {
11    pub entity_id: i64,
12    pub entity_type: String,
13    pub depth: u32,
14}
15
16/// Step in a path (edge + target node)
17#[derive(Debug, Clone)]
18pub struct PathStep {
19    pub from_id: i64,
20    pub to_id: i64,
21    pub relation_type: String,
22    pub weight: f64,
23}
24
25/// Complete path from source to target
26#[derive(Debug, Clone)]
27pub struct TraversalPath {
28    pub start_id: i64,
29    pub end_id: i64,
30    pub steps: Vec<PathStep>,
31    pub total_weight: f64,
32}
33
34/// Graph statistics
35#[derive(Debug, Clone)]
36pub struct GraphStats {
37    pub total_entities: i64,
38    pub total_relations: i64,
39    pub avg_degree: f64,
40    pub max_degree: i64,
41    pub density: f64,
42}
43
44/// Direction for traversal
45#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum Direction {
47    Outgoing,
48    Incoming,
49    Both,
50}
51
52/// Query parameters for traversal
53#[derive(Debug, Clone)]
54pub struct TraversalQuery {
55    pub direction: Direction,
56    pub rel_types: Option<Vec<String>>,
57    pub min_weight: Option<f64>,
58    pub max_depth: u32,
59}
60
61impl Default for TraversalQuery {
62    fn default() -> Self {
63        Self {
64            direction: Direction::Both,
65            rel_types: None,
66            min_weight: None,
67            max_depth: 3,
68        }
69    }
70}
71
72/// BFS traversal from a starting entity
73///
74/// Returns all reachable entities within max_depth, with their depth information.
75pub fn bfs_traversal(
76    conn: &Connection,
77    start_id: i64,
78    query: TraversalQuery,
79) -> Result<Vec<TraversalNode>> {
80    let mut result = Vec::new();
81    let mut visited = HashSet::new();
82    let mut queue = VecDeque::new();
83
84    // Get start entity type
85    let start_type: String = conn.query_row(
86        "SELECT entity_type FROM entities WHERE id = ?1",
87        [start_id],
88        |row| row.get(0),
89    )?;
90
91    queue.push_back((start_id, start_type, 0u32));
92    visited.insert(start_id);
93
94    while let Some((entity_id, _entity_type, depth)) = queue.pop_front() {
95        if depth > query.max_depth {
96            continue;
97        }
98
99        result.push(TraversalNode {
100            entity_id,
101            entity_type: _entity_type.clone(),
102            depth,
103        });
104
105        if depth == query.max_depth {
106            continue;
107        }
108
109        // Get neighbors based on direction
110        let neighbors = get_neighbors(conn, entity_id, &query)?;
111
112        for (neighbor_id, neighbor_type) in neighbors {
113            if !visited.contains(&neighbor_id) {
114                visited.insert(neighbor_id);
115                queue.push_back((neighbor_id, neighbor_type, depth + 1));
116            }
117        }
118    }
119
120    Ok(result)
121}
122
123/// DFS traversal from a starting entity
124///
125/// Returns all reachable entities within max_depth using depth-first search.
126pub fn dfs_traversal(
127    conn: &Connection,
128    start_id: i64,
129    query: TraversalQuery,
130) -> Result<Vec<TraversalNode>> {
131    let mut result = Vec::new();
132    let mut visited = HashSet::new();
133
134    // Get start entity type
135    let start_type: String = conn.query_row(
136        "SELECT entity_type FROM entities WHERE id = ?1",
137        [start_id],
138        |row| row.get(0),
139    )?;
140
141    dfs_visit(
142        conn,
143        start_id,
144        start_type,
145        0,
146        &query,
147        &mut visited,
148        &mut result,
149    )?;
150
151    Ok(result)
152}
153
154fn dfs_visit(
155    conn: &Connection,
156    entity_id: i64,
157    entity_type: String,
158    depth: u32,
159    query: &TraversalQuery,
160    visited: &mut HashSet<i64>,
161    result: &mut Vec<TraversalNode>,
162) -> Result<()> {
163    if visited.contains(&entity_id) || depth > query.max_depth {
164        return Ok(());
165    }
166
167    visited.insert(entity_id);
168    result.push(TraversalNode {
169        entity_id,
170        entity_type: entity_type.clone(),
171        depth,
172    });
173
174    if depth == query.max_depth {
175        return Ok(());
176    }
177
178    let neighbors = get_neighbors(conn, entity_id, query)?;
179
180    for (neighbor_id, neighbor_type) in neighbors {
181        dfs_visit(
182            conn,
183            neighbor_id,
184            neighbor_type,
185            depth + 1,
186            query,
187            visited,
188            result,
189        )?;
190    }
191
192    Ok(())
193}
194
195/// Find shortest path between two entities using BFS
196///
197/// Returns the shortest path (if exists) with all intermediate steps.
198pub fn find_shortest_path(
199    conn: &Connection,
200    from_id: i64,
201    to_id: i64,
202    max_depth: u32,
203) -> Result<Option<TraversalPath>> {
204    if from_id == to_id {
205        return Ok(Some(TraversalPath {
206            start_id: from_id,
207            end_id: to_id,
208            steps: Vec::new(),
209            total_weight: 0.0,
210        }));
211    }
212
213    let mut visited = HashMap::new(); // entity_id -> (from_id, relation_type, weight)
214    let mut queue = VecDeque::new();
215
216    queue.push_back(from_id);
217    visited.insert(from_id, None);
218
219    while let Some(current_id) = queue.pop_front() {
220        let current_depth = count_depth(&visited, current_id);
221
222        if current_depth >= max_depth {
223            continue;
224        }
225
226        // Get outgoing relations
227        let relations = get_outgoing_relations(conn, current_id)?;
228
229        for (target_id, rel_type, weight) in relations {
230            if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(target_id) {
231                e.insert(Some((current_id, rel_type.clone(), weight)));
232
233                if target_id == to_id {
234                    // Reconstruct path
235                    return Ok(Some(reconstruct_path(from_id, to_id, &visited)?));
236                }
237
238                queue.push_back(target_id);
239            }
240        }
241    }
242
243    Ok(None)
244}
245
246/// Compute graph statistics
247pub fn compute_graph_stats(conn: &Connection) -> Result<GraphStats> {
248    let total_entities: i64 =
249        conn.query_row("SELECT COUNT(*) FROM entities", [], |row| row.get(0))?;
250
251    let total_relations: i64 =
252        conn.query_row("SELECT COUNT(*) FROM relations", [], |row| row.get(0))?;
253
254    let max_degree: i64 = conn.query_row(
255        "SELECT COALESCE(MAX(cnt), 0) FROM (
256            SELECT from_id as id, COUNT(*) as cnt FROM relations GROUP BY from_id
257            UNION ALL
258            SELECT to_id as id, COUNT(*) as cnt FROM relations GROUP BY to_id
259        )",
260        [],
261        |row| row.get(0),
262    )?;
263
264    let avg_degree = if total_entities > 0 {
265        (total_relations as f64 * 2.0) / (total_entities as f64)
266    } else {
267        0.0
268    };
269
270    let density = if total_entities > 1 {
271        let possible_edges = total_entities * (total_entities - 1);
272        total_relations as f64 / possible_edges as f64
273    } else {
274        0.0
275    };
276
277    Ok(GraphStats {
278        total_entities,
279        total_relations,
280        avg_degree,
281        max_degree,
282        density,
283    })
284}
285
286// Helper functions
287
288fn get_neighbors(
289    conn: &Connection,
290    entity_id: i64,
291    query: &TraversalQuery,
292) -> Result<Vec<(i64, String)>> {
293    let mut neighbors = Vec::new();
294
295    let sql = match query.direction {
296        Direction::Outgoing => {
297            "SELECT r.to_id, e.entity_type FROM relations r
298             JOIN entities e ON r.to_id = e.id
299             WHERE r.from_id = ?1"
300        }
301        Direction::Incoming => {
302            "SELECT r.from_id, e.entity_type FROM relations r
303             JOIN entities e ON r.from_id = e.id
304             WHERE r.to_id = ?1"
305        }
306        Direction::Both => {
307            "SELECT r.to_id, e.entity_type FROM relations r
308             JOIN entities e ON r.to_id = e.id
309             WHERE r.from_id = ?1
310             UNION
311             SELECT r.from_id, e.entity_type FROM relations r
312             JOIN entities e ON r.from_id = e.id
313             WHERE r.to_id = ?1"
314        }
315    };
316
317    let mut stmt = conn.prepare(sql)?;
318
319    let rows = stmt.query_map([entity_id], |row| {
320        Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?))
321    })?;
322
323    for row in rows {
324        let (id, entity_type) = row?;
325        neighbors.push((id, entity_type));
326    }
327
328    Ok(neighbors)
329}
330
331fn get_outgoing_relations(conn: &Connection, entity_id: i64) -> Result<Vec<(i64, String, f64)>> {
332    let mut relations = Vec::new();
333
334    let mut stmt =
335        conn.prepare("SELECT to_id, relation_type, weight FROM relations WHERE from_id = ?1")?;
336
337    let rows = stmt.query_map([entity_id], |row| {
338        Ok((
339            row.get::<_, i64>(0)?,
340            row.get::<_, String>(1)?,
341            row.get::<_, f64>(2)?,
342        ))
343    })?;
344
345    for row in rows {
346        relations.push(row?);
347    }
348
349    Ok(relations)
350}
351
352fn count_depth(visited: &HashMap<i64, Option<(i64, String, f64)>>, entity_id: i64) -> u32 {
353    let mut depth = 0u32;
354    let mut current = entity_id;
355
356    while let Some(Some((from_id, _, _))) = visited.get(&current) {
357        depth += 1;
358        current = *from_id;
359        if depth > 100 {
360            break;
361        }
362    }
363
364    depth
365}
366
367fn reconstruct_path(
368    from_id: i64,
369    to_id: i64,
370    visited: &HashMap<i64, Option<(i64, String, f64)>>,
371) -> Result<TraversalPath> {
372    let mut steps = Vec::new();
373    let mut current = to_id;
374    let mut total_weight = 0.0;
375
376    while current != from_id {
377        if let Some(Some((from, rel_type, weight))) = visited.get(&current) {
378            steps.push(PathStep {
379                from_id: *from,
380                to_id: current,
381                relation_type: rel_type.clone(),
382                weight: *weight,
383            });
384            total_weight += weight;
385            current = *from;
386        } else {
387            break;
388        }
389    }
390
391    steps.reverse();
392
393    Ok(TraversalPath {
394        start_id: from_id,
395        end_id: to_id,
396        steps,
397        total_weight,
398    })
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use rusqlite::Connection;
405
406    fn setup_test_db() -> Connection {
407        let conn = Connection::open_in_memory().unwrap();
408
409        conn.execute_batch(
410            "CREATE TABLE entities (
411                id INTEGER PRIMARY KEY,
412                entity_type TEXT NOT NULL,
413                name TEXT,
414                metadata TEXT
415            );
416            CREATE TABLE relations (
417                id INTEGER PRIMARY KEY,
418                from_id INTEGER NOT NULL,
419                to_id INTEGER NOT NULL,
420                relation_type TEXT NOT NULL,
421                weight REAL DEFAULT 1.0,
422                confidence REAL DEFAULT 1.0
423            );
424            ",
425        )
426        .unwrap();
427
428        // Insert test entities
429        conn.execute(
430            "INSERT INTO entities (id, entity_type, name) VALUES (1, 'paper', 'A')",
431            [],
432        )
433        .unwrap();
434        conn.execute(
435            "INSERT INTO entities (id, entity_type, name) VALUES (2, 'paper', 'B')",
436            [],
437        )
438        .unwrap();
439        conn.execute(
440            "INSERT INTO entities (id, entity_type, name) VALUES (3, 'paper', 'C')",
441            [],
442        )
443        .unwrap();
444        conn.execute(
445            "INSERT INTO entities (id, entity_type, name) VALUES (4, 'paper', 'D')",
446            [],
447        )
448        .unwrap();
449
450        // Insert test relations: A -> B -> C, A -> D
451        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 2, 'cites', 1.0)", []).unwrap();
452        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (2, 3, 'cites', 1.0)", []).unwrap();
453        conn.execute("INSERT INTO relations (from_id, to_id, relation_type, weight) VALUES (1, 4, 'cites', 0.5)", []).unwrap();
454
455        conn
456    }
457
458    #[test]
459    fn test_bfs_traversal() {
460        let conn = setup_test_db();
461        let query = TraversalQuery {
462            direction: Direction::Outgoing,
463            max_depth: 2,
464            ..Default::default()
465        };
466
467        let result = bfs_traversal(&conn, 1, query).unwrap();
468
469        assert_eq!(result.len(), 4); // A, B, C, D
470        assert!(result.iter().any(|n| n.entity_id == 1 && n.depth == 0));
471        assert!(result.iter().any(|n| n.entity_id == 2 && n.depth == 1));
472        assert!(result.iter().any(|n| n.entity_id == 3 && n.depth == 2));
473        assert!(result.iter().any(|n| n.entity_id == 4 && n.depth == 1));
474    }
475
476    #[test]
477    fn test_dfs_traversal() {
478        let conn = setup_test_db();
479        let query = TraversalQuery {
480            direction: Direction::Outgoing,
481            max_depth: 2,
482            ..Default::default()
483        };
484
485        let result = dfs_traversal(&conn, 1, query).unwrap();
486
487        assert_eq!(result.len(), 4);
488        assert_eq!(result[0].entity_id, 1); // DFS visits start first
489    }
490
491    #[test]
492    fn test_shortest_path() {
493        let conn = setup_test_db();
494
495        // Path A -> B -> C
496        let path = find_shortest_path(&conn, 1, 3, 5).unwrap();
497        assert!(path.is_some());
498
499        let path = path.unwrap();
500        assert_eq!(path.start_id, 1);
501        assert_eq!(path.end_id, 3);
502        assert_eq!(path.steps.len(), 2); // A->B, B->C
503
504        // Direct path A -> D
505        let path = find_shortest_path(&conn, 1, 4, 5).unwrap();
506        assert!(path.is_some());
507        let path = path.unwrap();
508        assert_eq!(path.steps.len(), 1);
509    }
510
511    #[test]
512    fn test_no_path() {
513        let conn = setup_test_db();
514
515        // No path from D to A
516        let path = find_shortest_path(&conn, 4, 1, 5).unwrap();
517        assert!(path.is_none());
518    }
519
520    #[test]
521    fn test_graph_stats() {
522        let conn = setup_test_db();
523
524        let stats = compute_graph_stats(&conn).unwrap();
525
526        assert_eq!(stats.total_entities, 4);
527        assert_eq!(stats.total_relations, 3);
528        assert_eq!(stats.max_degree, 2); // Entity 1 has 2 outgoing edges
529    }
530}