1use crate::cli::RelationKind;
4use crate::constants::{
5 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
6};
7use crate::errors::AppError;
8use crate::i18n::errors_msg;
9use crate::output::{self, OutputFormat};
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use rusqlite::{params, Connection};
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16enum SeedKind {
18 Memory(i64),
19 Entity(i64),
20}
21
22type Neighbour = (i64, String, String, String, f64);
25
26#[derive(clap::Args)]
27#[command(after_long_help = "EXAMPLES:\n \
28 # List memories connected to a memory via the entity graph (default 2 hops)\n \
29 sqlite-graphrag related onboarding\n\n \
30 # Increase hop distance and filter by relation type\n \
31 sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n \
32 # Cap result count and require minimum edge weight\n \
33 sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
34pub struct RelatedArgs {
35 #[arg(
37 value_name = "NAME",
38 conflicts_with = "name",
39 help = "Memory name whose neighbours to traverse; alternative to --name"
40 )]
41 pub name_positional: Option<String>,
42 #[arg(long, alias = "from")]
44 pub name: Option<String>,
45 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
47 pub max_hops: u32,
48 #[arg(long, value_enum)]
49 pub relation: Option<RelationKind>,
50 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
51 pub min_weight: f64,
52 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
53 pub limit: usize,
54 #[arg(long)]
55 pub namespace: Option<String>,
56 #[arg(long, value_enum, default_value = "json")]
57 pub format: OutputFormat,
58 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
59 pub json: bool,
60 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
61 pub db: Option<String>,
62}
63
64#[derive(Serialize)]
65struct RelatedResponse {
66 name: String,
69 max_hops: u32,
71 results: Vec<RelatedMemory>,
72 elapsed_ms: u64,
73}
74
75#[derive(Serialize, Clone)]
76struct RelatedMemory {
77 memory_id: i64,
78 name: String,
79 namespace: String,
80 #[serde(rename = "type")]
81 memory_type: String,
82 description: String,
83 hop_distance: u32,
84 source_entity: Option<String>,
85 target_entity: Option<String>,
86 relation: Option<String>,
87 weight: Option<f64>,
88}
89
90pub fn run(args: RelatedArgs) -> Result<(), AppError> {
91 let inicio = std::time::Instant::now();
92 let name = args
93 .name_positional
94 .as_deref()
95 .or(args.name.as_deref())
96 .ok_or_else(|| {
97 AppError::Validation(
98 "name required: pass as positional argument or via --name".to_string(),
99 )
100 })?
101 .to_string();
102
103 if name.trim().is_empty() {
104 return Err(AppError::Validation("name must not be empty".to_string()));
105 }
106
107 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
108 let paths = AppPaths::resolve(args.db.as_deref())?;
109
110 crate::storage::connection::ensure_db_ready(&paths)?;
111
112 let conn = open_ro(&paths.db)?;
113
114 let seed = match conn.query_row(
116 "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
117 params![namespace, name],
118 |r| r.get::<_, i64>(0),
119 ) {
120 Ok(id) => SeedKind::Memory(id),
121 Err(rusqlite::Error::QueryReturnedNoRows) => {
122 match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
123 Some(id) => SeedKind::Entity(id),
124 None => {
125 return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
126 &name, &namespace,
127 )))
128 }
129 }
130 }
131 Err(e) => return Err(AppError::Database(e)),
132 };
133
134 let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
136 SeedKind::Memory(id) => {
137 let mem_id = *id;
138 let mut stmt =
139 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
140 let rows: Vec<i64> = stmt
141 .query_map(params![mem_id], |r| r.get(0))?
142 .collect::<Result<Vec<i64>, _>>()?;
143 (mem_id, rows)
144 }
145 SeedKind::Entity(entity_id) => {
146 (-1, vec![*entity_id])
149 }
150 };
151
152 let relation_filter = args.relation.map(|r| r.as_str().to_string());
153 let results = traverse_related(
154 &conn,
155 seed_memory_id,
156 &seed_entity_ids,
157 &namespace,
158 args.max_hops,
159 args.min_weight,
160 relation_filter.as_deref(),
161 args.limit,
162 )?;
163
164 match args.format {
165 OutputFormat::Json => output::emit_json(&RelatedResponse {
166 name: name.clone(),
167 max_hops: args.max_hops,
168 results,
169 elapsed_ms: inicio.elapsed().as_millis() as u64,
170 })?,
171 OutputFormat::Text => {
172 for item in &results {
173 if item.description.is_empty() {
174 output::emit_text(&format!(
175 "{}. {} ({})",
176 item.hop_distance, item.name, item.namespace
177 ));
178 } else {
179 let preview: String = item
180 .description
181 .chars()
182 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
183 .collect();
184 output::emit_text(&format!(
185 "{}. {} ({}): {}",
186 item.hop_distance, item.name, item.namespace, preview
187 ));
188 }
189 }
190 }
191 OutputFormat::Markdown => {
192 for item in &results {
193 if item.description.is_empty() {
194 output::emit_text(&format!(
195 "- **{}** ({}) — hop {}",
196 item.name, item.namespace, item.hop_distance
197 ));
198 } else {
199 let preview: String = item
200 .description
201 .chars()
202 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
203 .collect();
204 output::emit_text(&format!(
205 "- **{}** ({}) — hop {}: {}",
206 item.name, item.namespace, item.hop_distance, preview
207 ));
208 }
209 }
210 }
211 }
212
213 Ok(())
214}
215
216#[allow(clippy::too_many_arguments)]
217fn traverse_related(
218 conn: &Connection,
219 seed_memory_id: i64,
220 seed_entity_ids: &[i64],
221 namespace: &str,
222 max_hops: u32,
223 min_weight: f64,
224 relation_filter: Option<&str>,
225 limit: usize,
226) -> Result<Vec<RelatedMemory>, AppError> {
227 if seed_entity_ids.is_empty() || max_hops == 0 {
228 return Ok(Vec::new());
229 }
230
231 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
234 let mut entity_hop: HashMap<i64, u32> = HashMap::new();
235 for &e in seed_entity_ids {
236 entity_hop.insert(e, 0);
237 }
238 let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
241
242 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
243
244 while let Some(current_entity) = queue.pop_front() {
245 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
246 if current_hop >= max_hops {
247 continue;
248 }
249
250 let neighbours =
251 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
252
253 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
254 if visited.insert(neighbour_id) {
255 entity_hop.insert(neighbour_id, current_hop + 1);
256 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
257 queue.push_back(neighbour_id);
258 }
259 }
260 }
261
262 let mut out: Vec<RelatedMemory> = Vec::new();
264 let mut dedup_ids: HashSet<i64> = HashSet::new();
265 dedup_ids.insert(seed_memory_id);
266
267 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
269 .iter()
270 .filter(|(id, _)| !seed_entity_ids.contains(id))
271 .map(|(id, hop)| (*id, *hop))
272 .collect();
273 ordered_entities.sort_by(|a, b| {
274 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
275 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
276 a.1.cmp(&b.1).then_with(|| {
277 weight_b
278 .partial_cmp(&weight_a)
279 .unwrap_or(std::cmp::Ordering::Equal)
280 })
281 });
282
283 for (entity_id, hop) in ordered_entities {
284 let mut stmt = conn.prepare_cached(
285 "SELECT m.id, m.name, m.namespace, m.type, m.description
286 FROM memory_entities me
287 JOIN memories m ON m.id = me.memory_id
288 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
289 )?;
290 let rows = stmt
291 .query_map(params![entity_id], |r| {
292 Ok((
293 r.get::<_, i64>(0)?,
294 r.get::<_, String>(1)?,
295 r.get::<_, String>(2)?,
296 r.get::<_, String>(3)?,
297 r.get::<_, String>(4)?,
298 ))
299 })?
300 .collect::<Result<Vec<_>, _>>()?;
301
302 for (mid, name, ns, mtype, desc) in rows {
303 if !dedup_ids.insert(mid) {
304 continue;
305 }
306 let edge = entity_edge.get(&entity_id);
307 out.push(RelatedMemory {
308 memory_id: mid,
309 name,
310 namespace: ns,
311 memory_type: mtype,
312 description: desc,
313 hop_distance: hop,
314 source_entity: edge.map(|e| e.0.clone()),
315 target_entity: edge.map(|e| e.1.clone()),
316 relation: edge.map(|e| e.2.clone()),
317 weight: edge.map(|e| e.3),
318 });
319 if out.len() >= limit {
320 return Ok(out);
321 }
322 }
323 }
324
325 Ok(out)
326}
327
328fn fetch_neighbours(
329 conn: &Connection,
330 entity_id: i64,
331 namespace: &str,
332 min_weight: f64,
333 relation_filter: Option<&str>,
334) -> Result<Vec<Neighbour>, AppError> {
335 let base_sql = "\
338 SELECT r.target_id, se.name, te.name, r.relation, r.weight
339 FROM relationships r
340 JOIN entities se ON se.id = r.source_id
341 JOIN entities te ON te.id = r.target_id
342 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
343
344 let reverse_sql = "\
345 SELECT r.source_id, se.name, te.name, r.relation, r.weight
346 FROM relationships r
347 JOIN entities se ON se.id = r.source_id
348 JOIN entities te ON te.id = r.target_id
349 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
350
351 let mut results: Vec<Neighbour> = Vec::new();
352
353 let forward_sql = match relation_filter {
354 Some(_) => format!("{base_sql} AND r.relation = ?4"),
355 None => base_sql.to_string(),
356 };
357 let rev_sql = match relation_filter {
358 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
359 None => reverse_sql.to_string(),
360 };
361
362 let mut stmt = conn.prepare_cached(&forward_sql)?;
363 let rows: Vec<_> = if let Some(rel) = relation_filter {
364 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
365 Ok((
366 r.get::<_, i64>(0)?,
367 r.get::<_, String>(1)?,
368 r.get::<_, String>(2)?,
369 r.get::<_, String>(3)?,
370 r.get::<_, f64>(4)?,
371 ))
372 })?
373 .collect::<Result<Vec<_>, _>>()?
374 } else {
375 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
376 Ok((
377 r.get::<_, i64>(0)?,
378 r.get::<_, String>(1)?,
379 r.get::<_, String>(2)?,
380 r.get::<_, String>(3)?,
381 r.get::<_, f64>(4)?,
382 ))
383 })?
384 .collect::<Result<Vec<_>, _>>()?
385 };
386 results.extend(rows);
387
388 let mut stmt = conn.prepare_cached(&rev_sql)?;
389 let rows: Vec<_> = if let Some(rel) = relation_filter {
390 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
391 Ok((
392 r.get::<_, i64>(0)?,
393 r.get::<_, String>(1)?,
394 r.get::<_, String>(2)?,
395 r.get::<_, String>(3)?,
396 r.get::<_, f64>(4)?,
397 ))
398 })?
399 .collect::<Result<Vec<_>, _>>()?
400 } else {
401 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
402 Ok((
403 r.get::<_, i64>(0)?,
404 r.get::<_, String>(1)?,
405 r.get::<_, String>(2)?,
406 r.get::<_, String>(3)?,
407 r.get::<_, f64>(4)?,
408 ))
409 })?
410 .collect::<Result<Vec<_>, _>>()?
411 };
412 results.extend(rows);
413
414 Ok(results)
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 fn setup_related_db() -> rusqlite::Connection {
422 let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
423 conn.execute_batch(
424 "CREATE TABLE memories (
425 id INTEGER PRIMARY KEY AUTOINCREMENT,
426 name TEXT NOT NULL,
427 namespace TEXT NOT NULL DEFAULT 'global',
428 type TEXT NOT NULL DEFAULT 'fact',
429 description TEXT NOT NULL DEFAULT '',
430 deleted_at INTEGER
431 );
432 CREATE TABLE entities (
433 id INTEGER PRIMARY KEY AUTOINCREMENT,
434 namespace TEXT NOT NULL,
435 name TEXT NOT NULL
436 );
437 CREATE TABLE relationships (
438 id INTEGER PRIMARY KEY AUTOINCREMENT,
439 namespace TEXT NOT NULL,
440 source_id INTEGER NOT NULL,
441 target_id INTEGER NOT NULL,
442 relation TEXT NOT NULL DEFAULT 'related_to',
443 weight REAL NOT NULL DEFAULT 1.0
444 );
445 CREATE TABLE memory_entities (
446 memory_id INTEGER NOT NULL,
447 entity_id INTEGER NOT NULL
448 );",
449 )
450 .expect("failed to create test tables");
451 conn
452 }
453
454 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
455 conn.execute(
456 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
457 rusqlite::params![name, namespace],
458 )
459 .expect("failed to insert memory");
460 conn.last_insert_rowid()
461 }
462
463 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
464 conn.execute(
465 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
466 rusqlite::params![name, namespace],
467 )
468 .expect("failed to insert entity");
469 conn.last_insert_rowid()
470 }
471
472 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
473 conn.execute(
474 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
475 rusqlite::params![memory_id, entity_id],
476 )
477 .expect("failed to link memory-entity");
478 }
479
480 fn insert_relationship(
481 conn: &rusqlite::Connection,
482 namespace: &str,
483 source_id: i64,
484 target_id: i64,
485 relation: &str,
486 weight: f64,
487 ) {
488 conn.execute(
489 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
490 VALUES (?1, ?2, ?3, ?4, ?5)",
491 rusqlite::params![namespace, source_id, target_id, relation, weight],
492 )
493 .expect("failed to insert relationship");
494 }
495
496 #[test]
497 fn related_response_serializes_results_and_elapsed_ms() {
498 let resp = RelatedResponse {
499 name: "seed-mem".to_string(),
500 max_hops: 2,
501 results: vec![RelatedMemory {
502 memory_id: 1,
503 name: "neighbor-mem".to_string(),
504 namespace: "global".to_string(),
505 memory_type: "document".to_string(),
506 description: "desc".to_string(),
507 hop_distance: 1,
508 source_entity: Some("entity-a".to_string()),
509 target_entity: Some("entity-b".to_string()),
510 relation: Some("related_to".to_string()),
511 weight: Some(0.9),
512 }],
513 elapsed_ms: 7,
514 };
515
516 let json = serde_json::to_value(&resp).expect("serialization failed");
517 assert!(json["results"].is_array());
518 assert_eq!(json["results"].as_array().unwrap().len(), 1);
519 assert_eq!(json["elapsed_ms"], 7u64);
520 assert_eq!(json["results"][0]["type"], "document");
521 assert_eq!(json["results"][0]["hop_distance"], 1);
522 }
523
524 #[test]
525 fn traverse_related_returns_empty_without_seed_entities() {
526 let conn = setup_related_db();
527 let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
528 .expect("traverse_related failed");
529 assert!(
530 result.is_empty(),
531 "no seed entities must yield empty results"
532 );
533 }
534
535 #[test]
536 fn traverse_related_returns_empty_with_max_hops_zero() {
537 let conn = setup_related_db();
538 let mem_id = insert_memory(&conn, "seed-mem", "global");
539 let ent_id = insert_entity(&conn, "ent-a", "global");
540 link_memory_entity(&conn, mem_id, ent_id);
541
542 let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
543 .expect("traverse_related failed");
544 assert!(result.is_empty(), "max_hops=0 must return empty");
545 }
546
547 #[test]
548 fn traverse_related_discovers_neighbor_memory_via_graph() {
549 let conn = setup_related_db();
550
551 let seed_id = insert_memory(&conn, "seed-mem", "global");
552 let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
553 let ent_a = insert_entity(&conn, "ent-a", "global");
554 let ent_b = insert_entity(&conn, "ent-b", "global");
555
556 link_memory_entity(&conn, seed_id, ent_a);
557 link_memory_entity(&conn, neighbor_id, ent_b);
558 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
559
560 let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
561 .expect("traverse_related failed");
562
563 assert_eq!(result.len(), 1, "must find 1 neighboring memory");
564 assert_eq!(result[0].name, "neighbor-mem");
565 assert_eq!(result[0].hop_distance, 1);
566 }
567
568 #[test]
569 fn traverse_related_respects_limit() {
570 let conn = setup_related_db();
571
572 let seed_id = insert_memory(&conn, "seed", "global");
573 let ent_seed = insert_entity(&conn, "ent-seed", "global");
574 link_memory_entity(&conn, seed_id, ent_seed);
575
576 for i in 0..5 {
577 let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
578 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
579 link_memory_entity(&conn, mem_id, ent_id);
580 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
581 }
582
583 let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
584 .expect("traverse_related failed");
585
586 assert!(
587 result.len() <= 3,
588 "limit=3 must constrain to at most 3 results"
589 );
590 }
591
592 #[test]
593 fn related_memory_optional_null_fields_serialized() {
594 let mem = RelatedMemory {
595 memory_id: 99,
596 name: "no-relation".to_string(),
597 namespace: "ns".to_string(),
598 memory_type: "concept".to_string(),
599 description: "".to_string(),
600 hop_distance: 2,
601 source_entity: None,
602 target_entity: None,
603 relation: None,
604 weight: None,
605 };
606
607 let json = serde_json::to_value(&mem).expect("serialization failed");
608 assert!(json["source_entity"].is_null());
609 assert!(json["target_entity"].is_null());
610 assert!(json["relation"].is_null());
611 assert!(json["weight"].is_null());
612 assert_eq!(json["hop_distance"], 2);
613 }
614}