Skip to main content

sqlite_graphrag/commands/
recall.rs

1//! Handler for the `recall` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::i18n::errors_msg;
7use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
8use crate::paths::AppPaths;
9use crate::storage::connection::open_ro;
10use crate::storage::entities;
11use crate::storage::memories;
12
13/// Arguments for the `recall` subcommand.
14///
15/// When `--namespace` is omitted the query runs against the `global` namespace,
16/// which is the default namespace used by `remember` when no `--namespace` flag
17/// is provided. Pass an explicit `--namespace` value to search a different
18/// isolated namespace.
19#[derive(clap::Args)]
20pub struct RecallArgs {
21    pub query: String,
22    /// Maximum number of direct vector matches to return.
23    ///
24    /// Note: this flag controls only `direct_matches`. Graph traversal results
25    /// (`graph_matches`) are unbounded by default; use `--max-graph-results` to
26    /// cap them independently. The `results` field aggregates both lists.
27    #[arg(short = 'k', long, default_value = "10")]
28    pub k: usize,
29    /// Filter by memory.type. Note: distinct from graph entity_type
30    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker)
31    /// used in --entities-file.
32    #[arg(long, value_enum)]
33    pub r#type: Option<MemoryType>,
34    #[arg(long)]
35    pub namespace: Option<String>,
36    #[arg(long)]
37    pub no_graph: bool,
38    /// Disable -k cap and return all direct matches without truncation.
39    ///
40    /// When set, the `-k`/`--k` flag is ignored for `direct_matches` and the
41    /// response includes every match above the distance threshold. Useful when
42    /// callers need the complete set rather than a top-N preview.
43    #[arg(long)]
44    pub precise: bool,
45    #[arg(long, default_value = "2")]
46    pub max_hops: u32,
47    #[arg(long, default_value = "0.3")]
48    pub min_weight: f64,
49    /// Cap the size of `graph_matches` to at most N entries.
50    ///
51    /// Defaults to unbounded (`None`) so existing pipelines see the same shape
52    /// as in v1.0.22 and earlier. Set this when a query touches a dense graph
53    /// neighbourhood and the caller only needs a top-N preview. Added in v1.0.23.
54    #[arg(long, value_name = "N")]
55    pub max_graph_results: Option<usize>,
56    /// Filter results by maximum distance. Results with distance greater than this value
57    /// are excluded. If all matches exceed this threshold, the command exits with code 4
58    /// (`not found`) per the documented public contract.
59    /// Default `1.0` disables the filter and preserves the top-k behavior.
60    #[arg(long, alias = "min-distance", default_value = "1.0")]
61    pub max_distance: f32,
62    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
63    pub format: JsonOutputFormat,
64    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
65    pub db: Option<String>,
66    /// Accept `--json` as a no-op because output is already JSON by default.
67    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
68    pub json: bool,
69    /// Search across all namespaces instead of a single namespace.
70    ///
71    /// Cannot be combined with `--namespace`. When set, the query runs against
72    /// every namespace and results include a `namespace` field to identify origin.
73    #[arg(long, conflicts_with = "namespace")]
74    pub all_namespaces: bool,
75}
76
77pub fn run(args: RecallArgs) -> Result<(), AppError> {
78    let start = std::time::Instant::now();
79    let _ = args.format;
80    if args.query.trim().is_empty() {
81        return Err(AppError::Validation(
82            "query não pode estar vazia".to_string(),
83        ));
84    }
85    // Resolve the list of namespaces to search:
86    // - empty vec  => all namespaces (sentinel used by knn_search)
87    // - single vec => one namespace (default or --namespace value)
88    let namespaces: Vec<String> = if args.all_namespaces {
89        Vec::new()
90    } else {
91        vec![crate::namespace::resolve_namespace(
92            args.namespace.as_deref(),
93        )?]
94    };
95    // Single namespace string used for graph traversal and error messages.
96    let namespace_for_graph = namespaces
97        .first()
98        .cloned()
99        .unwrap_or_else(|| "global".to_string());
100    let paths = AppPaths::resolve(args.db.as_deref())?;
101
102    if !paths.db.exists() {
103        return Err(AppError::NotFound(errors_msg::database_not_found(
104            &paths.db.display().to_string(),
105        )));
106    }
107
108    output::emit_progress_i18n(
109        "Computing query embedding...",
110        "Calculando embedding da consulta...",
111    );
112    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
113
114    let conn = open_ro(&paths.db)?;
115
116    let memory_type_str = args.r#type.map(|t| t.as_str());
117    // When --precise is set, lift the -k cap so every match is returned; the
118    // max_distance filter below will trim irrelevant results instead.
119    let effective_k = if args.precise { 100_000 } else { args.k };
120    let knn_results =
121        memories::knn_search(&conn, &embedding, &namespaces, memory_type_str, effective_k)?;
122
123    let mut direct_matches = Vec::new();
124    let mut memory_ids: Vec<i64> = Vec::new();
125    for (memory_id, distance) in knn_results {
126        let row = {
127            let mut stmt = conn.prepare_cached(
128                "SELECT id, namespace, name, type, description, body, body_hash,
129                        session_id, source, metadata, created_at, updated_at
130                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
131            )?;
132            stmt.query_row(rusqlite::params![memory_id], |r| {
133                Ok(memories::MemoryRow {
134                    id: r.get(0)?,
135                    namespace: r.get(1)?,
136                    name: r.get(2)?,
137                    memory_type: r.get(3)?,
138                    description: r.get(4)?,
139                    body: r.get(5)?,
140                    body_hash: r.get(6)?,
141                    session_id: r.get(7)?,
142                    source: r.get(8)?,
143                    metadata: r.get(9)?,
144                    created_at: r.get(10)?,
145                    updated_at: r.get(11)?,
146                })
147            })
148            .ok()
149        };
150        if let Some(row) = row {
151            let snippet: String = row.body.chars().take(300).collect();
152            direct_matches.push(RecallItem {
153                memory_id: row.id,
154                name: row.name,
155                namespace: row.namespace,
156                memory_type: row.memory_type,
157                description: row.description,
158                snippet,
159                distance,
160                source: "direct".to_string(),
161                // Direct vector matches do not have a graph depth; rely on `distance`.
162                graph_depth: None,
163            });
164            memory_ids.push(memory_id);
165        }
166    }
167
168    let mut graph_matches = Vec::new();
169    if !args.no_graph {
170        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
171        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
172
173        let all_seed_ids: Vec<i64> = memory_ids
174            .iter()
175            .chain(entity_ids.iter())
176            .copied()
177            .collect();
178
179        if !all_seed_ids.is_empty() {
180            let graph_memory_ids = traverse_from_memories_with_hops(
181                &conn,
182                &all_seed_ids,
183                &namespace_for_graph,
184                args.min_weight,
185                args.max_hops,
186            )?;
187
188            for (graph_mem_id, hop) in graph_memory_ids {
189                // v1.0.23: respect the optional cap on graph results so dense
190                // neighbourhoods do not flood the response unintentionally.
191                if let Some(cap) = args.max_graph_results {
192                    if graph_matches.len() >= cap {
193                        break;
194                    }
195                }
196                let row = {
197                    let mut stmt = conn.prepare_cached(
198                        "SELECT id, namespace, name, type, description, body, body_hash,
199                                session_id, source, metadata, created_at, updated_at
200                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
201                    )?;
202                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
203                        Ok(memories::MemoryRow {
204                            id: r.get(0)?,
205                            namespace: r.get(1)?,
206                            name: r.get(2)?,
207                            memory_type: r.get(3)?,
208                            description: r.get(4)?,
209                            body: r.get(5)?,
210                            body_hash: r.get(6)?,
211                            session_id: r.get(7)?,
212                            source: r.get(8)?,
213                            metadata: r.get(9)?,
214                            created_at: r.get(10)?,
215                            updated_at: r.get(11)?,
216                        })
217                    })
218                    .ok()
219                };
220                if let Some(row) = row {
221                    let snippet: String = row.body.chars().take(300).collect();
222                    // Compute approximate distance from graph hop count.
223                    // WARNING: graph_distance is a hop-count proxy, NOT real cosine distance.
224                    // For confident ranking, prefer the `graph_depth` field (set to Some(hop)
225                    // below). Real cosine distance for graph matches would require
226                    // re-embedding (200-500ms latency) and is reserved for v1.0.28.
227                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
228                    graph_matches.push(RecallItem {
229                        memory_id: row.id,
230                        name: row.name,
231                        namespace: row.namespace,
232                        memory_type: row.memory_type,
233                        description: row.description,
234                        snippet,
235                        distance: graph_distance,
236                        source: "graph".to_string(),
237                        graph_depth: Some(hop),
238                    });
239                }
240            }
241        }
242    }
243
244    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
245    if args.max_distance < 1.0 {
246        let has_relevant = direct_matches
247            .iter()
248            .any(|item| item.distance <= args.max_distance);
249        if !has_relevant {
250            return Err(AppError::NotFound(errors_msg::no_recall_results(
251                args.max_distance,
252                &args.query,
253                &namespace_for_graph,
254            )));
255        }
256    }
257
258    let results: Vec<RecallItem> = direct_matches
259        .iter()
260        .cloned()
261        .chain(graph_matches.iter().cloned())
262        .collect();
263
264    output::emit_json(&RecallResponse {
265        query: args.query,
266        k: args.k,
267        direct_matches,
268        graph_matches,
269        results,
270        elapsed_ms: start.elapsed().as_millis() as u64,
271    })?;
272
273    Ok(())
274}
275
276#[cfg(test)]
277mod tests {
278    use crate::output::{RecallItem, RecallResponse};
279
280    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
281        RecallItem {
282            memory_id: 1,
283            name: name.to_string(),
284            namespace: "global".to_string(),
285            memory_type: "fact".to_string(),
286            description: "desc".to_string(),
287            snippet: "snippet".to_string(),
288            distance,
289            source: source.to_string(),
290            graph_depth: if source == "graph" { Some(0) } else { None },
291        }
292    }
293
294    #[test]
295    fn recall_response_serializa_campos_obrigatorios() {
296        let resp = RecallResponse {
297            query: "rust memory".to_string(),
298            k: 5,
299            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
300            graph_matches: vec![],
301            results: vec![make_item("mem-a", 0.12, "direct")],
302            elapsed_ms: 42,
303        };
304
305        let json = serde_json::to_value(&resp).expect("serialização falhou");
306        assert_eq!(json["query"], "rust memory");
307        assert_eq!(json["k"], 5);
308        assert_eq!(json["elapsed_ms"], 42u64);
309        assert!(json["direct_matches"].is_array());
310        assert!(json["graph_matches"].is_array());
311        assert!(json["results"].is_array());
312    }
313
314    #[test]
315    fn recall_item_serializa_type_renomeado() {
316        let item = make_item("mem-teste", 0.25, "direct");
317        let json = serde_json::to_value(&item).expect("serialização falhou");
318
319        // O campo memory_type é renomeado para "type" no JSON
320        assert_eq!(json["type"], "fact");
321        assert_eq!(json["distance"], 0.25f32);
322        assert_eq!(json["source"], "direct");
323    }
324
325    #[test]
326    fn recall_response_results_contem_direct_e_graph() {
327        let direct = make_item("d-mem", 0.10, "direct");
328        let graph = make_item("g-mem", 0.0, "graph");
329
330        let resp = RecallResponse {
331            query: "query".to_string(),
332            k: 10,
333            direct_matches: vec![direct.clone()],
334            graph_matches: vec![graph.clone()],
335            results: vec![direct, graph],
336            elapsed_ms: 10,
337        };
338
339        let json = serde_json::to_value(&resp).expect("serialização falhou");
340        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
341        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
342        assert_eq!(json["results"].as_array().unwrap().len(), 2);
343        assert_eq!(json["results"][0]["source"], "direct");
344        assert_eq!(json["results"][1]["source"], "graph");
345    }
346
347    #[test]
348    fn recall_response_vazio_serializa_arrays_vazios() {
349        let resp = RecallResponse {
350            query: "nada".to_string(),
351            k: 3,
352            direct_matches: vec![],
353            graph_matches: vec![],
354            results: vec![],
355            elapsed_ms: 1,
356        };
357
358        let json = serde_json::to_value(&resp).expect("serialização falhou");
359        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
360        assert_eq!(json["results"].as_array().unwrap().len(), 0);
361    }
362
363    #[test]
364    fn graph_matches_distance_uses_hop_count_proxy() {
365        // Verify the hop-count proxy formula: 1.0 - 1.0 / (hop + 1.0)
366        // hop=0 → 0.0 (seed-level entity, identity distance)
367        // hop=1 → 0.5
368        // hop=2 → ≈ 0.667
369        // hop=3 → 0.75
370        let cases: &[(u32, f32)] = &[(0, 0.0), (1, 0.5), (2, 0.6667), (3, 0.75)];
371        for &(hop, expected) in cases {
372            let d = 1.0_f32 - 1.0 / (hop as f32 + 1.0);
373            assert!(
374                (d - expected).abs() < 0.001,
375                "hop={hop} expected={expected} got={d}"
376            );
377        }
378    }
379}