1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
18pub struct RecallArgs {
19 pub query: String,
20 #[arg(short = 'k', long, default_value = "10")]
26 pub k: usize,
27 #[arg(long, value_enum)]
31 pub r#type: Option<MemoryType>,
32 #[arg(long)]
33 pub namespace: Option<String>,
34 #[arg(long)]
35 pub no_graph: bool,
36 #[arg(long)]
42 pub precise: bool,
43 #[arg(long, default_value = "2")]
44 pub max_hops: u32,
45 #[arg(long, default_value = "0.3")]
46 pub min_weight: f64,
47 #[arg(long, value_name = "N")]
53 pub max_graph_results: Option<usize>,
54 #[arg(long, alias = "min-distance", default_value = "1.0")]
59 pub max_distance: f32,
60 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
61 pub format: JsonOutputFormat,
62 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
63 pub db: Option<String>,
64 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
66 pub json: bool,
67}
68
69pub fn run(args: RecallArgs) -> Result<(), AppError> {
70 let start = std::time::Instant::now();
71 let _ = args.format;
72 if args.query.trim().is_empty() {
73 return Err(AppError::Validation(
74 "query não pode estar vazia".to_string(),
75 ));
76 }
77 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
78 let paths = AppPaths::resolve(args.db.as_deref())?;
79
80 if !paths.db.exists() {
81 return Err(AppError::NotFound(erros::banco_nao_encontrado(
82 &paths.db.display().to_string(),
83 )));
84 }
85
86 output::emit_progress_i18n(
87 "Computing query embedding...",
88 "Calculando embedding da consulta...",
89 );
90 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
91
92 let conn = open_ro(&paths.db)?;
93
94 let memory_type_str = args.r#type.map(|t| t.as_str());
95 let effective_k = if args.precise { 100_000 } else { args.k };
98 let knn_results =
99 memories::knn_search(&conn, &embedding, &namespace, memory_type_str, effective_k)?;
100
101 let mut direct_matches = Vec::new();
102 let mut memory_ids: Vec<i64> = Vec::new();
103 for (memory_id, distance) in knn_results {
104 let row = {
105 let mut stmt = conn.prepare_cached(
106 "SELECT id, namespace, name, type, description, body, body_hash,
107 session_id, source, metadata, created_at, updated_at
108 FROM memories WHERE id=?1 AND deleted_at IS NULL",
109 )?;
110 stmt.query_row(rusqlite::params![memory_id], |r| {
111 Ok(memories::MemoryRow {
112 id: r.get(0)?,
113 namespace: r.get(1)?,
114 name: r.get(2)?,
115 memory_type: r.get(3)?,
116 description: r.get(4)?,
117 body: r.get(5)?,
118 body_hash: r.get(6)?,
119 session_id: r.get(7)?,
120 source: r.get(8)?,
121 metadata: r.get(9)?,
122 created_at: r.get(10)?,
123 updated_at: r.get(11)?,
124 })
125 })
126 .ok()
127 };
128 if let Some(row) = row {
129 let snippet: String = row.body.chars().take(300).collect();
130 direct_matches.push(RecallItem {
131 memory_id: row.id,
132 name: row.name,
133 namespace: row.namespace,
134 memory_type: row.memory_type,
135 description: row.description,
136 snippet,
137 distance,
138 source: "direct".to_string(),
139 graph_depth: None,
141 });
142 memory_ids.push(memory_id);
143 }
144 }
145
146 let mut graph_matches = Vec::new();
147 if !args.no_graph {
148 let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
149 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
150
151 let all_seed_ids: Vec<i64> = memory_ids
152 .iter()
153 .chain(entity_ids.iter())
154 .copied()
155 .collect();
156
157 if !all_seed_ids.is_empty() {
158 let graph_memory_ids = traverse_from_memories(
159 &conn,
160 &all_seed_ids,
161 &namespace,
162 args.min_weight,
163 args.max_hops,
164 )?;
165
166 for graph_mem_id in graph_memory_ids {
167 if let Some(cap) = args.max_graph_results {
170 if graph_matches.len() >= cap {
171 break;
172 }
173 }
174 let row = {
175 let mut stmt = conn.prepare_cached(
176 "SELECT id, namespace, name, type, description, body, body_hash,
177 session_id, source, metadata, created_at, updated_at
178 FROM memories WHERE id=?1 AND deleted_at IS NULL",
179 )?;
180 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
181 Ok(memories::MemoryRow {
182 id: r.get(0)?,
183 namespace: r.get(1)?,
184 name: r.get(2)?,
185 memory_type: r.get(3)?,
186 description: r.get(4)?,
187 body: r.get(5)?,
188 body_hash: r.get(6)?,
189 session_id: r.get(7)?,
190 source: r.get(8)?,
191 metadata: r.get(9)?,
192 created_at: r.get(10)?,
193 updated_at: r.get(11)?,
194 })
195 })
196 .ok()
197 };
198 if let Some(row) = row {
199 let snippet: String = row.body.chars().take(300).collect();
200 graph_matches.push(RecallItem {
201 memory_id: row.id,
202 name: row.name,
203 namespace: row.namespace,
204 memory_type: row.memory_type,
205 description: row.description,
206 snippet,
207 distance: 0.0,
211 source: "graph".to_string(),
212 graph_depth: Some(0),
217 });
218 }
219 }
220 }
221 }
222
223 if args.max_distance < 1.0 {
225 let has_relevant = direct_matches
226 .iter()
227 .any(|item| item.distance <= args.max_distance);
228 if !has_relevant {
229 return Err(AppError::NotFound(erros::sem_resultados_recall(
230 args.max_distance,
231 &args.query,
232 &namespace,
233 )));
234 }
235 }
236
237 let results: Vec<RecallItem> = direct_matches
238 .iter()
239 .cloned()
240 .chain(graph_matches.iter().cloned())
241 .collect();
242
243 output::emit_json(&RecallResponse {
244 query: args.query,
245 k: args.k,
246 direct_matches,
247 graph_matches,
248 results,
249 elapsed_ms: start.elapsed().as_millis() as u64,
250 })?;
251
252 Ok(())
253}
254
255#[cfg(test)]
256mod testes {
257 use crate::output::{RecallItem, RecallResponse};
258
259 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
260 RecallItem {
261 memory_id: 1,
262 name: name.to_string(),
263 namespace: "global".to_string(),
264 memory_type: "fact".to_string(),
265 description: "desc".to_string(),
266 snippet: "snippet".to_string(),
267 distance,
268 source: source.to_string(),
269 graph_depth: if source == "graph" { Some(0) } else { None },
270 }
271 }
272
273 #[test]
274 fn recall_response_serializa_campos_obrigatorios() {
275 let resp = RecallResponse {
276 query: "rust memory".to_string(),
277 k: 5,
278 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
279 graph_matches: vec![],
280 results: vec![make_item("mem-a", 0.12, "direct")],
281 elapsed_ms: 42,
282 };
283
284 let json = serde_json::to_value(&resp).expect("serialização falhou");
285 assert_eq!(json["query"], "rust memory");
286 assert_eq!(json["k"], 5);
287 assert_eq!(json["elapsed_ms"], 42u64);
288 assert!(json["direct_matches"].is_array());
289 assert!(json["graph_matches"].is_array());
290 assert!(json["results"].is_array());
291 }
292
293 #[test]
294 fn recall_item_serializa_type_renomeado() {
295 let item = make_item("mem-teste", 0.25, "direct");
296 let json = serde_json::to_value(&item).expect("serialização falhou");
297
298 assert_eq!(json["type"], "fact");
300 assert_eq!(json["distance"], 0.25f32);
301 assert_eq!(json["source"], "direct");
302 }
303
304 #[test]
305 fn recall_response_results_contem_direct_e_graph() {
306 let direct = make_item("d-mem", 0.10, "direct");
307 let graph = make_item("g-mem", 0.0, "graph");
308
309 let resp = RecallResponse {
310 query: "query".to_string(),
311 k: 10,
312 direct_matches: vec![direct.clone()],
313 graph_matches: vec![graph.clone()],
314 results: vec![direct, graph],
315 elapsed_ms: 10,
316 };
317
318 let json = serde_json::to_value(&resp).expect("serialização falhou");
319 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
320 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
321 assert_eq!(json["results"].as_array().unwrap().len(), 2);
322 assert_eq!(json["results"][0]["source"], "direct");
323 assert_eq!(json["results"][1]["source"], "graph");
324 }
325
326 #[test]
327 fn recall_response_vazio_serializa_arrays_vazios() {
328 let resp = RecallResponse {
329 query: "nada".to_string(),
330 k: 3,
331 direct_matches: vec![],
332 graph_matches: vec![],
333 results: vec![],
334 elapsed_ms: 1,
335 };
336
337 let json = serde_json::to_value(&resp).expect("serialização falhou");
338 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
339 assert_eq!(json["results"].as_array().unwrap().len(), 0);
340 }
341}