1use crate::cli::RelationKind;
4use crate::constants::{
5 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
6};
7use crate::errors::AppError;
8use crate::i18n::errors_msg;
9use crate::output::{self, OutputFormat};
10use crate::paths::AppPaths;
11use crate::storage::connection::open_ro;
12use rusqlite::{params, Connection};
13use serde::Serialize;
14use std::collections::{HashMap, HashSet, VecDeque};
15
16type Neighbour = (i64, String, String, String, f64);
19
20#[derive(clap::Args)]
21pub struct RelatedArgs {
22 #[arg(value_name = "NAME", conflicts_with = "name")]
24 pub name_positional: Option<String>,
25 #[arg(long)]
27 pub name: Option<String>,
28 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
30 pub max_hops: u32,
31 #[arg(long, value_enum)]
32 pub relation: Option<RelationKind>,
33 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
34 pub min_weight: f64,
35 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
36 pub limit: usize,
37 #[arg(long)]
38 pub namespace: Option<String>,
39 #[arg(long, value_enum, default_value = "json")]
40 pub format: OutputFormat,
41 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
42 pub json: bool,
43 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
44 pub db: Option<String>,
45}
46
47#[derive(Serialize)]
48struct RelatedResponse {
49 results: Vec<RelatedMemory>,
50 elapsed_ms: u64,
51}
52
53#[derive(Serialize, Clone)]
54struct RelatedMemory {
55 memory_id: i64,
56 name: String,
57 namespace: String,
58 #[serde(rename = "type")]
59 memory_type: String,
60 description: String,
61 hop_distance: u32,
62 source_entity: Option<String>,
63 target_entity: Option<String>,
64 relation: Option<String>,
65 weight: Option<f64>,
66}
67
68pub fn run(args: RelatedArgs) -> Result<(), AppError> {
69 let inicio = std::time::Instant::now();
70 let name = args
71 .name_positional
72 .as_deref()
73 .or(args.name.as_deref())
74 .ok_or_else(|| {
75 AppError::Validation(
76 "name required: pass as positional argument or via --name".to_string(),
77 )
78 })?
79 .to_string();
80
81 if name.trim().is_empty() {
82 return Err(AppError::Validation("name must not be empty".to_string()));
83 }
84
85 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
86 let paths = AppPaths::resolve(args.db.as_deref())?;
87
88 if !paths.db.exists() {
89 return Err(AppError::NotFound(errors_msg::database_not_found(
90 &paths.db.display().to_string(),
91 )));
92 }
93
94 let conn = open_ro(&paths.db)?;
95
96 let seed_id: i64 = match conn.query_row(
98 "SELECT id FROM memories
99 WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
100 params![namespace, name],
101 |r| r.get(0),
102 ) {
103 Ok(id) => id,
104 Err(rusqlite::Error::QueryReturnedNoRows) => {
105 return Err(AppError::NotFound(errors_msg::memory_not_found(
106 &name, &namespace,
107 )));
108 }
109 Err(e) => return Err(AppError::Database(e)),
110 };
111
112 let seed_entity_ids: Vec<i64> = {
114 let mut stmt =
115 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
116 let rows: Vec<i64> = stmt
117 .query_map(params![seed_id], |r| r.get(0))?
118 .collect::<Result<Vec<i64>, _>>()?;
119 rows
120 };
121
122 let relation_filter = args.relation.map(|r| r.as_str().to_string());
123 let results = traverse_related(
124 &conn,
125 seed_id,
126 &seed_entity_ids,
127 &namespace,
128 args.max_hops,
129 args.min_weight,
130 relation_filter.as_deref(),
131 args.limit,
132 )?;
133
134 match args.format {
135 OutputFormat::Json => output::emit_json(&RelatedResponse {
136 results,
137 elapsed_ms: inicio.elapsed().as_millis() as u64,
138 })?,
139 OutputFormat::Text => {
140 for item in &results {
141 if item.description.is_empty() {
142 output::emit_text(&format!(
143 "{}. {} ({})",
144 item.hop_distance, item.name, item.namespace
145 ));
146 } else {
147 let preview: String = item
148 .description
149 .chars()
150 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
151 .collect();
152 output::emit_text(&format!(
153 "{}. {} ({}): {}",
154 item.hop_distance, item.name, item.namespace, preview
155 ));
156 }
157 }
158 }
159 OutputFormat::Markdown => {
160 for item in &results {
161 if item.description.is_empty() {
162 output::emit_text(&format!(
163 "- **{}** ({}) — hop {}",
164 item.name, item.namespace, item.hop_distance
165 ));
166 } else {
167 let preview: String = item
168 .description
169 .chars()
170 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
171 .collect();
172 output::emit_text(&format!(
173 "- **{}** ({}) — hop {}: {}",
174 item.name, item.namespace, item.hop_distance, preview
175 ));
176 }
177 }
178 }
179 }
180
181 Ok(())
182}
183
184#[allow(clippy::too_many_arguments)]
185fn traverse_related(
186 conn: &Connection,
187 seed_memory_id: i64,
188 seed_entity_ids: &[i64],
189 namespace: &str,
190 max_hops: u32,
191 min_weight: f64,
192 relation_filter: Option<&str>,
193 limit: usize,
194) -> Result<Vec<RelatedMemory>, AppError> {
195 if seed_entity_ids.is_empty() || max_hops == 0 {
196 return Ok(Vec::new());
197 }
198
199 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
202 let mut entity_hop: HashMap<i64, u32> = HashMap::new();
203 for &e in seed_entity_ids {
204 entity_hop.insert(e, 0);
205 }
206 let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
209
210 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
211
212 while let Some(current_entity) = queue.pop_front() {
213 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
214 if current_hop >= max_hops {
215 continue;
216 }
217
218 let neighbours =
219 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
220
221 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
222 if visited.insert(neighbour_id) {
223 entity_hop.insert(neighbour_id, current_hop + 1);
224 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
225 queue.push_back(neighbour_id);
226 }
227 }
228 }
229
230 let mut out: Vec<RelatedMemory> = Vec::new();
232 let mut dedup_ids: HashSet<i64> = HashSet::new();
233 dedup_ids.insert(seed_memory_id);
234
235 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
237 .iter()
238 .filter(|(id, _)| !seed_entity_ids.contains(id))
239 .map(|(id, hop)| (*id, *hop))
240 .collect();
241 ordered_entities.sort_by(|a, b| {
242 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
243 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
244 a.1.cmp(&b.1).then_with(|| {
245 weight_b
246 .partial_cmp(&weight_a)
247 .unwrap_or(std::cmp::Ordering::Equal)
248 })
249 });
250
251 for (entity_id, hop) in ordered_entities {
252 let mut stmt = conn.prepare_cached(
253 "SELECT m.id, m.name, m.namespace, m.type, m.description
254 FROM memory_entities me
255 JOIN memories m ON m.id = me.memory_id
256 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
257 )?;
258 let rows = stmt
259 .query_map(params![entity_id], |r| {
260 Ok((
261 r.get::<_, i64>(0)?,
262 r.get::<_, String>(1)?,
263 r.get::<_, String>(2)?,
264 r.get::<_, String>(3)?,
265 r.get::<_, String>(4)?,
266 ))
267 })?
268 .collect::<Result<Vec<_>, _>>()?;
269
270 for (mid, name, ns, mtype, desc) in rows {
271 if !dedup_ids.insert(mid) {
272 continue;
273 }
274 let edge = entity_edge.get(&entity_id);
275 out.push(RelatedMemory {
276 memory_id: mid,
277 name,
278 namespace: ns,
279 memory_type: mtype,
280 description: desc,
281 hop_distance: hop,
282 source_entity: edge.map(|e| e.0.clone()),
283 target_entity: edge.map(|e| e.1.clone()),
284 relation: edge.map(|e| e.2.clone()),
285 weight: edge.map(|e| e.3),
286 });
287 if out.len() >= limit {
288 return Ok(out);
289 }
290 }
291 }
292
293 Ok(out)
294}
295
296fn fetch_neighbours(
297 conn: &Connection,
298 entity_id: i64,
299 namespace: &str,
300 min_weight: f64,
301 relation_filter: Option<&str>,
302) -> Result<Vec<Neighbour>, AppError> {
303 let base_sql = "\
306 SELECT r.target_id, se.name, te.name, r.relation, r.weight
307 FROM relationships r
308 JOIN entities se ON se.id = r.source_id
309 JOIN entities te ON te.id = r.target_id
310 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
311
312 let reverse_sql = "\
313 SELECT r.source_id, se.name, te.name, r.relation, r.weight
314 FROM relationships r
315 JOIN entities se ON se.id = r.source_id
316 JOIN entities te ON te.id = r.target_id
317 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
318
319 let mut results: Vec<Neighbour> = Vec::new();
320
321 let forward_sql = match relation_filter {
322 Some(_) => format!("{base_sql} AND r.relation = ?4"),
323 None => base_sql.to_string(),
324 };
325 let rev_sql = match relation_filter {
326 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
327 None => reverse_sql.to_string(),
328 };
329
330 let mut stmt = conn.prepare_cached(&forward_sql)?;
331 let rows: Vec<_> = if let Some(rel) = relation_filter {
332 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
333 Ok((
334 r.get::<_, i64>(0)?,
335 r.get::<_, String>(1)?,
336 r.get::<_, String>(2)?,
337 r.get::<_, String>(3)?,
338 r.get::<_, f64>(4)?,
339 ))
340 })?
341 .collect::<Result<Vec<_>, _>>()?
342 } else {
343 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
344 Ok((
345 r.get::<_, i64>(0)?,
346 r.get::<_, String>(1)?,
347 r.get::<_, String>(2)?,
348 r.get::<_, String>(3)?,
349 r.get::<_, f64>(4)?,
350 ))
351 })?
352 .collect::<Result<Vec<_>, _>>()?
353 };
354 results.extend(rows);
355
356 let mut stmt = conn.prepare_cached(&rev_sql)?;
357 let rows: Vec<_> = if let Some(rel) = relation_filter {
358 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
359 Ok((
360 r.get::<_, i64>(0)?,
361 r.get::<_, String>(1)?,
362 r.get::<_, String>(2)?,
363 r.get::<_, String>(3)?,
364 r.get::<_, f64>(4)?,
365 ))
366 })?
367 .collect::<Result<Vec<_>, _>>()?
368 } else {
369 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
370 Ok((
371 r.get::<_, i64>(0)?,
372 r.get::<_, String>(1)?,
373 r.get::<_, String>(2)?,
374 r.get::<_, String>(3)?,
375 r.get::<_, f64>(4)?,
376 ))
377 })?
378 .collect::<Result<Vec<_>, _>>()?
379 };
380 results.extend(rows);
381
382 Ok(results)
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn setup_related_db() -> rusqlite::Connection {
390 let conn = rusqlite::Connection::open_in_memory().expect("falha ao abrir banco em memória");
391 conn.execute_batch(
392 "CREATE TABLE memories (
393 id INTEGER PRIMARY KEY AUTOINCREMENT,
394 name TEXT NOT NULL,
395 namespace TEXT NOT NULL DEFAULT 'global',
396 type TEXT NOT NULL DEFAULT 'fact',
397 description TEXT NOT NULL DEFAULT '',
398 deleted_at INTEGER
399 );
400 CREATE TABLE entities (
401 id INTEGER PRIMARY KEY AUTOINCREMENT,
402 namespace TEXT NOT NULL,
403 name TEXT NOT NULL
404 );
405 CREATE TABLE relationships (
406 id INTEGER PRIMARY KEY AUTOINCREMENT,
407 namespace TEXT NOT NULL,
408 source_id INTEGER NOT NULL,
409 target_id INTEGER NOT NULL,
410 relation TEXT NOT NULL DEFAULT 'related_to',
411 weight REAL NOT NULL DEFAULT 1.0
412 );
413 CREATE TABLE memory_entities (
414 memory_id INTEGER NOT NULL,
415 entity_id INTEGER NOT NULL
416 );",
417 )
418 .expect("falha ao criar tabelas de teste");
419 conn
420 }
421
422 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
423 conn.execute(
424 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
425 rusqlite::params![name, namespace],
426 )
427 .expect("falha ao inserir memória");
428 conn.last_insert_rowid()
429 }
430
431 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
432 conn.execute(
433 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
434 rusqlite::params![name, namespace],
435 )
436 .expect("falha ao inserir entidade");
437 conn.last_insert_rowid()
438 }
439
440 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
441 conn.execute(
442 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
443 rusqlite::params![memory_id, entity_id],
444 )
445 .expect("falha ao vincular memória-entidade");
446 }
447
448 fn insert_relationship(
449 conn: &rusqlite::Connection,
450 namespace: &str,
451 source_id: i64,
452 target_id: i64,
453 relation: &str,
454 weight: f64,
455 ) {
456 conn.execute(
457 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
458 VALUES (?1, ?2, ?3, ?4, ?5)",
459 rusqlite::params![namespace, source_id, target_id, relation, weight],
460 )
461 .expect("falha ao inserir relacionamento");
462 }
463
464 #[test]
465 fn related_response_serializa_results_e_elapsed_ms() {
466 let resp = RelatedResponse {
467 results: vec![RelatedMemory {
468 memory_id: 1,
469 name: "mem-vizinha".to_string(),
470 namespace: "global".to_string(),
471 memory_type: "fact".to_string(),
472 description: "desc".to_string(),
473 hop_distance: 1,
474 source_entity: Some("entidade-a".to_string()),
475 target_entity: Some("entidade-b".to_string()),
476 relation: Some("related_to".to_string()),
477 weight: Some(0.9),
478 }],
479 elapsed_ms: 7,
480 };
481
482 let json = serde_json::to_value(&resp).expect("serialização falhou");
483 assert!(json["results"].is_array());
484 assert_eq!(json["results"].as_array().unwrap().len(), 1);
485 assert_eq!(json["elapsed_ms"], 7u64);
486 assert_eq!(json["results"][0]["type"], "fact");
487 assert_eq!(json["results"][0]["hop_distance"], 1);
488 }
489
490 #[test]
491 fn traverse_related_retorna_vazio_sem_entidades_seed() {
492 let conn = setup_related_db();
493 let resultado = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
494 .expect("traverse_related falhou");
495 assert!(
496 resultado.is_empty(),
497 "sem entidades seed deve retornar vazio"
498 );
499 }
500
501 #[test]
502 fn traverse_related_retorna_vazio_com_max_hops_zero() {
503 let conn = setup_related_db();
504 let mem_id = insert_memory(&conn, "seed-mem", "global");
505 let ent_id = insert_entity(&conn, "ent-a", "global");
506 link_memory_entity(&conn, mem_id, ent_id);
507
508 let resultado = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
509 .expect("traverse_related falhou");
510 assert!(resultado.is_empty(), "max_hops=0 deve retornar vazio");
511 }
512
513 #[test]
514 fn traverse_related_descobre_memoria_vizinha_por_grafo() {
515 let conn = setup_related_db();
516
517 let seed_id = insert_memory(&conn, "seed-mem", "global");
518 let vizinha_id = insert_memory(&conn, "vizinha-mem", "global");
519 let ent_a = insert_entity(&conn, "ent-a", "global");
520 let ent_b = insert_entity(&conn, "ent-b", "global");
521
522 link_memory_entity(&conn, seed_id, ent_a);
523 link_memory_entity(&conn, vizinha_id, ent_b);
524 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
525
526 let resultado = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
527 .expect("traverse_related falhou");
528
529 assert_eq!(resultado.len(), 1, "deve encontrar 1 memória vizinha");
530 assert_eq!(resultado[0].name, "vizinha-mem");
531 assert_eq!(resultado[0].hop_distance, 1);
532 }
533
534 #[test]
535 fn traverse_related_respeita_limite() {
536 let conn = setup_related_db();
537
538 let seed_id = insert_memory(&conn, "seed", "global");
539 let ent_seed = insert_entity(&conn, "ent-seed", "global");
540 link_memory_entity(&conn, seed_id, ent_seed);
541
542 for i in 0..5 {
543 let mem_id = insert_memory(&conn, &format!("vizinha-{i}"), "global");
544 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
545 link_memory_entity(&conn, mem_id, ent_id);
546 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
547 }
548
549 let resultado = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
550 .expect("traverse_related falhou");
551
552 assert!(
553 resultado.len() <= 3,
554 "limite=3 deve restringir a no máximo 3 resultados"
555 );
556 }
557
558 #[test]
559 fn related_memory_campos_opcionais_nulos_serializados() {
560 let mem = RelatedMemory {
561 memory_id: 99,
562 name: "sem-relacao".to_string(),
563 namespace: "ns".to_string(),
564 memory_type: "concept".to_string(),
565 description: "".to_string(),
566 hop_distance: 2,
567 source_entity: None,
568 target_entity: None,
569 relation: None,
570 weight: None,
571 };
572
573 let json = serde_json::to_value(&mem).expect("serialização falhou");
574 assert!(json["source_entity"].is_null());
575 assert!(json["target_entity"].is_null());
576 assert!(json["relation"].is_null());
577 assert!(json["weight"].is_null());
578 assert_eq!(json["hop_distance"], 2);
579 }
580}