Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::constants::{
4    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashSet, VecDeque};
14
15/// Identifies whether the seed resolved to a memory or a bare entity.
16enum SeedKind {
17    Memory(i64),
18    Entity(i64),
19}
20
21/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
22/// target_name, relation, weight).
23type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n  \
27    # List memories connected to a memory via the entity graph (default 2 hops)\n  \
28    sqlite-graphrag related onboarding\n\n  \
29    # Increase hop distance and filter by relation type\n  \
30    sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n  \
31    # Cap result count and require minimum edge weight\n  \
32    sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34    /// Memory name as a positional argument. Alternative to `--name`.
35    #[arg(
36        value_name = "NAME",
37        conflicts_with = "name",
38        help = "Memory name whose neighbours to traverse; alternative to --name"
39    )]
40    pub name_positional: Option<String>,
41    /// Memory name as a flag. Required when the positional form is absent. Also accepts the alias `--from`.
42    #[arg(long, alias = "from")]
43    pub name: Option<String>,
44    /// Maximum graph hop count. Also accepts the alias `--hops`.
45    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46    pub max_hops: u32,
47    /// Filter results to a specific relation type. Canonical values:
48    /// applies-to, uses, depends-on, causes, fixes, contradicts, supports,
49    /// follows, related, mentions, replaces, tracked-in.
50    /// Any kebab-case or snake_case string is also accepted as a custom relation.
51    #[arg(long, value_parser = crate::parsers::parse_relation)]
52    pub relation: Option<String>,
53    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
54    pub min_weight: f64,
55    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
56    pub limit: usize,
57    #[arg(long)]
58    pub namespace: Option<String>,
59    #[arg(long, value_enum, default_value = "json")]
60    pub format: OutputFormat,
61    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
62    pub json: bool,
63    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
64    pub db: Option<String>,
65}
66
67#[derive(Serialize)]
68struct RelatedResponse {
69    /// Echo of the seed memory name resolved from `--name` or the positional argument.
70    /// Added in v1.0.35 for input transparency in JSON output.
71    name: String,
72    /// Echo of the resolved `--max-hops` value (default 2). Added in v1.0.35.
73    max_hops: u32,
74    results: Vec<RelatedMemory>,
75    /// Semantic alias of `results` following the v1.0.66 alias pattern (list has items/memories).
76    related_memories: Vec<RelatedMemory>,
77    elapsed_ms: u64,
78}
79
80#[derive(Serialize, Clone)]
81struct RelatedMemory {
82    memory_id: i64,
83    name: String,
84    namespace: String,
85    #[serde(rename = "type")]
86    memory_type: String,
87    description: String,
88    hop_distance: u32,
89    source_entity: Option<String>,
90    target_entity: Option<String>,
91    /// Alias of `source_entity` for cross-command consistency (graph, link, deep-research use from/to).
92    #[serde(skip_serializing_if = "Option::is_none")]
93    from: Option<String>,
94    /// Alias of `target_entity` for cross-command consistency.
95    #[serde(skip_serializing_if = "Option::is_none")]
96    to: Option<String>,
97    relation: Option<String>,
98    weight: Option<f64>,
99}
100
101pub fn run(args: RelatedArgs) -> Result<(), AppError> {
102    let inicio = std::time::Instant::now();
103    let name = args
104        .name_positional
105        .as_deref()
106        .or(args.name.as_deref())
107        .ok_or_else(|| {
108            AppError::Validation(
109                "name required: pass as positional argument or via --name".to_string(),
110            )
111        })?
112        .to_string();
113
114    if name.trim().is_empty() {
115        return Err(AppError::Validation("name must not be empty".to_string()));
116    }
117
118    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
119    let paths = AppPaths::resolve(args.db.as_deref())?;
120
121    crate::storage::connection::ensure_db_ready(&paths)?;
122
123    let conn = open_ro(&paths.db)?;
124
125    // Locate the seed: try memory first, fall back to bare entity.
126    let seed = match conn.query_row(
127        "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
128        params![namespace, name],
129        |r| r.get::<_, i64>(0),
130    ) {
131        Ok(id) => SeedKind::Memory(id),
132        Err(rusqlite::Error::QueryReturnedNoRows) => {
133            match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
134                Some(id) => SeedKind::Entity(id),
135                None => {
136                    return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
137                        &name, &namespace,
138                    )))
139                }
140            }
141        }
142        Err(e) => return Err(AppError::Database(e)),
143    };
144
145    // Collect seed entity IDs depending on seed kind.
146    let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
147        SeedKind::Memory(id) => {
148            let mem_id = *id;
149            let mut stmt =
150                conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
151            let rows: Vec<i64> = stmt
152                .query_map(params![mem_id], |r| r.get(0))?
153                .collect::<Result<Vec<i64>, _>>()?;
154            (mem_id, rows)
155        }
156        SeedKind::Entity(entity_id) => {
157            // For a bare entity seed there is no corresponding memory to skip.
158            // Use a sentinel -1 so dedup never matches a real memory_id.
159            (-1, vec![*entity_id])
160        }
161    };
162
163    let relation_filter = args.relation;
164    if let Some(ref r) = relation_filter {
165        crate::parsers::warn_if_non_canonical(r);
166    }
167    let results = traverse_related(
168        &conn,
169        seed_memory_id,
170        &seed_entity_ids,
171        &namespace,
172        args.max_hops,
173        args.min_weight,
174        relation_filter.as_deref(),
175        args.limit,
176    )?;
177
178    match args.format {
179        OutputFormat::Json => {
180            let related_memories = results.clone();
181            output::emit_json(&RelatedResponse {
182                name: name.clone(),
183                max_hops: args.max_hops,
184                results,
185                related_memories,
186                elapsed_ms: inicio.elapsed().as_millis() as u64,
187            })?;
188        }
189        OutputFormat::Text => {
190            for item in &results {
191                if item.description.is_empty() {
192                    output::emit_text(&format!(
193                        "{}. {} ({})",
194                        item.hop_distance, item.name, item.namespace
195                    ));
196                } else {
197                    let preview: String = item
198                        .description
199                        .chars()
200                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
201                        .collect();
202                    output::emit_text(&format!(
203                        "{}. {} ({}): {}",
204                        item.hop_distance, item.name, item.namespace, preview
205                    ));
206                }
207            }
208        }
209        OutputFormat::Markdown => {
210            for item in &results {
211                if item.description.is_empty() {
212                    output::emit_text(&format!(
213                        "- **{}** ({}) — hop {}",
214                        item.name, item.namespace, item.hop_distance
215                    ));
216                } else {
217                    let preview: String = item
218                        .description
219                        .chars()
220                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
221                        .collect();
222                    output::emit_text(&format!(
223                        "- **{}** ({}) — hop {}: {}",
224                        item.name, item.namespace, item.hop_distance, preview
225                    ));
226                }
227            }
228        }
229    }
230
231    Ok(())
232}
233
234#[allow(clippy::too_many_arguments)]
235fn traverse_related(
236    conn: &Connection,
237    seed_memory_id: i64,
238    seed_entity_ids: &[i64],
239    namespace: &str,
240    max_hops: u32,
241    min_weight: f64,
242    relation_filter: Option<&str>,
243    limit: usize,
244) -> Result<Vec<RelatedMemory>, AppError> {
245    if seed_entity_ids.is_empty() || max_hops == 0 {
246        return Ok(Vec::new());
247    }
248
249    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
250    // of the edge that first reached each entity.
251    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
252    let mut entity_hop: crate::hash::AHashMap<i64, u32> =
253        crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
254    for &e in seed_entity_ids {
255        entity_hop.insert(e, 0);
256    }
257    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
258    // that reached this entity — equivalent to BFS shortest path recall edge).
259    let mut entity_edge: crate::hash::AHashMap<i64, (String, String, String, f64)> =
260        crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
261
262    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
263
264    while let Some(current_entity) = queue.pop_front() {
265        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
266        if current_hop >= max_hops {
267            continue;
268        }
269
270        let neighbours =
271            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
272
273        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
274            if visited.insert(neighbour_id) {
275                entity_hop.insert(neighbour_id, current_hop + 1);
276                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
277                queue.push_back(neighbour_id);
278            }
279        }
280    }
281
282    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
283    let mut out: Vec<RelatedMemory> = Vec::with_capacity(limit);
284    let mut dedup_ids: crate::hash::AHashSet<i64> =
285        crate::hash::AHashSet::with_capacity_and_hasher(limit, Default::default());
286    dedup_ids.insert(seed_memory_id);
287
288    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
289    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
290        .iter()
291        .filter(|(id, _)| !seed_entity_ids.contains(id))
292        .map(|(id, hop)| (*id, *hop))
293        .collect();
294    ordered_entities.sort_by(|a, b| {
295        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
296        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
297        a.1.cmp(&b.1).then_with(|| {
298            weight_b
299                .partial_cmp(&weight_a)
300                .unwrap_or(std::cmp::Ordering::Equal)
301        })
302    });
303
304    for (entity_id, hop) in ordered_entities {
305        let mut stmt = conn.prepare_cached(
306            "SELECT m.id, m.name, m.namespace, m.type, m.description
307             FROM memory_entities me
308             JOIN memories m ON m.id = me.memory_id
309             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
310        )?;
311        let rows = stmt
312            .query_map(params![entity_id], |r| {
313                Ok((
314                    r.get::<_, i64>(0)?,
315                    r.get::<_, String>(1)?,
316                    r.get::<_, String>(2)?,
317                    r.get::<_, String>(3)?,
318                    r.get::<_, String>(4)?,
319                ))
320            })?
321            .collect::<Result<Vec<_>, _>>()?;
322
323        for (mid, name, ns, mtype, desc) in rows {
324            if !dedup_ids.insert(mid) {
325                continue;
326            }
327            let edge = entity_edge.get(&entity_id);
328            let src = edge.map(|e| e.0.clone());
329            let tgt = edge.map(|e| e.1.clone());
330            out.push(RelatedMemory {
331                memory_id: mid,
332                name,
333                namespace: ns,
334                memory_type: mtype,
335                description: desc,
336                hop_distance: hop,
337                source_entity: src.clone(),
338                target_entity: tgt.clone(),
339                from: src,
340                to: tgt,
341                relation: edge.map(|e| e.2.clone()),
342                weight: edge.map(|e| e.3),
343            });
344            if out.len() >= limit {
345                return Ok(out);
346            }
347        }
348    }
349
350    Ok(out)
351}
352
353fn fetch_neighbours(
354    conn: &Connection,
355    entity_id: i64,
356    namespace: &str,
357    min_weight: f64,
358    relation_filter: Option<&str>,
359) -> Result<Vec<Neighbour>, AppError> {
360    // Follow edges in both directions: source -> target and target -> source so traversal is
361    // undirected, which is how users typically reason about "related" memories.
362    let base_sql = "\
363        SELECT r.target_id, se.name, te.name, r.relation, r.weight
364        FROM relationships r
365        JOIN entities se ON se.id = r.source_id
366        JOIN entities te ON te.id = r.target_id
367        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
368
369    let reverse_sql = "\
370        SELECT r.source_id, se.name, te.name, r.relation, r.weight
371        FROM relationships r
372        JOIN entities se ON se.id = r.source_id
373        JOIN entities te ON te.id = r.target_id
374        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
375
376    let mut results: Vec<Neighbour> = Vec::with_capacity(16);
377
378    let forward_sql = match relation_filter {
379        Some(_) => format!("{base_sql} AND r.relation = ?4"),
380        None => base_sql.to_string(),
381    };
382    let rev_sql = match relation_filter {
383        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
384        None => reverse_sql.to_string(),
385    };
386
387    let mut stmt = conn.prepare_cached(&forward_sql)?;
388    let rows: Vec<_> = if let Some(rel) = relation_filter {
389        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
390            Ok((
391                r.get::<_, i64>(0)?,
392                r.get::<_, String>(1)?,
393                r.get::<_, String>(2)?,
394                r.get::<_, String>(3)?,
395                r.get::<_, f64>(4)?,
396            ))
397        })?
398        .collect::<Result<Vec<_>, _>>()?
399    } else {
400        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
401            Ok((
402                r.get::<_, i64>(0)?,
403                r.get::<_, String>(1)?,
404                r.get::<_, String>(2)?,
405                r.get::<_, String>(3)?,
406                r.get::<_, f64>(4)?,
407            ))
408        })?
409        .collect::<Result<Vec<_>, _>>()?
410    };
411    results.extend(rows);
412
413    let mut stmt = conn.prepare_cached(&rev_sql)?;
414    let rows: Vec<_> = if let Some(rel) = relation_filter {
415        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
416            Ok((
417                r.get::<_, i64>(0)?,
418                r.get::<_, String>(1)?,
419                r.get::<_, String>(2)?,
420                r.get::<_, String>(3)?,
421                r.get::<_, f64>(4)?,
422            ))
423        })?
424        .collect::<Result<Vec<_>, _>>()?
425    } else {
426        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
427            Ok((
428                r.get::<_, i64>(0)?,
429                r.get::<_, String>(1)?,
430                r.get::<_, String>(2)?,
431                r.get::<_, String>(3)?,
432                r.get::<_, f64>(4)?,
433            ))
434        })?
435        .collect::<Result<Vec<_>, _>>()?
436    };
437    results.extend(rows);
438
439    Ok(results)
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    fn setup_related_db() -> rusqlite::Connection {
447        let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
448        conn.execute_batch(
449            "CREATE TABLE memories (
450                id INTEGER PRIMARY KEY AUTOINCREMENT,
451                name TEXT NOT NULL,
452                namespace TEXT NOT NULL DEFAULT 'global',
453                type TEXT NOT NULL DEFAULT 'fact',
454                description TEXT NOT NULL DEFAULT '',
455                deleted_at INTEGER
456            );
457            CREATE TABLE entities (
458                id INTEGER PRIMARY KEY AUTOINCREMENT,
459                namespace TEXT NOT NULL,
460                name TEXT NOT NULL
461            );
462            CREATE TABLE relationships (
463                id INTEGER PRIMARY KEY AUTOINCREMENT,
464                namespace TEXT NOT NULL,
465                source_id INTEGER NOT NULL,
466                target_id INTEGER NOT NULL,
467                relation TEXT NOT NULL DEFAULT 'related_to',
468                weight REAL NOT NULL DEFAULT 1.0
469            );
470            CREATE TABLE memory_entities (
471                memory_id INTEGER NOT NULL,
472                entity_id INTEGER NOT NULL
473            );",
474        )
475        .expect("failed to create test tables");
476        conn
477    }
478
479    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
480        conn.execute(
481            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
482            rusqlite::params![name, namespace],
483        )
484        .expect("failed to insert memory");
485        conn.last_insert_rowid()
486    }
487
488    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
489        conn.execute(
490            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
491            rusqlite::params![name, namespace],
492        )
493        .expect("failed to insert entity");
494        conn.last_insert_rowid()
495    }
496
497    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
498        conn.execute(
499            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
500            rusqlite::params![memory_id, entity_id],
501        )
502        .expect("failed to link memory-entity");
503    }
504
505    fn insert_relationship(
506        conn: &rusqlite::Connection,
507        namespace: &str,
508        source_id: i64,
509        target_id: i64,
510        relation: &str,
511        weight: f64,
512    ) {
513        conn.execute(
514            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
515             VALUES (?1, ?2, ?3, ?4, ?5)",
516            rusqlite::params![namespace, source_id, target_id, relation, weight],
517        )
518        .expect("failed to insert relationship");
519    }
520
521    #[test]
522    fn related_response_serializes_results_and_elapsed_ms() {
523        let mem = RelatedMemory {
524            memory_id: 1,
525            name: "neighbor-mem".to_string(),
526            namespace: "global".to_string(),
527            memory_type: "document".to_string(),
528            description: "desc".to_string(),
529            hop_distance: 1,
530            source_entity: Some("entity-a".to_string()),
531            target_entity: Some("entity-b".to_string()),
532            from: Some("entity-a".to_string()),
533            to: Some("entity-b".to_string()),
534            relation: Some("related_to".to_string()),
535            weight: Some(0.9),
536        };
537        let resp = RelatedResponse {
538            name: "seed-mem".to_string(),
539            max_hops: 2,
540            related_memories: vec![mem.clone()],
541            results: vec![mem],
542            elapsed_ms: 7,
543        };
544
545        let json = serde_json::to_value(&resp).expect("serialization failed");
546        assert!(json["results"].is_array());
547        assert_eq!(json["results"].as_array().unwrap().len(), 1);
548        assert_eq!(json["elapsed_ms"], 7u64);
549        assert_eq!(json["results"][0]["type"], "document");
550        assert_eq!(json["results"][0]["hop_distance"], 1);
551    }
552
553    #[test]
554    fn traverse_related_returns_empty_without_seed_entities() {
555        let conn = setup_related_db();
556        let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
557            .expect("traverse_related failed");
558        assert!(
559            result.is_empty(),
560            "no seed entities must yield empty results"
561        );
562    }
563
564    #[test]
565    fn traverse_related_returns_empty_with_max_hops_zero() {
566        let conn = setup_related_db();
567        let mem_id = insert_memory(&conn, "seed-mem", "global");
568        let ent_id = insert_entity(&conn, "ent-a", "global");
569        link_memory_entity(&conn, mem_id, ent_id);
570
571        let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
572            .expect("traverse_related failed");
573        assert!(result.is_empty(), "max_hops=0 must return empty");
574    }
575
576    #[test]
577    fn traverse_related_discovers_neighbor_memory_via_graph() {
578        let conn = setup_related_db();
579
580        let seed_id = insert_memory(&conn, "seed-mem", "global");
581        let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
582        let ent_a = insert_entity(&conn, "ent-a", "global");
583        let ent_b = insert_entity(&conn, "ent-b", "global");
584
585        link_memory_entity(&conn, seed_id, ent_a);
586        link_memory_entity(&conn, neighbor_id, ent_b);
587        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
588
589        let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
590            .expect("traverse_related failed");
591
592        assert_eq!(result.len(), 1, "must find 1 neighboring memory");
593        assert_eq!(result[0].name, "neighbor-mem");
594        assert_eq!(result[0].hop_distance, 1);
595    }
596
597    #[test]
598    fn traverse_related_respects_limit() {
599        let conn = setup_related_db();
600
601        let seed_id = insert_memory(&conn, "seed", "global");
602        let ent_seed = insert_entity(&conn, "ent-seed", "global");
603        link_memory_entity(&conn, seed_id, ent_seed);
604
605        for i in 0..5 {
606            let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
607            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
608            link_memory_entity(&conn, mem_id, ent_id);
609            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
610        }
611
612        let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
613            .expect("traverse_related failed");
614
615        assert!(
616            result.len() <= 3,
617            "limit=3 must constrain to at most 3 results"
618        );
619    }
620
621    #[test]
622    fn related_memory_optional_null_fields_serialized() {
623        let mem = RelatedMemory {
624            memory_id: 99,
625            name: "no-relation".to_string(),
626            namespace: "ns".to_string(),
627            memory_type: "concept".to_string(),
628            description: "".to_string(),
629            hop_distance: 2,
630            source_entity: None,
631            target_entity: None,
632            from: None,
633            to: None,
634            relation: None,
635            weight: None,
636        };
637
638        let json = serde_json::to_value(&mem).expect("serialization failed");
639        assert!(json["source_entity"].is_null());
640        assert!(json["target_entity"].is_null());
641        assert!(json["relation"].is_null());
642        assert!(json["weight"].is_null());
643        assert_eq!(json["hop_distance"], 2);
644    }
645}