1use crate::constants::{
4 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
5};
6use crate::errors::AppError;
7use crate::i18n::errors_msg;
8use crate::output::{self, OutputFormat};
9use crate::paths::AppPaths;
10use crate::storage::connection::open_ro;
11use rusqlite::{params, Connection};
12use serde::Serialize;
13use std::collections::{HashMap, HashSet, VecDeque};
14
15enum SeedKind {
17 Memory(i64),
18 Entity(i64),
19}
20
21type Neighbour = (i64, String, String, String, f64);
24
25#[derive(clap::Args)]
26#[command(after_long_help = "EXAMPLES:\n \
27 # List memories connected to a memory via the entity graph (default 2 hops)\n \
28 sqlite-graphrag related onboarding\n\n \
29 # Increase hop distance and filter by relation type\n \
30 sqlite-graphrag related onboarding --max-hops 3 --relation related\n\n \
31 # Cap result count and require minimum edge weight\n \
32 sqlite-graphrag related onboarding --limit 5 --min-weight 0.5")]
33pub struct RelatedArgs {
34 #[arg(
36 value_name = "NAME",
37 conflicts_with = "name",
38 help = "Memory name whose neighbours to traverse; alternative to --name"
39 )]
40 pub name_positional: Option<String>,
41 #[arg(long, alias = "from")]
43 pub name: Option<String>,
44 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
46 pub max_hops: u32,
47 #[arg(long, value_parser = crate::parsers::parse_relation)]
48 pub relation: Option<String>,
49 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
50 pub min_weight: f64,
51 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
52 pub limit: usize,
53 #[arg(long)]
54 pub namespace: Option<String>,
55 #[arg(long, value_enum, default_value = "json")]
56 pub format: OutputFormat,
57 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
58 pub json: bool,
59 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
60 pub db: Option<String>,
61}
62
63#[derive(Serialize)]
64struct RelatedResponse {
65 name: String,
68 max_hops: u32,
70 results: Vec<RelatedMemory>,
71 elapsed_ms: u64,
72}
73
74#[derive(Serialize, Clone)]
75struct RelatedMemory {
76 memory_id: i64,
77 name: String,
78 namespace: String,
79 #[serde(rename = "type")]
80 memory_type: String,
81 description: String,
82 hop_distance: u32,
83 source_entity: Option<String>,
84 target_entity: Option<String>,
85 relation: Option<String>,
86 weight: Option<f64>,
87}
88
89pub fn run(args: RelatedArgs) -> Result<(), AppError> {
90 let inicio = std::time::Instant::now();
91 let name = args
92 .name_positional
93 .as_deref()
94 .or(args.name.as_deref())
95 .ok_or_else(|| {
96 AppError::Validation(
97 "name required: pass as positional argument or via --name".to_string(),
98 )
99 })?
100 .to_string();
101
102 if name.trim().is_empty() {
103 return Err(AppError::Validation("name must not be empty".to_string()));
104 }
105
106 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
107 let paths = AppPaths::resolve(args.db.as_deref())?;
108
109 crate::storage::connection::ensure_db_ready(&paths)?;
110
111 let conn = open_ro(&paths.db)?;
112
113 let seed = match conn.query_row(
115 "SELECT id FROM memories WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
116 params![namespace, name],
117 |r| r.get::<_, i64>(0),
118 ) {
119 Ok(id) => SeedKind::Memory(id),
120 Err(rusqlite::Error::QueryReturnedNoRows) => {
121 match crate::storage::entities::find_entity_id(&conn, &namespace, &name)? {
122 Some(id) => SeedKind::Entity(id),
123 None => {
124 return Err(AppError::NotFound(errors_msg::memory_or_entity_not_found(
125 &name, &namespace,
126 )))
127 }
128 }
129 }
130 Err(e) => return Err(AppError::Database(e)),
131 };
132
133 let (seed_memory_id, seed_entity_ids): (i64, Vec<i64>) = match &seed {
135 SeedKind::Memory(id) => {
136 let mem_id = *id;
137 let mut stmt =
138 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
139 let rows: Vec<i64> = stmt
140 .query_map(params![mem_id], |r| r.get(0))?
141 .collect::<Result<Vec<i64>, _>>()?;
142 (mem_id, rows)
143 }
144 SeedKind::Entity(entity_id) => {
145 (-1, vec![*entity_id])
148 }
149 };
150
151 let relation_filter = args.relation;
152 let results = traverse_related(
153 &conn,
154 seed_memory_id,
155 &seed_entity_ids,
156 &namespace,
157 args.max_hops,
158 args.min_weight,
159 relation_filter.as_deref(),
160 args.limit,
161 )?;
162
163 match args.format {
164 OutputFormat::Json => output::emit_json(&RelatedResponse {
165 name: name.clone(),
166 max_hops: args.max_hops,
167 results,
168 elapsed_ms: inicio.elapsed().as_millis() as u64,
169 })?,
170 OutputFormat::Text => {
171 for item in &results {
172 if item.description.is_empty() {
173 output::emit_text(&format!(
174 "{}. {} ({})",
175 item.hop_distance, item.name, item.namespace
176 ));
177 } else {
178 let preview: String = item
179 .description
180 .chars()
181 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
182 .collect();
183 output::emit_text(&format!(
184 "{}. {} ({}): {}",
185 item.hop_distance, item.name, item.namespace, preview
186 ));
187 }
188 }
189 }
190 OutputFormat::Markdown => {
191 for item in &results {
192 if item.description.is_empty() {
193 output::emit_text(&format!(
194 "- **{}** ({}) — hop {}",
195 item.name, item.namespace, item.hop_distance
196 ));
197 } else {
198 let preview: String = item
199 .description
200 .chars()
201 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
202 .collect();
203 output::emit_text(&format!(
204 "- **{}** ({}) — hop {}: {}",
205 item.name, item.namespace, item.hop_distance, preview
206 ));
207 }
208 }
209 }
210 }
211
212 Ok(())
213}
214
215#[allow(clippy::too_many_arguments)]
216fn traverse_related(
217 conn: &Connection,
218 seed_memory_id: i64,
219 seed_entity_ids: &[i64],
220 namespace: &str,
221 max_hops: u32,
222 min_weight: f64,
223 relation_filter: Option<&str>,
224 limit: usize,
225) -> Result<Vec<RelatedMemory>, AppError> {
226 if seed_entity_ids.is_empty() || max_hops == 0 {
227 return Ok(Vec::new());
228 }
229
230 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
233 let mut entity_hop: HashMap<i64, u32> = HashMap::new();
234 for &e in seed_entity_ids {
235 entity_hop.insert(e, 0);
236 }
237 let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
240
241 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
242
243 while let Some(current_entity) = queue.pop_front() {
244 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
245 if current_hop >= max_hops {
246 continue;
247 }
248
249 let neighbours =
250 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
251
252 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
253 if visited.insert(neighbour_id) {
254 entity_hop.insert(neighbour_id, current_hop + 1);
255 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
256 queue.push_back(neighbour_id);
257 }
258 }
259 }
260
261 let mut out: Vec<RelatedMemory> = Vec::new();
263 let mut dedup_ids: HashSet<i64> = HashSet::new();
264 dedup_ids.insert(seed_memory_id);
265
266 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
268 .iter()
269 .filter(|(id, _)| !seed_entity_ids.contains(id))
270 .map(|(id, hop)| (*id, *hop))
271 .collect();
272 ordered_entities.sort_by(|a, b| {
273 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
274 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
275 a.1.cmp(&b.1).then_with(|| {
276 weight_b
277 .partial_cmp(&weight_a)
278 .unwrap_or(std::cmp::Ordering::Equal)
279 })
280 });
281
282 for (entity_id, hop) in ordered_entities {
283 let mut stmt = conn.prepare_cached(
284 "SELECT m.id, m.name, m.namespace, m.type, m.description
285 FROM memory_entities me
286 JOIN memories m ON m.id = me.memory_id
287 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
288 )?;
289 let rows = stmt
290 .query_map(params![entity_id], |r| {
291 Ok((
292 r.get::<_, i64>(0)?,
293 r.get::<_, String>(1)?,
294 r.get::<_, String>(2)?,
295 r.get::<_, String>(3)?,
296 r.get::<_, String>(4)?,
297 ))
298 })?
299 .collect::<Result<Vec<_>, _>>()?;
300
301 for (mid, name, ns, mtype, desc) in rows {
302 if !dedup_ids.insert(mid) {
303 continue;
304 }
305 let edge = entity_edge.get(&entity_id);
306 out.push(RelatedMemory {
307 memory_id: mid,
308 name,
309 namespace: ns,
310 memory_type: mtype,
311 description: desc,
312 hop_distance: hop,
313 source_entity: edge.map(|e| e.0.clone()),
314 target_entity: edge.map(|e| e.1.clone()),
315 relation: edge.map(|e| e.2.clone()),
316 weight: edge.map(|e| e.3),
317 });
318 if out.len() >= limit {
319 return Ok(out);
320 }
321 }
322 }
323
324 Ok(out)
325}
326
327fn fetch_neighbours(
328 conn: &Connection,
329 entity_id: i64,
330 namespace: &str,
331 min_weight: f64,
332 relation_filter: Option<&str>,
333) -> Result<Vec<Neighbour>, AppError> {
334 let base_sql = "\
337 SELECT r.target_id, se.name, te.name, r.relation, r.weight
338 FROM relationships r
339 JOIN entities se ON se.id = r.source_id
340 JOIN entities te ON te.id = r.target_id
341 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
342
343 let reverse_sql = "\
344 SELECT r.source_id, se.name, te.name, r.relation, r.weight
345 FROM relationships r
346 JOIN entities se ON se.id = r.source_id
347 JOIN entities te ON te.id = r.target_id
348 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
349
350 let mut results: Vec<Neighbour> = Vec::new();
351
352 let forward_sql = match relation_filter {
353 Some(_) => format!("{base_sql} AND r.relation = ?4"),
354 None => base_sql.to_string(),
355 };
356 let rev_sql = match relation_filter {
357 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
358 None => reverse_sql.to_string(),
359 };
360
361 let mut stmt = conn.prepare_cached(&forward_sql)?;
362 let rows: Vec<_> = if let Some(rel) = relation_filter {
363 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
364 Ok((
365 r.get::<_, i64>(0)?,
366 r.get::<_, String>(1)?,
367 r.get::<_, String>(2)?,
368 r.get::<_, String>(3)?,
369 r.get::<_, f64>(4)?,
370 ))
371 })?
372 .collect::<Result<Vec<_>, _>>()?
373 } else {
374 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
375 Ok((
376 r.get::<_, i64>(0)?,
377 r.get::<_, String>(1)?,
378 r.get::<_, String>(2)?,
379 r.get::<_, String>(3)?,
380 r.get::<_, f64>(4)?,
381 ))
382 })?
383 .collect::<Result<Vec<_>, _>>()?
384 };
385 results.extend(rows);
386
387 let mut stmt = conn.prepare_cached(&rev_sql)?;
388 let rows: Vec<_> = if let Some(rel) = relation_filter {
389 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
390 Ok((
391 r.get::<_, i64>(0)?,
392 r.get::<_, String>(1)?,
393 r.get::<_, String>(2)?,
394 r.get::<_, String>(3)?,
395 r.get::<_, f64>(4)?,
396 ))
397 })?
398 .collect::<Result<Vec<_>, _>>()?
399 } else {
400 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
401 Ok((
402 r.get::<_, i64>(0)?,
403 r.get::<_, String>(1)?,
404 r.get::<_, String>(2)?,
405 r.get::<_, String>(3)?,
406 r.get::<_, f64>(4)?,
407 ))
408 })?
409 .collect::<Result<Vec<_>, _>>()?
410 };
411 results.extend(rows);
412
413 Ok(results)
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 fn setup_related_db() -> rusqlite::Connection {
421 let conn = rusqlite::Connection::open_in_memory().expect("failed to open in-memory db");
422 conn.execute_batch(
423 "CREATE TABLE memories (
424 id INTEGER PRIMARY KEY AUTOINCREMENT,
425 name TEXT NOT NULL,
426 namespace TEXT NOT NULL DEFAULT 'global',
427 type TEXT NOT NULL DEFAULT 'fact',
428 description TEXT NOT NULL DEFAULT '',
429 deleted_at INTEGER
430 );
431 CREATE TABLE entities (
432 id INTEGER PRIMARY KEY AUTOINCREMENT,
433 namespace TEXT NOT NULL,
434 name TEXT NOT NULL
435 );
436 CREATE TABLE relationships (
437 id INTEGER PRIMARY KEY AUTOINCREMENT,
438 namespace TEXT NOT NULL,
439 source_id INTEGER NOT NULL,
440 target_id INTEGER NOT NULL,
441 relation TEXT NOT NULL DEFAULT 'related_to',
442 weight REAL NOT NULL DEFAULT 1.0
443 );
444 CREATE TABLE memory_entities (
445 memory_id INTEGER NOT NULL,
446 entity_id INTEGER NOT NULL
447 );",
448 )
449 .expect("failed to create test tables");
450 conn
451 }
452
453 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
454 conn.execute(
455 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
456 rusqlite::params![name, namespace],
457 )
458 .expect("failed to insert memory");
459 conn.last_insert_rowid()
460 }
461
462 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
463 conn.execute(
464 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
465 rusqlite::params![name, namespace],
466 )
467 .expect("failed to insert entity");
468 conn.last_insert_rowid()
469 }
470
471 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
472 conn.execute(
473 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
474 rusqlite::params![memory_id, entity_id],
475 )
476 .expect("failed to link memory-entity");
477 }
478
479 fn insert_relationship(
480 conn: &rusqlite::Connection,
481 namespace: &str,
482 source_id: i64,
483 target_id: i64,
484 relation: &str,
485 weight: f64,
486 ) {
487 conn.execute(
488 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
489 VALUES (?1, ?2, ?3, ?4, ?5)",
490 rusqlite::params![namespace, source_id, target_id, relation, weight],
491 )
492 .expect("failed to insert relationship");
493 }
494
495 #[test]
496 fn related_response_serializes_results_and_elapsed_ms() {
497 let resp = RelatedResponse {
498 name: "seed-mem".to_string(),
499 max_hops: 2,
500 results: vec![RelatedMemory {
501 memory_id: 1,
502 name: "neighbor-mem".to_string(),
503 namespace: "global".to_string(),
504 memory_type: "document".to_string(),
505 description: "desc".to_string(),
506 hop_distance: 1,
507 source_entity: Some("entity-a".to_string()),
508 target_entity: Some("entity-b".to_string()),
509 relation: Some("related_to".to_string()),
510 weight: Some(0.9),
511 }],
512 elapsed_ms: 7,
513 };
514
515 let json = serde_json::to_value(&resp).expect("serialization failed");
516 assert!(json["results"].is_array());
517 assert_eq!(json["results"].as_array().unwrap().len(), 1);
518 assert_eq!(json["elapsed_ms"], 7u64);
519 assert_eq!(json["results"][0]["type"], "document");
520 assert_eq!(json["results"][0]["hop_distance"], 1);
521 }
522
523 #[test]
524 fn traverse_related_returns_empty_without_seed_entities() {
525 let conn = setup_related_db();
526 let result = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
527 .expect("traverse_related failed");
528 assert!(
529 result.is_empty(),
530 "no seed entities must yield empty results"
531 );
532 }
533
534 #[test]
535 fn traverse_related_returns_empty_with_max_hops_zero() {
536 let conn = setup_related_db();
537 let mem_id = insert_memory(&conn, "seed-mem", "global");
538 let ent_id = insert_entity(&conn, "ent-a", "global");
539 link_memory_entity(&conn, mem_id, ent_id);
540
541 let result = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
542 .expect("traverse_related failed");
543 assert!(result.is_empty(), "max_hops=0 must return empty");
544 }
545
546 #[test]
547 fn traverse_related_discovers_neighbor_memory_via_graph() {
548 let conn = setup_related_db();
549
550 let seed_id = insert_memory(&conn, "seed-mem", "global");
551 let neighbor_id = insert_memory(&conn, "neighbor-mem", "global");
552 let ent_a = insert_entity(&conn, "ent-a", "global");
553 let ent_b = insert_entity(&conn, "ent-b", "global");
554
555 link_memory_entity(&conn, seed_id, ent_a);
556 link_memory_entity(&conn, neighbor_id, ent_b);
557 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
558
559 let result = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
560 .expect("traverse_related failed");
561
562 assert_eq!(result.len(), 1, "must find 1 neighboring memory");
563 assert_eq!(result[0].name, "neighbor-mem");
564 assert_eq!(result[0].hop_distance, 1);
565 }
566
567 #[test]
568 fn traverse_related_respects_limit() {
569 let conn = setup_related_db();
570
571 let seed_id = insert_memory(&conn, "seed", "global");
572 let ent_seed = insert_entity(&conn, "ent-seed", "global");
573 link_memory_entity(&conn, seed_id, ent_seed);
574
575 for i in 0..5 {
576 let mem_id = insert_memory(&conn, &format!("neighbor-{i}"), "global");
577 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
578 link_memory_entity(&conn, mem_id, ent_id);
579 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
580 }
581
582 let result = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
583 .expect("traverse_related failed");
584
585 assert!(
586 result.len() <= 3,
587 "limit=3 must constrain to at most 3 results"
588 );
589 }
590
591 #[test]
592 fn related_memory_optional_null_fields_serialized() {
593 let mem = RelatedMemory {
594 memory_id: 99,
595 name: "no-relation".to_string(),
596 namespace: "ns".to_string(),
597 memory_type: "concept".to_string(),
598 description: "".to_string(),
599 hop_distance: 2,
600 source_entity: None,
601 target_entity: None,
602 relation: None,
603 weight: None,
604 };
605
606 let json = serde_json::to_value(&mem).expect("serialization failed");
607 assert!(json["source_entity"].is_null());
608 assert!(json["target_entity"].is_null());
609 assert!(json["relation"].is_null());
610 assert!(json["weight"].is_null());
611 assert_eq!(json["hop_distance"], 2);
612 }
613}