1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13 pub query: String,
14 #[arg(short = 'k', long, default_value = "10")]
20 pub k: usize,
21 #[arg(long, value_enum)]
22 pub r#type: Option<MemoryType>,
23 #[arg(long)]
24 pub namespace: Option<String>,
25 #[arg(long)]
26 pub no_graph: bool,
27 #[arg(long)]
28 pub precise: bool,
29 #[arg(long, default_value = "2")]
30 pub max_hops: u32,
31 #[arg(long, default_value = "0.3")]
32 pub min_weight: f64,
33 #[arg(long, value_name = "N")]
39 pub max_graph_results: Option<usize>,
40 #[arg(long, alias = "min-distance", default_value = "1.0")]
45 pub max_distance: f32,
46 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
47 pub format: JsonOutputFormat,
48 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
49 pub db: Option<String>,
50 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
52 pub json: bool,
53}
54
55pub fn run(args: RecallArgs) -> Result<(), AppError> {
56 let start = std::time::Instant::now();
57 let _ = args.format;
58 if args.query.trim().is_empty() {
59 return Err(AppError::Validation(
60 "query não pode estar vazia".to_string(),
61 ));
62 }
63 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
64 let paths = AppPaths::resolve(args.db.as_deref())?;
65
66 if !paths.db.exists() {
67 return Err(AppError::NotFound(erros::banco_nao_encontrado(
68 &paths.db.display().to_string(),
69 )));
70 }
71
72 output::emit_progress_i18n(
73 "Computing query embedding...",
74 "Calculando embedding da consulta...",
75 );
76 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
77
78 let conn = open_ro(&paths.db)?;
79
80 let memory_type_str = args.r#type.map(|t| t.as_str());
81 let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
82
83 let mut direct_matches = Vec::new();
84 let mut memory_ids: Vec<i64> = Vec::new();
85 for (memory_id, distance) in knn_results {
86 let row = {
87 let mut stmt = conn.prepare_cached(
88 "SELECT id, namespace, name, type, description, body, body_hash,
89 session_id, source, metadata, created_at, updated_at
90 FROM memories WHERE id=?1 AND deleted_at IS NULL",
91 )?;
92 stmt.query_row(rusqlite::params![memory_id], |r| {
93 Ok(memories::MemoryRow {
94 id: r.get(0)?,
95 namespace: r.get(1)?,
96 name: r.get(2)?,
97 memory_type: r.get(3)?,
98 description: r.get(4)?,
99 body: r.get(5)?,
100 body_hash: r.get(6)?,
101 session_id: r.get(7)?,
102 source: r.get(8)?,
103 metadata: r.get(9)?,
104 created_at: r.get(10)?,
105 updated_at: r.get(11)?,
106 })
107 })
108 .ok()
109 };
110 if let Some(row) = row {
111 let snippet: String = row.body.chars().take(300).collect();
112 direct_matches.push(RecallItem {
113 memory_id: row.id,
114 name: row.name,
115 namespace: row.namespace,
116 memory_type: row.memory_type,
117 description: row.description,
118 snippet,
119 distance,
120 source: "direct".to_string(),
121 graph_depth: None,
123 });
124 memory_ids.push(memory_id);
125 }
126 }
127
128 let mut graph_matches = Vec::new();
129 if !args.no_graph {
130 let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
131 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
132
133 let all_seed_ids: Vec<i64> = memory_ids
134 .iter()
135 .chain(entity_ids.iter())
136 .copied()
137 .collect();
138
139 if !all_seed_ids.is_empty() {
140 let graph_memory_ids = traverse_from_memories(
141 &conn,
142 &all_seed_ids,
143 &namespace,
144 args.min_weight,
145 args.max_hops,
146 )?;
147
148 for graph_mem_id in graph_memory_ids {
149 if let Some(cap) = args.max_graph_results {
152 if graph_matches.len() >= cap {
153 break;
154 }
155 }
156 let row = {
157 let mut stmt = conn.prepare_cached(
158 "SELECT id, namespace, name, type, description, body, body_hash,
159 session_id, source, metadata, created_at, updated_at
160 FROM memories WHERE id=?1 AND deleted_at IS NULL",
161 )?;
162 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
163 Ok(memories::MemoryRow {
164 id: r.get(0)?,
165 namespace: r.get(1)?,
166 name: r.get(2)?,
167 memory_type: r.get(3)?,
168 description: r.get(4)?,
169 body: r.get(5)?,
170 body_hash: r.get(6)?,
171 session_id: r.get(7)?,
172 source: r.get(8)?,
173 metadata: r.get(9)?,
174 created_at: r.get(10)?,
175 updated_at: r.get(11)?,
176 })
177 })
178 .ok()
179 };
180 if let Some(row) = row {
181 let snippet: String = row.body.chars().take(300).collect();
182 graph_matches.push(RecallItem {
183 memory_id: row.id,
184 name: row.name,
185 namespace: row.namespace,
186 memory_type: row.memory_type,
187 description: row.description,
188 snippet,
189 distance: 0.0,
193 source: "graph".to_string(),
194 graph_depth: Some(0),
199 });
200 }
201 }
202 }
203 }
204
205 if args.max_distance < 1.0 {
207 let has_relevant = direct_matches
208 .iter()
209 .any(|item| item.distance <= args.max_distance);
210 if !has_relevant {
211 return Err(AppError::NotFound(erros::sem_resultados_recall(
212 args.max_distance,
213 &args.query,
214 &namespace,
215 )));
216 }
217 }
218
219 let results: Vec<RecallItem> = direct_matches
220 .iter()
221 .cloned()
222 .chain(graph_matches.iter().cloned())
223 .collect();
224
225 output::emit_json(&RecallResponse {
226 query: args.query,
227 k: args.k,
228 direct_matches,
229 graph_matches,
230 results,
231 elapsed_ms: start.elapsed().as_millis() as u64,
232 })?;
233
234 Ok(())
235}
236
237#[cfg(test)]
238mod testes {
239 use crate::output::{RecallItem, RecallResponse};
240
241 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
242 RecallItem {
243 memory_id: 1,
244 name: name.to_string(),
245 namespace: "global".to_string(),
246 memory_type: "fact".to_string(),
247 description: "desc".to_string(),
248 snippet: "snippet".to_string(),
249 distance,
250 source: source.to_string(),
251 graph_depth: if source == "graph" { Some(0) } else { None },
252 }
253 }
254
255 #[test]
256 fn recall_response_serializa_campos_obrigatorios() {
257 let resp = RecallResponse {
258 query: "rust memory".to_string(),
259 k: 5,
260 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
261 graph_matches: vec![],
262 results: vec![make_item("mem-a", 0.12, "direct")],
263 elapsed_ms: 42,
264 };
265
266 let json = serde_json::to_value(&resp).expect("serialização falhou");
267 assert_eq!(json["query"], "rust memory");
268 assert_eq!(json["k"], 5);
269 assert_eq!(json["elapsed_ms"], 42u64);
270 assert!(json["direct_matches"].is_array());
271 assert!(json["graph_matches"].is_array());
272 assert!(json["results"].is_array());
273 }
274
275 #[test]
276 fn recall_item_serializa_type_renomeado() {
277 let item = make_item("mem-teste", 0.25, "direct");
278 let json = serde_json::to_value(&item).expect("serialização falhou");
279
280 assert_eq!(json["type"], "fact");
282 assert_eq!(json["distance"], 0.25f32);
283 assert_eq!(json["source"], "direct");
284 }
285
286 #[test]
287 fn recall_response_results_contem_direct_e_graph() {
288 let direct = make_item("d-mem", 0.10, "direct");
289 let graph = make_item("g-mem", 0.0, "graph");
290
291 let resp = RecallResponse {
292 query: "query".to_string(),
293 k: 10,
294 direct_matches: vec![direct.clone()],
295 graph_matches: vec![graph.clone()],
296 results: vec![direct, graph],
297 elapsed_ms: 10,
298 };
299
300 let json = serde_json::to_value(&resp).expect("serialização falhou");
301 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
302 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
303 assert_eq!(json["results"].as_array().unwrap().len(), 2);
304 assert_eq!(json["results"][0]["source"], "direct");
305 assert_eq!(json["results"][1]["source"], "graph");
306 }
307
308 #[test]
309 fn recall_response_vazio_serializa_arrays_vazios() {
310 let resp = RecallResponse {
311 query: "nada".to_string(),
312 k: 3,
313 direct_matches: vec![],
314 graph_matches: vec![],
315 results: vec![],
316 elapsed_ms: 1,
317 };
318
319 let json = serde_json::to_value(&resp).expect("serialização falhou");
320 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
321 assert_eq!(json["results"].as_array().unwrap().len(), 0);
322 }
323}