1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13#[derive(clap::Args)]
20#[command(after_long_help = "EXAMPLES:\n \
21 # Semantic search for top 5 matches\n \
22 sqlite-graphrag recall \"authentication design\" --k 5\n\n \
23 # Disable automatic graph expansion\n \
24 sqlite-graphrag recall \"JWT tokens\" --k 3 --no-graph\n\n \
25 # Limit graph traversal depth and minimum edge weight\n \
26 sqlite-graphrag recall \"auth\" --k 5 --max-hops 2 --min-weight 0.3\n\n \
27 # Filter by memory type\n \
28 sqlite-graphrag recall \"deployment\" --type decision --k 10\n\n \
29 # Cap results by distance threshold\n \
30 sqlite-graphrag recall \"API design\" --k 5 --max-distance 0.8")]
31pub struct RecallArgs {
32 #[arg(help = "Search query string (semantic vector search via sqlite-vec)")]
33 pub query: String,
34 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
42 pub k: usize,
43 #[arg(long, value_enum)]
47 pub r#type: Option<MemoryType>,
48 #[arg(long)]
49 pub namespace: Option<String>,
50 #[arg(long)]
51 pub no_graph: bool,
52 #[arg(long)]
58 pub precise: bool,
59 #[arg(long, default_value = "2")]
60 pub max_hops: u32,
61 #[arg(long, default_value = "0.3")]
62 pub min_weight: f64,
63 #[arg(long, value_name = "N")]
69 pub max_graph_results: Option<usize>,
70 #[arg(long, alias = "min-distance", default_value = "1.0")]
75 pub max_distance: f32,
76 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
77 pub format: JsonOutputFormat,
78 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
79 pub db: Option<String>,
80 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
82 pub json: bool,
83 #[arg(long, conflicts_with = "namespace")]
88 pub all_namespaces: bool,
89 #[command(flatten)]
90 pub daemon: crate::cli::DaemonOpts,
91}
92
93pub fn run(args: RecallArgs) -> Result<(), AppError> {
94 let start = std::time::Instant::now();
95 let _ = args.format;
96 if args.query.trim().is_empty() {
97 return Err(AppError::Validation(crate::i18n::validation::empty_query()));
98 }
99 let namespaces: Vec<String> = if args.all_namespaces {
103 Vec::new()
104 } else {
105 vec![crate::namespace::resolve_namespace(
106 args.namespace.as_deref(),
107 )?]
108 };
109 let namespace_for_graph = namespaces
111 .first()
112 .cloned()
113 .unwrap_or_else(|| "global".to_string());
114 let paths = AppPaths::resolve(args.db.as_deref())?;
115
116 crate::storage::connection::ensure_db_ready(&paths)?;
117
118 output::emit_progress_i18n(
119 "Computing query embedding...",
120 "Calculando embedding da consulta...",
121 );
122 let embedding = crate::daemon::embed_query_or_local(
123 &paths.models,
124 &args.query,
125 args.daemon.autostart_daemon,
126 )?;
127
128 let conn = open_ro(&paths.db)?;
129
130 let memory_type_str = args.r#type.map(|t| t.as_str());
131 let effective_k = if args.precise { 100_000 } else { args.k };
134 let knn_results =
135 memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
136
137 let mut direct_matches = Vec::with_capacity(effective_k);
138 let mut memory_ids: Vec<i64> = Vec::with_capacity(effective_k);
139 for (memory_id, distance) in knn_results {
140 let row = {
141 let mut stmt = conn.prepare_cached(
142 "SELECT id, namespace, name, type, description, body, body_hash,
143 session_id, source, metadata, created_at, updated_at
144 FROM memories WHERE id=?1 AND deleted_at IS NULL",
145 )?;
146 stmt.query_row(rusqlite::params![memory_id], |r| {
147 Ok(memories::MemoryRow {
148 id: r.get(0)?,
149 namespace: r.get(1)?,
150 name: r.get(2)?,
151 memory_type: r.get(3)?,
152 description: r.get(4)?,
153 body: r.get(5)?,
154 body_hash: r.get(6)?,
155 session_id: r.get(7)?,
156 source: r.get(8)?,
157 metadata: r.get(9)?,
158 created_at: r.get(10)?,
159 updated_at: r.get(11)?,
160 deleted_at: None,
161 })
162 })
163 .ok()
164 };
165 if let Some(row) = row {
166 let snippet: String = row.body.chars().take(300).collect();
167 direct_matches.push(RecallItem {
168 memory_id: row.id,
169 name: row.name,
170 namespace: row.namespace,
171 memory_type: row.memory_type,
172 description: row.description,
173 snippet,
174 distance,
175 score: RecallItem::score_from_distance(distance),
176 source: "direct".to_string(),
177 graph_depth: None,
179 });
180 memory_ids.push(memory_id);
181 }
182 }
183
184 let mut graph_matches = Vec::new();
185 if !args.no_graph {
186 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
187 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
188
189 let all_seed_ids: Vec<i64> = memory_ids
190 .iter()
191 .chain(entity_ids.iter())
192 .copied()
193 .collect();
194
195 if !all_seed_ids.is_empty() {
196 let graph_memory_ids = traverse_from_memories_with_hops(
197 &conn,
198 &all_seed_ids,
199 &namespace_for_graph,
200 args.min_weight,
201 args.max_hops,
202 )?;
203
204 for (graph_mem_id, hop) in graph_memory_ids {
205 if let Some(cap) = args.max_graph_results {
208 if graph_matches.len() >= cap {
209 break;
210 }
211 }
212 let row = {
213 let mut stmt = conn.prepare_cached(
214 "SELECT id, namespace, name, type, description, body, body_hash,
215 session_id, source, metadata, created_at, updated_at
216 FROM memories WHERE id=?1 AND deleted_at IS NULL",
217 )?;
218 stmt.query_row(rusqlite::params![graph_mem_id], |r| {
219 Ok(memories::MemoryRow {
220 id: r.get(0)?,
221 namespace: r.get(1)?,
222 name: r.get(2)?,
223 memory_type: r.get(3)?,
224 description: r.get(4)?,
225 body: r.get(5)?,
226 body_hash: r.get(6)?,
227 session_id: r.get(7)?,
228 source: r.get(8)?,
229 metadata: r.get(9)?,
230 created_at: r.get(10)?,
231 updated_at: r.get(11)?,
232 deleted_at: None,
233 })
234 })
235 .ok()
236 };
237 if let Some(row) = row {
238 let snippet: String = row.body.chars().take(300).collect();
239 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
245 graph_matches.push(RecallItem {
246 memory_id: row.id,
247 name: row.name,
248 namespace: row.namespace,
249 memory_type: row.memory_type,
250 description: row.description,
251 snippet,
252 distance: graph_distance,
253 score: RecallItem::score_from_distance(graph_distance),
254 source: "graph".to_string(),
255 graph_depth: Some(hop),
256 });
257 }
258 }
259 }
260 }
261
262 if args.max_distance < 1.0 {
264 let has_relevant = direct_matches
265 .iter()
266 .any(|item| item.distance <= args.max_distance);
267 if !has_relevant {
268 return Err(AppError::NotFound(errors_msg::no_recall_results(
269 args.max_distance,
270 &args.query,
271 &namespace_for_graph,
272 )));
273 }
274 }
275
276 let results: Vec<RecallItem> = direct_matches
277 .iter()
278 .cloned()
279 .chain(graph_matches.iter().cloned())
280 .collect();
281
282 output::emit_json(&RecallResponse {
283 query: args.query,
284 k: args.k,
285 direct_matches,
286 graph_matches,
287 results,
288 elapsed_ms: start.elapsed().as_millis() as u64,
289 })?;
290
291 Ok(())
292}
293
294#[cfg(test)]
295mod tests {
296 use crate::output::{RecallItem, RecallResponse};
297
298 fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
299 RecallItem {
300 memory_id: 1,
301 name: name.to_string(),
302 namespace: "global".to_string(),
303 memory_type: "fact".to_string(),
304 description: "desc".to_string(),
305 snippet: "snippet".to_string(),
306 distance,
307 score: RecallItem::score_from_distance(distance),
308 source: source.to_string(),
309 graph_depth: if source == "graph" { Some(0) } else { None },
310 }
311 }
312
313 #[test]
315 fn recall_item_score_is_present_and_finite_for_direct_match() {
316 let item = make_item("mem", 0.25, "direct");
317 let json = serde_json::to_value(&item).expect("serialization failed");
318 let score = json["score"].as_f64().expect("score must be a number");
319 assert!(
320 (0.0..=1.0).contains(&score),
321 "score must be in [0, 1], got {score}"
322 );
323 assert!(
324 (score - 0.75).abs() < 1e-6,
325 "score must equal 1 - distance for canonical case"
326 );
327 }
328
329 #[test]
330 fn recall_item_score_clamps_distance_outside_unit_range() {
331 assert_eq!(RecallItem::score_from_distance(2.0), 0.0);
333 assert_eq!(RecallItem::score_from_distance(-0.5), 1.0);
334 assert_eq!(RecallItem::score_from_distance(f32::NAN), 0.0);
335 }
336
337 #[test]
338 fn recall_response_serializes_required_fields() {
339 let resp = RecallResponse {
340 query: "rust memory".to_string(),
341 k: 5,
342 direct_matches: vec![make_item("mem-a", 0.12, "direct")],
343 graph_matches: vec![],
344 results: vec![make_item("mem-a", 0.12, "direct")],
345 elapsed_ms: 42,
346 };
347
348 let json = serde_json::to_value(&resp).expect("serialization failed");
349 assert_eq!(json["query"], "rust memory");
350 assert_eq!(json["k"], 5);
351 assert_eq!(json["elapsed_ms"], 42u64);
352 assert!(json["direct_matches"].is_array());
353 assert!(json["graph_matches"].is_array());
354 assert!(json["results"].is_array());
355 }
356
357 #[test]
358 fn recall_item_serializes_renamed_type() {
359 let item = make_item("mem-test", 0.25, "direct");
360 let json = serde_json::to_value(&item).expect("serialization failed");
361
362 assert_eq!(json["type"], "fact");
364 assert_eq!(json["distance"], 0.25f32);
365 assert_eq!(json["source"], "direct");
366 }
367
368 #[test]
369 fn recall_response_results_contains_direct_and_graph() {
370 let direct = make_item("d-mem", 0.10, "direct");
371 let graph = make_item("g-mem", 0.0, "graph");
372
373 let resp = RecallResponse {
374 query: "query".to_string(),
375 k: 10,
376 direct_matches: vec![direct.clone()],
377 graph_matches: vec![graph.clone()],
378 results: vec![direct, graph],
379 elapsed_ms: 10,
380 };
381
382 let json = serde_json::to_value(&resp).expect("serialization failed");
383 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
384 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
385 assert_eq!(json["results"].as_array().unwrap().len(), 2);
386 assert_eq!(json["results"][0]["source"], "direct");
387 assert_eq!(json["results"][1]["source"], "graph");
388 }
389
390 #[test]
391 fn recall_response_empty_serializes_empty_arrays() {
392 let resp = RecallResponse {
393 query: "nothing".to_string(),
394 k: 3,
395 direct_matches: vec![],
396 graph_matches: vec![],
397 results: vec![],
398 elapsed_ms: 1,
399 };
400
401 let json = serde_json::to_value(&resp).expect("serialization failed");
402 assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
403 assert_eq!(json["results"].as_array().unwrap().len(), 0);
404 }
405
406 #[test]
407 fn graph_matches_distance_uses_hop_count_proxy() {
408 let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
414 for &(hop, expected) in cases {
415 let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
416 assert!(
417 (d - expected).abs() < 0.001,
418 "hop={hop} expected={expected} got={d}"
419 );
420 }
421 }
422}