Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11/// Arguments for the `recall` subcommand.
12///
13/// When `--namespace` is omitted the query runs against the `global` namespace,
14/// which is the default namespace used by `remember` when no `--namespace` flag
15/// is provided. Pass an explicit `--namespace` value to search a different
16/// isolated namespace.
17#[derive(clap::Args)]
18pub struct RecallArgs {
19    pub query: String,
20    /// Maximum number of direct vector matches to return.
21    ///
22    /// Note: this flag controls only `direct_matches`. Graph traversal results
23    /// (`graph_matches`) are unbounded by default; use `--max-graph-results` to
24    /// cap them independently. The `results` field aggregates both lists.
25    #[arg(short = 'k', long, default_value = "10")]
26    pub k: usize,
27    /// Filter by memory.type. Note: distinct from graph entity_type
28    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker)
29    /// used in --entities-file.
30    #[arg(long, value_enum)]
31    pub r#type: Option<MemoryType>,
32    #[arg(long)]
33    pub namespace: Option<String>,
34    #[arg(long)]
35    pub no_graph: bool,
36    /// Disable -k cap and return all direct matches without truncation.
37    ///
38    /// When set, the `-k`/`--k` flag is ignored for `direct_matches` and the
39    /// response includes every match above the distance threshold. Useful when
40    /// callers need the complete set rather than a top-N preview.
41    #[arg(long)]
42    pub precise: bool,
43    #[arg(long, default_value = "2")]
44    pub max_hops: u32,
45    #[arg(long, default_value = "0.3")]
46    pub min_weight: f64,
47    /// Cap the size of `graph_matches` to at most N entries.
48    ///
49    /// Defaults to unbounded (`None`) so existing pipelines see the same shape
50    /// as in v1.0.22 and earlier. Set this when a query touches a dense graph
51    /// neighbourhood and the caller only needs a top-N preview. Added in v1.0.23.
52    #[arg(long, value_name = "N")]
53    pub max_graph_results: Option<usize>,
54    /// Filter results by maximum distance. Results with distance greater than this value
55    /// are excluded. If all matches exceed this threshold, the command exits with code 4
56    /// (`not found`) per the documented public contract.
57    /// Default `1.0` disables the filter and preserves the top-k behavior.
58    #[arg(long, alias = "min-distance", default_value = "1.0")]
59    pub max_distance: f32,
60    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
61    pub format: JsonOutputFormat,
62    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
63    pub db: Option<String>,
64    /// Accept `--json` as a no-op because output is already JSON by default.
65    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
66    pub json: bool,
67}
68
69pub fn run(args: RecallArgs) -> Result<(), AppError> {
70    let start = std::time::Instant::now();
71    let _ = args.format;
72    if args.query.trim().is_empty() {
73        return Err(AppError::Validation(
74            "query não pode estar vazia".to_string(),
75        ));
76    }
77    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
78    let paths = AppPaths::resolve(args.db.as_deref())?;
79
80    if !paths.db.exists() {
81        return Err(AppError::NotFound(erros::banco_nao_encontrado(
82            &paths.db.display().to_string(),
83        )));
84    }
85
86    output::emit_progress_i18n(
87        "Computing query embedding...",
88        "Calculando embedding da consulta...",
89    );
90    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
91
92    let conn = open_ro(&paths.db)?;
93
94    let memory_type_str = args.r#type.map(|t| t.as_str());
95    // When --precise is set, lift the -k cap so every match is returned; the
96    // max_distance filter below will trim irrelevant results instead.
97    let effective_k = if args.precise { 100_000 } else { args.k };
98    let knn_results =
99        memories::knn_search(&conn, &embedding, &namespace, memory_type_str, effective_k)?;
100
101    let mut direct_matches = Vec::new();
102    let mut memory_ids: Vec<i64> = Vec::new();
103    for (memory_id, distance) in knn_results {
104        let row = {
105            let mut stmt = conn.prepare_cached(
106                "SELECT id, namespace, name, type, description, body, body_hash,
107                        session_id, source, metadata, created_at, updated_at
108                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
109            )?;
110            stmt.query_row(rusqlite::params![memory_id], |r| {
111                Ok(memories::MemoryRow {
112                    id: r.get(0)?,
113                    namespace: r.get(1)?,
114                    name: r.get(2)?,
115                    memory_type: r.get(3)?,
116                    description: r.get(4)?,
117                    body: r.get(5)?,
118                    body_hash: r.get(6)?,
119                    session_id: r.get(7)?,
120                    source: r.get(8)?,
121                    metadata: r.get(9)?,
122                    created_at: r.get(10)?,
123                    updated_at: r.get(11)?,
124                })
125            })
126            .ok()
127        };
128        if let Some(row) = row {
129            let snippet: String = row.body.chars().take(300).collect();
130            direct_matches.push(RecallItem {
131                memory_id: row.id,
132                name: row.name,
133                namespace: row.namespace,
134                memory_type: row.memory_type,
135                description: row.description,
136                snippet,
137                distance,
138                source: "direct".to_string(),
139                // Direct vector matches do not have a graph depth; rely on `distance`.
140                graph_depth: None,
141            });
142            memory_ids.push(memory_id);
143        }
144    }
145
146    let mut graph_matches = Vec::new();
147    if !args.no_graph {
148        let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
149        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
150
151        let all_seed_ids: Vec<i64> = memory_ids
152            .iter()
153            .chain(entity_ids.iter())
154            .copied()
155            .collect();
156
157        if !all_seed_ids.is_empty() {
158            let graph_memory_ids = traverse_from_memories(
159                &conn,
160                &all_seed_ids,
161                &namespace,
162                args.min_weight,
163                args.max_hops,
164            )?;
165
166            for graph_mem_id in graph_memory_ids {
167                // v1.0.23: respect the optional cap on graph results so dense
168                // neighbourhoods do not flood the response unintentionally.
169                if let Some(cap) = args.max_graph_results {
170                    if graph_matches.len() >= cap {
171                        break;
172                    }
173                }
174                let row = {
175                    let mut stmt = conn.prepare_cached(
176                        "SELECT id, namespace, name, type, description, body, body_hash,
177                                session_id, source, metadata, created_at, updated_at
178                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
179                    )?;
180                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
181                        Ok(memories::MemoryRow {
182                            id: r.get(0)?,
183                            namespace: r.get(1)?,
184                            name: r.get(2)?,
185                            memory_type: r.get(3)?,
186                            description: r.get(4)?,
187                            body: r.get(5)?,
188                            body_hash: r.get(6)?,
189                            session_id: r.get(7)?,
190                            source: r.get(8)?,
191                            metadata: r.get(9)?,
192                            created_at: r.get(10)?,
193                            updated_at: r.get(11)?,
194                        })
195                    })
196                    .ok()
197                };
198                if let Some(row) = row {
199                    let snippet: String = row.body.chars().take(300).collect();
200                    graph_matches.push(RecallItem {
201                        memory_id: row.id,
202                        name: row.name,
203                        namespace: row.namespace,
204                        memory_type: row.memory_type,
205                        description: row.description,
206                        snippet,
207                        // Kept for backward compatibility; v1.0.23 callers should
208                        // read `graph_depth` instead. Future releases may switch
209                        // this to `f32::NAN` after a deprecation cycle.
210                        distance: 0.0,
211                        source: "graph".to_string(),
212                        // `traverse_from_memories` does not yet expose per-result
213                        // depth, so we report `Some(0)` as a sentinel meaning
214                        // "unknown depth, but reachable via graph traversal".
215                        // A future release should plumb the real hop count through.
216                        graph_depth: Some(0),
217                    });
218                }
219            }
220        }
221    }
222
223    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
224    if args.max_distance < 1.0 {
225        let has_relevant = direct_matches
226            .iter()
227            .any(|item| item.distance <= args.max_distance);
228        if !has_relevant {
229            return Err(AppError::NotFound(erros::sem_resultados_recall(
230                args.max_distance,
231                &args.query,
232                &namespace,
233            )));
234        }
235    }
236
237    let results: Vec<RecallItem> = direct_matches
238        .iter()
239        .cloned()
240        .chain(graph_matches.iter().cloned())
241        .collect();
242
243    output::emit_json(&RecallResponse {
244        query: args.query,
245        k: args.k,
246        direct_matches,
247        graph_matches,
248        results,
249        elapsed_ms: start.elapsed().as_millis() as u64,
250    })?;
251
252    Ok(())
253}
254
255#[cfg(test)]
256mod testes {
257    use crate::output::{RecallItem, RecallResponse};
258
259    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
260        RecallItem {
261            memory_id: 1,
262            name: name.to_string(),
263            namespace: "global".to_string(),
264            memory_type: "fact".to_string(),
265            description: "desc".to_string(),
266            snippet: "snippet".to_string(),
267            distance,
268            source: source.to_string(),
269            graph_depth: if source == "graph" { Some(0) } else { None },
270        }
271    }
272
273    #[test]
274    fn recall_response_serializa_campos_obrigatorios() {
275        let resp = RecallResponse {
276            query: "rust memory".to_string(),
277            k: 5,
278            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
279            graph_matches: vec![],
280            results: vec![make_item("mem-a", 0.12, "direct")],
281            elapsed_ms: 42,
282        };
283
284        let json = serde_json::to_value(&resp).expect("serialização falhou");
285        assert_eq!(json["query"], "rust memory");
286        assert_eq!(json["k"], 5);
287        assert_eq!(json["elapsed_ms"], 42u64);
288        assert!(json["direct_matches"].is_array());
289        assert!(json["graph_matches"].is_array());
290        assert!(json["results"].is_array());
291    }
292
293    #[test]
294    fn recall_item_serializa_type_renomeado() {
295        let item = make_item("mem-teste", 0.25, "direct");
296        let json = serde_json::to_value(&item).expect("serialização falhou");
297
298        // O campo memory_type é renomeado para "type" no JSON
299        assert_eq!(json["type"], "fact");
300        assert_eq!(json["distance"], 0.25f32);
301        assert_eq!(json["source"], "direct");
302    }
303
304    #[test]
305    fn recall_response_results_contem_direct_e_graph() {
306        let direct = make_item("d-mem", 0.10, "direct");
307        let graph = make_item("g-mem", 0.0, "graph");
308
309        let resp = RecallResponse {
310            query: "query".to_string(),
311            k: 10,
312            direct_matches: vec![direct.clone()],
313            graph_matches: vec![graph.clone()],
314            results: vec![direct, graph],
315            elapsed_ms: 10,
316        };
317
318        let json = serde_json::to_value(&resp).expect("serialização falhou");
319        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
320        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
321        assert_eq!(json["results"].as_array().unwrap().len(), 2);
322        assert_eq!(json["results"][0]["source"], "direct");
323        assert_eq!(json["results"][1]["source"], "graph");
324    }
325
326    #[test]
327    fn recall_response_vazio_serializa_arrays_vazios() {
328        let resp = RecallResponse {
329            query: "nada".to_string(),
330            k: 3,
331            direct_matches: vec![],
332            graph_matches: vec![],
333            results: vec![],
334            elapsed_ms: 1,
335        };
336
337        let json = serde_json::to_value(&resp).expect("serialização falhou");
338        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
339        assert_eq!(json["results"].as_array().unwrap().len(), 0);
340    }
341}