Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::cli::RelationKind;
4use crate::constants::{
5    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
6};
7use crate::errors::AppError;
8use crate::i18n::errors_msg;
9use crate::output::{self, OutputFormat};
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use rusqlite::{params, Connection};
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
17/// target_name, relation, weight).
18type Neighbour = (i64, String, String, String, f64);
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # List memories connected to a memory via the entity graph (default 2 hops)\n  \
23    sqlite-graphrag related onboarding\n\n  \
24    # Increase hop distance and filter by relation type\n  \
25    sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n  \
26    # Cap result count and require minimum edge weight\n  \
27    sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
28pub struct RelatedArgs {
29    /// Memory name as a positional argument. Alternative to `--name`.
30    #[arg(
31        value_name = "NAME",
32        conflicts_with = "name",
33        help = "Memory name whose neighbours to traverse; alternative to --name"
34    )]
35    pub name_positional: Option<String>,
36    /// Memory name as a flag. Required when the positional form is absent.
37    #[arg(long)]
38    pub name: Option<String>,
39    /// Maximum graph hop count. Also accepts the alias `--hops`.
40    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
41    pub max_hops: u32,
42    #[arg(long, value_enum)]
43    pub relation: Option<RelationKind>,
44    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
45    pub min_weight: f64,
46    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
47    pub limit: usize,
48    #[arg(long)]
49    pub namespace: Option<String>,
50    #[arg(long, value_enum, default_value = "json")]
51    pub format: OutputFormat,
52    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
53    pub json: bool,
54    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
55    pub db: Option<String>,
56}
57
58#[derive(Serialize)]
59struct RelatedResponse {
60    /// Echo of the seed memory name resolved from `--name` or the positional argument.
61    /// Added in v1.0.35 for input transparency in JSON output.
62    name: String,
63    /// Echo of the resolved `--max-hops` value (default 2). Added in v1.0.35.
64    max_hops: u32,
65    results: Vec<RelatedMemory>,
66    elapsed_ms: u64,
67}
68
69#[derive(Serialize, Clone)]
70struct RelatedMemory {
71    memory_id: i64,
72    name: String,
73    namespace: String,
74    #[serde(rename = "type")]
75    memory_type: String,
76    description: String,
77    hop_distance: u32,
78    source_entity: Option<String>,
79    target_entity: Option<String>,
80    relation: Option<String>,
81    weight: Option<f64>,
82}
83
84pub fn run(args: RelatedArgs) -> Result<(), AppError> {
85    let inicio = std::time::Instant::now();
86    let name = args
87        .name_positional
88        .as_deref()
89        .or(args.name.as_deref())
90        .ok_or_else(|| {
91            AppError::Validation(
92                "name required: pass as positional argument or via --name".to_string(),
93            )
94        })?
95        .to_string();
96
97    if name.trim().is_empty() {
98        return Err(AppError::Validation("name must not be empty".to_string()));
99    }
100
101    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
102    let paths = AppPaths::resolve(args.db.as_deref())?;
103
104    crate::storage::connection::ensure_db_ready(&paths)?;
105
106    let conn = open_ro(&paths.db)?;
107
108    // Locate the seed memory.
109    let seed_id: i64 = match conn.query_row(
110        "SELECT id FROM memories
111         WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
112        params![namespace, name],
113        |r| r.get(0),
114    ) {
115        Ok(id) => id,
116        Err(rusqlite::Error::QueryReturnedNoRows) => {
117            return Err(AppError::NotFound(errors_msg::memory_not_found(
118                &name, &namespace,
119            )));
120        }
121        Err(e) => return Err(AppError::Database(e)),
122    };
123
124    // Collect seed entity IDs from seed memory.
125    let seed_entity_ids: Vec<i64> = {
126        let mut stmt =
127            conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
128        let rows: Vec<i64> = stmt
129            .query_map(params![seed_id], |r| r.get(0))?
130            .collect::<Result<Vec<i64>, _>>()?;
131        rows
132    };
133
134    let relation_filter = args.relation.map(|r| r.as_str().to_string());
135    let results = traverse_related(
136        &conn,
137        seed_id,
138        &seed_entity_ids,
139        &namespace,
140        args.max_hops,
141        args.min_weight,
142        relation_filter.as_deref(),
143        args.limit,
144    )?;
145
146    match args.format {
147        OutputFormat::Json => output::emit_json(&RelatedResponse {
148            name: name.clone(),
149            max_hops: args.max_hops,
150            results,
151            elapsed_ms: inicio.elapsed().as_millis() as u64,
152        })?,
153        OutputFormat::Text => {
154            for item in &results {
155                if item.description.is_empty() {
156                    output::emit_text(&format!(
157                        "{}. {} ({})",
158                        item.hop_distance, item.name, item.namespace
159                    ));
160                } else {
161                    let preview: String = item
162                        .description
163                        .chars()
164                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
165                        .collect();
166                    output::emit_text(&format!(
167                        "{}. {} ({}): {}",
168                        item.hop_distance, item.name, item.namespace, preview
169                    ));
170                }
171            }
172        }
173        OutputFormat::Markdown => {
174            for item in &results {
175                if item.description.is_empty() {
176                    output::emit_text(&format!(
177                        "- **{}** ({}) — hop {}",
178                        item.name, item.namespace, item.hop_distance
179                    ));
180                } else {
181                    let preview: String = item
182                        .description
183                        .chars()
184                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
185                        .collect();
186                    output::emit_text(&format!(
187                        "- **{}** ({}) — hop {}: {}",
188                        item.name, item.namespace, item.hop_distance, preview
189                    ));
190                }
191            }
192        }
193    }
194
195    Ok(())
196}
197
198#[allow(clippy::too_many_arguments)]
199fn traverse_related(
200    conn: &Connection,
201    seed_memory_id: i64,
202    seed_entity_ids: &[i64],
203    namespace: &str,
204    max_hops: u32,
205    min_weight: f64,
206    relation_filter: Option<&str>,
207    limit: usize,
208) -> Result<Vec<RelatedMemory>, AppError> {
209    if seed_entity_ids.is_empty() || max_hops == 0 {
210        return Ok(Vec::new());
211    }
212
213    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
214    // of the edge that first reached each entity.
215    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
216    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
217    for &e in seed_entity_ids {
218        entity_hop.insert(e, 0);
219    }
220    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
221    // that reached this entity — equivalent to BFS shortest path recall edge).
222    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
223
224    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
225
226    while let Some(current_entity) = queue.pop_front() {
227        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
228        if current_hop >= max_hops {
229            continue;
230        }
231
232        let neighbours =
233            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
234
235        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
236            if visited.insert(neighbour_id) {
237                entity_hop.insert(neighbour_id, current_hop + 1);
238                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
239                queue.push_back(neighbour_id);
240            }
241        }
242    }
243
244    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
245    let mut out: Vec<RelatedMemory> = Vec::new();
246    let mut dedup_ids: HashSet<i64> = HashSet::new();
247    dedup_ids.insert(seed_memory_id);
248
249    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
250    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
251        .iter()
252        .filter(|(id, _)| !seed_entity_ids.contains(id))
253        .map(|(id, hop)| (*id, *hop))
254        .collect();
255    ordered_entities.sort_by(|a, b| {
256        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
257        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
258        a.1.cmp(&b.1).then_with(|| {
259            weight_b
260                .partial_cmp(&weight_a)
261                .unwrap_or(std::cmp::Ordering::Equal)
262        })
263    });
264
265    for (entity_id, hop) in ordered_entities {
266        let mut stmt = conn.prepare_cached(
267            "SELECT m.id, m.name, m.namespace, m.type, m.description
268             FROM memory_entities me
269             JOIN memories m ON m.id = me.memory_id
270             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
271        )?;
272        let rows = stmt
273            .query_map(params![entity_id], |r| {
274                Ok((
275                    r.get::<_, i64>(0)?,
276                    r.get::<_, String>(1)?,
277                    r.get::<_, String>(2)?,
278                    r.get::<_, String>(3)?,
279                    r.get::<_, String>(4)?,
280                ))
281            })?
282            .collect::<Result<Vec<_>, _>>()?;
283
284        for (mid, name, ns, mtype, desc) in rows {
285            if !dedup_ids.insert(mid) {
286                continue;
287            }
288            let edge = entity_edge.get(&entity_id);
289            out.push(RelatedMemory {
290                memory_id: mid,
291                name,
292                namespace: ns,
293                memory_type: mtype,
294                description: desc,
295                hop_distance: hop,
296                source_entity: edge.map(|e| e.0.clone()),
297                target_entity: edge.map(|e| e.1.clone()),
298                relation: edge.map(|e| e.2.clone()),
299                weight: edge.map(|e| e.3),
300            });
301            if out.len() >= limit {
302                return Ok(out);
303            }
304        }
305    }
306
307    Ok(out)
308}
309
310fn fetch_neighbours(
311    conn: &Connection,
312    entity_id: i64,
313    namespace: &str,
314    min_weight: f64,
315    relation_filter: Option<&str>,
316) -> Result<Vec<Neighbour>, AppError> {
317    // Follow edges in both directions: source -> target and target -> source so traversal is
318    // undirected, which is how users typically reason about "related" memories.
319    let base_sql = "\
320        SELECT r.target_id, se.name, te.name, r.relation, r.weight
321        FROM relationships r
322        JOIN entities se ON se.id = r.source_id
323        JOIN entities te ON te.id = r.target_id
324        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
325
326    let reverse_sql = "\
327        SELECT r.source_id, se.name, te.name, r.relation, r.weight
328        FROM relationships r
329        JOIN entities se ON se.id = r.source_id
330        JOIN entities te ON te.id = r.target_id
331        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
332
333    let mut results: Vec<Neighbour> = Vec::new();
334
335    let forward_sql = match relation_filter {
336        Some(_) => format!("{base_sql} AND r.relation = ?4"),
337        None => base_sql.to_string(),
338    };
339    let rev_sql = match relation_filter {
340        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
341        None => reverse_sql.to_string(),
342    };
343
344    let mut stmt = conn.prepare_cached(&forward_sql)?;
345    let rows: Vec<_> = if let Some(rel) = relation_filter {
346        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
347            Ok((
348                r.get::<_, i64>(0)?,
349                r.get::<_, String>(1)?,
350                r.get::<_, String>(2)?,
351                r.get::<_, String>(3)?,
352                r.get::<_, f64>(4)?,
353            ))
354        })?
355        .collect::<Result<Vec<_>, _>>()?
356    } else {
357        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
358            Ok((
359                r.get::<_, i64>(0)?,
360                r.get::<_, String>(1)?,
361                r.get::<_, String>(2)?,
362                r.get::<_, String>(3)?,
363                r.get::<_, f64>(4)?,
364            ))
365        })?
366        .collect::<Result<Vec<_>, _>>()?
367    };
368    results.extend(rows);
369
370    let mut stmt = conn.prepare_cached(&rev_sql)?;
371    let rows: Vec<_> = if let Some(rel) = relation_filter {
372        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
373            Ok((
374                r.get::<_, i64>(0)?,
375                r.get::<_, String>(1)?,
376                r.get::<_, String>(2)?,
377                r.get::<_, String>(3)?,
378                r.get::<_, f64>(4)?,
379            ))
380        })?
381        .collect::<Result<Vec<_>, _>>()?
382    } else {
383        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
384            Ok((
385                r.get::<_, i64>(0)?,
386                r.get::<_, String>(1)?,
387                r.get::<_, String>(2)?,
388                r.get::<_, String>(3)?,
389                r.get::<_, f64>(4)?,
390            ))
391        })?
392        .collect::<Result<Vec<_>, _>>()?
393    };
394    results.extend(rows);
395
396    Ok(results)
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    fn setup_related_db() -> rusqlite::Connection {
404        let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
405        conn.execute_batch(
406            "CREATE TABLE memories (
407                id INTEGER PRIMARY KEY AUTOINCREMENT,
408                name TEXT NOT NULL,
409                namespace TEXT NOT NULL DEFAULT 'global',
410                type TEXT NOT NULL DEFAULT 'fact',
411                description TEXT NOT NULL DEFAULT '',
412                deleted_at INTEGER
413            );
414            CREATE TABLE entities (
415                id INTEGER PRIMARY KEY AUTOINCREMENT,
416                namespace TEXT NOT NULL,
417                name TEXT NOT NULL
418            );
419            CREATE TABLE relationships (
420                id INTEGER PRIMARY KEY AUTOINCREMENT,
421                namespace TEXT NOT NULL,
422                source_id INTEGER NOT NULL,
423                target_id INTEGER NOT NULL,
424                relation TEXT NOT NULL DEFAULT 'related_to',
425                weight REAL NOT NULL DEFAULT 1.0
426            );
427            CREATE TABLE memory_entities (
428                memory_id INTEGER NOT NULL,
429                entity_id INTEGER NOT NULL
430            );",
431        )
432        .expect("failed to create test tables");
433        conn
434    }
435
436    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
437        conn.execute(
438            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
439            rusqlite::params![name, namespace],
440        )
441        .expect("failed to insert memory");
442        conn.last_insert_rowid()
443    }
444
445    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
446        conn.execute(
447            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
448            rusqlite::params![name, namespace],
449        )
450        .expect("failed to insert entity");
451        conn.last_insert_rowid()
452    }
453
454    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
455        conn.execute(
456            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
457            rusqlite::params![memory_id, entity_id],
458        )
459        .expect("failed to link memory-entity");
460    }
461
462    fn insert_relationship(
463        conn: &rusqlite::Connection,
464        namespace: &str,
465        source_id: i64,
466        target_id: i64,
467        relation: &str,
468        weight: f64,
469    ) {
470        conn.execute(
471            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
472             VALUES (?1, ?2, ?3, ?4, ?5)",
473            rusqlite::params![namespace, source_id, target_id, relation, weight],
474        )
475        .expect("failed to insert relationship");
476    }
477
478    #[test]
479    fn related_response_serializes_results_and_elapsed_ms() {
480        let resp = RelatedResponse {
481            name: "seed-mem".to_string(),
482            max_hops: 2,
483            results: vec![RelatedMemory {
484                memory_id: 1,
485                name: "neighbor-mem".to_string(),
486                namespace: "global".to_string(),
487                memory_type: "document".to_string(),
488                description: "desc".to_string(),
489                hop_distance: 1,
490                source_entity: Some("entity-a".to_string()),
491                target_entity: Some("entity-b".to_string()),
492                relation: Some("related_to".to_string()),
493                weight: Some(0.9),
494            }],
495            elapsed_ms: 7,
496        };
497
498        let json = serde_json::to_value(&resp).expect("serialization failed");
499        assert!(json["results"].is_array());
500        assert_eq!(json["results"].as_array().unwrap().len(), 1);
501        assert_eq!(json["elapsed_ms"], 7u64);
502        assert_eq!(json["results"][0]["type"], "document");
503        assert_eq!(json["results"][0]["hop_distance"], 1);
504    }
505
506    #[test]
507    fn traverse_related_returns_empty_without_seed_entities() {
508        let conn = setup_related_db();
509        let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
510            .expect("traverse_related failed");
511        assert!(
512            result.is_empty(),
513            "no seed entities must yield empty results"
514        );
515    }
516
517    #[test]
518    fn traverse_related_returns_empty_with_max_hops_zero() {
519        let conn = setup_related_db();
520        let mem_id = insert_memory(&conn, "seed-mem", "global");
521        let ent_id = insert_entity(&conn, "ent-a", "global");
522        link_memory_entity(&conn, mem_id, ent_id);
523
524        let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
525            .expect("traverse_related failed");
526        assert!(result.is_empty(), "max_hops=0 must return empty");
527    }
528
529    #[test]
530    fn traverse_related_discovers_neighbor_memory_via_graph() {
531        let conn = setup_related_db();
532
533        let seed_id = insert_memory(&conn, "seed-mem", "global");
534        let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
535        let ent_a = insert_entity(&conn, "ent-a", "global");
536        let ent_b = insert_entity(&conn, "ent-b", "global");
537
538        link_memory_entity(&conn, seed_id, ent_a);
539        link_memory_entity(&conn, neighbor_id, ent_b);
540        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
541
542        let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
543            .expect("traverse_related failed");
544
545        assert_eq!(result.len(), 1, "must find 1 neighboring memory");
546        assert_eq!(result[0].name, "neighbor-mem");
547        assert_eq!(result[0].hop_distance, 1);
548    }
549
550    #[test]
551    fn traverse_related_respects_limit() {
552        let conn = setup_related_db();
553
554        let seed_id = insert_memory(&conn, "seed", "global");
555        let ent_seed = insert_entity(&conn, "ent-seed", "global");
556        link_memory_entity(&conn, seed_id, ent_seed);
557
558        for i in 0..5 {
559            let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
560            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
561            link_memory_entity(&conn, mem_id, ent_id);
562            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
563        }
564
565        let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
566            .expect("traverse_related failed");
567
568        assert!(
569            result.len() <= 3,
570            "limit=3 must constrain to at most 3 results"
571        );
572    }
573
574    #[test]
575    fn related_memory_optional_null_fields_serialized() {
576        let mem = RelatedMemory {
577            memory_id: 99,
578            name: "no-relation".to_string(),
579            namespace: "ns".to_string(),
580            memory_type: "concept".to_string(),
581            description: "".to_string(),
582            hop_distance: 2,
583            source_entity: None,
584            target_entity: None,
585            relation: None,
586            weight: None,
587        };
588
589        let json = serde_json::to_value(&mem).expect("serialization failed");
590        assert!(json["source_entity"].is_null());
591        assert!(json["target_entity"].is_null());
592        assert!(json["relation"].is_null());
593        assert!(json["weight"].is_null());
594        assert_eq!(json["hop_distance"], 2);
595    }
596}