Skip to main content

sqlite_knowledge_graph/graph/
traversal.rs

1use crate::error::Result;
2/// Graph traversal algorithms for sqlite-knowledge-graph
3///
4/// Provides BFS/DFS traversal, shortest path, and graph statistics.
5use rusqlite::Connection;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Node with depth information for traversal results
9#[derive(Debug, Clone)]
10pub struct TraversalNode {
11    pub entity_id: i64,
12    pub entity_type: String,
13    pub depth: u32,
14}
15
16/// Step in a path (edge + target node)
17#[derive(Debug, Clone)]
18pub struct PathStep {
19    pub from_id: i64,
20    pub to_id: i64,
21    pub relation_type: String,
22    pub weight: f64,
23}
24
25/// Complete path from source to target
26#[derive(Debug, Clone)]
27pub struct TraversalPath {
28    pub start_id: i64,
29    pub end_id: i64,
30    pub steps: Vec<PathStep>,
31    pub total_weight: f64,
32}
33
34/// Graph statistics
35#[derive(Debug, Clone)]
36pub struct GraphStats {
37    pub total_entities: i64,
38    pub total_relations: i64,
39    pub avg_degree: f64,
40    pub max_degree: i64,
41    pub density: f64,
42}
43
44/// Direction for traversal
45#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum Direction {
47    Outgoing,
48    Incoming,
49    Both,
50}
51
52/// Query parameters for traversal
53#[derive(Debug, Clone)]
54pub struct TraversalQuery {
55    pub direction: Direction,
56    pub rel_types: Option<Vec<String>>,
57    pub min_weight: Option<f64>,
58    pub max_depth: u32,
59}
60
61impl Default for TraversalQuery {
62    fn default() -> Self {
63        Self {
64            direction: Direction::Both,
65            rel_types: None,
66            min_weight: None,
67            max_depth: 3,
68        }
69    }
70}
71
72/// BFS traversal from a starting entity
73///
74/// Returns all reachable entities within max_depth, with their depth information.
75pub fn bfs_traversal(
76    conn: &Connection,
77    start_id: i64,
78    query: TraversalQuery,
79) -> Result<Vec<TraversalNode>> {
80    let mut result = Vec::new();
81    let mut visited = HashSet::new();
82    let mut queue = VecDeque::new();
83
84    // Get start entity type
85    let start_type: String = conn.query_row(
86        "SELECT entity_type FROM kg_entities WHERE id = ?1",
87        [start_id],
88        |row| row.get(0),
89    )?;
90
91    queue.push_back((start_id, start_type, 0u32));
92    visited.insert(start_id);
93
94    while let Some((entity_id, _entity_type, depth)) = queue.pop_front() {
95        if depth > query.max_depth {
96            continue;
97        }
98
99        result.push(TraversalNode {
100            entity_id,
101            entity_type: _entity_type.clone(),
102            depth,
103        });
104
105        if depth == query.max_depth {
106            continue;
107        }
108
109        // Get neighbors based on direction
110        let neighbors = get_neighbors(conn, entity_id, &query)?;
111
112        for (neighbor_id, neighbor_type) in neighbors {
113            if !visited.contains(&neighbor_id) {
114                visited.insert(neighbor_id);
115                queue.push_back((neighbor_id, neighbor_type, depth + 1));
116            }
117        }
118    }
119
120    Ok(result)
121}
122
123/// DFS traversal from a starting entity
124///
125/// Returns all reachable entities within max_depth using depth-first search.
126pub fn dfs_traversal(
127    conn: &Connection,
128    start_id: i64,
129    query: TraversalQuery,
130) -> Result<Vec<TraversalNode>> {
131    let mut result = Vec::new();
132    let mut visited = HashSet::new();
133
134    // Get start entity type
135    let start_type: String = conn.query_row(
136        "SELECT entity_type FROM kg_entities WHERE id = ?1",
137        [start_id],
138        |row| row.get(0),
139    )?;
140
141    dfs_visit(
142        conn,
143        start_id,
144        start_type,
145        0,
146        &query,
147        &mut visited,
148        &mut result,
149    )?;
150
151    Ok(result)
152}
153
154fn dfs_visit(
155    conn: &Connection,
156    entity_id: i64,
157    entity_type: String,
158    depth: u32,
159    query: &TraversalQuery,
160    visited: &mut HashSet<i64>,
161    result: &mut Vec<TraversalNode>,
162) -> Result<()> {
163    if visited.contains(&entity_id) || depth > query.max_depth {
164        return Ok(());
165    }
166
167    visited.insert(entity_id);
168    result.push(TraversalNode {
169        entity_id,
170        entity_type: entity_type.clone(),
171        depth,
172    });
173
174    if depth == query.max_depth {
175        return Ok(());
176    }
177
178    let neighbors = get_neighbors(conn, entity_id, query)?;
179
180    for (neighbor_id, neighbor_type) in neighbors {
181        dfs_visit(
182            conn,
183            neighbor_id,
184            neighbor_type,
185            depth + 1,
186            query,
187            visited,
188            result,
189        )?;
190    }
191
192    Ok(())
193}
194
195/// Find shortest path between two entities using BFS
196///
197/// Returns the shortest path (if exists) with all intermediate steps.
198pub fn find_shortest_path(
199    conn: &Connection,
200    from_id: i64,
201    to_id: i64,
202    max_depth: u32,
203) -> Result<Option<TraversalPath>> {
204    if from_id == to_id {
205        return Ok(Some(TraversalPath {
206            start_id: from_id,
207            end_id: to_id,
208            steps: Vec::new(),
209            total_weight: 0.0,
210        }));
211    }
212
213    let mut visited = HashMap::new(); // entity_id -> (from_id, relation_type, weight)
214    let mut queue: VecDeque<(i64, u32)> = VecDeque::new(); // (entity_id, depth)
215
216    queue.push_back((from_id, 0));
217    visited.insert(from_id, None);
218
219    while let Some((current_id, current_depth)) = queue.pop_front() {
220        if current_depth >= max_depth {
221            continue;
222        }
223
224        // Get outgoing relations
225        let relations = get_outgoing_relations(conn, current_id)?;
226
227        for (target_id, rel_type, weight) in relations {
228            if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(target_id) {
229                e.insert(Some((current_id, rel_type.clone(), weight)));
230
231                if target_id == to_id {
232                    // Reconstruct path
233                    return Ok(Some(reconstruct_path(from_id, to_id, &visited)?));
234                }
235
236                queue.push_back((target_id, current_depth + 1));
237            }
238        }
239    }
240
241    Ok(None)
242}
243
244/// Compute graph statistics
245pub fn compute_graph_stats(conn: &Connection) -> Result<GraphStats> {
246    let total_entities: i64 =
247        conn.query_row("SELECT COUNT(*) FROM kg_entities", [], |row| row.get(0))?;
248
249    let total_relations: i64 =
250        conn.query_row("SELECT COUNT(*) FROM kg_relations", [], |row| row.get(0))?;
251
252    let max_degree: i64 = conn.query_row(
253        "SELECT COALESCE(MAX(cnt), 0) FROM (
254            SELECT source_id as id, COUNT(*) as cnt FROM kg_relations GROUP BY source_id
255            UNION ALL
256            SELECT target_id as id, COUNT(*) as cnt FROM kg_relations GROUP BY target_id
257        )",
258        [],
259        |row| row.get(0),
260    )?;
261
262    let avg_degree = if total_entities > 0 {
263        (total_relations as f64 * 2.0) / (total_entities as f64)
264    } else {
265        0.0
266    };
267
268    let density = if total_entities > 1 {
269        let possible_edges = total_entities * (total_entities - 1);
270        total_relations as f64 / possible_edges as f64
271    } else {
272        0.0
273    };
274
275    Ok(GraphStats {
276        total_entities,
277        total_relations,
278        avg_degree,
279        max_degree,
280        density,
281    })
282}
283
284// Helper functions
285
286fn get_neighbors(
287    conn: &Connection,
288    entity_id: i64,
289    query: &TraversalQuery,
290) -> Result<Vec<(i64, String)>> {
291    let mut neighbors = Vec::new();
292
293    let sql = match query.direction {
294        Direction::Outgoing => {
295            "SELECT r.target_id, e.entity_type FROM kg_relations r
296             JOIN kg_entities e ON r.target_id = e.id
297             WHERE r.source_id = ?1"
298        }
299        Direction::Incoming => {
300            "SELECT r.source_id, e.entity_type FROM kg_relations r
301             JOIN kg_entities e ON r.source_id = e.id
302             WHERE r.target_id = ?1"
303        }
304        Direction::Both => {
305            "SELECT r.target_id, e.entity_type FROM kg_relations r
306             JOIN kg_entities e ON r.target_id = e.id
307             WHERE r.source_id = ?1
308             UNION
309             SELECT r.source_id, e.entity_type FROM kg_relations r
310             JOIN kg_entities e ON r.source_id = e.id
311             WHERE r.target_id = ?1"
312        }
313    };
314
315    let mut stmt = conn.prepare(sql)?;
316
317    let rows = stmt.query_map([entity_id], |row| {
318        Ok((row.get::<_, i64>(0)?, row.get::<_, String>(1)?))
319    })?;
320
321    for row in rows {
322        let (id, entity_type) = row?;
323        neighbors.push((id, entity_type));
324    }
325
326    Ok(neighbors)
327}
328
329fn get_outgoing_relations(conn: &Connection, entity_id: i64) -> Result<Vec<(i64, String, f64)>> {
330    let mut relations = Vec::new();
331
332    let mut stmt =
333        conn.prepare("SELECT target_id, rel_type, weight FROM kg_relations WHERE source_id = ?1")?;
334
335    let rows = stmt.query_map([entity_id], |row| {
336        Ok((
337            row.get::<_, i64>(0)?,
338            row.get::<_, String>(1)?,
339            row.get::<_, f64>(2)?,
340        ))
341    })?;
342
343    for row in rows {
344        relations.push(row?);
345    }
346
347    Ok(relations)
348}
349
350fn reconstruct_path(
351    from_id: i64,
352    to_id: i64,
353    visited: &HashMap<i64, Option<(i64, String, f64)>>,
354) -> Result<TraversalPath> {
355    let mut steps = Vec::new();
356    let mut current = to_id;
357    let mut total_weight = 0.0;
358
359    while current != from_id {
360        if let Some(Some((from, rel_type, weight))) = visited.get(&current) {
361            steps.push(PathStep {
362                from_id: *from,
363                to_id: current,
364                relation_type: rel_type.clone(),
365                weight: *weight,
366            });
367            total_weight += weight;
368            current = *from;
369        } else {
370            break;
371        }
372    }
373
374    steps.reverse();
375
376    Ok(TraversalPath {
377        start_id: from_id,
378        end_id: to_id,
379        steps,
380        total_weight,
381    })
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use rusqlite::Connection;
388
389    fn setup_test_db() -> Connection {
390        let conn = Connection::open_in_memory().unwrap();
391        crate::schema::create_schema(&conn).unwrap();
392
393        use crate::graph::entity::{insert_entity, Entity};
394        use crate::graph::relation::{insert_relation, Relation};
395
396        // Insert test entities: A=1, B=2, C=3, D=4
397        let id_a = insert_entity(&conn, &Entity::new("paper", "A")).unwrap();
398        let id_b = insert_entity(&conn, &Entity::new("paper", "B")).unwrap();
399        let id_c = insert_entity(&conn, &Entity::new("paper", "C")).unwrap();
400        let id_d = insert_entity(&conn, &Entity::new("paper", "D")).unwrap();
401
402        // Insert test relations: A -> B -> C, A -> D
403        insert_relation(&conn, &Relation::new(id_a, id_b, "cites", 1.0).unwrap()).unwrap();
404        insert_relation(&conn, &Relation::new(id_b, id_c, "cites", 1.0).unwrap()).unwrap();
405        insert_relation(&conn, &Relation::new(id_a, id_d, "cites", 0.5).unwrap()).unwrap();
406
407        conn
408    }
409
410    #[test]
411    fn test_bfs_traversal() {
412        let conn = setup_test_db();
413        let query = TraversalQuery {
414            direction: Direction::Outgoing,
415            max_depth: 2,
416            ..Default::default()
417        };
418
419        let result = bfs_traversal(&conn, 1, query).unwrap();
420
421        assert_eq!(result.len(), 4); // A, B, C, D
422        assert!(result.iter().any(|n| n.entity_id == 1 && n.depth == 0));
423        assert!(result.iter().any(|n| n.entity_id == 2 && n.depth == 1));
424        assert!(result.iter().any(|n| n.entity_id == 3 && n.depth == 2));
425        assert!(result.iter().any(|n| n.entity_id == 4 && n.depth == 1));
426    }
427
428    #[test]
429    fn test_dfs_traversal() {
430        let conn = setup_test_db();
431        let query = TraversalQuery {
432            direction: Direction::Outgoing,
433            max_depth: 2,
434            ..Default::default()
435        };
436
437        let result = dfs_traversal(&conn, 1, query).unwrap();
438
439        assert_eq!(result.len(), 4);
440        assert_eq!(result[0].entity_id, 1); // DFS visits start first
441    }
442
443    #[test]
444    fn test_shortest_path() {
445        let conn = setup_test_db();
446
447        // Path A -> B -> C
448        let path = find_shortest_path(&conn, 1, 3, 5).unwrap();
449        assert!(path.is_some());
450
451        let path = path.unwrap();
452        assert_eq!(path.start_id, 1);
453        assert_eq!(path.end_id, 3);
454        assert_eq!(path.steps.len(), 2); // A->B, B->C
455
456        // Direct path A -> D
457        let path = find_shortest_path(&conn, 1, 4, 5).unwrap();
458        assert!(path.is_some());
459        let path = path.unwrap();
460        assert_eq!(path.steps.len(), 1);
461    }
462
463    #[test]
464    fn test_no_path() {
465        let conn = setup_test_db();
466
467        // No path from D to A
468        let path = find_shortest_path(&conn, 4, 1, 5).unwrap();
469        assert!(path.is_none());
470    }
471
472    #[test]
473    fn test_graph_stats() {
474        let conn = setup_test_db();
475
476        let stats = compute_graph_stats(&conn).unwrap();
477
478        assert_eq!(stats.total_entities, 4);
479        assert_eq!(stats.total_relations, 3);
480        assert_eq!(stats.max_degree, 2); // Entity 1 has 2 outgoing edges
481    }
482}