Skip to main content

sqlite_graphrag/
graph.rs

1// src/graph.rs
2
3use crate::errors::AppError;
4use rusqlite::{params, Connection};
5
6/// Percorre o grafo de entidades por BFS a partir de memórias-semente.
7///
8/// Retorna `memory_id`s alcançáveis pelo grafo de entidades e relacionamentos,
9/// excluindo as próprias sementes. O algoritmo:
10/// 1. Coleta entidades associadas às sementes via `memory_entities`.
11/// 2. Executa BFS sobre `relationships` filtrando por `weight >= min_weight` e `namespace`.
12/// 3. Retorna memórias ligadas às entidades descobertas (excluindo soft-deleted).
13///
14/// # Errors
15///
16/// Propaga [`AppError::Database`] (exit 10) em falhas de consulta SQLite.
17///
18/// # Examples
19///
20/// ```
21/// use rusqlite::Connection;
22/// use sqlite_graphrag::graph::traverse_from_memories;
23///
24/// // Lista de sementes vazia retorna imediatamente sem consultar o banco.
25/// let conn = Connection::open_in_memory().unwrap();
26/// let ids = traverse_from_memories(&conn, &[], "global", 0.5, 3).unwrap();
27/// assert!(ids.is_empty());
28/// ```
29///
30/// ```
31/// use rusqlite::Connection;
32/// use sqlite_graphrag::graph::traverse_from_memories;
33///
34/// // max_hops == 0 retorna imediatamente sem traversal.
35/// let conn = Connection::open_in_memory().unwrap();
36/// let ids = traverse_from_memories(&conn, &[1, 2], "global", 0.5, 0).unwrap();
37/// assert!(ids.is_empty());
38/// ```
39pub fn traverse_from_memories(
40    conn: &Connection,
41    seed_memory_ids: &[i64],
42    namespace: &str,
43    min_weight: f64,
44    max_hops: u32,
45) -> Result<Vec<i64>, AppError> {
46    if seed_memory_ids.is_empty() || max_hops == 0 {
47        return Ok(vec![]);
48    }
49
50    // Step 1: collect seed entity IDs from seed memories
51    let mut seed_entities: Vec<i64> = Vec::new();
52    for &mem_id in seed_memory_ids {
53        let mut stmt =
54            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
55        let ids: Vec<i64> = stmt
56            .query_map(params![mem_id], |r| r.get(0))?
57            .filter_map(|r| r.ok())
58            .collect();
59        seed_entities.extend(ids);
60    }
61    seed_entities.sort_unstable();
62    seed_entities.dedup();
63
64    if seed_entities.is_empty() {
65        return Ok(vec![]);
66    }
67
68    // Step 2: BFS over relationships
69    use std::collections::HashSet;
70    let mut visited: HashSet<i64> = seed_entities.iter().cloned().collect();
71    let mut frontier = seed_entities.clone();
72
73    for _ in 0..max_hops {
74        if frontier.is_empty() {
75            break;
76        }
77        let mut next_frontier = Vec::new();
78
79        for &entity_id in &frontier {
80            let mut stmt = conn.prepare_cached(
81                "SELECT target_id FROM relationships
82                 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
83            )?;
84            let neighbors: Vec<i64> = stmt
85                .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
86                .filter_map(|r| r.ok())
87                .filter(|id| !visited.contains(id))
88                .collect();
89
90            for id in neighbors {
91                visited.insert(id);
92                next_frontier.push(id);
93            }
94        }
95        frontier = next_frontier;
96    }
97
98    // Step 3: find memories connected to traversed entities (excluding seeds)
99    let seed_set: HashSet<i64> = seed_memory_ids.iter().cloned().collect();
100    let graph_only_entities: Vec<i64> = visited
101        .into_iter()
102        .filter(|id| !seed_entities.contains(id))
103        .collect();
104
105    let mut result_ids: Vec<i64> = Vec::new();
106    for &entity_id in &graph_only_entities {
107        let mut stmt = conn.prepare_cached(
108            "SELECT DISTINCT me.memory_id
109             FROM memory_entities me
110             JOIN memories m ON m.id = me.memory_id
111             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
112        )?;
113        let mem_ids: Vec<i64> = stmt
114            .query_map(params![entity_id], |r| r.get(0))?
115            .filter_map(|r| r.ok())
116            .filter(|id| !seed_set.contains(id))
117            .collect();
118        result_ids.extend(mem_ids);
119    }
120
121    result_ids.sort_unstable();
122    result_ids.dedup();
123    Ok(result_ids)
124}
125
126/// BFS graph traversal that also returns the hop distance for each reached memory.
127///
128/// Identical to [`traverse_from_memories`] but returns `(memory_id, hop_count)` tuples
129/// instead of bare IDs. `hop_count` is the BFS depth at which the entity was first
130/// discovered, starting from 1 for direct neighbours of the seed entities.
131///
132/// # Errors
133///
134/// Propaga [`AppError::Database`] (exit 10) em falhas de consulta SQLite.
135pub fn traverse_from_memories_with_hops(
136    conn: &Connection,
137    seed_memory_ids: &[i64],
138    namespace: &str,
139    min_weight: f64,
140    max_hops: u32,
141) -> Result<Vec<(i64, u32)>, AppError> {
142    if seed_memory_ids.is_empty() || max_hops == 0 {
143        return Ok(vec![]);
144    }
145
146    // Collect seed entity IDs from seed memories
147    let mut seed_entities: Vec<i64> = Vec::new();
148    for &mem_id in seed_memory_ids {
149        let mut stmt =
150            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
151        let ids: Vec<i64> = stmt
152            .query_map(params![mem_id], |r| r.get(0))?
153            .filter_map(|r| r.ok())
154            .collect();
155        seed_entities.extend(ids);
156    }
157    seed_entities.sort_unstable();
158    seed_entities.dedup();
159
160    if seed_entities.is_empty() {
161        return Ok(vec![]);
162    }
163
164    // BFS over relationships, tracking depth per entity
165    use std::collections::HashMap;
166    let mut entity_depth: HashMap<i64, u32> = seed_entities.iter().map(|&id| (id, 0)).collect();
167    let mut frontier = seed_entities.clone();
168
169    for hop in 1..=max_hops {
170        if frontier.is_empty() {
171            break;
172        }
173        let mut next_frontier = Vec::new();
174
175        for &entity_id in &frontier {
176            let mut stmt = conn.prepare_cached(
177                "SELECT target_id FROM relationships
178                 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
179            )?;
180            let neighbors: Vec<i64> = stmt
181                .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
182                .filter_map(|r| r.ok())
183                .filter(|id| !entity_depth.contains_key(id))
184                .collect();
185
186            for id in neighbors {
187                entity_depth.insert(id, hop);
188                next_frontier.push(id);
189            }
190        }
191        frontier = next_frontier;
192    }
193
194    // Find memories connected to traversed entities (excluding seeds), preserving hop depth
195    let seed_set: std::collections::HashSet<i64> = seed_memory_ids.iter().cloned().collect();
196    let seed_entity_set: std::collections::HashSet<i64> = seed_entities.iter().cloned().collect();
197
198    let mut result: Vec<(i64, u32)> = Vec::new();
199    let mut seen_memories: std::collections::HashSet<i64> = std::collections::HashSet::new();
200
201    for (&entity_id, &hop) in &entity_depth {
202        if seed_entity_set.contains(&entity_id) {
203            continue;
204        }
205        let mut stmt = conn.prepare_cached(
206            "SELECT DISTINCT me.memory_id
207             FROM memory_entities me
208             JOIN memories m ON m.id = me.memory_id
209             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
210        )?;
211        let mem_ids: Vec<i64> = stmt
212            .query_map(params![entity_id], |r| r.get(0))?
213            .filter_map(|r| r.ok())
214            .filter(|id| !seed_set.contains(id) && !seen_memories.contains(id))
215            .collect();
216
217        for mem_id in mem_ids {
218            seen_memories.insert(mem_id);
219            result.push((mem_id, hop));
220        }
221    }
222
223    result.sort_unstable_by_key(|&(id, _)| id);
224    Ok(result)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use rusqlite::Connection;
231
232    fn setup_db() -> Connection {
233        let conn = Connection::open_in_memory().unwrap();
234        conn.execute_batch(
235            "CREATE TABLE memories (
236                id INTEGER PRIMARY KEY,
237                namespace TEXT NOT NULL,
238                deleted_at TEXT
239            );
240            CREATE TABLE memory_entities (
241                memory_id INTEGER NOT NULL,
242                entity_id INTEGER NOT NULL
243            );
244            CREATE TABLE relationships (
245                source_id INTEGER NOT NULL,
246                target_id INTEGER NOT NULL,
247                weight REAL NOT NULL,
248                namespace TEXT NOT NULL
249            );",
250        )
251        .unwrap();
252        conn
253    }
254
255    fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
256        conn.execute(
257            "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
258            params![
259                id,
260                namespace,
261                if deleted { Some("2024-01-01") } else { None }
262            ],
263        )
264        .unwrap();
265    }
266
267    fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
268        conn.execute(
269            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
270            params![memory_id, entity_id],
271        )
272        .unwrap();
273    }
274
275    fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
276        conn.execute(
277            "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
278            params![src, tgt, weight, ns],
279        )
280        .unwrap();
281    }
282
283    // --- edge cases retornando vazio ---
284
285    #[test]
286    fn retorna_vazio_quando_seeds_vazio() {
287        let conn = setup_db();
288        let resultado = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
289        assert!(resultado.is_empty());
290    }
291
292    #[test]
293    fn retorna_vazio_quando_max_hops_zero() {
294        let conn = setup_db();
295        insert_memory(&conn, 1, "ns", false);
296        link_memory_entity(&conn, 1, 10);
297        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
298        assert!(resultado.is_empty());
299    }
300
301    #[test]
302    fn retorna_vazio_quando_seed_sem_entidades() {
303        let conn = setup_db();
304        insert_memory(&conn, 1, "ns", false);
305        // memoria existe mas não tem entidades associadas
306        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
307        assert!(resultado.is_empty());
308    }
309
310    #[test]
311    fn retorna_vazio_quando_sem_relacionamentos() {
312        let conn = setup_db();
313        insert_memory(&conn, 1, "ns", false);
314        link_memory_entity(&conn, 1, 10);
315        // entidade 10 não tem relacionamentos
316        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
317        assert!(resultado.is_empty());
318    }
319
320    // --- happy path básico ---
321
322    #[test]
323    fn traversal_basico_um_hop() {
324        let conn = setup_db();
325
326        // seed: memory 1 com entity 10
327        insert_memory(&conn, 1, "ns", false);
328        link_memory_entity(&conn, 1, 10);
329
330        // vizinha: entity 20 ligada a memory 2
331        insert_memory(&conn, 2, "ns", false);
332        link_memory_entity(&conn, 2, 20);
333
334        // relacionamento 10 -> 20
335        insert_relationship(&conn, 10, 20, 1.0, "ns");
336
337        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
338        assert_eq!(resultado, vec![2]);
339    }
340
341    #[test]
342    fn traversal_dois_hops() {
343        let conn = setup_db();
344
345        insert_memory(&conn, 1, "ns", false);
346        link_memory_entity(&conn, 1, 10);
347
348        insert_memory(&conn, 2, "ns", false);
349        link_memory_entity(&conn, 2, 20);
350
351        insert_memory(&conn, 3, "ns", false);
352        link_memory_entity(&conn, 3, 30);
353
354        // cadeia 10 -> 20 -> 30
355        insert_relationship(&conn, 10, 20, 1.0, "ns");
356        insert_relationship(&conn, 20, 30, 1.0, "ns");
357
358        let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
359        resultado.sort_unstable();
360        assert_eq!(resultado, vec![2, 3]);
361    }
362
363    #[test]
364    fn max_hops_limita_profundidade() {
365        let conn = setup_db();
366
367        insert_memory(&conn, 1, "ns", false);
368        link_memory_entity(&conn, 1, 10);
369
370        insert_memory(&conn, 2, "ns", false);
371        link_memory_entity(&conn, 2, 20);
372
373        insert_memory(&conn, 3, "ns", false);
374        link_memory_entity(&conn, 3, 30);
375
376        insert_relationship(&conn, 10, 20, 1.0, "ns");
377        insert_relationship(&conn, 20, 30, 1.0, "ns");
378
379        // com apenas 1 hop, memory 3 não deve aparecer
380        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
381        assert_eq!(resultado, vec![2]);
382        assert!(!resultado.contains(&3));
383    }
384
385    // --- filtro de peso ---
386
387    #[test]
388    fn relacionamento_com_peso_abaixo_do_minimo_ignorado() {
389        let conn = setup_db();
390
391        insert_memory(&conn, 1, "ns", false);
392        link_memory_entity(&conn, 1, 10);
393
394        insert_memory(&conn, 2, "ns", false);
395        link_memory_entity(&conn, 2, 20);
396
397        // peso 0.3 < min_weight 0.5
398        insert_relationship(&conn, 10, 20, 0.3, "ns");
399
400        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
401        assert!(resultado.is_empty());
402    }
403
404    #[test]
405    fn relacionamento_com_peso_exatamente_no_minimo_incluido() {
406        let conn = setup_db();
407
408        insert_memory(&conn, 1, "ns", false);
409        link_memory_entity(&conn, 1, 10);
410
411        insert_memory(&conn, 2, "ns", false);
412        link_memory_entity(&conn, 2, 20);
413
414        insert_relationship(&conn, 10, 20, 0.5, "ns");
415
416        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
417        assert_eq!(resultado, vec![2]);
418    }
419
420    // --- isolamento de namespace ---
421
422    #[test]
423    fn relacionamento_de_namespace_diferente_ignorado() {
424        let conn = setup_db();
425
426        insert_memory(&conn, 1, "ns_a", false);
427        link_memory_entity(&conn, 1, 10);
428
429        insert_memory(&conn, 2, "ns_a", false);
430        link_memory_entity(&conn, 2, 20);
431
432        // relacionamento no namespace errado
433        insert_relationship(&conn, 10, 20, 1.0, "ns_b");
434
435        let resultado = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
436        assert!(resultado.is_empty());
437    }
438
439    // --- excluir seeds do resultado ---
440
441    #[test]
442    fn seeds_nao_aparecem_no_resultado() {
443        let conn = setup_db();
444
445        insert_memory(&conn, 1, "ns", false);
446        link_memory_entity(&conn, 1, 10);
447
448        insert_memory(&conn, 2, "ns", false);
449        link_memory_entity(&conn, 2, 20);
450
451        // relacionamento de 20 de volta para 10 (ciclo)
452        insert_relationship(&conn, 10, 20, 1.0, "ns");
453        insert_relationship(&conn, 20, 10, 1.0, "ns");
454
455        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
456        // memory 1 não deve aparecer mesmo com ciclo
457        assert!(!resultado.contains(&1));
458        assert_eq!(resultado, vec![2]);
459    }
460
461    // --- memórias soft-deleted excluídas ---
462
463    #[test]
464    fn memorias_deletadas_nao_incluidas() {
465        let conn = setup_db();
466
467        insert_memory(&conn, 1, "ns", false);
468        link_memory_entity(&conn, 1, 10);
469
470        // memory 2 foi deletada
471        insert_memory(&conn, 2, "ns", true);
472        link_memory_entity(&conn, 2, 20);
473
474        insert_relationship(&conn, 10, 20, 1.0, "ns");
475
476        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
477        assert!(resultado.is_empty());
478    }
479
480    // --- múltiplos seeds ---
481
482    #[test]
483    fn multiplos_seeds_unidos_no_resultado() {
484        let conn = setup_db();
485
486        insert_memory(&conn, 1, "ns", false);
487        link_memory_entity(&conn, 1, 10);
488
489        insert_memory(&conn, 2, "ns", false);
490        link_memory_entity(&conn, 2, 20);
491
492        insert_memory(&conn, 3, "ns", false);
493        link_memory_entity(&conn, 3, 30);
494
495        insert_memory(&conn, 4, "ns", false);
496        link_memory_entity(&conn, 4, 40);
497
498        insert_relationship(&conn, 10, 30, 1.0, "ns");
499        insert_relationship(&conn, 20, 40, 1.0, "ns");
500
501        let mut resultado = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
502        resultado.sort_unstable();
503        assert_eq!(resultado, vec![3, 4]);
504    }
505
506    // --- deduplicação de resultado ---
507
508    #[test]
509    fn resultado_sem_duplicatas() {
510        let conn = setup_db();
511
512        insert_memory(&conn, 1, "ns", false);
513        link_memory_entity(&conn, 1, 10);
514        link_memory_entity(&conn, 1, 11); // dois seeds na mesma memory
515
516        insert_memory(&conn, 2, "ns", false);
517        link_memory_entity(&conn, 2, 20);
518
519        // ambos os seeds apontam para a mesma entity 20
520        insert_relationship(&conn, 10, 20, 1.0, "ns");
521        insert_relationship(&conn, 11, 20, 1.0, "ns");
522
523        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
524        // memory 2 deve aparecer apenas uma vez
525        assert_eq!(resultado.len(), 1);
526        assert_eq!(resultado, vec![2]);
527    }
528
529    // --- nó único (single node) ---
530
531    #[test]
532    fn single_node_sem_vizinhos_retorna_vazio() {
533        let conn = setup_db();
534
535        insert_memory(&conn, 1, "ns", false);
536        link_memory_entity(&conn, 1, 10);
537        // entity 10 não tem relacionamentos de saída
538
539        let resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
540        assert!(resultado.is_empty());
541    }
542
543    // --- ciclos no grafo ---
544
545    #[test]
546    fn ciclo_nao_causa_loop_infinito() {
547        let conn = setup_db();
548
549        insert_memory(&conn, 1, "ns", false);
550        link_memory_entity(&conn, 1, 10);
551
552        insert_memory(&conn, 2, "ns", false);
553        link_memory_entity(&conn, 2, 20);
554
555        insert_memory(&conn, 3, "ns", false);
556        link_memory_entity(&conn, 3, 30);
557
558        // triângulo 10 -> 20 -> 30 -> 10
559        insert_relationship(&conn, 10, 20, 1.0, "ns");
560        insert_relationship(&conn, 20, 30, 1.0, "ns");
561        insert_relationship(&conn, 30, 10, 1.0, "ns");
562
563        let mut resultado = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
564        resultado.sort_unstable();
565        // deve retornar 2 e 3 sem loop infinito
566        assert_eq!(resultado, vec![2, 3]);
567    }
568}