Skip to main content

sqlite_graphrag/commands/
recall.rs

1//! Handler for the `recall` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13/// Arguments for the `recall` subcommand.
14///
15/// When `--namespace` is omitted the query runs against the `global` namespace,
16/// which is the default namespace used by `remember` when no `--namespace` flag
17/// is provided. Pass an explicit `--namespace` value to search a different
18/// isolated namespace.
19#[derive(clap::Args)]
20#[command(after_long_help = "EXAMPLES:\n  \
21    # Semantic search for top 5 matches\n  \
22    sqlite-graphrag recall \"authentication design\" --k 5\n\n  \
23    # Disable automatic graph expansion\n  \
24    sqlite-graphrag recall \"JWT tokens\" --k 3 --no-graph\n\n  \
25    # Limit graph traversal depth and minimum edge weight\n  \
26    sqlite-graphrag recall \"auth\" --k 5 --max-hops 2 --min-weight 0.3\n\n  \
27    # Filter by memory type\n  \
28    sqlite-graphrag recall \"deployment\" --type decision --k 10\n\n  \
29    # Cap results by distance threshold\n  \
30    sqlite-graphrag recall \"API design\" --k 5 --max-distance 0.8\n\n  \
31NOTES:\n  \
32    When --no-graph is active, graph traversal is skipped and every result has\n  \
33    source=\"direct\". The source field is therefore redundant with --no-graph and\n  \
34    may be ignored by callers in that mode.")]
35pub struct RecallArgs {
36    #[arg(help = "Search query string (semantic vector search via sqlite-vec)")]
37    pub query: String,
38    /// Maximum number of direct vector matches to return.
39    ///
40    /// Note: this flag controls only `direct_matches`. Graph traversal results
41    /// (`graph_matches`) are unbounded by default; use `--max-graph-results` to
42    /// cap them independently. The `results` field aggregates both lists.
43    /// Validated to the inclusive range `1..=4096` (the upper bound matches
44    /// `sqlite-vec`'s knn limit; out-of-range values are rejected at parse time).
45    #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
46    pub k: usize,
47    /// Filter by memory.type. Note: distinct from graph entity_type
48    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
49    /// used in --entities-file.
50    #[arg(long, value_enum)]
51    pub r#type: Option<MemoryType>,
52    #[arg(long)]
53    pub namespace: Option<String>,
54    #[arg(long)]
55    pub no_graph: bool,
56    /// Disable -k cap and return all direct matches without truncation.
57    ///
58    /// When set, the `-k`/`--k` flag is ignored for `direct_matches` and the
59    /// response includes every match above the distance threshold. Useful when
60    /// callers need the complete set rather than a top-N preview.
61    #[arg(long)]
62    pub precise: bool,
63    #[arg(long, default_value = "2")]
64    pub max_hops: u32,
65    #[arg(long, default_value = "0.3")]
66    pub min_weight: f64,
67    /// Cap the size of `graph_matches` to at most N entries.
68    ///
69    /// Defaults to unbounded (`None`) so existing pipelines see the same shape
70    /// as in v1.0.22 and earlier. Set this when a query touches a dense graph
71    /// neighbourhood and the caller only needs a top-N preview. Added in v1.0.23.
72    #[arg(long, value_name = "N")]
73    pub max_graph_results: Option<usize>,
74    /// Filter results by maximum distance. Results with distance greater than this value
75    /// are excluded. If all matches exceed this threshold, the command exits with code 4
76    /// (`not found`) per the documented public contract.
77    /// Default `1.0` disables the filter and preserves the top-k behavior.
78    #[arg(long, alias = "min-distance", default_value = "1.0")]
79    pub max_distance: f32,
80    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
81    pub format: JsonOutputFormat,
82    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
83    pub db: Option<String>,
84    /// Accept `--json` as a no-op because output is already JSON by default.
85    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
86    pub json: bool,
87    /// Search across all namespaces instead of a single namespace.
88    ///
89    /// Cannot be combined with `--namespace`. When set, the query runs against
90    /// every namespace and results include a `namespace` field to identify origin.
91    #[arg(long, conflicts_with = "namespace")]
92    pub all_namespaces: bool,
93    #[command(flatten)]
94    pub daemon: crate::cli::DaemonOpts,
95}
96
97pub fn run(args: RecallArgs) -> Result<(), AppError> {
98    let start = std::time::Instant::now();
99    let _ = args.format;
100    if args.query.trim().is_empty() {
101        return Err(AppError::Validation(crate::i18n::validation::empty_query()));
102    }
103    // Resolve the list of namespaces to search:
104    // - empty vec  => all namespaces (sentinel used by knn_search)
105    // - single vec => one namespace (default or --namespace value)
106    let namespaces: Vec<String> = if args.all_namespaces {
107        Vec::new()
108    } else {
109        vec![crate::namespace::resolve_namespace(
110            args.namespace.as_deref(),
111        )?]
112    };
113    // Single namespace string used for graph traversal and error messages.
114    let namespace_for_graph = namespaces
115        .first()
116        .cloned()
117        .unwrap_or_else(|| "global".to_string());
118    let paths = AppPaths::resolve(args.db.as_deref())?;
119
120    crate::storage::connection::ensure_db_ready(&paths)?;
121
122    output::emit_progress_i18n(
123        "Computing query embedding...",
124        "Calculando embedding da consulta...",
125    );
126    let embedding = crate::daemon::embed_query_or_local(
127        &paths.models,
128        &args.query,
129        args.daemon.autostart_daemon,
130    )?;
131
132    let conn = open_ro(&paths.db)?;
133
134    let memory_type_str = args.r#type.map(|t| t.as_str());
135    // When --precise is set, lift the -k cap so every match is returned; the
136    // max_distance filter below will trim irrelevant results instead.
137    let effective_k = if args.precise { 100_000 } else { args.k };
138    let knn_results =
139        memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
140
141    let mut direct_matches = Vec::with_capacity(effective_k);
142    let mut memory_ids: Vec<i64> = Vec::with_capacity(effective_k);
143    for (memory_id, distance) in knn_results {
144        let row = {
145            let mut stmt = conn.prepare_cached(
146                "SELECT id, namespace, name, type, description, body, body_hash,
147                        session_id, source, metadata, created_at, updated_at
148                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
149            )?;
150            stmt.query_row(rusqlite::params![memory_id], |r| {
151                Ok(memories::MemoryRow {
152                    id: r.get(0)?,
153                    namespace: r.get(1)?,
154                    name: r.get(2)?,
155                    memory_type: r.get(3)?,
156                    description: r.get(4)?,
157                    body: r.get(5)?,
158                    body_hash: r.get(6)?,
159                    session_id: r.get(7)?,
160                    source: r.get(8)?,
161                    metadata: r.get(9)?,
162                    created_at: r.get(10)?,
163                    updated_at: r.get(11)?,
164                    deleted_at: None,
165                })
166            })
167            .ok()
168        };
169        if let Some(row) = row {
170            let snippet: String = row.body.chars().take(300).collect();
171            direct_matches.push(RecallItem {
172                memory_id: row.id,
173                name: row.name,
174                namespace: row.namespace,
175                memory_type: row.memory_type,
176                description: row.description,
177                snippet,
178                distance,
179                score: RecallItem::score_from_distance(distance),
180                source: "direct".to_string(),
181                // Direct vector matches do not have a graph depth; rely on `distance`.
182                graph_depth: None,
183            });
184            memory_ids.push(memory_id);
185        }
186    }
187
188    let mut graph_matches = Vec::with_capacity(8);
189    if !args.no_graph {
190        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
191        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
192
193        let all_seed_ids: Vec<i64> = memory_ids
194            .iter()
195            .chain(entity_ids.iter())
196            .copied()
197            .collect();
198
199        if !all_seed_ids.is_empty() {
200            let graph_memory_ids = traverse_from_memories_with_hops(
201                &conn,
202                &all_seed_ids,
203                &namespace_for_graph,
204                args.min_weight,
205                args.max_hops,
206            )?;
207
208            for (graph_mem_id, hop) in graph_memory_ids {
209                // v1.0.23: respect the optional cap on graph results so dense
210                // neighbourhoods do not flood the response unintentionally.
211                if let Some(cap) = args.max_graph_results {
212                    if graph_matches.len() >= cap {
213                        break;
214                    }
215                }
216                let row = {
217                    let mut stmt = conn.prepare_cached(
218                        "SELECT id, namespace, name, type, description, body, body_hash,
219                                session_id, source, metadata, created_at, updated_at
220                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
221                    )?;
222                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
223                        Ok(memories::MemoryRow {
224                            id: r.get(0)?,
225                            namespace: r.get(1)?,
226                            name: r.get(2)?,
227                            memory_type: r.get(3)?,
228                            description: r.get(4)?,
229                            body: r.get(5)?,
230                            body_hash: r.get(6)?,
231                            session_id: r.get(7)?,
232                            source: r.get(8)?,
233                            metadata: r.get(9)?,
234                            created_at: r.get(10)?,
235                            updated_at: r.get(11)?,
236                            deleted_at: None,
237                        })
238                    })
239                    .ok()
240                };
241                if let Some(row) = row {
242                    let snippet: String = row.body.chars().take(300).collect();
243                    // Compute approximate distance from graph hop count.
244                    // WARNING: graph_distance is a hop-count proxy, NOT real cosine distance.
245                    // For confident ranking, prefer the `graph_depth` field (set to Some(hop)
246                    // below). Real cosine distance for graph matches would require
247                    // re-embedding (200-500ms latency) and is reserved for v1.0.28.
248                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
249                    graph_matches.push(RecallItem {
250                        memory_id: row.id,
251                        name: row.name,
252                        namespace: row.namespace,
253                        memory_type: row.memory_type,
254                        description: row.description,
255                        snippet,
256                        distance: graph_distance,
257                        score: RecallItem::score_from_distance(graph_distance),
258                        source: "graph".to_string(),
259                        graph_depth: Some(hop),
260                    });
261                }
262            }
263        }
264    }
265
266    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
267    if args.max_distance < 1.0 {
268        let has_relevant = direct_matches
269            .iter()
270            .any(|item| item.distance <= args.max_distance);
271        if !has_relevant {
272            return Err(AppError::NotFound(errors_msg::no_recall_results(
273                args.max_distance,
274                &args.query,
275                &namespace_for_graph,
276            )));
277        }
278    }
279
280    let results: Vec<RecallItem> = direct_matches
281        .iter()
282        .cloned()
283        .chain(graph_matches.iter().cloned())
284        .collect();
285
286    output::emit_json(&RecallResponse {
287        query: args.query,
288        k: args.k,
289        direct_matches,
290        graph_matches,
291        results,
292        elapsed_ms: start.elapsed().as_millis() as u64,
293    })?;
294
295    Ok(())
296}
297
298#[cfg(test)]
299mod tests {
300    use crate::output::{RecallItem, RecallResponse};
301
302    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
303        RecallItem {
304            memory_id: 1,
305            name: name.to_string(),
306            namespace: "global".to_string(),
307            memory_type: "fact".to_string(),
308            description: "desc".to_string(),
309            snippet: "snippet".to_string(),
310            distance,
311            score: RecallItem::score_from_distance(distance),
312            source: source.to_string(),
313            graph_depth: if source == "graph" { Some(0) } else { None },
314        }
315    }
316
317    // Bug M-A5: every RecallItem carries a non-null cosine similarity score.
318    #[test]
319    fn recall_item_score_is_present_and_finite_for_direct_match() {
320        let item = make_item("mem", 0.25, "direct");
321        let json = serde_json::to_value(&item).expect("serialization failed");
322        let score = json["score"].as_f64().expect("score must be a number");
323        assert!(
324            (0.0..=1.0).contains(&score),
325            "score must be in [0, 1], got {score}"
326        );
327        assert!(
328            (score - 0.75).abs() < 1e-6,
329            "score must equal 1 - distance for canonical case"
330        );
331    }
332
333    #[test]
334    fn recall_item_score_clamps_distance_outside_unit_range() {
335        // Pathological distances must not yield score outside [0, 1] or NaN.
336        assert_eq!(RecallItem::score_from_distance(2.0), 0.0);
337        assert_eq!(RecallItem::score_from_distance(-0.5), 1.0);
338        assert_eq!(RecallItem::score_from_distance(f32::NAN), 0.0);
339    }
340
341    #[test]
342    fn recall_response_serializes_required_fields() {
343        let resp = RecallResponse {
344            query: "rust memory".to_string(),
345            k: 5,
346            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
347            graph_matches: vec![],
348            results: vec![make_item("mem-a", 0.12, "direct")],
349            elapsed_ms: 42,
350        };
351
352        let json = serde_json::to_value(&resp).expect("serialization failed");
353        assert_eq!(json["query"], "rust memory");
354        assert_eq!(json["k"], 5);
355        assert_eq!(json["elapsed_ms"], 42u64);
356        assert!(json["direct_matches"].is_array());
357        assert!(json["graph_matches"].is_array());
358        assert!(json["results"].is_array());
359    }
360
361    #[test]
362    fn recall_item_serializes_renamed_type() {
363        let item = make_item("mem-test", 0.25, "direct");
364        let json = serde_json::to_value(&item).expect("serialization failed");
365
366        // The memory_type field is renamed to "type" in JSON
367        assert_eq!(json["type"], "fact");
368        assert_eq!(json["distance"], 0.25f32);
369        assert_eq!(json["source"], "direct");
370    }
371
372    #[test]
373    fn recall_response_results_contains_direct_and_graph() {
374        let direct = make_item("d-mem", 0.10, "direct");
375        let graph = make_item("g-mem", 0.0, "graph");
376
377        let resp = RecallResponse {
378            query: "query".to_string(),
379            k: 10,
380            direct_matches: vec![direct.clone()],
381            graph_matches: vec![graph.clone()],
382            results: vec![direct, graph],
383            elapsed_ms: 10,
384        };
385
386        let json = serde_json::to_value(&resp).expect("serialization failed");
387        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
388        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
389        assert_eq!(json["results"].as_array().unwrap().len(), 2);
390        assert_eq!(json["results"][0]["source"], "direct");
391        assert_eq!(json["results"][1]["source"], "graph");
392    }
393
394    #[test]
395    fn recall_response_empty_serializes_empty_arrays() {
396        let resp = RecallResponse {
397            query: "nothing".to_string(),
398            k: 3,
399            direct_matches: vec![],
400            graph_matches: vec![],
401            results: vec![],
402            elapsed_ms: 1,
403        };
404
405        let json = serde_json::to_value(&resp).expect("serialization failed");
406        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
407        assert_eq!(json["results"].as_array().unwrap().len(), 0);
408    }
409
410    #[test]
411    fn graph_matches_distance_uses_hop_count_proxy() {
412        // Verify the hop-count proxy formula: 1.0 - 1.0 / (hop + 1.0)
413        // hop=0 → 0.0 (seed-level entity, identity distance)
414        // hop=1 → 0.5
415        // hop=2 → ≈ 0.667
416        // hop=3 → 0.75
417        let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
418        for &(hop, expected) in cases {
419            let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
420            assert!(
421                (d - expected).abs() < 0.001,
422                "hop={hop} expected={expected} got={d}"
423            );
424        }
425    }
426}