Skip to main content

sqlite_graphrag/commands/
recall.rs

1//! Handler for the `recall` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13/// Arguments for the `recall` subcommand.
14///
15/// When `--namespace` is omitted the query runs against the `global` namespace,
16/// which is the default namespace used by `remember` when no `--namespace` flag
17/// is provided. Pass an explicit `--namespace` value to search a different
18/// isolated namespace.
19#[derive(clap::Args)]
20pub struct RecallArgs {
21    #[arg(help = "Search query string (semantic vector search via sqlite-vec)")]
22    pub query: String,
23    /// Maximum number of direct vector matches to return.
24    ///
25    /// Note: this flag controls only `direct_matches`. Graph traversal results
26    /// (`graph_matches`) are unbounded by default; use `--max-graph-results` to
27    /// cap them independently. The `results` field aggregates both lists.
28    /// Validated to the inclusive range `1..=4096` (the upper bound matches
29    /// `sqlite-vec`'s knn limit; out-of-range values are rejected at parse time).
30    #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
31    pub k: usize,
32    /// Filter by memory.type. Note: distinct from graph entity_type
33    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
34    /// used in --entities-file.
35    #[arg(long, value_enum)]
36    pub r#type: Option<MemoryType>,
37    #[arg(long)]
38    pub namespace: Option<String>,
39    #[arg(long)]
40    pub no_graph: bool,
41    /// Disable -k cap and return all direct matches without truncation.
42    ///
43    /// When set, the `-k`/`--k` flag is ignored for `direct_matches` and the
44    /// response includes every match above the distance threshold. Useful when
45    /// callers need the complete set rather than a top-N preview.
46    #[arg(long)]
47    pub precise: bool,
48    #[arg(long, default_value = "2")]
49    pub max_hops: u32,
50    #[arg(long, default_value = "0.3")]
51    pub min_weight: f64,
52    /// Cap the size of `graph_matches` to at most N entries.
53    ///
54    /// Defaults to unbounded (`None`) so existing pipelines see the same shape
55    /// as in v1.0.22 and earlier. Set this when a query touches a dense graph
56    /// neighbourhood and the caller only needs a top-N preview. Added in v1.0.23.
57    #[arg(long, value_name = "N")]
58    pub max_graph_results: Option<usize>,
59    /// Filter results by maximum distance. Results with distance greater than this value
60    /// are excluded. If all matches exceed this threshold, the command exits with code 4
61    /// (`not found`) per the documented public contract.
62    /// Default `1.0` disables the filter and preserves the top-k behavior.
63    #[arg(long, alias = "min-distance", default_value = "1.0")]
64    pub max_distance: f32,
65    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66    pub format: JsonOutputFormat,
67    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68    pub db: Option<String>,
69    /// Accept `--json` as a no-op because output is already JSON by default.
70    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71    pub json: bool,
72    /// Search across all namespaces instead of a single namespace.
73    ///
74    /// Cannot be combined with `--namespace`. When set, the query runs against
75    /// every namespace and results include a `namespace` field to identify origin.
76    #[arg(long, conflicts_with = "namespace")]
77    pub all_namespaces: bool,
78    #[command(flatten)]
79    pub daemon: crate::cli::DaemonOpts,
80}
81
82pub fn run(args: RecallArgs) -> Result<(), AppError> {
83    let start = std::time::Instant::now();
84    let _ = args.format;
85    if args.query.trim().is_empty() {
86        return Err(AppError::Validation(crate::i18n::validation::empty_query()));
87    }
88    // Resolve the list of namespaces to search:
89    // - empty vec  => all namespaces (sentinel used by knn_search)
90    // - single vec => one namespace (default or --namespace value)
91    let namespaces: Vec<String> = if args.all_namespaces {
92        Vec::new()
93    } else {
94        vec![crate::namespace::resolve_namespace(
95            args.namespace.as_deref(),
96        )?]
97    };
98    // Single namespace string used for graph traversal and error messages.
99    let namespace_for_graph = namespaces
100        .first()
101        .cloned()
102        .unwrap_or_else(|| "global".to_string());
103    let paths = AppPaths::resolve(args.db.as_deref())?;
104
105    crate::storage::connection::ensure_db_ready(&paths)?;
106
107    output::emit_progress_i18n(
108        "Computing query embedding...",
109        "Calculando embedding da consulta...",
110    );
111    let embedding = crate::daemon::embed_query_or_local(
112        &paths.models,
113        &args.query,
114        args.daemon.autostart_daemon,
115    )?;
116
117    let conn = open_ro(&paths.db)?;
118
119    let memory_type_str = args.r#type.map(|t| t.as_str());
120    // When --precise is set, lift the -k cap so every match is returned; the
121    // max_distance filter below will trim irrelevant results instead.
122    let effective_k = if args.precise { 100_000 } else { args.k };
123    let knn_results =
124        memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
125
126    let mut direct_matches = Vec::new();
127    let mut memory_ids: Vec<i64> = Vec::new();
128    for (memory_id, distance) in knn_results {
129        let row = {
130            let mut stmt = conn.prepare_cached(
131                "SELECT id, namespace, name, type, description, body, body_hash,
132                        session_id, source, metadata, created_at, updated_at
133                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
134            )?;
135            stmt.query_row(rusqlite::params![memory_id], |r| {
136                Ok(memories::MemoryRow {
137                    id: r.get(0)?,
138                    namespace: r.get(1)?,
139                    name: r.get(2)?,
140                    memory_type: r.get(3)?,
141                    description: r.get(4)?,
142                    body: r.get(5)?,
143                    body_hash: r.get(6)?,
144                    session_id: r.get(7)?,
145                    source: r.get(8)?,
146                    metadata: r.get(9)?,
147                    created_at: r.get(10)?,
148                    updated_at: r.get(11)?,
149                    deleted_at: None,
150                })
151            })
152            .ok()
153        };
154        if let Some(row) = row {
155            let snippet: String = row.body.chars().take(300).collect();
156            direct_matches.push(RecallItem {
157                memory_id: row.id,
158                name: row.name,
159                namespace: row.namespace,
160                memory_type: row.memory_type,
161                description: row.description,
162                snippet,
163                distance,
164                source: "direct".to_string(),
165                // Direct vector matches do not have a graph depth; rely on `distance`.
166                graph_depth: None,
167            });
168            memory_ids.push(memory_id);
169        }
170    }
171
172    let mut graph_matches = Vec::new();
173    if !args.no_graph {
174        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
175        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
176
177        let all_seed_ids: Vec<i64> = memory_ids
178            .iter()
179            .chain(entity_ids.iter())
180            .copied()
181            .collect();
182
183        if !all_seed_ids.is_empty() {
184            let graph_memory_ids = traverse_from_memories_with_hops(
185                &conn,
186                &all_seed_ids,
187                &namespace_for_graph,
188                args.min_weight,
189                args.max_hops,
190            )?;
191
192            for (graph_mem_id, hop) in graph_memory_ids {
193                // v1.0.23: respect the optional cap on graph results so dense
194                // neighbourhoods do not flood the response unintentionally.
195                if let Some(cap) = args.max_graph_results {
196                    if graph_matches.len() >= cap {
197                        break;
198                    }
199                }
200                let row = {
201                    let mut stmt = conn.prepare_cached(
202                        "SELECT id, namespace, name, type, description, body, body_hash,
203                                session_id, source, metadata, created_at, updated_at
204                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
205                    )?;
206                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
207                        Ok(memories::MemoryRow {
208                            id: r.get(0)?,
209                            namespace: r.get(1)?,
210                            name: r.get(2)?,
211                            memory_type: r.get(3)?,
212                            description: r.get(4)?,
213                            body: r.get(5)?,
214                            body_hash: r.get(6)?,
215                            session_id: r.get(7)?,
216                            source: r.get(8)?,
217                            metadata: r.get(9)?,
218                            created_at: r.get(10)?,
219                            updated_at: r.get(11)?,
220                            deleted_at: None,
221                        })
222                    })
223                    .ok()
224                };
225                if let Some(row) = row {
226                    let snippet: String = row.body.chars().take(300).collect();
227                    // Compute approximate distance from graph hop count.
228                    // WARNING: graph_distance is a hop-count proxy, NOT real cosine distance.
229                    // For confident ranking, prefer the `graph_depth` field (set to Some(hop)
230                    // below). Real cosine distance for graph matches would require
231                    // re-embedding (200-500ms latency) and is reserved for v1.0.28.
232                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
233                    graph_matches.push(RecallItem {
234                        memory_id: row.id,
235                        name: row.name,
236                        namespace: row.namespace,
237                        memory_type: row.memory_type,
238                        description: row.description,
239                        snippet,
240                        distance: graph_distance,
241                        source: "graph".to_string(),
242                        graph_depth: Some(hop),
243                    });
244                }
245            }
246        }
247    }
248
249    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
250    if args.max_distance < 1.0 {
251        let has_relevant = direct_matches
252            .iter()
253            .any(|item| item.distance <= args.max_distance);
254        if !has_relevant {
255            return Err(AppError::NotFound(errors_msg::no_recall_results(
256                args.max_distance,
257                &args.query,
258                &namespace_for_graph,
259            )));
260        }
261    }
262
263    let results: Vec<RecallItem> = direct_matches
264        .iter()
265        .cloned()
266        .chain(graph_matches.iter().cloned())
267        .collect();
268
269    output::emit_json(&RecallResponse {
270        query: args.query,
271        k: args.k,
272        direct_matches,
273        graph_matches,
274        results,
275        elapsed_ms: start.elapsed().as_millis() as u64,
276    })?;
277
278    Ok(())
279}
280
281#[cfg(test)]
282mod tests {
283    use crate::output::{RecallItem, RecallResponse};
284
285    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
286        RecallItem {
287            memory_id: 1,
288            name: name.to_string(),
289            namespace: "global".to_string(),
290            memory_type: "fact".to_string(),
291            description: "desc".to_string(),
292            snippet: "snippet".to_string(),
293            distance,
294            source: source.to_string(),
295            graph_depth: if source == "graph" { Some(0) } else { None },
296        }
297    }
298
299    #[test]
300    fn recall_response_serializes_required_fields() {
301        let resp = RecallResponse {
302            query: "rust memory".to_string(),
303            k: 5,
304            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
305            graph_matches: vec![],
306            results: vec![make_item("mem-a", 0.12, "direct")],
307            elapsed_ms: 42,
308        };
309
310        let json = serde_json::to_value(&resp).expect("serialization failed");
311        assert_eq!(json["query"], "rust memory");
312        assert_eq!(json["k"], 5);
313        assert_eq!(json["elapsed_ms"], 42u64);
314        assert!(json["direct_matches"].is_array());
315        assert!(json["graph_matches"].is_array());
316        assert!(json["results"].is_array());
317    }
318
319    #[test]
320    fn recall_item_serializes_renamed_type() {
321        let item = make_item("mem-test", 0.25, "direct");
322        let json = serde_json::to_value(&item).expect("serialization failed");
323
324        // The memory_type field is renamed to "type" in JSON
325        assert_eq!(json["type"], "fact");
326        assert_eq!(json["distance"], 0.25f32);
327        assert_eq!(json["source"], "direct");
328    }
329
330    #[test]
331    fn recall_response_results_contains_direct_and_graph() {
332        let direct = make_item("d-mem", 0.10, "direct");
333        let graph = make_item("g-mem", 0.0, "graph");
334
335        let resp = RecallResponse {
336            query: "query".to_string(),
337            k: 10,
338            direct_matches: vec![direct.clone()],
339            graph_matches: vec![graph.clone()],
340            results: vec![direct, graph],
341            elapsed_ms: 10,
342        };
343
344        let json = serde_json::to_value(&resp).expect("serialization failed");
345        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
346        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
347        assert_eq!(json["results"].as_array().unwrap().len(), 2);
348        assert_eq!(json["results"][0]["source"], "direct");
349        assert_eq!(json["results"][1]["source"], "graph");
350    }
351
352    #[test]
353    fn recall_response_empty_serializes_empty_arrays() {
354        let resp = RecallResponse {
355            query: "nothing".to_string(),
356            k: 3,
357            direct_matches: vec![],
358            graph_matches: vec![],
359            results: vec![],
360            elapsed_ms: 1,
361        };
362
363        let json = serde_json::to_value(&resp).expect("serialization failed");
364        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
365        assert_eq!(json["results"].as_array().unwrap().len(), 0);
366    }
367
368    #[test]
369    fn graph_matches_distance_uses_hop_count_proxy() {
370        // Verify the hop-count proxy formula: 1.0 - 1.0 / (hop + 1.0)
371        // hop=0 → 0.0 (seed-level entity, identity distance)
372        // hop=1 → 0.5
373        // hop=2 → ≈ 0.667
374        // hop=3 → 0.75
375        let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
376        for &(hop, expected) in cases {
377            let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
378            assert!(
379                (d - expected).abs() < 0.001,
380                "hop={hop} expected={expected} got={d}"
381            );
382        }
383    }
384}