Skip to main content

sqlite_graphrag/commands/
related.rs

1//! Handler for the `related` CLI subcommand.
2
3use crate::constants::{
4    DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashMap, HashSet, VecDeque};
14
15/// Identifies whether the seed resolved to a memory or a bare entity.
16enum SeedKind {
17    Memory(i64),
18    Entity(i64),
19}
20
21/// Tuple returned by the adjacency fetch: (neighbour_entity_id, source_name,
22/// target_name, relation, weight).
23type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n  \
27    # List memories connected to a memory via the entity graph (default 2 hops)\n  \
28    sqlite-graphrag related onboarding\n\n  \
29    # Increase hop distance and filter by relation type\n  \
30    sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n  \
31    # Cap result count and require minimum edge weight\n  \
32    sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34    /// Memory name as a positional argument. Alternative to `--name`.
35    #[arg(
36        value_name = "NAME",
37        conflicts_with = "name",
38        help = "Memory name whose neighbours to traverse; alternative to --name"
39    )]
40    pub name_positional: Option<String>,
41    /// Memory name as a flag. Required when the positional form is absent. Also accepts the alias `--from`.
42    #[arg(long, alias = "from")]
43    pub name: Option<String>,
44    /// Maximum graph hop count. Also accepts the alias `--hops`.
45    #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46    pub max_hops: u32,
47    #[arg(long, value_parser = crate::parsers::parse_relation)]
48    pub relation: Option<String>,
49    #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
50    pub min_weight: f64,
51    #[arg(long, default_value_t = DEFAULT_K_RECALL)]
52    pub limit: usize,
53    #[arg(long)]
54    pub namespace: Option<String>,
55    #[arg(long, value_enum, default_value = "json")]
56    pub format: OutputFormat,
57    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
58    pub json: bool,
59    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
60    pub db: Option<String>,
61}
62
63#[derive(Serialize)]
64struct RelatedResponse {
65    /// Echo of the seed memory name resolved from `--name` or the positional argument.
66    /// Added in v1.0.35 for input transparency in JSON output.
67    name: String,
68    /// Echo of the resolved `--max-hops` value (default 2). Added in v1.0.35.
69    max_hops: u32,
70    results: Vec<RelatedMemory>,
71    elapsed_ms: u64,
72}
73
74#[derive(Serialize, Clone)]
75struct RelatedMemory {
76    memory_id: i64,
77    name: String,
78    namespace: String,
79    #[serde(rename = "type")]
80    memory_type: String,
81    description: String,
82    hop_distance: u32,
83    source_entity: Option<String>,
84    target_entity: Option<String>,
85    relation: Option<String>,
86    weight: Option<f64>,
87}
88
89pub fn run(args: RelatedArgs) -> Result<(), AppError> {
90    let inicio = std::time::Instant::now();
91    let name = args
92        .name_positional
93        .as_deref()
94        .or(args.name.as_deref())
95        .ok_or_else(|| {
96            AppError::Validation(
97                "name required: pass as positional argument or via --name".to_string(),
98            )
99        })?
100        .to_string();
101
102    if name.trim().is_empty() {
103        return Err(AppError::Validation("name must not be empty".to_string()));
104    }
105
106    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
107    let paths = AppPaths::resolve(args.db.as_deref())?;
108
109    crate::storage::connection::ensure_db_ready(&paths)?;
110
111    let conn = open_ro(&paths.db)?;
112
113    // Locate the seed: try memory first, fall back to bare entity.
114    let seed = match conn.query_row(
115        "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
116        params![namespace, name],
117        |r| r.get::<_, i64>(0),
118    ) {
119        Ok(id) => SeedKind::Memory(id),
120        Err(rusqlite::Error::QueryReturnedNoRows) => {
121            match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
122                Some(id) => SeedKind::Entity(id),
123                None => {
124                    return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
125                        &name, &namespace,
126                    )))
127                }
128            }
129        }
130        Err(e) => return Err(AppError::Database(e)),
131    };
132
133    // Collect seed entity IDs depending on seed kind.
134    let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
135        SeedKind::Memory(id) => {
136            let mem_id = *id;
137            let mut stmt =
138                conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
139            let rows: Vec<i64> = stmt
140                .query_map(params![mem_id], |r| r.get(0))?
141                .collect::<Result<Vec<i64>, _>>()?;
142            (mem_id, rows)
143        }
144        SeedKind::Entity(entity_id) => {
145            // For a bare entity seed there is no corresponding memory to skip.
146            // Use a sentinel -1 so dedup never matches a real memory_id.
147            (-1, vec![*entity_id])
148        }
149    };
150
151    let relation_filter = args.relation;
152    let results = traverse_related(
153        &conn,
154        seed_memory_id,
155        &seed_entity_ids,
156        &namespace,
157        args.max_hops,
158        args.min_weight,
159        relation_filter.as_deref(),
160        args.limit,
161    )?;
162
163    match args.format {
164        OutputFormat::Json => output::emit_json(&RelatedResponse {
165            name: name.clone(),
166            max_hops: args.max_hops,
167            results,
168            elapsed_ms: inicio.elapsed().as_millis() as u64,
169        })?,
170        OutputFormat::Text => {
171            for item in &results {
172                if item.description.is_empty() {
173                    output::emit_text(&format!(
174                        "{}. {} ({})",
175                        item.hop_distance, item.name, item.namespace
176                    ));
177                } else {
178                    let preview: String = item
179                        .description
180                        .chars()
181                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
182                        .collect();
183                    output::emit_text(&format!(
184                        "{}. {} ({}): {}",
185                        item.hop_distance, item.name, item.namespace, preview
186                    ));
187                }
188            }
189        }
190        OutputFormat::Markdown => {
191            for item in &results {
192                if item.description.is_empty() {
193                    output::emit_text(&format!(
194                        "- **{}** ({}) — hop {}",
195                        item.name, item.namespace, item.hop_distance
196                    ));
197                } else {
198                    let preview: String = item
199                        .description
200                        .chars()
201                        .take(TEXT_DESCRIPTION_PREVIEW_LEN)
202                        .collect();
203                    output::emit_text(&format!(
204                        "- **{}** ({}) — hop {}: {}",
205                        item.name, item.namespace, item.hop_distance, preview
206                    ));
207                }
208            }
209        }
210    }
211
212    Ok(())
213}
214
215#[allow(clippy::too_many_arguments)]
216fn traverse_related(
217    conn: &Connection,
218    seed_memory_id: i64,
219    seed_entity_ids: &[i64],
220    namespace: &str,
221    max_hops: u32,
222    min_weight: f64,
223    relation_filter: Option<&str>,
224    limit: usize,
225) -> Result<Vec<RelatedMemory>, AppError> {
226    if seed_entity_ids.is_empty() || max_hops == 0 {
227        return Ok(Vec::new());
228    }
229
230    // BFS over entities keeping track of hop distance and the (source, target, relation, weight)
231    // of the edge that first reached each entity.
232    let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
233    let mut entity_hop: HashMap<i64, u32> = HashMap::new();
234    for &e in seed_entity_ids {
235        entity_hop.insert(e, 0);
236    }
237    // Per-entity edge info: source_name, target_name, relation, weight (captures the FIRST edge
238    // that reached this entity — equivalent to BFS shortest path recall edge).
239    let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
240
241    let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
242
243    while let Some(current_entity) = queue.pop_front() {
244        let current_hop = *entity_hop.get(&current_entity).unwrap_or(&0);
245        if current_hop >= max_hops {
246            continue;
247        }
248
249        let neighbours =
250            fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
251
252        for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
253            if visited.insert(neighbour_id) {
254                entity_hop.insert(neighbour_id, current_hop + 1);
255                entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
256                queue.push_back(neighbour_id);
257            }
258        }
259    }
260
261    // For each discovered entity (hop >= 1) find its memories, skipping the seed memory.
262    let mut out: Vec<RelatedMemory> = Vec::new();
263    let mut dedup_ids: HashSet<i64> = HashSet::new();
264    dedup_ids.insert(seed_memory_id);
265
266    // Sort entities by hop ASC, weight DESC so we emit closer entities first.
267    let mut ordered_entities: Vec<(i64, u32)> = entity_hop
268        .iter()
269        .filter(|(id, _)| !seed_entity_ids.contains(id))
270        .map(|(id, hop)| (*id, *hop))
271        .collect();
272    ordered_entities.sort_by(|a, b| {
273        let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
274        let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
275        a.1.cmp(&b.1).then_with(|| {
276            weight_b
277                .partial_cmp(&weight_a)
278                .unwrap_or(std::cmp::Ordering::Equal)
279        })
280    });
281
282    for (entity_id, hop) in ordered_entities {
283        let mut stmt = conn.prepare_cached(
284            "SELECT m.id, m.name, m.namespace, m.type, m.description
285             FROM memory_entities me
286             JOIN memories m ON m.id = me.memory_id
287             WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
288        )?;
289        let rows = stmt
290            .query_map(params![entity_id], |r| {
291                Ok((
292                    r.get::<_, i64>(0)?,
293                    r.get::<_, String>(1)?,
294                    r.get::<_, String>(2)?,
295                    r.get::<_, String>(3)?,
296                    r.get::<_, String>(4)?,
297                ))
298            })?
299            .collect::<Result<Vec<_>, _>>()?;
300
301        for (mid, name, ns, mtype, desc) in rows {
302            if !dedup_ids.insert(mid) {
303                continue;
304            }
305            let edge = entity_edge.get(&entity_id);
306            out.push(RelatedMemory {
307                memory_id: mid,
308                name,
309                namespace: ns,
310                memory_type: mtype,
311                description: desc,
312                hop_distance: hop,
313                source_entity: edge.map(|e| e.0.clone()),
314                target_entity: edge.map(|e| e.1.clone()),
315                relation: edge.map(|e| e.2.clone()),
316                weight: edge.map(|e| e.3),
317            });
318            if out.len() >= limit {
319                return Ok(out);
320            }
321        }
322    }
323
324    Ok(out)
325}
326
327fn fetch_neighbours(
328    conn: &Connection,
329    entity_id: i64,
330    namespace: &str,
331    min_weight: f64,
332    relation_filter: Option<&str>,
333) -> Result<Vec<Neighbour>, AppError> {
334    // Follow edges in both directions: source -> target and target -> source so traversal is
335    // undirected, which is how users typically reason about "related" memories.
336    let base_sql = "\
337        SELECT r.target_id, se.name, te.name, r.relation, r.weight
338        FROM relationships r
339        JOIN entities se ON se.id = r.source_id
340        JOIN entities te ON te.id = r.target_id
341        WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
342
343    let reverse_sql = "\
344        SELECT r.source_id, se.name, te.name, r.relation, r.weight
345        FROM relationships r
346        JOIN entities se ON se.id = r.source_id
347        JOIN entities te ON te.id = r.target_id
348        WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
349
350    let mut results: Vec<Neighbour> = Vec::new();
351
352    let forward_sql = match relation_filter {
353        Some(_) => format!("{base_sql} AND r.relation = ?4"),
354        None => base_sql.to_string(),
355    };
356    let rev_sql = match relation_filter {
357        Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
358        None => reverse_sql.to_string(),
359    };
360
361    let mut stmt = conn.prepare_cached(&forward_sql)?;
362    let rows: Vec<_> = if let Some(rel) = relation_filter {
363        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
364            Ok((
365                r.get::<_, i64>(0)?,
366                r.get::<_, String>(1)?,
367                r.get::<_, String>(2)?,
368                r.get::<_, String>(3)?,
369                r.get::<_, f64>(4)?,
370            ))
371        })?
372        .collect::<Result<Vec<_>, _>>()?
373    } else {
374        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
375            Ok((
376                r.get::<_, i64>(0)?,
377                r.get::<_, String>(1)?,
378                r.get::<_, String>(2)?,
379                r.get::<_, String>(3)?,
380                r.get::<_, f64>(4)?,
381            ))
382        })?
383        .collect::<Result<Vec<_>, _>>()?
384    };
385    results.extend(rows);
386
387    let mut stmt = conn.prepare_cached(&rev_sql)?;
388    let rows: Vec<_> = if let Some(rel) = relation_filter {
389        stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
390            Ok((
391                r.get::<_, i64>(0)?,
392                r.get::<_, String>(1)?,
393                r.get::<_, String>(2)?,
394                r.get::<_, String>(3)?,
395                r.get::<_, f64>(4)?,
396            ))
397        })?
398        .collect::<Result<Vec<_>, _>>()?
399    } else {
400        stmt.query_map(params![entity_id, min_weight, namespace], |r| {
401            Ok((
402                r.get::<_, i64>(0)?,
403                r.get::<_, String>(1)?,
404                r.get::<_, String>(2)?,
405                r.get::<_, String>(3)?,
406                r.get::<_, f64>(4)?,
407            ))
408        })?
409        .collect::<Result<Vec<_>, _>>()?
410    };
411    results.extend(rows);
412
413    Ok(results)
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    fn setup_related_db() -> rusqlite::Connection {
421        let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
422        conn.execute_batch(
423            "CREATE TABLE memories (
424                id INTEGER PRIMARY KEY AUTOINCREMENT,
425                name TEXT NOT NULL,
426                namespace TEXT NOT NULL DEFAULT 'global',
427                type TEXT NOT NULL DEFAULT 'fact',
428                description TEXT NOT NULL DEFAULT '',
429                deleted_at INTEGER
430            );
431            CREATE TABLE entities (
432                id INTEGER PRIMARY KEY AUTOINCREMENT,
433                namespace TEXT NOT NULL,
434                name TEXT NOT NULL
435            );
436            CREATE TABLE relationships (
437                id INTEGER PRIMARY KEY AUTOINCREMENT,
438                namespace TEXT NOT NULL,
439                source_id INTEGER NOT NULL,
440                target_id INTEGER NOT NULL,
441                relation TEXT NOT NULL DEFAULT 'related_to',
442                weight REAL NOT NULL DEFAULT 1.0
443            );
444            CREATE TABLE memory_entities (
445                memory_id INTEGER NOT NULL,
446                entity_id INTEGER NOT NULL
447            );",
448        )
449        .expect("failed to create test tables");
450        conn
451    }
452
453    fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
454        conn.execute(
455            "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
456            rusqlite::params![name, namespace],
457        )
458        .expect("failed to insert memory");
459        conn.last_insert_rowid()
460    }
461
462    fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
463        conn.execute(
464            "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
465            rusqlite::params![name, namespace],
466        )
467        .expect("failed to insert entity");
468        conn.last_insert_rowid()
469    }
470
471    fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
472        conn.execute(
473            "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
474            rusqlite::params![memory_id, entity_id],
475        )
476        .expect("failed to link memory-entity");
477    }
478
479    fn insert_relationship(
480        conn: &rusqlite::Connection,
481        namespace: &str,
482        source_id: i64,
483        target_id: i64,
484        relation: &str,
485        weight: f64,
486    ) {
487        conn.execute(
488            "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
489             VALUES (?1, ?2, ?3, ?4, ?5)",
490            rusqlite::params![namespace, source_id, target_id, relation, weight],
491        )
492        .expect("failed to insert relationship");
493    }
494
495    #[test]
496    fn related_response_serializes_results_and_elapsed_ms() {
497        let resp = RelatedResponse {
498            name: "seed-mem".to_string(),
499            max_hops: 2,
500            results: vec![RelatedMemory {
501                memory_id: 1,
502                name: "neighbor-mem".to_string(),
503                namespace: "global".to_string(),
504                memory_type: "document".to_string(),
505                description: "desc".to_string(),
506                hop_distance: 1,
507                source_entity: Some("entity-a".to_string()),
508                target_entity: Some("entity-b".to_string()),
509                relation: Some("related_to".to_string()),
510                weight: Some(0.9),
511            }],
512            elapsed_ms: 7,
513        };
514
515        let json = serde_json::to_value(&resp).expect("serialization failed");
516        assert!(json["results"].is_array());
517        assert_eq!(json["results"].as_array().unwrap().len(), 1);
518        assert_eq!(json["elapsed_ms"], 7u64);
519        assert_eq!(json["results"][0]["type"], "document");
520        assert_eq!(json["results"][0]["hop_distance"], 1);
521    }
522
523    #[test]
524    fn traverse_related_returns_empty_without_seed_entities() {
525        let conn = setup_related_db();
526        let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
527            .expect("traverse_related failed");
528        assert!(
529            result.is_empty(),
530            "no seed entities must yield empty results"
531        );
532    }
533
534    #[test]
535    fn traverse_related_returns_empty_with_max_hops_zero() {
536        let conn = setup_related_db();
537        let mem_id = insert_memory(&conn, "seed-mem", "global");
538        let ent_id = insert_entity(&conn, "ent-a", "global");
539        link_memory_entity(&conn, mem_id, ent_id);
540
541        let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
542            .expect("traverse_related failed");
543        assert!(result.is_empty(), "max_hops=0 must return empty");
544    }
545
546    #[test]
547    fn traverse_related_discovers_neighbor_memory_via_graph() {
548        let conn = setup_related_db();
549
550        let seed_id = insert_memory(&conn, "seed-mem", "global");
551        let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
552        let ent_a = insert_entity(&conn, "ent-a", "global");
553        let ent_b = insert_entity(&conn, "ent-b", "global");
554
555        link_memory_entity(&conn, seed_id, ent_a);
556        link_memory_entity(&conn, neighbor_id, ent_b);
557        insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
558
559        let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
560            .expect("traverse_related failed");
561
562        assert_eq!(result.len(), 1, "must find 1 neighboring memory");
563        assert_eq!(result[0].name, "neighbor-mem");
564        assert_eq!(result[0].hop_distance, 1);
565    }
566
567    #[test]
568    fn traverse_related_respects_limit() {
569        let conn = setup_related_db();
570
571        let seed_id = insert_memory(&conn, "seed", "global");
572        let ent_seed = insert_entity(&conn, "ent-seed", "global");
573        link_memory_entity(&conn, seed_id, ent_seed);
574
575        for i in 0..5 {
576            let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
577            let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
578            link_memory_entity(&conn, mem_id, ent_id);
579            insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
580        }
581
582        let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
583            .expect("traverse_related failed");
584
585        assert!(
586            result.len() <= 3,
587            "limit=3 must constrain to at most 3 results"
588        );
589    }
590
591    #[test]
592    fn related_memory_optional_null_fields_serialized() {
593        let mem = RelatedMemory {
594            memory_id: 99,
595            name: "no-relation".to_string(),
596            namespace: "ns".to_string(),
597            memory_type: "concept".to_string(),
598            description: "".to_string(),
599            hop_distance: 2,
600            source_entity: None,
601            target_entity: None,
602            relation: None,
603            weight: None,
604        };
605
606        let json = serde_json::to_value(&mem).expect("serialization failed");
607        assert!(json["source_entity"].is_null());
608        assert!(json["target_entity"].is_null());
609        assert!(json["relation"].is_null());
610        assert!(json["weight"].is_null());
611        assert_eq!(json["hop_distance"], 2);
612    }
613}