1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13#[derive(clap::Args)]
20pub struct RecallArgs {
21 pub query: String,
22 #[arg(short = 'k', long, default_value = "10")]
28 pub k: usize,
29 #[arg(long, value_enum)]
33 pub r#type: Option<MemoryType>,
34 #[arg(long)]
35 pub namespace: Option<String>,
36 #[arg(long)]
37 pub no_graph: bool,
38 #[arg(long)]
44 pub precise: bool,
45 #[arg(long, default_value = "2")]
46 pub max_hops: u32,
47 #[arg(long, default_value = "0.3")]
48 pub min_weight: f64,
49 #[arg(long, value_name = "N")]
55 pub max_graph_results: Option<usize>,
56 #[arg(long, alias = "min-distance", default_value = "1.0")]
61 pub max_distance: f32,
62 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
63 pub format: JsonOutputFormat,
64 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
65 pub db: Option<String>,
66 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
68 pub json: bool,
69 #[arg(long, conflicts_with = "namespace")]
74 pub all_namespaces: bool,
75}
76
77pub fn run(args: RecallArgs) -> Result<(), AppError> {
78 let start = std::time::Instant::now();
79 let _ = args.format;
80 if args.query.trim().is_empty() {
81 return Err(AppError::Validation(
82 "query não pode estar vazia".to_string(),
83 ));
84 }
85 let namespaces: Vec<String> = if args.all_namespaces {
89 Vec::new()
90 } else {
91 vec![crate::namespace::resolve_namespace(
92 args.namespace.as_deref(),
93 )?]
94 };
95 let namespace_for_graph = namespaces
97 .first()
98 .cloned()
99 .unwrap_or_else(|| "global".to_string());
100 let paths = AppPaths::resolve(args.db.as_deref())?;
101
102 if !paths.db.exists() {
103 return Err(AppError::NotFound(errors_msg::database_not_found(
104 &paths.db.display().to_string(),
105 )));
106 }
107
108 output::emit_progress_i18n(
109 "Computing query embedding...",
110 "Calculando embedding da consulta...",
111 );
112 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
113
114 let conn = open_ro(&paths.db)?;
115
116 let memory_type_str = args.r#type.map(|t| t.as_str());
117 let effective_k = if args.precise { 100_000 } else { args.k };
120 let knn_results =
121 memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
122
123 let mut direct_matches = Vec::new();
124 let mut memory_ids: Vec<i64> = Vec::new();
125 for (memory_id, distance) in knn_results {
126 let row = {
127 let mut stmt = conn.prepare_cached(
128 "SELECT id, namespace, name, type, description, body, body_hash,
129 session_id, source, metadata, created_at, updated_at
130 FROM memories WHERE id=?1 AND deleted_at IS NULL",
131 )?;
132 stmt.query_row(rusqlite::params![memory_id], |r| {
133 Ok(memories::MemoryRow {
134 id: r.get(0)?,
135 namespace: r.get(1)?,
136 name: r.get(2)?,
137 memory_type: r.get(3)?,
138 description: r.get(4)?,
139 body: r.get(5)?,
140 body_hash: r.get(6)?,
141 session_id: r.get(7)?,
142 source: r.get(8)?,
143 metadata: r.get(9)?,
144 created_at: r.get(10)?,
145 updated_at: r.get(11)?,
146 })
147 })
148 .ok()
149 };
150 if let Some(row) = row {
151 let snippet: String = row.body.chars().take(300).collect();
152 direct_matches.push(RecallItem {
153 memory_id: row.id,
154 name: row.name,
155 namespace: row.namespace,
156 memory_type: row.memory_type,
157 description: row.description,
158 snippet,
159 distance,
160 source: "direct".to_string(),
161 graph_depth: None,
163 });
164 memory_ids.push(memory_id);
165 }
166 }
167
168 let mut graph_matches = Vec::new();
169 if !args.no_graph {
170 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
171 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
172
173 let all_seed_ids: Vec<i64> = memory_ids
174 .iter()
175 .chain(entity_ids.iter())
176 .copied()
177 .collect();
178
179 if !all_seed_ids.is_empty() {
180 let graph_memory_ids = traverse_from_memories_with_hops(
181 &conn,
182 &all_seed_ids,
183 &namespace_for_graph,
184 args.min_weight,
185 args.max_hops,
186 )?;
187
188 for (graph_mem_id, hop) in graph_memory_ids {
189 if let Some(cap) = args.max_graph_results {
192 if graph_matches.len() >= cap {
193 break;
194 }
195 }
196 let row = {
197 let mut stmt = conn.prepare_cached(
198 "SELECT id, namespace, name, type, description, body, body_hash,
199 session_id, source, metadata, created_at, updated_at
200 FROM memories WHERE id=?1 AND deleted_at IS NULL",
201 )?;
202 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
203 Ok(memories::MemoryRow {
204 id: r.get(0)?,
205 namespace: r.get(1)?,
206 name: r.get(2)?,
207 memory_type: r.get(3)?,
208 description: r.get(4)?,
209 body: r.get(5)?,
210 body_hash: r.get(6)?,
211 session_id: r.get(7)?,
212 source: r.get(8)?,
213 metadata: r.get(9)?,
214 created_at: r.get(10)?,
215 updated_at: r.get(11)?,
216 })
217 })
218 .ok()
219 };
220 if let Some(row) = row {
221 let snippet: String = row.body.chars().take(300).collect();
222 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
228 graph_matches.push(RecallItem {
229 memory_id: row.id,
230 name: row.name,
231 namespace: row.namespace,
232 memory_type: row.memory_type,
233 description: row.description,
234 snippet,
235 distance: graph_distance,
236 source: "graph".to_string(),
237 graph_depth: Some(hop),
238 });
239 }
240 }
241 }
242 }
243
244 if args.max_distance < 1.0 {
246 let has_relevant = direct_matches
247 .iter()
248 .any(|item| item.distance <= args.max_distance);
249 if !has_relevant {
250 return Err(AppError::NotFound(errors_msg::no_recall_results(
251 args.max_distance,
252 &args.query,
253 &namespace_for_graph,
254 )));
255 }
256 }
257
258 let results: Vec<RecallItem> = direct_matches
259 .iter()
260 .cloned()
261 .chain(graph_matches.iter().cloned())
262 .collect();
263
264 output::emit_json(&RecallResponse {
265 query: args.query,
266 k: args.k,
267 direct_matches,
268 graph_matches,
269 results,
270 elapsed_ms: start.elapsed().as_millis() as u64,
271 })?;
272
273 Ok(())
274}
275
276#[cfg(test)]
277mod tests {
278 use crate::output::{RecallItem, RecallResponse};
279
280 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
281 RecallItem {
282 memory_id: 1,
283 name: name.to_string(),
284 namespace: "global".to_string(),
285 memory_type: "fact".to_string(),
286 description: "desc".to_string(),
287 snippet: "snippet".to_string(),
288 distance,
289 source: source.to_string(),
290 graph_depth: if source == "graph" { Some(0) } else { None },
291 }
292 }
293
294 #[test]
295 fn recall_response_serializa_campos_obrigatorios() {
296 let resp = RecallResponse {
297 query: "rust memory".to_string(),
298 k: 5,
299 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
300 graph_matches: vec![],
301 results: vec![make_item("mem-a", 0.12, "direct")],
302 elapsed_ms: 42,
303 };
304
305 let json = serde_json::to_value(&resp).expect("serialização falhou");
306 assert_eq!(json["query"], "rust memory");
307 assert_eq!(json["k"], 5);
308 assert_eq!(json["elapsed_ms"], 42u64);
309 assert!(json["direct_matches"].is_array());
310 assert!(json["graph_matches"].is_array());
311 assert!(json["results"].is_array());
312 }
313
314 #[test]
315 fn recall_item_serializa_type_renomeado() {
316 let item = make_item("mem-teste", 0.25, "direct");
317 let json = serde_json::to_value(&item).expect("serialização falhou");
318
319 assert_eq!(json["type"], "fact");
321 assert_eq!(json["distance"], 0.25f32);
322 assert_eq!(json["source"], "direct");
323 }
324
325 #[test]
326 fn recall_response_results_contem_direct_e_graph() {
327 let direct = make_item("d-mem", 0.10, "direct");
328 let graph = make_item("g-mem", 0.0, "graph");
329
330 let resp = RecallResponse {
331 query: "query".to_string(),
332 k: 10,
333 direct_matches: vec![direct.clone()],
334 graph_matches: vec![graph.clone()],
335 results: vec![direct, graph],
336 elapsed_ms: 10,
337 };
338
339 let json = serde_json::to_value(&resp).expect("serialização falhou");
340 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
341 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
342 assert_eq!(json["results"].as_array().unwrap().len(), 2);
343 assert_eq!(json["results"][0]["source"], "direct");
344 assert_eq!(json["results"][1]["source"], "graph");
345 }
346
347 #[test]
348 fn recall_response_vazio_serializa_arrays_vazios() {
349 let resp = RecallResponse {
350 query: "nada".to_string(),
351 k: 3,
352 direct_matches: vec![],
353 graph_matches: vec![],
354 results: vec![],
355 elapsed_ms: 1,
356 };
357
358 let json = serde_json::to_value(&resp).expect("serialização falhou");
359 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
360 assert_eq!(json["results"].as_array().unwrap().len(), 0);
361 }
362
363 #[test]
364 fn graph_matches_distance_uses_hop_count_proxy() {
365 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
371 for &(hop, expected) in cases {
372 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
373 assert!(
374 (d - expected).abs() < 0.001,
375 "hop={hop} expected={expected} got={d}"
376 );
377 }
378 }
379}