1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13#[derive(clap::Args)]
20pub struct RecallArgs {
21 #[arg(help = "Search query string (semantic vector search via sqlite-vec)")]
22 pub query: String,
23 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
31 pub k: usize,
32 #[arg(long, value_enum)]
36 pub r#type: Option<MemoryType>,
37 #[arg(long)]
38 pub namespace: Option<String>,
39 #[arg(long)]
40 pub no_graph: bool,
41 #[arg(long)]
47 pub precise: bool,
48 #[arg(long, default_value = "2")]
49 pub max_hops: u32,
50 #[arg(long, default_value = "0.3")]
51 pub min_weight: f64,
52 #[arg(long, value_name = "N")]
58 pub max_graph_results: Option<usize>,
59 #[arg(long, alias = "min-distance", default_value = "1.0")]
64 pub max_distance: f32,
65 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66 pub format: JsonOutputFormat,
67 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68 pub db: Option<String>,
69 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71 pub json: bool,
72 #[arg(long, conflicts_with = "namespace")]
77 pub all_namespaces: bool,
78}
79
80pub fn run(args: RecallArgs) -> Result<(), AppError> {
81 let start = std::time::Instant::now();
82 let _ = args.format;
83 if args.query.trim().is_empty() {
84 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
85 }
86 let namespaces: Vec<String> = if args.all_namespaces {
90 Vec::new()
91 } else {
92 vec![crate::namespace::resolve_namespace(
93 args.namespace.as_deref(),
94 )?]
95 };
96 let namespace_for_graph = namespaces
98 .first()
99 .cloned()
100 .unwrap_or_else(|| "global".to_string());
101 let paths = AppPaths::resolve(args.db.as_deref())?;
102
103 crate::storage::connection::ensure_db_ready(&paths)?;
104
105 output::emit_progress_i18n(
106 "Computing query embedding...",
107 "Calculando embedding da consulta...",
108 );
109 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
110
111 let conn = open_ro(&paths.db)?;
112
113 let memory_type_str = args.r#type.map(|t| t.as_str());
114 let effective_k = if args.precise { 100_000 } else { args.k };
117 let knn_results =
118 memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
119
120 let mut direct_matches = Vec::new();
121 let mut memory_ids: Vec<i64> = Vec::new();
122 for (memory_id, distance) in knn_results {
123 let row = {
124 let mut stmt = conn.prepare_cached(
125 "SELECT id, namespace, name, type, description, body, body_hash,
126 session_id, source, metadata, created_at, updated_at
127 FROM memories WHERE id=?1 AND deleted_at IS NULL",
128 )?;
129 stmt.query_row(rusqlite::params![memory_id], |r| {
130 Ok(memories::MemoryRow {
131 id: r.get(0)?,
132 namespace: r.get(1)?,
133 name: r.get(2)?,
134 memory_type: r.get(3)?,
135 description: r.get(4)?,
136 body: r.get(5)?,
137 body_hash: r.get(6)?,
138 session_id: r.get(7)?,
139 source: r.get(8)?,
140 metadata: r.get(9)?,
141 created_at: r.get(10)?,
142 updated_at: r.get(11)?,
143 })
144 })
145 .ok()
146 };
147 if let Some(row) = row {
148 let snippet: String = row.body.chars().take(300).collect();
149 direct_matches.push(RecallItem {
150 memory_id: row.id,
151 name: row.name,
152 namespace: row.namespace,
153 memory_type: row.memory_type,
154 description: row.description,
155 snippet,
156 distance,
157 source: "direct".to_string(),
158 graph_depth: None,
160 });
161 memory_ids.push(memory_id);
162 }
163 }
164
165 let mut graph_matches = Vec::new();
166 if !args.no_graph {
167 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
168 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
169
170 let all_seed_ids: Vec<i64> = memory_ids
171 .iter()
172 .chain(entity_ids.iter())
173 .copied()
174 .collect();
175
176 if !all_seed_ids.is_empty() {
177 let graph_memory_ids = traverse_from_memories_with_hops(
178 &conn,
179 &all_seed_ids,
180 &namespace_for_graph,
181 args.min_weight,
182 args.max_hops,
183 )?;
184
185 for (graph_mem_id, hop) in graph_memory_ids {
186 if let Some(cap) = args.max_graph_results {
189 if graph_matches.len() >= cap {
190 break;
191 }
192 }
193 let row = {
194 let mut stmt = conn.prepare_cached(
195 "SELECT id, namespace, name, type, description, body, body_hash,
196 session_id, source, metadata, created_at, updated_at
197 FROM memories WHERE id=?1 AND deleted_at IS NULL",
198 )?;
199 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
200 Ok(memories::MemoryRow {
201 id: r.get(0)?,
202 namespace: r.get(1)?,
203 name: r.get(2)?,
204 memory_type: r.get(3)?,
205 description: r.get(4)?,
206 body: r.get(5)?,
207 body_hash: r.get(6)?,
208 session_id: r.get(7)?,
209 source: r.get(8)?,
210 metadata: r.get(9)?,
211 created_at: r.get(10)?,
212 updated_at: r.get(11)?,
213 })
214 })
215 .ok()
216 };
217 if let Some(row) = row {
218 let snippet: String = row.body.chars().take(300).collect();
219 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
225 graph_matches.push(RecallItem {
226 memory_id: row.id,
227 name: row.name,
228 namespace: row.namespace,
229 memory_type: row.memory_type,
230 description: row.description,
231 snippet,
232 distance: graph_distance,
233 source: "graph".to_string(),
234 graph_depth: Some(hop),
235 });
236 }
237 }
238 }
239 }
240
241 if args.max_distance < 1.0 {
243 let has_relevant = direct_matches
244 .iter()
245 .any(|item| item.distance <= args.max_distance);
246 if !has_relevant {
247 return Err(AppError::NotFound(errors_msg::no_recall_results(
248 args.max_distance,
249 &args.query,
250 &namespace_for_graph,
251 )));
252 }
253 }
254
255 let results: Vec<RecallItem> = direct_matches
256 .iter()
257 .cloned()
258 .chain(graph_matches.iter().cloned())
259 .collect();
260
261 output::emit_json(&RecallResponse {
262 query: args.query,
263 k: args.k,
264 direct_matches,
265 graph_matches,
266 results,
267 elapsed_ms: start.elapsed().as_millis() as u64,
268 })?;
269
270 Ok(())
271}
272
273#[cfg(test)]
274mod tests {
275 use crate::output::{RecallItem, RecallResponse};
276
277 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
278 RecallItem {
279 memory_id: 1,
280 name: name.to_string(),
281 namespace: "global".to_string(),
282 memory_type: "fact".to_string(),
283 description: "desc".to_string(),
284 snippet: "snippet".to_string(),
285 distance,
286 source: source.to_string(),
287 graph_depth: if source == "graph" { Some(0) } else { None },
288 }
289 }
290
291 #[test]
292 fn recall_response_serializes_required_fields() {
293 let resp = RecallResponse {
294 query: "rust memory".to_string(),
295 k: 5,
296 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
297 graph_matches: vec![],
298 results: vec![make_item("mem-a", 0.12, "direct")],
299 elapsed_ms: 42,
300 };
301
302 let json = serde_json::to_value(&resp).expect("serialization failed");
303 assert_eq!(json["query"], "rust memory");
304 assert_eq!(json["k"], 5);
305 assert_eq!(json["elapsed_ms"], 42u64);
306 assert!(json["direct_matches"].is_array());
307 assert!(json["graph_matches"].is_array());
308 assert!(json["results"].is_array());
309 }
310
311 #[test]
312 fn recall_item_serializes_renamed_type() {
313 let item = make_item("mem-test", 0.25, "direct");
314 let json = serde_json::to_value(&item).expect("serialization failed");
315
316 assert_eq!(json["type"], "fact");
318 assert_eq!(json["distance"], 0.25f32);
319 assert_eq!(json["source"], "direct");
320 }
321
322 #[test]
323 fn recall_response_results_contains_direct_and_graph() {
324 let direct = make_item("d-mem", 0.10, "direct");
325 let graph = make_item("g-mem", 0.0, "graph");
326
327 let resp = RecallResponse {
328 query: "query".to_string(),
329 k: 10,
330 direct_matches: vec![direct.clone()],
331 graph_matches: vec![graph.clone()],
332 results: vec![direct, graph],
333 elapsed_ms: 10,
334 };
335
336 let json = serde_json::to_value(&resp).expect("serialization failed");
337 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
338 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
339 assert_eq!(json["results"].as_array().unwrap().len(), 2);
340 assert_eq!(json["results"][0]["source"], "direct");
341 assert_eq!(json["results"][1]["source"], "graph");
342 }
343
344 #[test]
345 fn recall_response_empty_serializes_empty_arrays() {
346 let resp = RecallResponse {
347 query: "nothing".to_string(),
348 k: 3,
349 direct_matches: vec![],
350 graph_matches: vec![],
351 results: vec![],
352 elapsed_ms: 1,
353 };
354
355 let json = serde_json::to_value(&resp).expect("serialization failed");
356 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
357 assert_eq!(json["results"].as_array().unwrap().len(), 0);
358 }
359
360 #[test]
361 fn graph_matches_distance_uses_hop_count_proxy() {
362 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
368 for &(hop, expected) in cases {
369 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
370 assert!(
371 (d - expected).abs() < 0.001,
372 "hop={hop} expected={expected} got={d}"
373 );
374 }
375 }
376}