1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13#[derive(clap::Args)]
20pub struct RecallArgs {
21 #[arg(help = "Search query string (semantic vector search via sqlite-vec)")]
22 pub query: String,
23 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
31 pub k: usize,
32 #[arg(long, value_enum)]
36 pub r#type: Option<MemoryType>,
37 #[arg(long)]
38 pub namespace: Option<String>,
39 #[arg(long)]
40 pub no_graph: bool,
41 #[arg(long)]
47 pub precise: bool,
48 #[arg(long, default_value = "2")]
49 pub max_hops: u32,
50 #[arg(long, default_value = "0.3")]
51 pub min_weight: f64,
52 #[arg(long, value_name = "N")]
58 pub max_graph_results: Option<usize>,
59 #[arg(long, alias = "min-distance", default_value = "1.0")]
64 pub max_distance: f32,
65 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66 pub format: JsonOutputFormat,
67 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68 pub db: Option<String>,
69 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71 pub json: bool,
72 #[arg(long, conflicts_with = "namespace")]
77 pub all_namespaces: bool,
78 #[command(flatten)]
79 pub daemon: crate::cli::DaemonOpts,
80}
81
82pub fn run(args: RecallArgs) -> Result<(), AppError> {
83 let start = std::time::Instant::now();
84 let _ = args.format;
85 if args.query.trim().is_empty() {
86 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
87 }
88 let namespaces: Vec<String> = if args.all_namespaces {
92 Vec::new()
93 } else {
94 vec![crate::namespace::resolve_namespace(
95 args.namespace.as_deref(),
96 )?]
97 };
98 let namespace_for_graph = namespaces
100 .first()
101 .cloned()
102 .unwrap_or_else(|| "global".to_string());
103 let paths = AppPaths::resolve(args.db.as_deref())?;
104
105 crate::storage::connection::ensure_db_ready(&paths)?;
106
107 output::emit_progress_i18n(
108 "Computing query embedding...",
109 "Calculando embedding da consulta...",
110 );
111 let embedding = crate::daemon::embed_query_or_local(
112 &paths.models,
113 &args.query,
114 args.daemon.autostart_daemon,
115 )?;
116
117 let conn = open_ro(&paths.db)?;
118
119 let memory_type_str = args.r#type.map(|t| t.as_str());
120 let effective_k = if args.precise { 100_000 } else { args.k };
123 let knn_results =
124 memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
125
126 let mut direct_matches = Vec::new();
127 let mut memory_ids: Vec<i64> = Vec::new();
128 for (memory_id, distance) in knn_results {
129 let row = {
130 let mut stmt = conn.prepare_cached(
131 "SELECT id, namespace, name, type, description, body, body_hash,
132 session_id, source, metadata, created_at, updated_at
133 FROM memories WHERE id=?1 AND deleted_at IS NULL",
134 )?;
135 stmt.query_row(rusqlite::params![memory_id], |r| {
136 Ok(memories::MemoryRow {
137 id: r.get(0)?,
138 namespace: r.get(1)?,
139 name: r.get(2)?,
140 memory_type: r.get(3)?,
141 description: r.get(4)?,
142 body: r.get(5)?,
143 body_hash: r.get(6)?,
144 session_id: r.get(7)?,
145 source: r.get(8)?,
146 metadata: r.get(9)?,
147 created_at: r.get(10)?,
148 updated_at: r.get(11)?,
149 deleted_at: None,
150 })
151 })
152 .ok()
153 };
154 if let Some(row) = row {
155 let snippet: String = row.body.chars().take(300).collect();
156 direct_matches.push(RecallItem {
157 memory_id: row.id,
158 name: row.name,
159 namespace: row.namespace,
160 memory_type: row.memory_type,
161 description: row.description,
162 snippet,
163 distance,
164 score: RecallItem::score_from_distance(distance),
165 source: "direct".to_string(),
166 graph_depth: None,
168 });
169 memory_ids.push(memory_id);
170 }
171 }
172
173 let mut graph_matches = Vec::new();
174 if !args.no_graph {
175 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
176 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
177
178 let all_seed_ids: Vec<i64> = memory_ids
179 .iter()
180 .chain(entity_ids.iter())
181 .copied()
182 .collect();
183
184 if !all_seed_ids.is_empty() {
185 let graph_memory_ids = traverse_from_memories_with_hops(
186 &conn,
187 &all_seed_ids,
188 &namespace_for_graph,
189 args.min_weight,
190 args.max_hops,
191 )?;
192
193 for (graph_mem_id, hop) in graph_memory_ids {
194 if let Some(cap) = args.max_graph_results {
197 if graph_matches.len() >= cap {
198 break;
199 }
200 }
201 let row = {
202 let mut stmt = conn.prepare_cached(
203 "SELECT id, namespace, name, type, description, body, body_hash,
204 session_id, source, metadata, created_at, updated_at
205 FROM memories WHERE id=?1 AND deleted_at IS NULL",
206 )?;
207 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
208 Ok(memories::MemoryRow {
209 id: r.get(0)?,
210 namespace: r.get(1)?,
211 name: r.get(2)?,
212 memory_type: r.get(3)?,
213 description: r.get(4)?,
214 body: r.get(5)?,
215 body_hash: r.get(6)?,
216 session_id: r.get(7)?,
217 source: r.get(8)?,
218 metadata: r.get(9)?,
219 created_at: r.get(10)?,
220 updated_at: r.get(11)?,
221 deleted_at: None,
222 })
223 })
224 .ok()
225 };
226 if let Some(row) = row {
227 let snippet: String = row.body.chars().take(300).collect();
228 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
234 graph_matches.push(RecallItem {
235 memory_id: row.id,
236 name: row.name,
237 namespace: row.namespace,
238 memory_type: row.memory_type,
239 description: row.description,
240 snippet,
241 distance: graph_distance,
242 score: RecallItem::score_from_distance(graph_distance),
243 source: "graph".to_string(),
244 graph_depth: Some(hop),
245 });
246 }
247 }
248 }
249 }
250
251 if args.max_distance < 1.0 {
253 let has_relevant = direct_matches
254 .iter()
255 .any(|item| item.distance <= args.max_distance);
256 if !has_relevant {
257 return Err(AppError::NotFound(errors_msg::no_recall_results(
258 args.max_distance,
259 &args.query,
260 &namespace_for_graph,
261 )));
262 }
263 }
264
265 let results: Vec<RecallItem> = direct_matches
266 .iter()
267 .cloned()
268 .chain(graph_matches.iter().cloned())
269 .collect();
270
271 output::emit_json(&RecallResponse {
272 query: args.query,
273 k: args.k,
274 direct_matches,
275 graph_matches,
276 results,
277 elapsed_ms: start.elapsed().as_millis() as u64,
278 })?;
279
280 Ok(())
281}
282
283#[cfg(test)]
284mod tests {
285 use crate::output::{RecallItem, RecallResponse};
286
287 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
288 RecallItem {
289 memory_id: 1,
290 name: name.to_string(),
291 namespace: "global".to_string(),
292 memory_type: "fact".to_string(),
293 description: "desc".to_string(),
294 snippet: "snippet".to_string(),
295 distance,
296 score: RecallItem::score_from_distance(distance),
297 source: source.to_string(),
298 graph_depth: if source == "graph" { Some(0) } else { None },
299 }
300 }
301
302 #[test]
304 fn recall_item_score_is_present_and_finite_for_direct_match() {
305 let item = make_item("mem", 0.25, "direct");
306 let json = serde_json::to_value(&item).expect("serialization failed");
307 let score = json["score"].as_f64().expect("score must be a number");
308 assert!(
309 (0.0..=1.0).contains(&score),
310 "score must be in [0, 1], got {score}"
311 );
312 assert!(
313 (score - 0.75).abs() < 1e-6,
314 "score must equal 1 - distance for canonical case"
315 );
316 }
317
318 #[test]
319 fn recall_item_score_clamps_distance_outside_unit_range() {
320 assert_eq!(RecallItem::score_from_distance(2.0), 0.0);
322 assert_eq!(RecallItem::score_from_distance(-0.5), 1.0);
323 assert_eq!(RecallItem::score_from_distance(f32::NAN), 0.0);
324 }
325
326 #[test]
327 fn recall_response_serializes_required_fields() {
328 let resp = RecallResponse {
329 query: "rust memory".to_string(),
330 k: 5,
331 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
332 graph_matches: vec![],
333 results: vec![make_item("mem-a", 0.12, "direct")],
334 elapsed_ms: 42,
335 };
336
337 let json = serde_json::to_value(&resp).expect("serialization failed");
338 assert_eq!(json["query"], "rust memory");
339 assert_eq!(json["k"], 5);
340 assert_eq!(json["elapsed_ms"], 42u64);
341 assert!(json["direct_matches"].is_array());
342 assert!(json["graph_matches"].is_array());
343 assert!(json["results"].is_array());
344 }
345
346 #[test]
347 fn recall_item_serializes_renamed_type() {
348 let item = make_item("mem-test", 0.25, "direct");
349 let json = serde_json::to_value(&item).expect("serialization failed");
350
351 assert_eq!(json["type"], "fact");
353 assert_eq!(json["distance"], 0.25f32);
354 assert_eq!(json["source"], "direct");
355 }
356
357 #[test]
358 fn recall_response_results_contains_direct_and_graph() {
359 let direct = make_item("d-mem", 0.10, "direct");
360 let graph = make_item("g-mem", 0.0, "graph");
361
362 let resp = RecallResponse {
363 query: "query".to_string(),
364 k: 10,
365 direct_matches: vec![direct.clone()],
366 graph_matches: vec![graph.clone()],
367 results: vec![direct, graph],
368 elapsed_ms: 10,
369 };
370
371 let json = serde_json::to_value(&resp).expect("serialization failed");
372 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
373 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
374 assert_eq!(json["results"].as_array().unwrap().len(), 2);
375 assert_eq!(json["results"][0]["source"], "direct");
376 assert_eq!(json["results"][1]["source"], "graph");
377 }
378
379 #[test]
380 fn recall_response_empty_serializes_empty_arrays() {
381 let resp = RecallResponse {
382 query: "nothing".to_string(),
383 k: 3,
384 direct_matches: vec![],
385 graph_matches: vec![],
386 results: vec![],
387 elapsed_ms: 1,
388 };
389
390 let json = serde_json::to_value(&resp).expect("serialization failed");
391 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
392 assert_eq!(json["results"].as_array().unwrap().len(), 0);
393 }
394
395 #[test]
396 fn graph_matches_distance_uses_hop_count_proxy() {
397 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
403 for &(hop, expected) in cases {
404 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
405 assert!(
406 (d - expected).abs() < 0.001,
407 "hop={hop} expected={expected} got={d}"
408 );
409 }
410 }
411}