1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13 pub query: String,
14 #[arg(short = 'k', long, default_value = "10")]
15 pub k: usize,
16 #[arg(long, value_enum)]
17 pub r#type: Option<MemoryType>,
18 #[arg(long)]
19 pub namespace: Option<String>,
20 #[arg(long)]
21 pub no_graph: bool,
22 #[arg(long)]
23 pub precise: bool,
24 #[arg(long, default_value = "2")]
25 pub max_hops: u32,
26 #[arg(long, default_value = "0.3")]
27 pub min_weight: f64,
28 #[arg(long, alias = "min-distance", default_value = "1.0")]
33 pub max_distance: f32,
34 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
35 pub format: JsonOutputFormat,
36 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
37 pub db: Option<String>,
38 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
40 pub json: bool,
41}
42
43pub fn run(args: RecallArgs) -> Result<(), AppError> {
44 let start = std::time::Instant::now();
45 let _ = args.format;
46 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
47 let paths = AppPaths::resolve(args.db.as_deref())?;
48
49 if !paths.db.exists() {
50 return Err(AppError::NotFound(erros::banco_nao_encontrado(
51 &paths.db.display().to_string(),
52 )));
53 }
54
55 output::emit_progress_i18n(
56 "Computing query embedding...",
57 "Calculando embedding da consulta...",
58 );
59 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
60
61 let conn = open_ro(&paths.db)?;
62
63 let memory_type_str = args.r#type.map(|t| t.as_str());
64 let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
65
66 let mut direct_matches = Vec::new();
67 let mut memory_ids: Vec<i64> = Vec::new();
68 for (memory_id, distance) in knn_results {
69 let row = {
70 let mut stmt = conn.prepare_cached(
71 "SELECT id, namespace, name, type, description, body, body_hash,
72 session_id, source, metadata, created_at, updated_at
73 FROM memories WHERE id=?1 AND deleted_at IS NULL",
74 )?;
75 stmt.query_row(rusqlite::params![memory_id], |r| {
76 Ok(memories::MemoryRow {
77 id: r.get(0)?,
78 namespace: r.get(1)?,
79 name: r.get(2)?,
80 memory_type: r.get(3)?,
81 description: r.get(4)?,
82 body: r.get(5)?,
83 body_hash: r.get(6)?,
84 session_id: r.get(7)?,
85 source: r.get(8)?,
86 metadata: r.get(9)?,
87 created_at: r.get(10)?,
88 updated_at: r.get(11)?,
89 })
90 })
91 .ok()
92 };
93 if let Some(row) = row {
94 let snippet: String = row.body.chars().take(300).collect();
95 direct_matches.push(RecallItem {
96 memory_id: row.id,
97 name: row.name,
98 namespace: row.namespace,
99 memory_type: row.memory_type,
100 description: row.description,
101 snippet,
102 distance,
103 source: "direct".to_string(),
104 });
105 memory_ids.push(memory_id);
106 }
107 }
108
109 let mut graph_matches = Vec::new();
110 if !args.no_graph {
111 let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
112 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
113
114 let all_seed_ids: Vec<i64> = memory_ids
115 .iter()
116 .chain(entity_ids.iter())
117 .copied()
118 .collect();
119
120 if !all_seed_ids.is_empty() {
121 let graph_memory_ids = traverse_from_memories(
122 &conn,
123 &all_seed_ids,
124 &namespace,
125 args.min_weight,
126 args.max_hops,
127 )?;
128
129 for graph_mem_id in graph_memory_ids {
130 let row = {
131 let mut stmt = conn.prepare_cached(
132 "SELECT id, namespace, name, type, description, body, body_hash,
133 session_id, source, metadata, created_at, updated_at
134 FROM memories WHERE id=?1 AND deleted_at IS NULL",
135 )?;
136 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
137 Ok(memories::MemoryRow {
138 id: r.get(0)?,
139 namespace: r.get(1)?,
140 name: r.get(2)?,
141 memory_type: r.get(3)?,
142 description: r.get(4)?,
143 body: r.get(5)?,
144 body_hash: r.get(6)?,
145 session_id: r.get(7)?,
146 source: r.get(8)?,
147 metadata: r.get(9)?,
148 created_at: r.get(10)?,
149 updated_at: r.get(11)?,
150 })
151 })
152 .ok()
153 };
154 if let Some(row) = row {
155 let snippet: String = row.body.chars().take(300).collect();
156 graph_matches.push(RecallItem {
157 memory_id: row.id,
158 name: row.name,
159 namespace: row.namespace,
160 memory_type: row.memory_type,
161 description: row.description,
162 snippet,
163 distance: 0.0,
164 source: "graph".to_string(),
165 });
166 }
167 }
168 }
169 }
170
171 if args.max_distance < 1.0 {
173 let has_relevant = direct_matches
174 .iter()
175 .any(|item| item.distance <= args.max_distance);
176 if !has_relevant {
177 return Err(AppError::NotFound(erros::sem_resultados_recall(
178 args.max_distance,
179 &args.query,
180 &namespace,
181 )));
182 }
183 }
184
185 let results: Vec<RecallItem> = direct_matches
186 .iter()
187 .cloned()
188 .chain(graph_matches.iter().cloned())
189 .collect();
190
191 output::emit_json(&RecallResponse {
192 query: args.query,
193 k: args.k,
194 direct_matches,
195 graph_matches,
196 results,
197 elapsed_ms: start.elapsed().as_millis() as u64,
198 })?;
199
200 Ok(())
201}
202
203#[cfg(test)]
204mod testes {
205 use crate::output::{RecallItem, RecallResponse};
206
207 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
208 RecallItem {
209 memory_id: 1,
210 name: name.to_string(),
211 namespace: "global".to_string(),
212 memory_type: "fact".to_string(),
213 description: "desc".to_string(),
214 snippet: "snippet".to_string(),
215 distance,
216 source: source.to_string(),
217 }
218 }
219
220 #[test]
221 fn recall_response_serializa_campos_obrigatorios() {
222 let resp = RecallResponse {
223 query: "rust memory".to_string(),
224 k: 5,
225 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
226 graph_matches: vec![],
227 results: vec![make_item("mem-a", 0.12, "direct")],
228 elapsed_ms: 42,
229 };
230
231 let json = serde_json::to_value(&resp).expect("serialização falhou");
232 assert_eq!(json["query"], "rust memory");
233 assert_eq!(json["k"], 5);
234 assert_eq!(json["elapsed_ms"], 42u64);
235 assert!(json["direct_matches"].is_array());
236 assert!(json["graph_matches"].is_array());
237 assert!(json["results"].is_array());
238 }
239
240 #[test]
241 fn recall_item_serializa_type_renomeado() {
242 let item = make_item("mem-teste", 0.25, "direct");
243 let json = serde_json::to_value(&item).expect("serialização falhou");
244
245 assert_eq!(json["type"], "fact");
247 assert_eq!(json["distance"], 0.25f32);
248 assert_eq!(json["source"], "direct");
249 }
250
251 #[test]
252 fn recall_response_results_contem_direct_e_graph() {
253 let direct = make_item("d-mem", 0.10, "direct");
254 let graph = make_item("g-mem", 0.0, "graph");
255
256 let resp = RecallResponse {
257 query: "query".to_string(),
258 k: 10,
259 direct_matches: vec![direct.clone()],
260 graph_matches: vec![graph.clone()],
261 results: vec![direct, graph],
262 elapsed_ms: 10,
263 };
264
265 let json = serde_json::to_value(&resp).expect("serialização falhou");
266 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
267 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
268 assert_eq!(json["results"].as_array().unwrap().len(), 2);
269 assert_eq!(json["results"][0]["source"], "direct");
270 assert_eq!(json["results"][1]["source"], "graph");
271 }
272
273 #[test]
274 fn recall_response_vazio_serializa_arrays_vazios() {
275 let resp = RecallResponse {
276 query: "nada".to_string(),
277 k: 3,
278 direct_matches: vec![],
279 graph_matches: vec![],
280 results: vec![],
281 elapsed_ms: 1,
282 };
283
284 let json = serde_json::to_value(&resp).expect("serialização falhou");
285 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
286 assert_eq!(json["results"].as_array().unwrap().len(), 0);
287 }
288}