1use crate::constants::{
4 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashSet, VecDeque};
14
15enum SeedKind {
17 Memory(i64),
18 Entity(i64),
19}
20
21type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n \
27 # List memories connected to a memory via the entity graph (default 2 hops)\n \
28 sqlite-graphrag related onboarding\n\n \
29 # Increase hop distance and filter by relation type\n \
30 sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n \
31 # Cap result count and require minimum edge weight\n \
32 sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34 #[arg(
36 value_name = "NAME",
37 conflicts_with = "name",
38 help = "Memory name whose neighbours to traverse; alternative to --name"
39 )]
40 pub name_positional: Option<String>,
41 #[arg(long, alias = "from")]
43 pub name: Option<String>,
44 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46 pub max_hops: u32,
47 #[arg(long, value_parser = crate::parsers::parse_relation)]
52 pub relation: Option<String>,
53 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
54 pub min_weight: f64,
55 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
56 pub limit: usize,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long, value_enum, default_value = "json")]
60 pub format: OutputFormat,
61 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
62 pub json: bool,
63 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
64 pub db: Option<String>,
65}
66
67#[derive(Serialize)]
68struct RelatedResponse {
69 name: String,
72 max_hops: u32,
74 results: Vec<RelatedMemory>,
75 related_memories: Vec<RelatedMemory>,
77 elapsed_ms: u64,
78}
79
80#[derive(Serialize, Clone)]
81struct RelatedMemory {
82 memory_id: i64,
83 name: String,
84 namespace: String,
85 #[serde(rename = "type")]
86 memory_type: String,
87 description: String,
88 hop_distance: u32,
89 source_entity: Option<String>,
90 target_entity: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 from: Option<String>,
94 #[serde(skip_serializing_if = "Option::is_none")]
96 to: Option<String>,
97 relation: Option<String>,
98 weight: Option<f64>,
99}
100
101pub fn run(args: RelatedArgs) -> Result<(), AppError> {
102 let inicio = std::time::Instant::now();
103 let name = args
104 .name_positional
105 .as_deref()
106 .or(args.name.as_deref())
107 .ok_or_else(|| {
108 AppError::Validation(
109 "name required: pass as positional argument or via --name".to_string(),
110 )
111 })?
112 .to_string();
113
114 if name.trim().is_empty() {
115 return Err(AppError::Validation("name must not be empty".to_string()));
116 }
117
118 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
119 let paths = AppPaths::resolve(args.db.as_deref())?;
120
121 crate::storage::connection::ensure_db_ready(&paths)?;
122
123 let conn = open_ro(&paths.db)?;
124
125 let seed = match conn.query_row(
127 "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
128 params![namespace, name],
129 |r| r.get::<_, i64>(0),
130 ) {
131 Ok(id) => SeedKind::Memory(id),
132 Err(rusqlite::Error::QueryReturnedNoRows) => {
133 match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
134 Some(id) => SeedKind::Entity(id),
135 None => {
136 return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
137 &name, &namespace,
138 )))
139 }
140 }
141 }
142 Err(e) => return Err(AppError::Database(e)),
143 };
144
145 let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
147 SeedKind::Memory(id) => {
148 let mem_id = *id;
149 let mut stmt =
150 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
151 let rows: Vec<i64> = stmt
152 .query_map(params![mem_id], |r| r.get(0))?
153 .collect::<Result<Vec<i64>, _>>()?;
154 (mem_id, rows)
155 }
156 SeedKind::Entity(entity_id) => {
157 (-1, vec![*entity_id])
160 }
161 };
162
163 let relation_filter = args.relation;
164 if let Some(ref r) = relation_filter {
165 crate::parsers::warn_if_non_canonical(r);
166 }
167 let results = traverse_related(
168 &conn,
169 seed_memory_id,
170 &seed_entity_ids,
171 &namespace,
172 args.max_hops,
173 args.min_weight,
174 relation_filter.as_deref(),
175 args.limit,
176 )?;
177
178 match args.format {
179 OutputFormat::Json => {
180 let related_memories = results.clone();
181 output::emit_json(&RelatedResponse {
182 name: name.clone(),
183 max_hops: args.max_hops,
184 results,
185 related_memories,
186 elapsed_ms: inicio.elapsed().as_millis() as u64,
187 })?;
188 }
189 OutputFormat::Text => {
190 for item in &results {
191 if item.description.is_empty() {
192 output::emit_text(&format!(
193 "{}. {} ({})",
194 item.hop_distance, item.name, item.namespace
195 ));
196 } else {
197 let preview: String = item
198 .description
199 .chars()
200 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
201 .collect();
202 output::emit_text(&format!(
203 "{}. {} ({}): {}",
204 item.hop_distance, item.name, item.namespace, preview
205 ));
206 }
207 }
208 }
209 OutputFormat::Markdown => {
210 for item in &results {
211 if item.description.is_empty() {
212 output::emit_text(&format!(
213 "- **{}** ({}) — hop {}",
214 item.name, item.namespace, item.hop_distance
215 ));
216 } else {
217 let preview: String = item
218 .description
219 .chars()
220 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
221 .collect();
222 output::emit_text(&format!(
223 "- **{}** ({}) — hop {}: {}",
224 item.name, item.namespace, item.hop_distance, preview
225 ));
226 }
227 }
228 }
229 }
230
231 Ok(())
232}
233
234#[allow(clippy::too_many_arguments)]
235fn traverse_related(
236 conn: &Connection,
237 seed_memory_id: i64,
238 seed_entity_ids: &[i64],
239 namespace: &str,
240 max_hops: u32,
241 min_weight: f64,
242 relation_filter: Option<&str>,
243 limit: usize,
244) -> Result<Vec<RelatedMemory>, AppError> {
245 if seed_entity_ids.is_empty() || max_hops == 0 {
246 return Ok(Vec::new());
247 }
248
249 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
252 let mut entity_hop: crate::hash::AHashMap<i64, u32> =
253 crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
254 for &e in seed_entity_ids {
255 entity_hop.insert(e, 0);
256 }
257 let mut entity_edge: crate::hash::AHashMap<i64, (String, String, String, f64)> =
260 crate::hash::AHashMap::with_capacity_and_hasher(max_hops as usize * 10, Default::default());
261
262 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
263
264 while let Some(current_entity) = queue.pop_front() {
265 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
266 if current_hop >= max_hops {
267 continue;
268 }
269
270 let neighbours =
271 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
272
273 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
274 if visited.insert(neighbour_id) {
275 entity_hop.insert(neighbour_id, current_hop + 1);
276 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
277 queue.push_back(neighbour_id);
278 }
279 }
280 }
281
282 let mut out: Vec<RelatedMemory> = Vec::with_capacity(limit);
284 let mut dedup_ids: crate::hash::AHashSet<i64> =
285 crate::hash::AHashSet::with_capacity_and_hasher(limit, Default::default());
286 dedup_ids.insert(seed_memory_id);
287
288 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
290 .iter()
291 .filter(|(id, _)| !seed_entity_ids.contains(id))
292 .map(|(id, hop)| (*id, *hop))
293 .collect();
294 ordered_entities.sort_by(|a, b| {
295 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
296 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
297 a.1.cmp(&b.1).then_with(|| {
298 weight_b
299 .partial_cmp(&weight_a)
300 .unwrap_or(std::cmp::Ordering::Equal)
301 })
302 });
303
304 for (entity_id, hop) in ordered_entities {
305 let mut stmt = conn.prepare_cached(
306 "SELECT m.id, m.name, m.namespace, m.type, m.description
307 FROM memory_entities me
308 JOIN memories m ON m.id = me.memory_id
309 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
310 )?;
311 let rows = stmt
312 .query_map(params![entity_id], |r| {
313 Ok((
314 r.get::<_, i64>(0)?,
315 r.get::<_, String>(1)?,
316 r.get::<_, String>(2)?,
317 r.get::<_, String>(3)?,
318 r.get::<_, String>(4)?,
319 ))
320 })?
321 .collect::<Result<Vec<_>, _>>()?;
322
323 for (mid, name, ns, mtype, desc) in rows {
324 if !dedup_ids.insert(mid) {
325 continue;
326 }
327 let edge = entity_edge.get(&entity_id);
328 let src = edge.map(|e| e.0.clone());
329 let tgt = edge.map(|e| e.1.clone());
330 out.push(RelatedMemory {
331 memory_id: mid,
332 name,
333 namespace: ns,
334 memory_type: mtype,
335 description: desc,
336 hop_distance: hop,
337 source_entity: src.clone(),
338 target_entity: tgt.clone(),
339 from: src,
340 to: tgt,
341 relation: edge.map(|e| e.2.clone()),
342 weight: edge.map(|e| e.3),
343 });
344 if out.len() >= limit {
345 return Ok(out);
346 }
347 }
348 }
349
350 Ok(out)
351}
352
353fn fetch_neighbours(
354 conn: &Connection,
355 entity_id: i64,
356 namespace: &str,
357 min_weight: f64,
358 relation_filter: Option<&str>,
359) -> Result<Vec<Neighbour>, AppError> {
360 let base_sql = "\
363 SELECT r.target_id, se.name, te.name, r.relation, r.weight
364 FROM relationships r
365 JOIN entities se ON se.id = r.source_id
366 JOIN entities te ON te.id = r.target_id
367 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
368
369 let reverse_sql = "\
370 SELECT r.source_id, se.name, te.name, r.relation, r.weight
371 FROM relationships r
372 JOIN entities se ON se.id = r.source_id
373 JOIN entities te ON te.id = r.target_id
374 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
375
376 let mut results: Vec<Neighbour> = Vec::with_capacity(16);
377
378 let forward_sql = match relation_filter {
379 Some(_) => format!("{base_sql} AND r.relation = ?4"),
380 None => base_sql.to_string(),
381 };
382 let rev_sql = match relation_filter {
383 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
384 None => reverse_sql.to_string(),
385 };
386
387 let mut stmt = conn.prepare_cached(&forward_sql)?;
388 let rows: Vec<_> = if let Some(rel) = relation_filter {
389 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
390 Ok((
391 r.get::<_, i64>(0)?,
392 r.get::<_, String>(1)?,
393 r.get::<_, String>(2)?,
394 r.get::<_, String>(3)?,
395 r.get::<_, f64>(4)?,
396 ))
397 })?
398 .collect::<Result<Vec<_>, _>>()?
399 } else {
400 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
401 Ok((
402 r.get::<_, i64>(0)?,
403 r.get::<_, String>(1)?,
404 r.get::<_, String>(2)?,
405 r.get::<_, String>(3)?,
406 r.get::<_, f64>(4)?,
407 ))
408 })?
409 .collect::<Result<Vec<_>, _>>()?
410 };
411 results.extend(rows);
412
413 let mut stmt = conn.prepare_cached(&rev_sql)?;
414 let rows: Vec<_> = if let Some(rel) = relation_filter {
415 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
416 Ok((
417 r.get::<_, i64>(0)?,
418 r.get::<_, String>(1)?,
419 r.get::<_, String>(2)?,
420 r.get::<_, String>(3)?,
421 r.get::<_, f64>(4)?,
422 ))
423 })?
424 .collect::<Result<Vec<_>, _>>()?
425 } else {
426 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
427 Ok((
428 r.get::<_, i64>(0)?,
429 r.get::<_, String>(1)?,
430 r.get::<_, String>(2)?,
431 r.get::<_, String>(3)?,
432 r.get::<_, f64>(4)?,
433 ))
434 })?
435 .collect::<Result<Vec<_>, _>>()?
436 };
437 results.extend(rows);
438
439 Ok(results)
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 fn setup_related_db() -> rusqlite::Connection {
447 let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
448 conn.execute_batch(
449 "CREATE TABLE memories (
450 id INTEGER PRIMARY KEY AUTOINCREMENT,
451 name TEXT NOT NULL,
452 namespace TEXT NOT NULL DEFAULT 'global',
453 type TEXT NOT NULL DEFAULT 'fact',
454 description TEXT NOT NULL DEFAULT '',
455 deleted_at INTEGER
456 );
457 CREATE TABLE entities (
458 id INTEGER PRIMARY KEY AUTOINCREMENT,
459 namespace TEXT NOT NULL,
460 name TEXT NOT NULL
461 );
462 CREATE TABLE relationships (
463 id INTEGER PRIMARY KEY AUTOINCREMENT,
464 namespace TEXT NOT NULL,
465 source_id INTEGER NOT NULL,
466 target_id INTEGER NOT NULL,
467 relation TEXT NOT NULL DEFAULT 'related_to',
468 weight REAL NOT NULL DEFAULT 1.0
469 );
470 CREATE TABLE memory_entities (
471 memory_id INTEGER NOT NULL,
472 entity_id INTEGER NOT NULL
473 );",
474 )
475 .expect("failed to create test tables");
476 conn
477 }
478
479 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
480 conn.execute(
481 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
482 rusqlite::params![name, namespace],
483 )
484 .expect("failed to insert memory");
485 conn.last_insert_rowid()
486 }
487
488 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
489 conn.execute(
490 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
491 rusqlite::params![name, namespace],
492 )
493 .expect("failed to insert entity");
494 conn.last_insert_rowid()
495 }
496
497 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
498 conn.execute(
499 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
500 rusqlite::params![memory_id, entity_id],
501 )
502 .expect("failed to link memory-entity");
503 }
504
505 fn insert_relationship(
506 conn: &rusqlite::Connection,
507 namespace: &str,
508 source_id: i64,
509 target_id: i64,
510 relation: &str,
511 weight: f64,
512 ) {
513 conn.execute(
514 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
515 VALUES (?1, ?2, ?3, ?4, ?5)",
516 rusqlite::params![namespace, source_id, target_id, relation, weight],
517 )
518 .expect("failed to insert relationship");
519 }
520
521 #[test]
522 fn related_response_serializes_results_and_elapsed_ms() {
523 let mem = RelatedMemory {
524 memory_id: 1,
525 name: "neighbor-mem".to_string(),
526 namespace: "global".to_string(),
527 memory_type: "document".to_string(),
528 description: "desc".to_string(),
529 hop_distance: 1,
530 source_entity: Some("entity-a".to_string()),
531 target_entity: Some("entity-b".to_string()),
532 from: Some("entity-a".to_string()),
533 to: Some("entity-b".to_string()),
534 relation: Some("related_to".to_string()),
535 weight: Some(0.9),
536 };
537 let resp = RelatedResponse {
538 name: "seed-mem".to_string(),
539 max_hops: 2,
540 related_memories: vec![mem.clone()],
541 results: vec![mem],
542 elapsed_ms: 7,
543 };
544
545 let json = serde_json::to_value(&resp).expect("serialization failed");
546 assert!(json["results"].is_array());
547 assert_eq!(json["results"].as_array().unwrap().len(), 1);
548 assert_eq!(json["elapsed_ms"], 7u64);
549 assert_eq!(json["results"][0]["type"], "document");
550 assert_eq!(json["results"][0]["hop_distance"], 1);
551 }
552
553 #[test]
554 fn traverse_related_returns_empty_without_seed_entities() {
555 let conn = setup_related_db();
556 let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
557 .expect("traverse_related failed");
558 assert!(
559 result.is_empty(),
560 "no seed entities must yield empty results"
561 );
562 }
563
564 #[test]
565 fn traverse_related_returns_empty_with_max_hops_zero() {
566 let conn = setup_related_db();
567 let mem_id = insert_memory(&conn, "seed-mem", "global");
568 let ent_id = insert_entity(&conn, "ent-a", "global");
569 link_memory_entity(&conn, mem_id, ent_id);
570
571 let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
572 .expect("traverse_related failed");
573 assert!(result.is_empty(), "max_hops=0 must return empty");
574 }
575
576 #[test]
577 fn traverse_related_discovers_neighbor_memory_via_graph() {
578 let conn = setup_related_db();
579
580 let seed_id = insert_memory(&conn, "seed-mem", "global");
581 let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
582 let ent_a = insert_entity(&conn, "ent-a", "global");
583 let ent_b = insert_entity(&conn, "ent-b", "global");
584
585 link_memory_entity(&conn, seed_id, ent_a);
586 link_memory_entity(&conn, neighbor_id, ent_b);
587 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
588
589 let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
590 .expect("traverse_related failed");
591
592 assert_eq!(result.len(), 1, "must find 1 neighboring memory");
593 assert_eq!(result[0].name, "neighbor-mem");
594 assert_eq!(result[0].hop_distance, 1);
595 }
596
597 #[test]
598 fn traverse_related_respects_limit() {
599 let conn = setup_related_db();
600
601 let seed_id = insert_memory(&conn, "seed", "global");
602 let ent_seed = insert_entity(&conn, "ent-seed", "global");
603 link_memory_entity(&conn, seed_id, ent_seed);
604
605 for i in 0..5 {
606 let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
607 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
608 link_memory_entity(&conn, mem_id, ent_id);
609 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
610 }
611
612 let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
613 .expect("traverse_related failed");
614
615 assert!(
616 result.len() <= 3,
617 "limit=3 must constrain to at most 3 results"
618 );
619 }
620
621 #[test]
622 fn related_memory_optional_null_fields_serialized() {
623 let mem = RelatedMemory {
624 memory_id: 99,
625 name: "no-relation".to_string(),
626 namespace: "ns".to_string(),
627 memory_type: "concept".to_string(),
628 description: "".to_string(),
629 hop_distance: 2,
630 source_entity: None,
631 target_entity: None,
632 from: None,
633 to: None,
634 relation: None,
635 weight: None,
636 };
637
638 let json = serde_json::to_value(&mem).expect("serialization failed");
639 assert!(json["source_entity"].is_null());
640 assert!(json["target_entity"].is_null());
641 assert!(json["relation"].is_null());
642 assert!(json["weight"].is_null());
643 assert_eq!(json["hop_distance"], 2);
644 }
645}