Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::cli::RelationKind;
4use crate::constants::{
5    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
6};
7use crate::errors::AppError;
8use crate::i18n::errors_msg;
9use crate::output::{self, OutputFormat};
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use rusqlite::{params, Connection};
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
17/// target_name, relation, weight).
18type Neighbour = (i64, String, String, String, f64);
19
20#[derive(clap::Args)]
21pub struct RelatedArgs {
22    /// Memory name as a positional argument. Alternative to `--name`.
23    #[arg(value_name = "NAME", conflicts_with = "name")]
24    pub name_positional: Option<String>,
25    /// Memory name as a flag. Required when the positional form is absent.
26    #[arg(long)]
27    pub name: Option<String>,
28    /// Maximum graph hop count. Also accepts the alias `--hops`.
29    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
30    pub max_hops: u32,
31    #[arg(long, value_enum)]
32    pub relation: Option<RelationKind>,
33    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
34    pub min_weight: f64,
35    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
36    pub limit: usize,
37    #[arg(long)]
38    pub namespace: Option<String>,
39    #[arg(long, value_enum, default_value = "json")]
40    pub format: OutputFormat,
41    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
42    pub json: bool,
43    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
44    pub db: Option<String>,
45}
46
47#[derive(Serialize)]
48struct RelatedResponse {
49    results: Vec<RelatedMemory>,
50    elapsed_ms: u64,
51}
52
53#[derive(Serialize, Clone)]
54struct RelatedMemory {
55    memory_id: i64,
56    name: String,
57    namespace: String,
58    #[serde(rename = "type")]
59    memory_type: String,
60    description: String,
61    hop_distance: u32,
62    source_entity: Option<String>,
63    target_entity: Option<String>,
64    relation: Option<String>,
65    weight: Option<f64>,
66}
67
68pub fn run(args: RelatedArgs) -> Result<(), AppError> {
69    let inicio = std::time::Instant::now();
70    let name = args
71        .name_positional
72        .as_deref()
73        .or(args.name.as_deref())
74        .ok_or_else(|| {
75            AppError::Validation(
76                "name required: pass as positional argument or via --name".to_string(),
77            )
78        })?
79        .to_string();
80
81    if name.trim().is_empty() {
82        return Err(AppError::Validation("name must not be empty".to_string()));
83    }
84
85    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
86    let paths = AppPaths::resolve(args.db.as_deref())?;
87
88    if !paths.db.exists() {
89        return Err(AppError::NotFound(errors_msg::database_not_found(
90            &paths.db.display().to_string(),
91        )));
92    }
93
94    let conn = open_ro(&paths.db)?;
95
96    // Locate the seed memory.
97    let seed_id: i64 = match conn.query_row(
98        "SELECT id FROM memories
99         WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
100        params![namespace, name],
101        |r| r.get(0),
102    ) {
103        Ok(id) => id,
104        Err(rusqlite::Error::QueryReturnedNoRows) => {
105            return Err(AppError::NotFound(errors_msg::memory_not_found(
106                &name, &namespace,
107            )));
108        }
109        Err(e) => return Err(AppError::Database(e)),
110    };
111
112    // Collect seed entity IDs from seed memory.
113    let seed_entity_ids: Vec<i64> = {
114        let mut stmt =
115            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
116        let rows: Vec<i64> = stmt
117            .query_map(params![seed_id], |r| r.get(0))?
118            .collect::<Result<Vec<i64>, _>>()?;
119        rows
120    };
121
122    let relation_filter = args.relation.map(|r| r.as_str().to_string());
123    let results = traverse_related(
124        &conn,
125        seed_id,
126        &seed_entity_ids,
127        &namespace,
128        args.max_hops,
129        args.min_weight,
130        relation_filter.as_deref(),
131        args.limit,
132    )?;
133
134    match args.format {
135        OutputFormat::Json => output::emit_json(&RelatedResponse {
136            results,
137            elapsed_ms: inicio.elapsed().as_millis() as u64,
138        })?,
139        OutputFormat::Text => {
140            for item in &results {
141                if item.description.is_empty() {
142                    output::emit_text(&format!(
143                        "{}. {} ({})",
144                        item.hop_distance, item.name, item.namespace
145                    ));
146                } else {
147                    let preview: String = item
148                        .description
149                        .chars()
150                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
151                        .collect();
152                    output::emit_text(&format!(
153                        "{}. {} ({}): {}",
154                        item.hop_distance, item.name, item.namespace, preview
155                    ));
156                }
157            }
158        }
159        OutputFormat::Markdown => {
160            for item in &results {
161                if item.description.is_empty() {
162                    output::emit_text(&format!(
163                        "- **{}** ({}) — hop {}",
164                        item.name, item.namespace, item.hop_distance
165                    ));
166                } else {
167                    let preview: String = item
168                        .description
169                        .chars()
170                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
171                        .collect();
172                    output::emit_text(&format!(
173                        "- **{}** ({}) — hop {}: {}",
174                        item.name, item.namespace, item.hop_distance, preview
175                    ));
176                }
177            }
178        }
179    }
180
181    Ok(())
182}
183
184#[allow(clippy::too_many_arguments)]
185fn traverse_related(
186    conn: &Connection,
187    seed_memory_id: i64,
188    seed_entity_ids: &[i64],
189    namespace: &str,
190    max_hops: u32,
191    min_weight: f64,
192    relation_filter: Option<&str>,
193    limit: usize,
194) -> Result<Vec<RelatedMemory>, AppError> {
195    if seed_entity_ids.is_empty() || max_hops == 0 {
196        return Ok(Vec::new());
197    }
198
199    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
200    // of the edge that first reached each entity.
201    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
202    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
203    for &e in seed_entity_ids {
204        entity_hop.insert(e, 0);
205    }
206    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
207    // that reached this entity — equivalent to BFS shortest path recall edge).
208    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
209
210    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
211
212    while let Some(current_entity) = queue.pop_front() {
213        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
214        if current_hop >= max_hops {
215            continue;
216        }
217
218        let neighbours =
219            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
220
221        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
222            if visited.insert(neighbour_id) {
223                entity_hop.insert(neighbour_id, current_hop + 1);
224                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
225                queue.push_back(neighbour_id);
226            }
227        }
228    }
229
230    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
231    let mut out: Vec<RelatedMemory> = Vec::new();
232    let mut dedup_ids: HashSet<i64> = HashSet::new();
233    dedup_ids.insert(seed_memory_id);
234
235    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
236    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
237        .iter()
238        .filter(|(id, _)| !seed_entity_ids.contains(id))
239        .map(|(id, hop)| (*id, *hop))
240        .collect();
241    ordered_entities.sort_by(|a, b| {
242        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
243        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
244        a.1.cmp(&b.1).then_with(|| {
245            weight_b
246                .partial_cmp(&weight_a)
247                .unwrap_or(std::cmp::Ordering::Equal)
248        })
249    });
250
251    for (entity_id, hop) in ordered_entities {
252        let mut stmt = conn.prepare_cached(
253            "SELECT m.id, m.name, m.namespace, m.type, m.description
254             FROM memory_entities me
255             JOIN memories m ON m.id = me.memory_id
256             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
257        )?;
258        let rows = stmt
259            .query_map(params![entity_id], |r| {
260                Ok((
261                    r.get::<_, i64>(0)?,
262                    r.get::<_, String>(1)?,
263                    r.get::<_, String>(2)?,
264                    r.get::<_, String>(3)?,
265                    r.get::<_, String>(4)?,
266                ))
267            })?
268            .collect::<Result<Vec<_>, _>>()?;
269
270        for (mid, name, ns, mtype, desc) in rows {
271            if !dedup_ids.insert(mid) {
272                continue;
273            }
274            let edge = entity_edge.get(&entity_id);
275            out.push(RelatedMemory {
276                memory_id: mid,
277                name,
278                namespace: ns,
279                memory_type: mtype,
280                description: desc,
281                hop_distance: hop,
282                source_entity: edge.map(|e| e.0.clone()),
283                target_entity: edge.map(|e| e.1.clone()),
284                relation: edge.map(|e| e.2.clone()),
285                weight: edge.map(|e| e.3),
286            });
287            if out.len() >= limit {
288                return Ok(out);
289            }
290        }
291    }
292
293    Ok(out)
294}
295
296fn fetch_neighbours(
297    conn: &Connection,
298    entity_id: i64,
299    namespace: &str,
300    min_weight: f64,
301    relation_filter: Option<&str>,
302) -> Result<Vec<Neighbour>, AppError> {
303    // Follow edges in both directions: source -> target and target -> source so traversal is
304    // undirected, which is how users typically reason about "related" memories.
305    let base_sql = "\
306        SELECT r.target_id, se.name, te.name, r.relation, r.weight
307        FROM relationships r
308        JOIN entities se ON se.id = r.source_id
309        JOIN entities te ON te.id = r.target_id
310        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
311
312    let reverse_sql = "\
313        SELECT r.source_id, se.name, te.name, r.relation, r.weight
314        FROM relationships r
315        JOIN entities se ON se.id = r.source_id
316        JOIN entities te ON te.id = r.target_id
317        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
318
319    let mut results: Vec<Neighbour> = Vec::new();
320
321    let forward_sql = match relation_filter {
322        Some(_) => format!("{base_sql} AND r.relation = ?4"),
323        None => base_sql.to_string(),
324    };
325    let rev_sql = match relation_filter {
326        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
327        None => reverse_sql.to_string(),
328    };
329
330    let mut stmt = conn.prepare_cached(&forward_sql)?;
331    let rows: Vec<_> = if let Some(rel) = relation_filter {
332        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
333            Ok((
334                r.get::<_, i64>(0)?,
335                r.get::<_, String>(1)?,
336                r.get::<_, String>(2)?,
337                r.get::<_, String>(3)?,
338                r.get::<_, f64>(4)?,
339            ))
340        })?
341        .collect::<Result<Vec<_>, _>>()?
342    } else {
343        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
344            Ok((
345                r.get::<_, i64>(0)?,
346                r.get::<_, String>(1)?,
347                r.get::<_, String>(2)?,
348                r.get::<_, String>(3)?,
349                r.get::<_, f64>(4)?,
350            ))
351        })?
352        .collect::<Result<Vec<_>, _>>()?
353    };
354    results.extend(rows);
355
356    let mut stmt = conn.prepare_cached(&rev_sql)?;
357    let rows: Vec<_> = if let Some(rel) = relation_filter {
358        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
359            Ok((
360                r.get::<_, i64>(0)?,
361                r.get::<_, String>(1)?,
362                r.get::<_, String>(2)?,
363                r.get::<_, String>(3)?,
364                r.get::<_, f64>(4)?,
365            ))
366        })?
367        .collect::<Result<Vec<_>, _>>()?
368    } else {
369        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
370            Ok((
371                r.get::<_, i64>(0)?,
372                r.get::<_, String>(1)?,
373                r.get::<_, String>(2)?,
374                r.get::<_, String>(3)?,
375                r.get::<_, f64>(4)?,
376            ))
377        })?
378        .collect::<Result<Vec<_>, _>>()?
379    };
380    results.extend(rows);
381
382    Ok(results)
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    fn setup_related_db() -> rusqlite::Connection {
390        let conn = rusqlite::Connection::open_in_memory().expect("falha ao abrir banco em memória");
391        conn.execute_batch(
392            "CREATE TABLE memories (
393                id INTEGER PRIMARY KEY AUTOINCREMENT,
394                name TEXT NOT NULL,
395                namespace TEXT NOT NULL DEFAULT 'global',
396                type TEXT NOT NULL DEFAULT 'fact',
397                description TEXT NOT NULL DEFAULT '',
398                deleted_at INTEGER
399            );
400            CREATE TABLE entities (
401                id INTEGER PRIMARY KEY AUTOINCREMENT,
402                namespace TEXT NOT NULL,
403                name TEXT NOT NULL
404            );
405            CREATE TABLE relationships (
406                id INTEGER PRIMARY KEY AUTOINCREMENT,
407                namespace TEXT NOT NULL,
408                source_id INTEGER NOT NULL,
409                target_id INTEGER NOT NULL,
410                relation TEXT NOT NULL DEFAULT 'related_to',
411                weight REAL NOT NULL DEFAULT 1.0
412            );
413            CREATE TABLE memory_entities (
414                memory_id INTEGER NOT NULL,
415                entity_id INTEGER NOT NULL
416            );",
417        )
418        .expect("falha ao criar tabelas de teste");
419        conn
420    }
421
422    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
423        conn.execute(
424            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
425            rusqlite::params![name, namespace],
426        )
427        .expect("falha ao inserir memória");
428        conn.last_insert_rowid()
429    }
430
431    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
432        conn.execute(
433            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
434            rusqlite::params![name, namespace],
435        )
436        .expect("falha ao inserir entidade");
437        conn.last_insert_rowid()
438    }
439
440    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
441        conn.execute(
442            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
443            rusqlite::params![memory_id, entity_id],
444        )
445        .expect("falha ao vincular memória-entidade");
446    }
447
448    fn insert_relationship(
449        conn: &rusqlite::Connection,
450        namespace: &str,
451        source_id: i64,
452        target_id: i64,
453        relation: &str,
454        weight: f64,
455    ) {
456        conn.execute(
457            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
458             VALUES (?1, ?2, ?3, ?4, ?5)",
459            rusqlite::params![namespace, source_id, target_id, relation, weight],
460        )
461        .expect("falha ao inserir relacionamento");
462    }
463
464    #[test]
465    fn related_response_serializa_results_e_elapsed_ms() {
466        let resp = RelatedResponse {
467            results: vec![RelatedMemory {
468                memory_id: 1,
469                name: "mem-vizinha".to_string(),
470                namespace: "global".to_string(),
471                memory_type: "fact".to_string(),
472                description: "desc".to_string(),
473                hop_distance: 1,
474                source_entity: Some("entidade-a".to_string()),
475                target_entity: Some("entidade-b".to_string()),
476                relation: Some("related_to".to_string()),
477                weight: Some(0.9),
478            }],
479            elapsed_ms: 7,
480        };
481
482        let json = serde_json::to_value(&resp).expect("serialização falhou");
483        assert!(json["results"].is_array());
484        assert_eq!(json["results"].as_array().unwrap().len(), 1);
485        assert_eq!(json["elapsed_ms"], 7u64);
486        assert_eq!(json["results"][0]["type"], "fact");
487        assert_eq!(json["results"][0]["hop_distance"], 1);
488    }
489
490    #[test]
491    fn traverse_related_retorna_vazio_sem_entidades_seed() {
492        let conn = setup_related_db();
493        let resultado = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
494            .expect("traverse_related falhou");
495        assert!(
496            resultado.is_empty(),
497            "sem entidades seed deve retornar vazio"
498        );
499    }
500
501    #[test]
502    fn traverse_related_retorna_vazio_com_max_hops_zero() {
503        let conn = setup_related_db();
504        let mem_id = insert_memory(&conn, "seed-mem", "global");
505        let ent_id = insert_entity(&conn, "ent-a", "global");
506        link_memory_entity(&conn, mem_id, ent_id);
507
508        let resultado = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
509            .expect("traverse_related falhou");
510        assert!(resultado.is_empty(), "max_hops=0 deve retornar vazio");
511    }
512
513    #[test]
514    fn traverse_related_descobre_memoria_vizinha_por_grafo() {
515        let conn = setup_related_db();
516
517        let seed_id = insert_memory(&conn, "seed-mem", "global");
518        let vizinha_id = insert_memory(&conn, "vizinha-mem", "global");
519        let ent_a = insert_entity(&conn, "ent-a", "global");
520        let ent_b = insert_entity(&conn, "ent-b", "global");
521
522        link_memory_entity(&conn, seed_id, ent_a);
523        link_memory_entity(&conn, vizinha_id, ent_b);
524        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
525
526        let resultado = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
527            .expect("traverse_related falhou");
528
529        assert_eq!(resultado.len(), 1, "deve encontrar 1 memória vizinha");
530        assert_eq!(resultado[0].name, "vizinha-mem");
531        assert_eq!(resultado[0].hop_distance, 1);
532    }
533
534    #[test]
535    fn traverse_related_respeita_limite() {
536        let conn = setup_related_db();
537
538        let seed_id = insert_memory(&conn, "seed", "global");
539        let ent_seed = insert_entity(&conn, "ent-seed", "global");
540        link_memory_entity(&conn, seed_id, ent_seed);
541
542        for i in 0..5 {
543            let mem_id = insert_memory(&conn, &format!("vizinha-{i}"), "global");
544            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
545            link_memory_entity(&conn, mem_id, ent_id);
546            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
547        }
548
549        let resultado = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
550            .expect("traverse_related falhou");
551
552        assert!(
553            resultado.len() <= 3,
554            "limite=3 deve restringir a no máximo 3 resultados"
555        );
556    }
557
558    #[test]
559    fn related_memory_campos_opcionais_nulos_serializados() {
560        let mem = RelatedMemory {
561            memory_id: 99,
562            name: "sem-relacao".to_string(),
563            namespace: "ns".to_string(),
564            memory_type: "concept".to_string(),
565            description: "".to_string(),
566            hop_distance: 2,
567            source_entity: None,
568            target_entity: None,
569            relation: None,
570            weight: None,
571        };
572
573        let json = serde_json::to_value(&mem).expect("serialização falhou");
574        assert!(json["source_entity"].is_null());
575        assert!(json["target_entity"].is_null());
576        assert!(json["relation"].is_null());
577        assert!(json["weight"].is_null());
578        assert_eq!(json["hop_distance"], 2);
579    }
580}