1use crate::cli::RelationKind;
2use crate::constants::{
3 DEFAULT_K_RECALL, DEFAULT_MAX_HOPS, DEFAULT_MIN_WEIGHT, TEXT_DESCRIPTION_PREVIEW_LEN,
4};
5use crate::errors::AppError;
6use crate::i18n::erros;
7use crate::output::{self, OutputFormat};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use rusqlite::{params, Connection};
11use serde::Serialize;
12use std::collections::{HashMap, HashSet, VecDeque};
13
14type Neighbour = (i64, String, String, String, f64);
17
18#[derive(clap::Args)]
19pub struct RelatedArgs {
20 #[arg(value_name = "NAME", conflicts_with = "name")]
22 pub name_positional: Option<String>,
23 #[arg(long)]
25 pub name: Option<String>,
26 #[arg(long, alias = "hops", default_value_t = DEFAULT_MAX_HOPS)]
28 pub max_hops: u32,
29 #[arg(long, value_enum)]
30 pub relation: Option<RelationKind>,
31 #[arg(long, default_value_t = DEFAULT_MIN_WEIGHT)]
32 pub min_weight: f64,
33 #[arg(long, default_value_t = DEFAULT_K_RECALL)]
34 pub limit: usize,
35 #[arg(long)]
36 pub namespace: Option<String>,
37 #[arg(long, value_enum, default_value = "json")]
38 pub format: OutputFormat,
39 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
40 pub json: bool,
41 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
42 pub db: Option<String>,
43}
44
45#[derive(Serialize)]
46struct RelatedResponse {
47 results: Vec<RelatedMemory>,
48 elapsed_ms: u64,
49}
50
51#[derive(Serialize, Clone)]
52struct RelatedMemory {
53 memory_id: i64,
54 name: String,
55 namespace: String,
56 #[serde(rename = "type")]
57 memory_type: String,
58 description: String,
59 hop_distance: u32,
60 source_entity: Option<String>,
61 target_entity: Option<String>,
62 relation: Option<String>,
63 weight: Option<f64>,
64}
65
66pub fn run(args: RelatedArgs) -> Result<(), AppError> {
67 let inicio = std::time::Instant::now();
68 let name = args
69 .name_positional
70 .as_deref()
71 .or(args.name.as_deref())
72 .ok_or_else(|| {
73 AppError::Validation(
74 "name required: pass as positional argument or via --name".to_string(),
75 )
76 })?
77 .to_string();
78
79 if name.trim().is_empty() {
80 return Err(AppError::Validation("name must not be empty".to_string()));
81 }
82
83 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
84 let paths = AppPaths::resolve(args.db.as_deref())?;
85
86 if !paths.db.exists() {
87 return Err(AppError::NotFound(erros::banco_nao_encontrado(
88 &paths.db.display().to_string(),
89 )));
90 }
91
92 let conn = open_ro(&paths.db)?;
93
94 let seed_id: i64 = match conn.query_row(
96 "SELECT id FROM memories
97 WHERE namespace = ?1 AND name = ?2 AND deleted_at IS NULL",
98 params![namespace, name],
99 |r| r.get(0),
100 ) {
101 Ok(id) => id,
102 Err(rusqlite::Error::QueryReturnedNoRows) => {
103 return Err(AppError::NotFound(erros::memoria_nao_encontrada(
104 &name, &namespace,
105 )));
106 }
107 Err(e) => return Err(AppError::Database(e)),
108 };
109
110 let seed_entity_ids: Vec<i64> = {
112 let mut stmt =
113 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
114 let rows: Vec<i64> = stmt
115 .query_map(params![seed_id], |r| r.get(0))?
116 .collect::<Result<Vec<i64>, _>>()?;
117 rows
118 };
119
120 let relation_filter = args.relation.map(|r| r.as_str().to_string());
121 let results = traverse_related(
122 &conn,
123 seed_id,
124 &seed_entity_ids,
125 &namespace,
126 args.max_hops,
127 args.min_weight,
128 relation_filter.as_deref(),
129 args.limit,
130 )?;
131
132 match args.format {
133 OutputFormat::Json => output::emit_json(&RelatedResponse {
134 results,
135 elapsed_ms: inicio.elapsed().as_millis() as u64,
136 })?,
137 OutputFormat::Text => {
138 for item in &results {
139 if item.description.is_empty() {
140 output::emit_text(&format!(
141 "{}. {} ({})",
142 item.hop_distance, item.name, item.namespace
143 ));
144 } else {
145 let preview: String = item
146 .description
147 .chars()
148 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
149 .collect();
150 output::emit_text(&format!(
151 "{}. {} ({}): {}",
152 item.hop_distance, item.name, item.namespace, preview
153 ));
154 }
155 }
156 }
157 OutputFormat::Markdown => {
158 for item in &results {
159 if item.description.is_empty() {
160 output::emit_text(&format!(
161 "- **{}** ({}) — hop {}",
162 item.name, item.namespace, item.hop_distance
163 ));
164 } else {
165 let preview: String = item
166 .description
167 .chars()
168 .take(TEXT_DESCRIPTION_PREVIEW_LEN)
169 .collect();
170 output::emit_text(&format!(
171 "- **{}** ({}) — hop {}: {}",
172 item.name, item.namespace, item.hop_distance, preview
173 ));
174 }
175 }
176 }
177 }
178
179 Ok(())
180}
181
182#[allow(clippy::too_many_arguments)]
183fn traverse_related(
184 conn: &Connection,
185 seed_memory_id: i64,
186 seed_entity_ids: &[i64],
187 namespace: &str,
188 max_hops: u32,
189 min_weight: f64,
190 relation_filter: Option<&str>,
191 limit: usize,
192) -> Result<Vec<RelatedMemory>, AppError> {
193 if seed_entity_ids.is_empty() || max_hops == 0 {
194 return Ok(Vec::new());
195 }
196
197 let mut visited: HashSet<i64> = seed_entity_ids.iter().copied().collect();
200 let mut entity_hop: HashMap<i64, u32> = HashMap::new();
201 for &e in seed_entity_ids {
202 entity_hop.insert(e, 0);
203 }
204 let mut entity_edge: HashMap<i64, (String, String, String, f64)> = HashMap::new();
207
208 let mut queue: VecDeque<i64> = seed_entity_ids.iter().copied().collect();
209
210 while let Some(current_entity) = queue.pop_front() {
211 let current_hop = *entity_hop.get(¤t_entity).unwrap_or(&0);
212 if current_hop >= max_hops {
213 continue;
214 }
215
216 let neighbours =
217 fetch_neighbours(conn, current_entity, namespace, min_weight, relation_filter)?;
218
219 for (neighbour_id, source_name, target_name, relation, weight) in neighbours {
220 if visited.insert(neighbour_id) {
221 entity_hop.insert(neighbour_id, current_hop + 1);
222 entity_edge.insert(neighbour_id, (source_name, target_name, relation, weight));
223 queue.push_back(neighbour_id);
224 }
225 }
226 }
227
228 let mut out: Vec<RelatedMemory> = Vec::new();
230 let mut dedup_ids: HashSet<i64> = HashSet::new();
231 dedup_ids.insert(seed_memory_id);
232
233 let mut ordered_entities: Vec<(i64, u32)> = entity_hop
235 .iter()
236 .filter(|(id, _)| !seed_entity_ids.contains(id))
237 .map(|(id, hop)| (*id, *hop))
238 .collect();
239 ordered_entities.sort_by(|a, b| {
240 let weight_a = entity_edge.get(&a.0).map(|e| e.3).unwrap_or(0.0);
241 let weight_b = entity_edge.get(&b.0).map(|e| e.3).unwrap_or(0.0);
242 a.1.cmp(&b.1).then_with(|| {
243 weight_b
244 .partial_cmp(&weight_a)
245 .unwrap_or(std::cmp::Ordering::Equal)
246 })
247 });
248
249 for (entity_id, hop) in ordered_entities {
250 let mut stmt = conn.prepare_cached(
251 "SELECT m.id, m.name, m.namespace, m.type, m.description
252 FROM memory_entities me
253 JOIN memories m ON m.id = me.memory_id
254 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
255 )?;
256 let rows = stmt
257 .query_map(params![entity_id], |r| {
258 Ok((
259 r.get::<_, i64>(0)?,
260 r.get::<_, String>(1)?,
261 r.get::<_, String>(2)?,
262 r.get::<_, String>(3)?,
263 r.get::<_, String>(4)?,
264 ))
265 })?
266 .collect::<Result<Vec<_>, _>>()?;
267
268 for (mid, name, ns, mtype, desc) in rows {
269 if !dedup_ids.insert(mid) {
270 continue;
271 }
272 let edge = entity_edge.get(&entity_id);
273 out.push(RelatedMemory {
274 memory_id: mid,
275 name,
276 namespace: ns,
277 memory_type: mtype,
278 description: desc,
279 hop_distance: hop,
280 source_entity: edge.map(|e| e.0.clone()),
281 target_entity: edge.map(|e| e.1.clone()),
282 relation: edge.map(|e| e.2.clone()),
283 weight: edge.map(|e| e.3),
284 });
285 if out.len() >= limit {
286 return Ok(out);
287 }
288 }
289 }
290
291 Ok(out)
292}
293
294fn fetch_neighbours(
295 conn: &Connection,
296 entity_id: i64,
297 namespace: &str,
298 min_weight: f64,
299 relation_filter: Option<&str>,
300) -> Result<Vec<Neighbour>, AppError> {
301 let base_sql = "\
304 SELECT r.target_id, se.name, te.name, r.relation, r.weight
305 FROM relationships r
306 JOIN entities se ON se.id = r.source_id
307 JOIN entities te ON te.id = r.target_id
308 WHERE r.source_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
309
310 let reverse_sql = "\
311 SELECT r.source_id, se.name, te.name, r.relation, r.weight
312 FROM relationships r
313 JOIN entities se ON se.id = r.source_id
314 JOIN entities te ON te.id = r.target_id
315 WHERE r.target_id = ?1 AND r.weight >= ?2 AND r.namespace = ?3";
316
317 let mut results: Vec<Neighbour> = Vec::new();
318
319 let forward_sql = match relation_filter {
320 Some(_) => format!("{base_sql} AND r.relation = ?4"),
321 None => base_sql.to_string(),
322 };
323 let rev_sql = match relation_filter {
324 Some(_) => format!("{reverse_sql} AND r.relation = ?4"),
325 None => reverse_sql.to_string(),
326 };
327
328 let mut stmt = conn.prepare_cached(&forward_sql)?;
329 let rows: Vec<_> = if let Some(rel) = relation_filter {
330 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
331 Ok((
332 r.get::<_, i64>(0)?,
333 r.get::<_, String>(1)?,
334 r.get::<_, String>(2)?,
335 r.get::<_, String>(3)?,
336 r.get::<_, f64>(4)?,
337 ))
338 })?
339 .collect::<Result<Vec<_>, _>>()?
340 } else {
341 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
342 Ok((
343 r.get::<_, i64>(0)?,
344 r.get::<_, String>(1)?,
345 r.get::<_, String>(2)?,
346 r.get::<_, String>(3)?,
347 r.get::<_, f64>(4)?,
348 ))
349 })?
350 .collect::<Result<Vec<_>, _>>()?
351 };
352 results.extend(rows);
353
354 let mut stmt = conn.prepare_cached(&rev_sql)?;
355 let rows: Vec<_> = if let Some(rel) = relation_filter {
356 stmt.query_map(params![entity_id, min_weight, namespace, rel], |r| {
357 Ok((
358 r.get::<_, i64>(0)?,
359 r.get::<_, String>(1)?,
360 r.get::<_, String>(2)?,
361 r.get::<_, String>(3)?,
362 r.get::<_, f64>(4)?,
363 ))
364 })?
365 .collect::<Result<Vec<_>, _>>()?
366 } else {
367 stmt.query_map(params![entity_id, min_weight, namespace], |r| {
368 Ok((
369 r.get::<_, i64>(0)?,
370 r.get::<_, String>(1)?,
371 r.get::<_, String>(2)?,
372 r.get::<_, String>(3)?,
373 r.get::<_, f64>(4)?,
374 ))
375 })?
376 .collect::<Result<Vec<_>, _>>()?
377 };
378 results.extend(rows);
379
380 Ok(results)
381}
382
383#[cfg(test)]
384mod testes {
385 use super::*;
386
387 fn setup_related_db() -> rusqlite::Connection {
388 let conn = rusqlite::Connection::open_in_memory().expect("falha ao abrir banco em memória");
389 conn.execute_batch(
390 "CREATE TABLE memories (
391 id INTEGER PRIMARY KEY AUTOINCREMENT,
392 name TEXT NOT NULL,
393 namespace TEXT NOT NULL DEFAULT 'global',
394 type TEXT NOT NULL DEFAULT 'fact',
395 description TEXT NOT NULL DEFAULT '',
396 deleted_at INTEGER
397 );
398 CREATE TABLE entities (
399 id INTEGER PRIMARY KEY AUTOINCREMENT,
400 namespace TEXT NOT NULL,
401 name TEXT NOT NULL
402 );
403 CREATE TABLE relationships (
404 id INTEGER PRIMARY KEY AUTOINCREMENT,
405 namespace TEXT NOT NULL,
406 source_id INTEGER NOT NULL,
407 target_id INTEGER NOT NULL,
408 relation TEXT NOT NULL DEFAULT 'related_to',
409 weight REAL NOT NULL DEFAULT 1.0
410 );
411 CREATE TABLE memory_entities (
412 memory_id INTEGER NOT NULL,
413 entity_id INTEGER NOT NULL
414 );",
415 )
416 .expect("falha ao criar tabelas de teste");
417 conn
418 }
419
420 fn insert_memory(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
421 conn.execute(
422 "INSERT INTO memories (name, namespace) VALUES (?1, ?2)",
423 rusqlite::params![name, namespace],
424 )
425 .expect("falha ao inserir memória");
426 conn.last_insert_rowid()
427 }
428
429 fn insert_entity(conn: &rusqlite::Connection, name: &str, namespace: &str) -> i64 {
430 conn.execute(
431 "INSERT INTO entities (name, namespace) VALUES (?1, ?2)",
432 rusqlite::params![name, namespace],
433 )
434 .expect("falha ao inserir entidade");
435 conn.last_insert_rowid()
436 }
437
438 fn link_memory_entity(conn: &rusqlite::Connection, memory_id: i64, entity_id: i64) {
439 conn.execute(
440 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
441 rusqlite::params![memory_id, entity_id],
442 )
443 .expect("falha ao vincular memória-entidade");
444 }
445
446 fn insert_relationship(
447 conn: &rusqlite::Connection,
448 namespace: &str,
449 source_id: i64,
450 target_id: i64,
451 relation: &str,
452 weight: f64,
453 ) {
454 conn.execute(
455 "INSERT INTO relationships (namespace, source_id, target_id, relation, weight)
456 VALUES (?1, ?2, ?3, ?4, ?5)",
457 rusqlite::params![namespace, source_id, target_id, relation, weight],
458 )
459 .expect("falha ao inserir relacionamento");
460 }
461
462 #[test]
463 fn related_response_serializa_results_e_elapsed_ms() {
464 let resp = RelatedResponse {
465 results: vec![RelatedMemory {
466 memory_id: 1,
467 name: "mem-vizinha".to_string(),
468 namespace: "global".to_string(),
469 memory_type: "fact".to_string(),
470 description: "desc".to_string(),
471 hop_distance: 1,
472 source_entity: Some("entidade-a".to_string()),
473 target_entity: Some("entidade-b".to_string()),
474 relation: Some("related_to".to_string()),
475 weight: Some(0.9),
476 }],
477 elapsed_ms: 7,
478 };
479
480 let json = serde_json::to_value(&resp).expect("serialização falhou");
481 assert!(json["results"].is_array());
482 assert_eq!(json["results"].as_array().unwrap().len(), 1);
483 assert_eq!(json["elapsed_ms"], 7u64);
484 assert_eq!(json["results"][0]["type"], "fact");
485 assert_eq!(json["results"][0]["hop_distance"], 1);
486 }
487
488 #[test]
489 fn traverse_related_retorna_vazio_sem_entidades_seed() {
490 let conn = setup_related_db();
491 let resultado = traverse_related(&conn, 1, &[], "global", 2, 0.0, None, 10)
492 .expect("traverse_related falhou");
493 assert!(
494 resultado.is_empty(),
495 "sem entidades seed deve retornar vazio"
496 );
497 }
498
499 #[test]
500 fn traverse_related_retorna_vazio_com_max_hops_zero() {
501 let conn = setup_related_db();
502 let mem_id = insert_memory(&conn, "seed-mem", "global");
503 let ent_id = insert_entity(&conn, "ent-a", "global");
504 link_memory_entity(&conn, mem_id, ent_id);
505
506 let resultado = traverse_related(&conn, mem_id, &[ent_id], "global", 0, 0.0, None, 10)
507 .expect("traverse_related falhou");
508 assert!(resultado.is_empty(), "max_hops=0 deve retornar vazio");
509 }
510
511 #[test]
512 fn traverse_related_descobre_memoria_vizinha_por_grafo() {
513 let conn = setup_related_db();
514
515 let seed_id = insert_memory(&conn, "seed-mem", "global");
516 let vizinha_id = insert_memory(&conn, "vizinha-mem", "global");
517 let ent_a = insert_entity(&conn, "ent-a", "global");
518 let ent_b = insert_entity(&conn, "ent-b", "global");
519
520 link_memory_entity(&conn, seed_id, ent_a);
521 link_memory_entity(&conn, vizinha_id, ent_b);
522 insert_relationship(&conn, "global", ent_a, ent_b, "related_to", 1.0);
523
524 let resultado = traverse_related(&conn, seed_id, &[ent_a], "global", 2, 0.0, None, 10)
525 .expect("traverse_related falhou");
526
527 assert_eq!(resultado.len(), 1, "deve encontrar 1 memória vizinha");
528 assert_eq!(resultado[0].name, "vizinha-mem");
529 assert_eq!(resultado[0].hop_distance, 1);
530 }
531
532 #[test]
533 fn traverse_related_respeita_limite() {
534 let conn = setup_related_db();
535
536 let seed_id = insert_memory(&conn, "seed", "global");
537 let ent_seed = insert_entity(&conn, "ent-seed", "global");
538 link_memory_entity(&conn, seed_id, ent_seed);
539
540 for i in 0..5 {
541 let mem_id = insert_memory(&conn, &format!("vizinha-{i}"), "global");
542 let ent_id = insert_entity(&conn, &format!("ent-{i}"), "global");
543 link_memory_entity(&conn, mem_id, ent_id);
544 insert_relationship(&conn, "global", ent_seed, ent_id, "related_to", 1.0);
545 }
546
547 let resultado = traverse_related(&conn, seed_id, &[ent_seed], "global", 1, 0.0, None, 3)
548 .expect("traverse_related falhou");
549
550 assert!(
551 resultado.len() <= 3,
552 "limite=3 deve restringir a no máximo 3 resultados"
553 );
554 }
555
556 #[test]
557 fn related_memory_campos_opcionais_nulos_serializados() {
558 let mem = RelatedMemory {
559 memory_id: 99,
560 name: "sem-relacao".to_string(),
561 namespace: "ns".to_string(),
562 memory_type: "concept".to_string(),
563 description: "".to_string(),
564 hop_distance: 2,
565 source_entity: None,
566 target_entity: None,
567 relation: None,
568 weight: None,
569 };
570
571 let json = serde_json::to_value(&mem).expect("serialização falhou");
572 assert!(json["source_entity"].is_null());
573 assert!(json["target_entity"].is_null());
574 assert!(json["relation"].is_null());
575 assert!(json["weight"].is_null());
576 assert_eq!(json["hop_distance"], 2);
577 }
578}