Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::cli::RelationKind;
4use crate::constants::{
5    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
6};
7use crate::errors::AppError;
8use crate::i18n::errors_msg;
9use crate::output::{self, OutputFormat};
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use rusqlite::{params, Connection};
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16/// Identifies whether the seed resolved to a memory or a bare entity.
17enum SeedKind {
18    Memory(i64),
19    Entity(i64),
20}
21
22/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
23/// target_name, relation, weight).
24type Neighbour = (i64, String, String, String, f64);
25
26#[derive(clap::Args)]
27#[command(after_long_help = "EXAMPLES:\n  \
28    # List memories connected to a memory via the entity graph (default 2 hops)\n  \
29    sqlite-graphrag related onboarding\n\n  \
30    # Increase hop distance and filter by relation type\n  \
31    sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n  \
32    # Cap result count and require minimum edge weight\n  \
33    sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
34pub struct RelatedArgs {
35    /// Memory name as a positional argument. Alternative to `--name`.
36    #[arg(
37        value_name = "NAME",
38        conflicts_with = "name",
39        help = "Memory name whose neighbours to traverse; alternative to --name"
40    )]
41    pub name_positional: Option<String>,
42    /// Memory name as a flag. Required when the positional form is absent. Also accepts the alias `--from`.
43    #[arg(long, alias = "from")]
44    pub name: Option<String>,
45    /// Maximum graph hop count. Also accepts the alias `--hops`.
46    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
47    pub max_hops: u32,
48    #[arg(long, value_enum)]
49    pub relation: Option<RelationKind>,
50    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
51    pub min_weight: f64,
52    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
53    pub limit: usize,
54    #[arg(long)]
55    pub namespace: Option<String>,
56    #[arg(long, value_enum, default_value = "json")]
57    pub format: OutputFormat,
58    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
59    pub json: bool,
60    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
61    pub db: Option<String>,
62}
63
64#[derive(Serialize)]
65struct RelatedResponse {
66    /// Echo of the seed memory name resolved from `--name` or the positional argument.
67    /// Added in v1.0.35 for input transparency in JSON output.
68    name: String,
69    /// Echo of the resolved `--max-hops` value (default 2). Added in v1.0.35.
70    max_hops: u32,
71    results: Vec<RelatedMemory>,
72    elapsed_ms: u64,
73}
74
75#[derive(Serialize, Clone)]
76struct RelatedMemory {
77    memory_id: i64,
78    name: String,
79    namespace: String,
80    #[serde(rename = "type")]
81    memory_type: String,
82    description: String,
83    hop_distance: u32,
84    source_entity: Option<String>,
85    target_entity: Option<String>,
86    relation: Option<String>,
87    weight: Option<f64>,
88}
89
90pub fn run(args: RelatedArgs) -> Result<(), AppError> {
91    let inicio = std::time::Instant::now();
92    let name = args
93        .name_positional
94        .as_deref()
95        .or(args.name.as_deref())
96        .ok_or_else(|| {
97            AppError::Validation(
98                "name required: pass as positional argument or via --name".to_string(),
99            )
100        })?
101        .to_string();
102
103    if name.trim().is_empty() {
104        return Err(AppError::Validation("name must not be empty".to_string()));
105    }
106
107    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
108    let paths = AppPaths::resolve(args.db.as_deref())?;
109
110    crate::storage::connection::ensure_db_ready(&paths)?;
111
112    let conn = open_ro(&paths.db)?;
113
114    // Locate the seed: try memory first, fall back to bare entity.
115    let seed = match conn.query_row(
116        "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
117        params![namespace, name],
118        |r| r.get::<_, i64>(0),
119    ) {
120        Ok(id) => SeedKind::Memory(id),
121        Err(rusqlite::Error::QueryReturnedNoRows) => {
122            match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
123                Some(id) => SeedKind::Entity(id),
124                None => {
125                    return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
126                        &name, &namespace,
127                    )))
128                }
129            }
130        }
131        Err(e) => return Err(AppError::Database(e)),
132    };
133
134    // Collect seed entity IDs depending on seed kind.
135    let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
136        SeedKind::Memory(id) => {
137            let mem_id = *id;
138            let mut stmt =
139                conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
140            let rows: Vec<i64> = stmt
141                .query_map(params![mem_id], |r| r.get(0))?
142                .collect::<Result<Vec<i64>, _>>()?;
143            (mem_id, rows)
144        }
145        SeedKind::Entity(entity_id) => {
146            // For a bare entity seed there is no corresponding memory to skip.
147            // Use a sentinel -1 so dedup never matches a real memory_id.
148            (-1, vec![*entity_id])
149        }
150    };
151
152    let relation_filter = args.relation.map(|r| r.as_str().to_string());
153    let results = traverse_related(
154        &conn,
155        seed_memory_id,
156        &seed_entity_ids,
157        &namespace,
158        args.max_hops,
159        args.min_weight,
160        relation_filter.as_deref(),
161        args.limit,
162    )?;
163
164    match args.format {
165        OutputFormat::Json => output::emit_json(&RelatedResponse {
166            name: name.clone(),
167            max_hops: args.max_hops,
168            results,
169            elapsed_ms: inicio.elapsed().as_millis() as u64,
170        })?,
171        OutputFormat::Text => {
172            for item in &results {
173                if item.description.is_empty() {
174                    output::emit_text(&format!(
175                        "{}. {} ({})",
176                        item.hop_distance, item.name, item.namespace
177                    ));
178                } else {
179                    let preview: String = item
180                        .description
181                        .chars()
182                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
183                        .collect();
184                    output::emit_text(&format!(
185                        "{}. {} ({}): {}",
186                        item.hop_distance, item.name, item.namespace, preview
187                    ));
188                }
189            }
190        }
191        OutputFormat::Markdown => {
192            for item in &results {
193                if item.description.is_empty() {
194                    output::emit_text(&format!(
195                        "- **{}** ({}) — hop {}",
196                        item.name, item.namespace, item.hop_distance
197                    ));
198                } else {
199                    let preview: String = item
200                        .description
201                        .chars()
202                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
203                        .collect();
204                    output::emit_text(&format!(
205                        "- **{}** ({}) — hop {}: {}",
206                        item.name, item.namespace, item.hop_distance, preview
207                    ));
208                }
209            }
210        }
211    }
212
213    Ok(())
214}
215
216#[allow(clippy::too_many_arguments)]
217fn traverse_related(
218    conn: &Connection,
219    seed_memory_id: i64,
220    seed_entity_ids: &[i64],
221    namespace: &str,
222    max_hops: u32,
223    min_weight: f64,
224    relation_filter: Option<&str>,
225    limit: usize,
226) -> Result<Vec<RelatedMemory>, AppError> {
227    if seed_entity_ids.is_empty() || max_hops == 0 {
228        return Ok(Vec::new());
229    }
230
231    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
232    // of the edge that first reached each entity.
233    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
234    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
235    for &e in seed_entity_ids {
236        entity_hop.insert(e, 0);
237    }
238    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
239    // that reached this entity — equivalent to BFS shortest path recall edge).
240    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
241
242    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
243
244    while let Some(current_entity) = queue.pop_front() {
245        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
246        if current_hop >= max_hops {
247            continue;
248        }
249
250        let neighbours =
251            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
252
253        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
254            if visited.insert(neighbour_id) {
255                entity_hop.insert(neighbour_id, current_hop + 1);
256                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
257                queue.push_back(neighbour_id);
258            }
259        }
260    }
261
262    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
263    let mut out: Vec<RelatedMemory> = Vec::new();
264    let mut dedup_ids: HashSet<i64> = HashSet::new();
265    dedup_ids.insert(seed_memory_id);
266
267    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
268    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
269        .iter()
270        .filter(|(id, _)| !seed_entity_ids.contains(id))
271        .map(|(id, hop)| (*id, *hop))
272        .collect();
273    ordered_entities.sort_by(|a, b| {
274        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
275        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
276        a.1.cmp(&b.1).then_with(|| {
277            weight_b
278                .partial_cmp(&weight_a)
279                .unwrap_or(std::cmp::Ordering::Equal)
280        })
281    });
282
283    for (entity_id, hop) in ordered_entities {
284        let mut stmt = conn.prepare_cached(
285            "SELECT m.id, m.name, m.namespace, m.type, m.description
286             FROM memory_entities me
287             JOIN memories m ON m.id = me.memory_id
288             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
289        )?;
290        let rows = stmt
291            .query_map(params![entity_id], |r| {
292                Ok((
293                    r.get::<_, i64>(0)?,
294                    r.get::<_, String>(1)?,
295                    r.get::<_, String>(2)?,
296                    r.get::<_, String>(3)?,
297                    r.get::<_, String>(4)?,
298                ))
299            })?
300            .collect::<Result<Vec<_>, _>>()?;
301
302        for (mid, name, ns, mtype, desc) in rows {
303            if !dedup_ids.insert(mid) {
304                continue;
305            }
306            let edge = entity_edge.get(&entity_id);
307            out.push(RelatedMemory {
308                memory_id: mid,
309                name,
310                namespace: ns,
311                memory_type: mtype,
312                description: desc,
313                hop_distance: hop,
314                source_entity: edge.map(|e| e.0.clone()),
315                target_entity: edge.map(|e| e.1.clone()),
316                relation: edge.map(|e| e.2.clone()),
317                weight: edge.map(|e| e.3),
318            });
319            if out.len() >= limit {
320                return Ok(out);
321            }
322        }
323    }
324
325    Ok(out)
326}
327
328fn fetch_neighbours(
329    conn: &Connection,
330    entity_id: i64,
331    namespace: &str,
332    min_weight: f64,
333    relation_filter: Option<&str>,
334) -> Result<Vec<Neighbour>, AppError> {
335    // Follow edges in both directions: source -> target and target -> source so traversal is
336    // undirected, which is how users typically reason about "related" memories.
337    let base_sql = "\
338        SELECT r.target_id, se.name, te.name, r.relation, r.weight
339        FROM relationships r
340        JOIN entities se ON se.id = r.source_id
341        JOIN entities te ON te.id = r.target_id
342        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
343
344    let reverse_sql = "\
345        SELECT r.source_id, se.name, te.name, r.relation, r.weight
346        FROM relationships r
347        JOIN entities se ON se.id = r.source_id
348        JOIN entities te ON te.id = r.target_id
349        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
350
351    let mut results: Vec<Neighbour> = Vec::new();
352
353    let forward_sql = match relation_filter {
354        Some(_) => format!("{base_sql} AND r.relation = ?4"),
355        None => base_sql.to_string(),
356    };
357    let rev_sql = match relation_filter {
358        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
359        None => reverse_sql.to_string(),
360    };
361
362    let mut stmt = conn.prepare_cached(&forward_sql)?;
363    let rows: Vec<_> = if let Some(rel) = relation_filter {
364        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
365            Ok((
366                r.get::<_, i64>(0)?,
367                r.get::<_, String>(1)?,
368                r.get::<_, String>(2)?,
369                r.get::<_, String>(3)?,
370                r.get::<_, f64>(4)?,
371            ))
372        })?
373        .collect::<Result<Vec<_>, _>>()?
374    } else {
375        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
376            Ok((
377                r.get::<_, i64>(0)?,
378                r.get::<_, String>(1)?,
379                r.get::<_, String>(2)?,
380                r.get::<_, String>(3)?,
381                r.get::<_, f64>(4)?,
382            ))
383        })?
384        .collect::<Result<Vec<_>, _>>()?
385    };
386    results.extend(rows);
387
388    let mut stmt = conn.prepare_cached(&rev_sql)?;
389    let rows: Vec<_> = if let Some(rel) = relation_filter {
390        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
391            Ok((
392                r.get::<_, i64>(0)?,
393                r.get::<_, String>(1)?,
394                r.get::<_, String>(2)?,
395                r.get::<_, String>(3)?,
396                r.get::<_, f64>(4)?,
397            ))
398        })?
399        .collect::<Result<Vec<_>, _>>()?
400    } else {
401        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
402            Ok((
403                r.get::<_, i64>(0)?,
404                r.get::<_, String>(1)?,
405                r.get::<_, String>(2)?,
406                r.get::<_, String>(3)?,
407                r.get::<_, f64>(4)?,
408            ))
409        })?
410        .collect::<Result<Vec<_>, _>>()?
411    };
412    results.extend(rows);
413
414    Ok(results)
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    fn setup_related_db() -> rusqlite::Connection {
422        let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
423        conn.execute_batch(
424            "CREATE TABLE memories (
425                id INTEGER PRIMARY KEY AUTOINCREMENT,
426                name TEXT NOT NULL,
427                namespace TEXT NOT NULL DEFAULT 'global',
428                type TEXT NOT NULL DEFAULT 'fact',
429                description TEXT NOT NULL DEFAULT '',
430                deleted_at INTEGER
431            );
432            CREATE TABLE entities (
433                id INTEGER PRIMARY KEY AUTOINCREMENT,
434                namespace TEXT NOT NULL,
435                name TEXT NOT NULL
436            );
437            CREATE TABLE relationships (
438                id INTEGER PRIMARY KEY AUTOINCREMENT,
439                namespace TEXT NOT NULL,
440                source_id INTEGER NOT NULL,
441                target_id INTEGER NOT NULL,
442                relation TEXT NOT NULL DEFAULT 'related_to',
443                weight REAL NOT NULL DEFAULT 1.0
444            );
445            CREATE TABLE memory_entities (
446                memory_id INTEGER NOT NULL,
447                entity_id INTEGER NOT NULL
448            );",
449        )
450        .expect("failed to create test tables");
451        conn
452    }
453
454    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
455        conn.execute(
456            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
457            rusqlite::params![name, namespace],
458        )
459        .expect("failed to insert memory");
460        conn.last_insert_rowid()
461    }
462
463    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
464        conn.execute(
465            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
466            rusqlite::params![name, namespace],
467        )
468        .expect("failed to insert entity");
469        conn.last_insert_rowid()
470    }
471
472    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
473        conn.execute(
474            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
475            rusqlite::params![memory_id, entity_id],
476        )
477        .expect("failed to link memory-entity");
478    }
479
480    fn insert_relationship(
481        conn: &rusqlite::Connection,
482        namespace: &str,
483        source_id: i64,
484        target_id: i64,
485        relation: &str,
486        weight: f64,
487    ) {
488        conn.execute(
489            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
490             VALUES (?1, ?2, ?3, ?4, ?5)",
491            rusqlite::params![namespace, source_id, target_id, relation, weight],
492        )
493        .expect("failed to insert relationship");
494    }
495
496    #[test]
497    fn related_response_serializes_results_and_elapsed_ms() {
498        let resp = RelatedResponse {
499            name: "seed-mem".to_string(),
500            max_hops: 2,
501            results: vec![RelatedMemory {
502                memory_id: 1,
503                name: "neighbor-mem".to_string(),
504                namespace: "global".to_string(),
505                memory_type: "document".to_string(),
506                description: "desc".to_string(),
507                hop_distance: 1,
508                source_entity: Some("entity-a".to_string()),
509                target_entity: Some("entity-b".to_string()),
510                relation: Some("related_to".to_string()),
511                weight: Some(0.9),
512            }],
513            elapsed_ms: 7,
514        };
515
516        let json = serde_json::to_value(&resp).expect("serialization failed");
517        assert!(json["results"].is_array());
518        assert_eq!(json["results"].as_array().unwrap().len(), 1);
519        assert_eq!(json["elapsed_ms"], 7u64);
520        assert_eq!(json["results"][0]["type"], "document");
521        assert_eq!(json["results"][0]["hop_distance"], 1);
522    }
523
524    #[test]
525    fn traverse_related_returns_empty_without_seed_entities() {
526        let conn = setup_related_db();
527        let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
528            .expect("traverse_related failed");
529        assert!(
530            result.is_empty(),
531            "no seed entities must yield empty results"
532        );
533    }
534
535    #[test]
536    fn traverse_related_returns_empty_with_max_hops_zero() {
537        let conn = setup_related_db();
538        let mem_id = insert_memory(&conn, "seed-mem", "global");
539        let ent_id = insert_entity(&conn, "ent-a", "global");
540        link_memory_entity(&conn, mem_id, ent_id);
541
542        let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
543            .expect("traverse_related failed");
544        assert!(result.is_empty(), "max_hops=0 must return empty");
545    }
546
547    #[test]
548    fn traverse_related_discovers_neighbor_memory_via_graph() {
549        let conn = setup_related_db();
550
551        let seed_id = insert_memory(&conn, "seed-mem", "global");
552        let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
553        let ent_a = insert_entity(&conn, "ent-a", "global");
554        let ent_b = insert_entity(&conn, "ent-b", "global");
555
556        link_memory_entity(&conn, seed_id, ent_a);
557        link_memory_entity(&conn, neighbor_id, ent_b);
558        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
559
560        let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
561            .expect("traverse_related failed");
562
563        assert_eq!(result.len(), 1, "must find 1 neighboring memory");
564        assert_eq!(result[0].name, "neighbor-mem");
565        assert_eq!(result[0].hop_distance, 1);
566    }
567
568    #[test]
569    fn traverse_related_respects_limit() {
570        let conn = setup_related_db();
571
572        let seed_id = insert_memory(&conn, "seed", "global");
573        let ent_seed = insert_entity(&conn, "ent-seed", "global");
574        link_memory_entity(&conn, seed_id, ent_seed);
575
576        for i in 0..5 {
577            let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
578            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
579            link_memory_entity(&conn, mem_id, ent_id);
580            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
581        }
582
583        let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
584            .expect("traverse_related failed");
585
586        assert!(
587            result.len() <= 3,
588            "limit=3 must constrain to at most 3 results"
589        );
590    }
591
592    #[test]
593    fn related_memory_optional_null_fields_serialized() {
594        let mem = RelatedMemory {
595            memory_id: 99,
596            name: "no-relation".to_string(),
597            namespace: "ns".to_string(),
598            memory_type: "concept".to_string(),
599            description: "".to_string(),
600            hop_distance: 2,
601            source_entity: None,
602            target_entity: None,
603            relation: None,
604            weight: None,
605        };
606
607        let json = serde_json::to_value(&mem).expect("serialization failed");
608        assert!(json["source_entity"].is_null());
609        assert!(json["target_entity"].is_null());
610        assert!(json["relation"].is_null());
611        assert!(json["weight"].is_null());
612        assert_eq!(json["hop_distance"], 2);
613    }
614}