1use crate::cli::RelationKind;
4use crate::constants::{
5 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
6};
7use crate::errors::AppError;
8use crate::i18n::errors_msg;
9use crate::output::{self, OutputFormat};
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use rusqlite::{params, Connection};
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16type Neighbour = (i64, String, String, String, f64);
19
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # List memories connected to a memory via the entity graph (default 2 hops)\n \
23 sqlite-graphrag related onboarding\n\n \
24 # Increase hop distance and filter by relation type\n \
25 sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n \
26 # Cap result count and require minimum edge weight\n \
27 sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
28pub struct RelatedArgs {
29 #[arg(
31 value_name = "NAME",
32 conflicts_with = "name",
33 help = "Memory name whose neighbours to traverse; alternative to --name"
34 )]
35 pub name_positional: Option<String>,
36 #[arg(long)]
38 pub name: Option<String>,
39 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
41 pub max_hops: u32,
42 #[arg(long, value_enum)]
43 pub relation: Option<RelationKind>,
44 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
45 pub min_weight: f64,
46 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
47 pub limit: usize,
48 #[arg(long)]
49 pub namespace: Option<String>,
50 #[arg(long, value_enum, default_value = "json")]
51 pub format: OutputFormat,
52 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
53 pub json: bool,
54 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
55 pub db: Option<String>,
56}
57
58#[derive(Serialize)]
59struct RelatedResponse {
60 name: String,
63 max_hops: u32,
65 results: Vec<RelatedMemory>,
66 elapsed_ms: u64,
67}
68
69#[derive(Serialize, Clone)]
70struct RelatedMemory {
71 memory_id: i64,
72 name: String,
73 namespace: String,
74 #[serde(rename = "type")]
75 memory_type: String,
76 description: String,
77 hop_distance: u32,
78 source_entity: Option<String>,
79 target_entity: Option<String>,
80 relation: Option<String>,
81 weight: Option<f64>,
82}
83
84pub fn run(args: RelatedArgs) -> Result<(), AppError> {
85 let inicio = std::time::Instant::now();
86 let name = args
87 .name_positional
88 .as_deref()
89 .or(args.name.as_deref())
90 .ok_or_else(|| {
91 AppError::Validation(
92 "name required: pass as positional argument or via --name".to_string(),
93 )
94 })?
95 .to_string();
96
97 if name.trim().is_empty() {
98 return Err(AppError::Validation("name must not be empty".to_string()));
99 }
100
101 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
102 let paths = AppPaths::resolve(args.db.as_deref())?;
103
104 crate::storage::connection::ensure_db_ready(&paths)?;
105
106 let conn = open_ro(&paths.db)?;
107
108 let seed_id: i64 = match conn.query_row(
110 "SELECT id FROM memories
111 WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
112 params![namespace, name],
113 |r| r.get(0),
114 ) {
115 Ok(id) => id,
116 Err(rusqlite::Error::QueryReturnedNoRows) => {
117 return Err(AppError::NotFound(errors_msg::memory_not_found(
118 &name, &namespace,
119 )));
120 }
121 Err(e) => return Err(AppError::Database(e)),
122 };
123
124 let seed_entity_ids: Vec<i64> = {
126 let mut stmt =
127 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
128 let rows: Vec<i64> = stmt
129 .query_map(params![seed_id], |r| r.get(0))?
130 .collect::<Result<Vec<i64>, _>>()?;
131 rows
132 };
133
134 let relation_filter = args.relation.map(|r| r.as_str().to_string());
135 let results = traverse_related(
136 &conn,
137 seed_id,
138 &seed_entity_ids,
139 &namespace,
140 args.max_hops,
141 args.min_weight,
142 relation_filter.as_deref(),
143 args.limit,
144 )?;
145
146 match args.format {
147 OutputFormat::Json => output::emit_json(&RelatedResponse {
148 name: name.clone(),
149 max_hops: args.max_hops,
150 results,
151 elapsed_ms: inicio.elapsed().as_millis() as u64,
152 })?,
153 OutputFormat::Text => {
154 for item in &results {
155 if item.description.is_empty() {
156 output::emit_text(&format!(
157 "{}. {} ({})",
158 item.hop_distance, item.name, item.namespace
159 ));
160 } else {
161 let preview: String = item
162 .description
163 .chars()
164 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
165 .collect();
166 output::emit_text(&format!(
167 "{}. {} ({}): {}",
168 item.hop_distance, item.name, item.namespace, preview
169 ));
170 }
171 }
172 }
173 OutputFormat::Markdown => {
174 for item in &results {
175 if item.description.is_empty() {
176 output::emit_text(&format!(
177 "- **{}** ({}) — hop {}",
178 item.name, item.namespace, item.hop_distance
179 ));
180 } else {
181 let preview: String = item
182 .description
183 .chars()
184 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
185 .collect();
186 output::emit_text(&format!(
187 "- **{}** ({}) — hop {}: {}",
188 item.name, item.namespace, item.hop_distance, preview
189 ));
190 }
191 }
192 }
193 }
194
195 Ok(())
196}
197
198#[allow(clippy::too_many_arguments)]
199fn traverse_related(
200 conn: &Connection,
201 seed_memory_id: i64,
202 seed_entity_ids: &[i64],
203 namespace: &str,
204 max_hops: u32,
205 min_weight: f64,
206 relation_filter: Option<&str>,
207 limit: usize,
208) -> Result<Vec<RelatedMemory>, AppError> {
209 if seed_entity_ids.is_empty() || max_hops == 0 {
210 return Ok(Vec::new());
211 }
212
213 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
216 let mut entity_hop: HashMap<i64, u32> = HashMap::new();
217 for &e in seed_entity_ids {
218 entity_hop.insert(e, 0);
219 }
220 let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
223
224 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
225
226 while let Some(current_entity) = queue.pop_front() {
227 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
228 if current_hop >= max_hops {
229 continue;
230 }
231
232 let neighbours =
233 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
234
235 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
236 if visited.insert(neighbour_id) {
237 entity_hop.insert(neighbour_id, current_hop + 1);
238 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
239 queue.push_back(neighbour_id);
240 }
241 }
242 }
243
244 let mut out: Vec<RelatedMemory> = Vec::new();
246 let mut dedup_ids: HashSet<i64> = HashSet::new();
247 dedup_ids.insert(seed_memory_id);
248
249 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
251 .iter()
252 .filter(|(id, _)| !seed_entity_ids.contains(id))
253 .map(|(id, hop)| (*id, *hop))
254 .collect();
255 ordered_entities.sort_by(|a, b| {
256 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
257 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
258 a.1.cmp(&b.1).then_with(|| {
259 weight_b
260 .partial_cmp(&weight_a)
261 .unwrap_or(std::cmp::Ordering::Equal)
262 })
263 });
264
265 for (entity_id, hop) in ordered_entities {
266 let mut stmt = conn.prepare_cached(
267 "SELECT m.id, m.name, m.namespace, m.type, m.description
268 FROM memory_entities me
269 JOIN memories m ON m.id = me.memory_id
270 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
271 )?;
272 let rows = stmt
273 .query_map(params![entity_id], |r| {
274 Ok((
275 r.get::<_, i64>(0)?,
276 r.get::<_, String>(1)?,
277 r.get::<_, String>(2)?,
278 r.get::<_, String>(3)?,
279 r.get::<_, String>(4)?,
280 ))
281 })?
282 .collect::<Result<Vec<_>, _>>()?;
283
284 for (mid, name, ns, mtype, desc) in rows {
285 if !dedup_ids.insert(mid) {
286 continue;
287 }
288 let edge = entity_edge.get(&entity_id);
289 out.push(RelatedMemory {
290 memory_id: mid,
291 name,
292 namespace: ns,
293 memory_type: mtype,
294 description: desc,
295 hop_distance: hop,
296 source_entity: edge.map(|e| e.0.clone()),
297 target_entity: edge.map(|e| e.1.clone()),
298 relation: edge.map(|e| e.2.clone()),
299 weight: edge.map(|e| e.3),
300 });
301 if out.len() >= limit {
302 return Ok(out);
303 }
304 }
305 }
306
307 Ok(out)
308}
309
310fn fetch_neighbours(
311 conn: &Connection,
312 entity_id: i64,
313 namespace: &str,
314 min_weight: f64,
315 relation_filter: Option<&str>,
316) -> Result<Vec<Neighbour>, AppError> {
317 let base_sql = "\
320 SELECT r.target_id, se.name, te.name, r.relation, r.weight
321 FROM relationships r
322 JOIN entities se ON se.id = r.source_id
323 JOIN entities te ON te.id = r.target_id
324 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
325
326 let reverse_sql = "\
327 SELECT r.source_id, se.name, te.name, r.relation, r.weight
328 FROM relationships r
329 JOIN entities se ON se.id = r.source_id
330 JOIN entities te ON te.id = r.target_id
331 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
332
333 let mut results: Vec<Neighbour> = Vec::new();
334
335 let forward_sql = match relation_filter {
336 Some(_) => format!("{base_sql} AND r.relation = ?4"),
337 None => base_sql.to_string(),
338 };
339 let rev_sql = match relation_filter {
340 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
341 None => reverse_sql.to_string(),
342 };
343
344 let mut stmt = conn.prepare_cached(&forward_sql)?;
345 let rows: Vec<_> = if let Some(rel) = relation_filter {
346 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
347 Ok((
348 r.get::<_, i64>(0)?,
349 r.get::<_, String>(1)?,
350 r.get::<_, String>(2)?,
351 r.get::<_, String>(3)?,
352 r.get::<_, f64>(4)?,
353 ))
354 })?
355 .collect::<Result<Vec<_>, _>>()?
356 } else {
357 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
358 Ok((
359 r.get::<_, i64>(0)?,
360 r.get::<_, String>(1)?,
361 r.get::<_, String>(2)?,
362 r.get::<_, String>(3)?,
363 r.get::<_, f64>(4)?,
364 ))
365 })?
366 .collect::<Result<Vec<_>, _>>()?
367 };
368 results.extend(rows);
369
370 let mut stmt = conn.prepare_cached(&rev_sql)?;
371 let rows: Vec<_> = if let Some(rel) = relation_filter {
372 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
373 Ok((
374 r.get::<_, i64>(0)?,
375 r.get::<_, String>(1)?,
376 r.get::<_, String>(2)?,
377 r.get::<_, String>(3)?,
378 r.get::<_, f64>(4)?,
379 ))
380 })?
381 .collect::<Result<Vec<_>, _>>()?
382 } else {
383 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
384 Ok((
385 r.get::<_, i64>(0)?,
386 r.get::<_, String>(1)?,
387 r.get::<_, String>(2)?,
388 r.get::<_, String>(3)?,
389 r.get::<_, f64>(4)?,
390 ))
391 })?
392 .collect::<Result<Vec<_>, _>>()?
393 };
394 results.extend(rows);
395
396 Ok(results)
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 fn setup_related_db() -> rusqlite::Connection {
404 let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
405 conn.execute_batch(
406 "CREATE TABLE memories (
407 id INTEGER PRIMARY KEY AUTOINCREMENT,
408 name TEXT NOT NULL,
409 namespace TEXT NOT NULL DEFAULT 'global',
410 type TEXT NOT NULL DEFAULT 'fact',
411 description TEXT NOT NULL DEFAULT '',
412 deleted_at INTEGER
413 );
414 CREATE TABLE entities (
415 id INTEGER PRIMARY KEY AUTOINCREMENT,
416 namespace TEXT NOT NULL,
417 name TEXT NOT NULL
418 );
419 CREATE TABLE relationships (
420 id INTEGER PRIMARY KEY AUTOINCREMENT,
421 namespace TEXT NOT NULL,
422 source_id INTEGER NOT NULL,
423 target_id INTEGER NOT NULL,
424 relation TEXT NOT NULL DEFAULT 'related_to',
425 weight REAL NOT NULL DEFAULT 1.0
426 );
427 CREATE TABLE memory_entities (
428 memory_id INTEGER NOT NULL,
429 entity_id INTEGER NOT NULL
430 );",
431 )
432 .expect("failed to create test tables");
433 conn
434 }
435
436 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
437 conn.execute(
438 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
439 rusqlite::params![name, namespace],
440 )
441 .expect("failed to insert memory");
442 conn.last_insert_rowid()
443 }
444
445 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
446 conn.execute(
447 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
448 rusqlite::params![name, namespace],
449 )
450 .expect("failed to insert entity");
451 conn.last_insert_rowid()
452 }
453
454 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
455 conn.execute(
456 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
457 rusqlite::params![memory_id, entity_id],
458 )
459 .expect("failed to link memory-entity");
460 }
461
462 fn insert_relationship(
463 conn: &rusqlite::Connection,
464 namespace: &str,
465 source_id: i64,
466 target_id: i64,
467 relation: &str,
468 weight: f64,
469 ) {
470 conn.execute(
471 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
472 VALUES (?1, ?2, ?3, ?4, ?5)",
473 rusqlite::params![namespace, source_id, target_id, relation, weight],
474 )
475 .expect("failed to insert relationship");
476 }
477
478 #[test]
479 fn related_response_serializes_results_and_elapsed_ms() {
480 let resp = RelatedResponse {
481 name: "seed-mem".to_string(),
482 max_hops: 2,
483 results: vec![RelatedMemory {
484 memory_id: 1,
485 name: "neighbor-mem".to_string(),
486 namespace: "global".to_string(),
487 memory_type: "document".to_string(),
488 description: "desc".to_string(),
489 hop_distance: 1,
490 source_entity: Some("entity-a".to_string()),
491 target_entity: Some("entity-b".to_string()),
492 relation: Some("related_to".to_string()),
493 weight: Some(0.9),
494 }],
495 elapsed_ms: 7,
496 };
497
498 let json = serde_json::to_value(&resp).expect("serialization failed");
499 assert!(json["results"].is_array());
500 assert_eq!(json["results"].as_array().unwrap().len(), 1);
501 assert_eq!(json["elapsed_ms"], 7u64);
502 assert_eq!(json["results"][0]["type"], "document");
503 assert_eq!(json["results"][0]["hop_distance"], 1);
504 }
505
506 #[test]
507 fn traverse_related_returns_empty_without_seed_entities() {
508 let conn = setup_related_db();
509 let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
510 .expect("traverse_related failed");
511 assert!(
512 result.is_empty(),
513 "no seed entities must yield empty results"
514 );
515 }
516
517 #[test]
518 fn traverse_related_returns_empty_with_max_hops_zero() {
519 let conn = setup_related_db();
520 let mem_id = insert_memory(&conn, "seed-mem", "global");
521 let ent_id = insert_entity(&conn, "ent-a", "global");
522 link_memory_entity(&conn, mem_id, ent_id);
523
524 let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
525 .expect("traverse_related failed");
526 assert!(result.is_empty(), "max_hops=0 must return empty");
527 }
528
529 #[test]
530 fn traverse_related_discovers_neighbor_memory_via_graph() {
531 let conn = setup_related_db();
532
533 let seed_id = insert_memory(&conn, "seed-mem", "global");
534 let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
535 let ent_a = insert_entity(&conn, "ent-a", "global");
536 let ent_b = insert_entity(&conn, "ent-b", "global");
537
538 link_memory_entity(&conn, seed_id, ent_a);
539 link_memory_entity(&conn, neighbor_id, ent_b);
540 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
541
542 let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
543 .expect("traverse_related failed");
544
545 assert_eq!(result.len(), 1, "must find 1 neighboring memory");
546 assert_eq!(result[0].name, "neighbor-mem");
547 assert_eq!(result[0].hop_distance, 1);
548 }
549
550 #[test]
551 fn traverse_related_respects_limit() {
552 let conn = setup_related_db();
553
554 let seed_id = insert_memory(&conn, "seed", "global");
555 let ent_seed = insert_entity(&conn, "ent-seed", "global");
556 link_memory_entity(&conn, seed_id, ent_seed);
557
558 for i in 0..5 {
559 let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
560 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
561 link_memory_entity(&conn, mem_id, ent_id);
562 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
563 }
564
565 let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
566 .expect("traverse_related failed");
567
568 assert!(
569 result.len() <= 3,
570 "limit=3 must constrain to at most 3 results"
571 );
572 }
573
574 #[test]
575 fn related_memory_optional_null_fields_serialized() {
576 let mem = RelatedMemory {
577 memory_id: 99,
578 name: "no-relation".to_string(),
579 namespace: "ns".to_string(),
580 memory_type: "concept".to_string(),
581 description: "".to_string(),
582 hop_distance: 2,
583 source_entity: None,
584 target_entity: None,
585 relation: None,
586 weight: None,
587 };
588
589 let json = serde_json::to_value(&mem).expect("serialization failed");
590 assert!(json["source_entity"].is_null());
591 assert!(json["target_entity"].is_null());
592 assert!(json["relation"].is_null());
593 assert!(json["weight"].is_null());
594 assert_eq!(json["hop_distance"], 2);
595 }
596}