1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13#[derive(clap::Args)]
20#[command(after_long_help = "EXAMPLES:\n \
21 # Semantic search for top 5 matches\n \
22 sqlite-graphrag recall \"authentication design\" --k 5\n\n \
23 # Disable automatic graph expansion\n \
24 sqlite-graphrag recall \"JWT tokens\" --k 3 --no-graph\n\n \
25 # Limit graph traversal depth and minimum edge weight\n \
26 sqlite-graphrag recall \"auth\" --k 5 --max-hops 2 --min-weight 0.3\n\n \
27 # Filter by memory type\n \
28 sqlite-graphrag recall \"deployment\" --type decision --k 10\n\n \
29 # Cap results by distance threshold\n \
30 sqlite-graphrag recall \"API design\" --k 5 --max-distance 0.8\n\n \
31NOTES:\n \
32 When --no-graph is active, graph traversal is skipped and every result has\n \
33 source=\"direct\". The source field is therefore redundant with --no-graph and\n \
34 may be ignored by callers in that mode.")]
35pub struct RecallArgs {
36 #[arg(help = "Search query string (semantic vector search via sqlite-vec)")]
37 pub query: String,
38 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
46 pub k: usize,
47 #[arg(long, value_enum)]
51 pub r#type: Option<MemoryType>,
52 #[arg(long)]
53 pub namespace: Option<String>,
54 #[arg(long)]
55 pub no_graph: bool,
56 #[arg(long)]
62 pub precise: bool,
63 #[arg(long, default_value = "2")]
64 pub max_hops: u32,
65 #[arg(long, default_value = "0.3")]
66 pub min_weight: f64,
67 #[arg(long, value_name = "N")]
73 pub max_graph_results: Option<usize>,
74 #[arg(long, alias = "min-distance", default_value = "1.0")]
79 pub max_distance: f32,
80 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
81 pub format: JsonOutputFormat,
82 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
83 pub db: Option<String>,
84 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
86 pub json: bool,
87 #[arg(long, conflicts_with = "namespace")]
92 pub all_namespaces: bool,
93 #[command(flatten)]
94 pub daemon: crate::cli::DaemonOpts,
95}
96
97pub fn run(args: RecallArgs) -> Result<(), AppError> {
98 let start = std::time::Instant::now();
99 let _ = args.format;
100 if args.query.trim().is_empty() {
101 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
102 }
103 let namespaces: Vec<String> = if args.all_namespaces {
107 Vec::new()
108 } else {
109 vec![crate::namespace::resolve_namespace(
110 args.namespace.as_deref(),
111 )?]
112 };
113 let namespace_for_graph = namespaces
115 .first()
116 .cloned()
117 .unwrap_or_else(|| "global".to_string());
118 let paths = AppPaths::resolve(args.db.as_deref())?;
119
120 crate::storage::connection::ensure_db_ready(&paths)?;
121
122 output::emit_progress_i18n(
123 "Computing query embedding...",
124 "Calculando embedding da consulta...",
125 );
126 let embedding = crate::daemon::embed_query_or_local(
127 &paths.models,
128 &args.query,
129 args.daemon.autostart_daemon,
130 )?;
131
132 let conn = open_ro(&paths.db)?;
133
134 let memory_type_str = args.r#type.map(|t| t.as_str());
135 let effective_k = if args.precise { 100_000 } else { args.k };
138 let knn_results =
139 memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
140
141 let mut direct_matches = Vec::with_capacity(effective_k);
142 let mut memory_ids: Vec<i64> = Vec::with_capacity(effective_k);
143 for (memory_id, distance) in knn_results {
144 let row = {
145 let mut stmt = conn.prepare_cached(
146 "SELECT id, namespace, name, type, description, body, body_hash,
147 session_id, source, metadata, created_at, updated_at
148 FROM memories WHERE id=?1 AND deleted_at IS NULL",
149 )?;
150 stmt.query_row(rusqlite::params![memory_id], |r| {
151 Ok(memories::MemoryRow {
152 id: r.get(0)?,
153 namespace: r.get(1)?,
154 name: r.get(2)?,
155 memory_type: r.get(3)?,
156 description: r.get(4)?,
157 body: r.get(5)?,
158 body_hash: r.get(6)?,
159 session_id: r.get(7)?,
160 source: r.get(8)?,
161 metadata: r.get(9)?,
162 created_at: r.get(10)?,
163 updated_at: r.get(11)?,
164 deleted_at: None,
165 })
166 })
167 .ok()
168 };
169 if let Some(row) = row {
170 let snippet: String = row.body.chars().take(300).collect();
171 direct_matches.push(RecallItem {
172 memory_id: row.id,
173 name: row.name,
174 namespace: row.namespace,
175 memory_type: row.memory_type,
176 description: row.description,
177 snippet,
178 distance,
179 score: RecallItem::score_from_distance(distance),
180 source: "direct".to_string(),
181 graph_depth: None,
183 });
184 memory_ids.push(memory_id);
185 }
186 }
187
188 let mut graph_matches = Vec::with_capacity(8);
189 if !args.no_graph {
190 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
191 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
192
193 let all_seed_ids: Vec<i64> = memory_ids
194 .iter()
195 .chain(entity_ids.iter())
196 .copied()
197 .collect();
198
199 if !all_seed_ids.is_empty() {
200 let graph_memory_ids = traverse_from_memories_with_hops(
201 &conn,
202 &all_seed_ids,
203 &namespace_for_graph,
204 args.min_weight,
205 args.max_hops,
206 )?;
207
208 for (graph_mem_id, hop) in graph_memory_ids {
209 if let Some(cap) = args.max_graph_results {
212 if graph_matches.len() >= cap {
213 break;
214 }
215 }
216 let row = {
217 let mut stmt = conn.prepare_cached(
218 "SELECT id, namespace, name, type, description, body, body_hash,
219 session_id, source, metadata, created_at, updated_at
220 FROM memories WHERE id=?1 AND deleted_at IS NULL",
221 )?;
222 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
223 Ok(memories::MemoryRow {
224 id: r.get(0)?,
225 namespace: r.get(1)?,
226 name: r.get(2)?,
227 memory_type: r.get(3)?,
228 description: r.get(4)?,
229 body: r.get(5)?,
230 body_hash: r.get(6)?,
231 session_id: r.get(7)?,
232 source: r.get(8)?,
233 metadata: r.get(9)?,
234 created_at: r.get(10)?,
235 updated_at: r.get(11)?,
236 deleted_at: None,
237 })
238 })
239 .ok()
240 };
241 if let Some(row) = row {
242 let snippet: String = row.body.chars().take(300).collect();
243 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
249 graph_matches.push(RecallItem {
250 memory_id: row.id,
251 name: row.name,
252 namespace: row.namespace,
253 memory_type: row.memory_type,
254 description: row.description,
255 snippet,
256 distance: graph_distance,
257 score: RecallItem::score_from_distance(graph_distance),
258 source: "graph".to_string(),
259 graph_depth: Some(hop),
260 });
261 }
262 }
263 }
264 }
265
266 if args.max_distance < 1.0 {
268 let has_relevant = direct_matches
269 .iter()
270 .any(|item| item.distance <= args.max_distance);
271 if !has_relevant {
272 return Err(AppError::NotFound(errors_msg::no_recall_results(
273 args.max_distance,
274 &args.query,
275 &namespace_for_graph,
276 )));
277 }
278 }
279
280 let results: Vec<RecallItem> = direct_matches
281 .iter()
282 .cloned()
283 .chain(graph_matches.iter().cloned())
284 .collect();
285
286 output::emit_json(&RecallResponse {
287 query: args.query,
288 k: args.k,
289 direct_matches,
290 graph_matches,
291 results,
292 elapsed_ms: start.elapsed().as_millis() as u64,
293 })?;
294
295 Ok(())
296}
297
298#[cfg(test)]
299mod tests {
300 use crate::output::{RecallItem, RecallResponse};
301
302 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
303 RecallItem {
304 memory_id: 1,
305 name: name.to_string(),
306 namespace: "global".to_string(),
307 memory_type: "fact".to_string(),
308 description: "desc".to_string(),
309 snippet: "snippet".to_string(),
310 distance,
311 score: RecallItem::score_from_distance(distance),
312 source: source.to_string(),
313 graph_depth: if source == "graph" { Some(0) } else { None },
314 }
315 }
316
317 #[test]
319 fn recall_item_score_is_present_and_finite_for_direct_match() {
320 let item = make_item("mem", 0.25, "direct");
321 let json = serde_json::to_value(&item).expect("serialization failed");
322 let score = json["score"].as_f64().expect("score must be a number");
323 assert!(
324 (0.0..=1.0).contains(&score),
325 "score must be in [0, 1], got {score}"
326 );
327 assert!(
328 (score - 0.75).abs() < 1e-6,
329 "score must equal 1 - distance for canonical case"
330 );
331 }
332
333 #[test]
334 fn recall_item_score_clamps_distance_outside_unit_range() {
335 assert_eq!(RecallItem::score_from_distance(2.0), 0.0);
337 assert_eq!(RecallItem::score_from_distance(-0.5), 1.0);
338 assert_eq!(RecallItem::score_from_distance(f32::NAN), 0.0);
339 }
340
341 #[test]
342 fn recall_response_serializes_required_fields() {
343 let resp = RecallResponse {
344 query: "rust memory".to_string(),
345 k: 5,
346 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
347 graph_matches: vec![],
348 results: vec![make_item("mem-a", 0.12, "direct")],
349 elapsed_ms: 42,
350 };
351
352 let json = serde_json::to_value(&resp).expect("serialization failed");
353 assert_eq!(json["query"], "rust memory");
354 assert_eq!(json["k"], 5);
355 assert_eq!(json["elapsed_ms"], 42u64);
356 assert!(json["direct_matches"].is_array());
357 assert!(json["graph_matches"].is_array());
358 assert!(json["results"].is_array());
359 }
360
361 #[test]
362 fn recall_item_serializes_renamed_type() {
363 let item = make_item("mem-test", 0.25, "direct");
364 let json = serde_json::to_value(&item).expect("serialization failed");
365
366 assert_eq!(json["type"], "fact");
368 assert_eq!(json["distance"], 0.25f32);
369 assert_eq!(json["source"], "direct");
370 }
371
372 #[test]
373 fn recall_response_results_contains_direct_and_graph() {
374 let direct = make_item("d-mem", 0.10, "direct");
375 let graph = make_item("g-mem", 0.0, "graph");
376
377 let resp = RecallResponse {
378 query: "query".to_string(),
379 k: 10,
380 direct_matches: vec![direct.clone()],
381 graph_matches: vec![graph.clone()],
382 results: vec![direct, graph],
383 elapsed_ms: 10,
384 };
385
386 let json = serde_json::to_value(&resp).expect("serialization failed");
387 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
388 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
389 assert_eq!(json["results"].as_array().unwrap().len(), 2);
390 assert_eq!(json["results"][0]["source"], "direct");
391 assert_eq!(json["results"][1]["source"], "graph");
392 }
393
394 #[test]
395 fn recall_response_empty_serializes_empty_arrays() {
396 let resp = RecallResponse {
397 query: "nothing".to_string(),
398 k: 3,
399 direct_matches: vec![],
400 graph_matches: vec![],
401 results: vec![],
402 elapsed_ms: 1,
403 };
404
405 let json = serde_json::to_value(&resp).expect("serialization failed");
406 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
407 assert_eq!(json["results"].as_array().unwrap().len(), 0);
408 }
409
410 #[test]
411 fn graph_matches_distance_uses_hop_count_proxy() {
412 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
418 for &(hop, expected) in cases {
419 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
420 assert!(
421 (d - expected).abs() < 0.001,
422 "hop={hop} expected={expected} got={d}"
423 );
424 }
425 }
426}