Skip to main content

sqlite_graphrag/
graph.rs

1//! Entity graph traversal (BFS over memory_entities + relations).
2//!
3//! Queries the SQLite entity and relation tables to expand neighbourhood
4//! sets used by the `related` and `recall` commands.
5
6// src/graph.rs
7
8use crate::errors::AppError;
9use rusqlite::{params, Connection};
10
11/// Traverses the entity graph by BFS from seed memories.
12///
13/// Returns `memory_id`s reachable through entity and relationship edges,
14/// excluding the seeds themselves. The algorithm:
15/// 1. Collects entities associated with seeds via `memory_entities`.
16/// 2. Runs BFS over `relationships` filtered by `weight >= min_weight` and `namespace`.
17/// 3. Returns memories linked to discovered entities (excluding soft-deleted).
18///
19/// # Errors
20///
21/// Propaga [`AppError::Database`] (exit 10) em falhas de consulta SQLite.
22///
23/// # Examples
24///
25/// ```
26/// use rusqlite::Connection;
27/// use sqlite_graphrag::graph::traverse_from_memories;
28///
29/// // Lista de sementes vazia retorna imediatamente sem consultar o banco.
30/// let conn = Connection::open_in_memory().unwrap();
31/// let ids = traverse_from_memories(&conn, &[], "global", 0.5, 3).unwrap();
32/// assert!(ids.is_empty());
33/// ```
34///
35/// ```
36/// use rusqlite::Connection;
37/// use sqlite_graphrag::graph::traverse_from_memories;
38///
39/// // max_hops == 0 retorna imediatamente sem traversal.
40/// let conn = Connection::open_in_memory().unwrap();
41/// let ids = traverse_from_memories(&conn, &[1, 2], "global", 0.5, 0).unwrap();
42/// assert!(ids.is_empty());
43/// ```
44pub fn traverse_from_memories(
45    conn: &Connection,
46    seed_memory_ids: &[i64],
47    namespace: &str,
48    min_weight: f64,
49    max_hops: u32,
50) -> Result<Vec<i64>, AppError> {
51    if seed_memory_ids.is_empty() || max_hops == 0 {
52        return Ok(vec![]);
53    }
54
55    // Step 1: collect seed entity IDs from seed memories
56    let mut seed_entities: Vec<i64> = Vec::new();
57    for &mem_id in seed_memory_ids {
58        let mut stmt =
59            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
60        let ids: Vec<i64> = stmt
61            .query_map(params![mem_id], |r| r.get(0))?
62            .filter_map(|r| r.ok())
63            .collect();
64        seed_entities.extend(ids);
65    }
66    seed_entities.sort_unstable();
67    seed_entities.dedup();
68
69    if seed_entities.is_empty() {
70        return Ok(vec![]);
71    }
72
73    // Step 2: BFS over relationships
74    use std::collections::HashSet;
75    let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
76    let mut frontier = seed_entities.clone();
77
78    for _ in 0..max_hops {
79        if frontier.is_empty() {
80            break;
81        }
82        let mut next_frontier = Vec::new();
83
84        for &entity_id in &frontier {
85            let mut stmt = conn.prepare_cached(
86                "SELECT target_id FROM relationships
87                 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
88            )?;
89            let neighbors: Vec<i64> = stmt
90                .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
91                .filter_map(|r| r.ok())
92                .filter(|id| !visited.contains(id))
93                .collect();
94
95            for id in neighbors {
96                visited.insert(id);
97                next_frontier.push(id);
98            }
99        }
100        frontier = next_frontier;
101    }
102
103    // Step 3: find memories connected to traversed entities (excluding seeds)
104    let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
105    let graph_only_entities: Vec<i64> = visited
106        .into_iter()
107        .filter(|id| !seed_entities.contains(id))
108        .collect();
109
110    let mut result_ids: Vec<i64> = Vec::new();
111    for &entity_id in &graph_only_entities {
112        let mut stmt = conn.prepare_cached(
113            "SELECT DISTINCT me.memory_id
114             FROM memory_entities me
115             JOIN memories m ON m.id = me.memory_id
116             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
117        )?;
118        let mem_ids: Vec<i64> = stmt
119            .query_map(params![entity_id], |r| r.get(0))?
120            .filter_map(|r| r.ok())
121            .filter(|id| !seed_set.contains(id))
122            .collect();
123        result_ids.extend(mem_ids);
124    }
125
126    result_ids.sort_unstable();
127    result_ids.dedup();
128    Ok(result_ids)
129}
130
131/// BFS graph traversal that also returns the hop distance for each reached memory.
132///
133/// Identical to [`traverse_from_memories`] but returns `(memory_id, hop_count)` tuples
134/// instead of bare IDs. `hop_count` is the BFS depth at which the entity was first
135/// discovered, starting from 1 for direct neighbours of the seed entities.
136///
137/// # Errors
138///
139/// Propaga [`AppError::Database`] (exit 10) em falhas de consulta SQLite.
140pub fn traverse_from_memories_with_hops(
141    conn: &Connection,
142    seed_memory_ids: &[i64],
143    namespace: &str,
144    min_weight: f64,
145    max_hops: u32,
146) -> Result<Vec<(i64, u32)>, AppError> {
147    if seed_memory_ids.is_empty() || max_hops == 0 {
148        return Ok(vec![]);
149    }
150
151    // Collect seed entity IDs from seed memories
152    let mut seed_entities: Vec<i64> = Vec::new();
153    for &mem_id in seed_memory_ids {
154        let mut stmt =
155            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
156        let ids: Vec<i64> = stmt
157            .query_map(params![mem_id], |r| r.get(0))?
158            .filter_map(|r| r.ok())
159            .collect();
160        seed_entities.extend(ids);
161    }
162    seed_entities.sort_unstable();
163    seed_entities.dedup();
164
165    if seed_entities.is_empty() {
166        return Ok(vec![]);
167    }
168
169    // BFS over relationships, tracking depth per entity
170    use std::collections::HashMap;
171    let mut entity_depth: HashMap<i64, u32> = seed_entities.iter().map(|&id| (id, 0)).collect();
172    let mut frontier = seed_entities.clone();
173
174    for hop in 1..=max_hops {
175        if frontier.is_empty() {
176            break;
177        }
178        let mut next_frontier = Vec::new();
179
180        for &entity_id in &frontier {
181            let mut stmt = conn.prepare_cached(
182                "SELECT target_id FROM relationships
183                 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
184            )?;
185            let neighbors: Vec<i64> = stmt
186                .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
187                .filter_map(|r| r.ok())
188                .filter(|id| !entity_depth.contains_key(id))
189                .collect();
190
191            for id in neighbors {
192                entity_depth.insert(id, hop);
193                next_frontier.push(id);
194            }
195        }
196        frontier = next_frontier;
197    }
198
199    // Find memories connected to traversed entities (excluding seeds), preserving hop depth
200    let seed_set: std::collections::HashSet<i64> = seed_memory_ids.iter().cloned().collect();
201    let seed_entity_set: std::collections::HashSet<i64> = seed_entities.iter().cloned().collect();
202
203    let mut result: Vec<(i64, u32)> = Vec::new();
204    let mut seen_memories: std::collections::HashSet<i64> = std::collections::HashSet::new();
205
206    for (&entity_id, &hop) in &entity_depth {
207        if seed_entity_set.contains(&entity_id) {
208            continue;
209        }
210        let mut stmt = conn.prepare_cached(
211            "SELECT DISTINCT me.memory_id
212             FROM memory_entities me
213             JOIN memories m ON m.id = me.memory_id
214             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
215        )?;
216        let mem_ids: Vec<i64> = stmt
217            .query_map(params![entity_id], |r| r.get(0))?
218            .filter_map(|r| r.ok())
219            .filter(|id| !seed_set.contains(id) && !seen_memories.contains(id))
220            .collect();
221
222        for mem_id in mem_ids {
223            seen_memories.insert(mem_id);
224            result.push((mem_id, hop));
225        }
226    }
227
228    result.sort_unstable_by_key(|&(id, _)| id);
229    Ok(result)
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use rusqlite::Connection;
236
237    fn setup_db() -> Connection {
238        let conn = Connection::open_in_memory().unwrap();
239        conn.execute_batch(
240            "CREATE TABLE memories (
241                id INTEGER PRIMARY KEY,
242                namespace TEXT NOT NULL,
243                deleted_at TEXT
244            );
245            CREATE TABLE memory_entities (
246                memory_id INTEGER NOT NULL,
247                entity_id INTEGER NOT NULL
248            );
249            CREATE TABLE relationships (
250                source_id INTEGER NOT NULL,
251                target_id INTEGER NOT NULL,
252                weight REAL NOT NULL,
253                namespace TEXT NOT NULL
254            );",
255        )
256        .unwrap();
257        conn
258    }
259
260    fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
261        conn.execute(
262            "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
263            params![
264                id,
265                namespace,
266                if deleted { Some("2024-01-01") } else { None }
267            ],
268        )
269        .unwrap();
270    }
271
272    fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
273        conn.execute(
274            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
275            params![memory_id, entity_id],
276        )
277        .unwrap();
278    }
279
280    fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
281        conn.execute(
282            "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
283            params![src, tgt, weight, ns],
284        )
285        .unwrap();
286    }
287
288    // --- edge cases retornando vazio ---
289
290    #[test]
291    fn returns_empty_when_seeds_empty() {
292        let conn = setup_db();
293        let result = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
294        assert!(result.is_empty());
295    }
296
297    #[test]
298    fn returns_empty_when_max_hops_zero() {
299        let conn = setup_db();
300        insert_memory(&conn, 1, "ns", false);
301        link_memory_entity(&conn, 1, 10);
302        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
303        assert!(result.is_empty());
304    }
305
306    #[test]
307    fn returns_empty_when_seed_has_no_entities() {
308        let conn = setup_db();
309        insert_memory(&conn, 1, "ns", false);
310        // memory exists but has no associated entities
311        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
312        assert!(result.is_empty());
313    }
314
315    #[test]
316    fn returns_empty_when_no_relationships() {
317        let conn = setup_db();
318        insert_memory(&conn, 1, "ns", false);
319        link_memory_entity(&conn, 1, 10);
320        // entity 10 has no relationships
321        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
322        assert!(result.is_empty());
323    }
324
325    // --- basic happy path ---
326
327    #[test]
328    fn traversal_basic_one_hop() {
329        let conn = setup_db();
330
331        // seed: memory 1 com entity 10
332        insert_memory(&conn, 1, "ns", false);
333        link_memory_entity(&conn, 1, 10);
334
335        // vizinha: entity 20 ligada a memory 2
336        insert_memory(&conn, 2, "ns", false);
337        link_memory_entity(&conn, 2, 20);
338
339        // relacionamento 10 -> 20
340        insert_relationship(&conn, 10, 20, 1.0, "ns");
341
342        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
343        assert_eq!(result, vec![2]);
344    }
345
346    #[test]
347    fn traversal_two_hops() {
348        let conn = setup_db();
349
350        insert_memory(&conn, 1, "ns", false);
351        link_memory_entity(&conn, 1, 10);
352
353        insert_memory(&conn, 2, "ns", false);
354        link_memory_entity(&conn, 2, 20);
355
356        insert_memory(&conn, 3, "ns", false);
357        link_memory_entity(&conn, 3, 30);
358
359        // cadeia 10 -> 20 -> 30
360        insert_relationship(&conn, 10, 20, 1.0, "ns");
361        insert_relationship(&conn, 20, 30, 1.0, "ns");
362
363        let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
364        result.sort_unstable();
365        assert_eq!(result, vec![2, 3]);
366    }
367
368    #[test]
369    fn max_hops_limits_depth() {
370        let conn = setup_db();
371
372        insert_memory(&conn, 1, "ns", false);
373        link_memory_entity(&conn, 1, 10);
374
375        insert_memory(&conn, 2, "ns", false);
376        link_memory_entity(&conn, 2, 20);
377
378        insert_memory(&conn, 3, "ns", false);
379        link_memory_entity(&conn, 3, 30);
380
381        insert_relationship(&conn, 10, 20, 1.0, "ns");
382        insert_relationship(&conn, 20, 30, 1.0, "ns");
383
384        // with only 1 hop, memory 3 must not appear
385        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
386        assert_eq!(result, vec![2]);
387        assert!(!result.contains(&3));
388    }
389
390    // --- filtro de peso ---
391
392    #[test]
393    fn relationship_with_weight_below_min_ignored() {
394        let conn = setup_db();
395
396        insert_memory(&conn, 1, "ns", false);
397        link_memory_entity(&conn, 1, 10);
398
399        insert_memory(&conn, 2, "ns", false);
400        link_memory_entity(&conn, 2, 20);
401
402        // peso 0.3 < min_weight 0.5
403        insert_relationship(&conn, 10, 20, 0.3, "ns");
404
405        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
406        assert!(result.is_empty());
407    }
408
409    #[test]
410    fn relationship_with_weight_exactly_at_min_included() {
411        let conn = setup_db();
412
413        insert_memory(&conn, 1, "ns", false);
414        link_memory_entity(&conn, 1, 10);
415
416        insert_memory(&conn, 2, "ns", false);
417        link_memory_entity(&conn, 2, 20);
418
419        insert_relationship(&conn, 10, 20, 0.5, "ns");
420
421        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
422        assert_eq!(result, vec![2]);
423    }
424
425    // --- isolamento de namespace ---
426
427    #[test]
428    fn relationship_from_different_namespace_ignored() {
429        let conn = setup_db();
430
431        insert_memory(&conn, 1, "ns_a", false);
432        link_memory_entity(&conn, 1, 10);
433
434        insert_memory(&conn, 2, "ns_a", false);
435        link_memory_entity(&conn, 2, 20);
436
437        // relacionamento no namespace errado
438        insert_relationship(&conn, 10, 20, 1.0, "ns_b");
439
440        let result = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
441        assert!(result.is_empty());
442    }
443
444    // --- exclude seeds from result ---
445
446    #[test]
447    fn seeds_do_not_appear_in_result() {
448        let conn = setup_db();
449
450        insert_memory(&conn, 1, "ns", false);
451        link_memory_entity(&conn, 1, 10);
452
453        insert_memory(&conn, 2, "ns", false);
454        link_memory_entity(&conn, 2, 20);
455
456        // relacionamento de 20 de volta para 10 (ciclo)
457        insert_relationship(&conn, 10, 20, 1.0, "ns");
458        insert_relationship(&conn, 20, 10, 1.0, "ns");
459
460        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
461        // memory 1 must not appear even with a cycle
462        assert!(!result.contains(&1));
463        assert_eq!(result, vec![2]);
464    }
465
466    // --- soft-deleted memories excluded ---
467
468    #[test]
469    fn deleted_memories_not_included() {
470        let conn = setup_db();
471
472        insert_memory(&conn, 1, "ns", false);
473        link_memory_entity(&conn, 1, 10);
474
475        // memory 2 foi deletada
476        insert_memory(&conn, 2, "ns", true);
477        link_memory_entity(&conn, 2, 20);
478
479        insert_relationship(&conn, 10, 20, 1.0, "ns");
480
481        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
482        assert!(result.is_empty());
483    }
484
485    // --- multiple seeds ---
486
487    #[test]
488    fn multiple_seeds_merged_in_result() {
489        let conn = setup_db();
490
491        insert_memory(&conn, 1, "ns", false);
492        link_memory_entity(&conn, 1, 10);
493
494        insert_memory(&conn, 2, "ns", false);
495        link_memory_entity(&conn, 2, 20);
496
497        insert_memory(&conn, 3, "ns", false);
498        link_memory_entity(&conn, 3, 30);
499
500        insert_memory(&conn, 4, "ns", false);
501        link_memory_entity(&conn, 4, 40);
502
503        insert_relationship(&conn, 10, 30, 1.0, "ns");
504        insert_relationship(&conn, 20, 40, 1.0, "ns");
505
506        let mut result = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
507        result.sort_unstable();
508        assert_eq!(result, vec![3, 4]);
509    }
510
511    // --- result deduplication ---
512
513    #[test]
514    fn result_without_duplicates() {
515        let conn = setup_db();
516
517        insert_memory(&conn, 1, "ns", false);
518        link_memory_entity(&conn, 1, 10);
519        link_memory_entity(&conn, 1, 11); // dois seeds na mesma memory
520
521        insert_memory(&conn, 2, "ns", false);
522        link_memory_entity(&conn, 2, 20);
523
524        // ambos os seeds apontam para a mesma entity 20
525        insert_relationship(&conn, 10, 20, 1.0, "ns");
526        insert_relationship(&conn, 11, 20, 1.0, "ns");
527
528        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
529        // memory 2 deve aparecer apenas uma vez
530        assert_eq!(result.len(), 1);
531        assert_eq!(result, vec![2]);
532    }
533
534    // --- single node ---
535
536    #[test]
537    fn single_node_without_neighbors_returns_empty() {
538        let conn = setup_db();
539
540        insert_memory(&conn, 1, "ns", false);
541        link_memory_entity(&conn, 1, 10);
542        // entity 10 has no outgoing relationships
543
544        let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
545        assert!(result.is_empty());
546    }
547
548    // --- ciclos no grafo ---
549
550    #[test]
551    fn cycle_does_not_cause_infinite_loop() {
552        let conn = setup_db();
553
554        insert_memory(&conn, 1, "ns", false);
555        link_memory_entity(&conn, 1, 10);
556
557        insert_memory(&conn, 2, "ns", false);
558        link_memory_entity(&conn, 2, 20);
559
560        insert_memory(&conn, 3, "ns", false);
561        link_memory_entity(&conn, 3, 30);
562
563        // triangle 10 -> 20 -> 30 -> 10
564        insert_relationship(&conn, 10, 20, 1.0, "ns");
565        insert_relationship(&conn, 20, 30, 1.0, "ns");
566        insert_relationship(&conn, 30, 10, 1.0, "ns");
567
568        let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
569        result.sort_unstable();
570        // deve retornar 2 e 3 sem loop infinito
571        assert_eq!(result, vec![2, 3]);
572    }
573}