Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13    pub query: String,
14    /// Maximum number of direct vector matches to return.
15    ///
16    /// Note: this flag controls only `direct_matches`. Graph traversal results
17    /// (`graph_matches`) are unbounded by default; use `--max-graph-results` to
18    /// cap them independently. The `results` field aggregates both lists.
19    #[arg(short = 'k', long, default_value = "10")]
20    pub k: usize,
21    #[arg(long, value_enum)]
22    pub r#type: Option<MemoryType>,
23    #[arg(long)]
24    pub namespace: Option<String>,
25    #[arg(long)]
26    pub no_graph: bool,
27    #[arg(long)]
28    pub precise: bool,
29    #[arg(long, default_value = "2")]
30    pub max_hops: u32,
31    #[arg(long, default_value = "0.3")]
32    pub min_weight: f64,
33    /// Cap the size of `graph_matches` to at most N entries.
34    ///
35    /// Defaults to unbounded (`None`) so existing pipelines see the same shape
36    /// as in v1.0.22 and earlier. Set this when a query touches a dense graph
37    /// neighbourhood and the caller only needs a top-N preview. Added in v1.0.23.
38    #[arg(long, value_name = "N")]
39    pub max_graph_results: Option<usize>,
40    /// Filter results by maximum distance. Results with distance greater than this value
41    /// are excluded. If all matches exceed this threshold, the command exits with code 4
42    /// (`not found`) per the documented public contract.
43    /// Default `1.0` disables the filter and preserves the top-k behavior.
44    #[arg(long, alias = "min-distance", default_value = "1.0")]
45    pub max_distance: f32,
46    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
47    pub format: JsonOutputFormat,
48    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
49    pub db: Option<String>,
50    /// Accept `--json` as a no-op because output is already JSON by default.
51    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
52    pub json: bool,
53}
54
55pub fn run(args: RecallArgs) -> Result<(), AppError> {
56    let start = std::time::Instant::now();
57    let _ = args.format;
58    if args.query.trim().is_empty() {
59        return Err(AppError::Validation(
60            "query não pode estar vazia".to_string(),
61        ));
62    }
63    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
64    let paths = AppPaths::resolve(args.db.as_deref())?;
65
66    if !paths.db.exists() {
67        return Err(AppError::NotFound(erros::banco_nao_encontrado(
68            &paths.db.display().to_string(),
69        )));
70    }
71
72    output::emit_progress_i18n(
73        "Computing query embedding...",
74        "Calculando embedding da consulta...",
75    );
76    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
77
78    let conn = open_ro(&paths.db)?;
79
80    let memory_type_str = args.r#type.map(|t| t.as_str());
81    let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
82
83    let mut direct_matches = Vec::new();
84    let mut memory_ids: Vec<i64> = Vec::new();
85    for (memory_id, distance) in knn_results {
86        let row = {
87            let mut stmt = conn.prepare_cached(
88                "SELECT id, namespace, name, type, description, body, body_hash,
89                        session_id, source, metadata, created_at, updated_at
90                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
91            )?;
92            stmt.query_row(rusqlite::params![memory_id], |r| {
93                Ok(memories::MemoryRow {
94                    id: r.get(0)?,
95                    namespace: r.get(1)?,
96                    name: r.get(2)?,
97                    memory_type: r.get(3)?,
98                    description: r.get(4)?,
99                    body: r.get(5)?,
100                    body_hash: r.get(6)?,
101                    session_id: r.get(7)?,
102                    source: r.get(8)?,
103                    metadata: r.get(9)?,
104                    created_at: r.get(10)?,
105                    updated_at: r.get(11)?,
106                })
107            })
108            .ok()
109        };
110        if let Some(row) = row {
111            let snippet: String = row.body.chars().take(300).collect();
112            direct_matches.push(RecallItem {
113                memory_id: row.id,
114                name: row.name,
115                namespace: row.namespace,
116                memory_type: row.memory_type,
117                description: row.description,
118                snippet,
119                distance,
120                source: "direct".to_string(),
121                // Direct vector matches do not have a graph depth; rely on `distance`.
122                graph_depth: None,
123            });
124            memory_ids.push(memory_id);
125        }
126    }
127
128    let mut graph_matches = Vec::new();
129    if !args.no_graph {
130        let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
131        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
132
133        let all_seed_ids: Vec<i64> = memory_ids
134            .iter()
135            .chain(entity_ids.iter())
136            .copied()
137            .collect();
138
139        if !all_seed_ids.is_empty() {
140            let graph_memory_ids = traverse_from_memories(
141                &conn,
142                &all_seed_ids,
143                &namespace,
144                args.min_weight,
145                args.max_hops,
146            )?;
147
148            for graph_mem_id in graph_memory_ids {
149                // v1.0.23: respect the optional cap on graph results so dense
150                // neighbourhoods do not flood the response unintentionally.
151                if let Some(cap) = args.max_graph_results {
152                    if graph_matches.len() >= cap {
153                        break;
154                    }
155                }
156                let row = {
157                    let mut stmt = conn.prepare_cached(
158                        "SELECT id, namespace, name, type, description, body, body_hash,
159                                session_id, source, metadata, created_at, updated_at
160                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
161                    )?;
162                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
163                        Ok(memories::MemoryRow {
164                            id: r.get(0)?,
165                            namespace: r.get(1)?,
166                            name: r.get(2)?,
167                            memory_type: r.get(3)?,
168                            description: r.get(4)?,
169                            body: r.get(5)?,
170                            body_hash: r.get(6)?,
171                            session_id: r.get(7)?,
172                            source: r.get(8)?,
173                            metadata: r.get(9)?,
174                            created_at: r.get(10)?,
175                            updated_at: r.get(11)?,
176                        })
177                    })
178                    .ok()
179                };
180                if let Some(row) = row {
181                    let snippet: String = row.body.chars().take(300).collect();
182                    graph_matches.push(RecallItem {
183                        memory_id: row.id,
184                        name: row.name,
185                        namespace: row.namespace,
186                        memory_type: row.memory_type,
187                        description: row.description,
188                        snippet,
189                        // Kept for backward compatibility; v1.0.23 callers should
190                        // read `graph_depth` instead. Future releases may switch
191                        // this to `f32::NAN` after a deprecation cycle.
192                        distance: 0.0,
193                        source: "graph".to_string(),
194                        // `traverse_from_memories` does not yet expose per-result
195                        // depth, so we report `Some(0)` as a sentinel meaning
196                        // "unknown depth, but reachable via graph traversal".
197                        // A future release should plumb the real hop count through.
198                        graph_depth: Some(0),
199                    });
200                }
201            }
202        }
203    }
204
205    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
206    if args.max_distance < 1.0 {
207        let has_relevant = direct_matches
208            .iter()
209            .any(|item| item.distance <= args.max_distance);
210        if !has_relevant {
211            return Err(AppError::NotFound(erros::sem_resultados_recall(
212                args.max_distance,
213                &args.query,
214                &namespace,
215            )));
216        }
217    }
218
219    let results: Vec<RecallItem> = direct_matches
220        .iter()
221        .cloned()
222        .chain(graph_matches.iter().cloned())
223        .collect();
224
225    output::emit_json(&RecallResponse {
226        query: args.query,
227        k: args.k,
228        direct_matches,
229        graph_matches,
230        results,
231        elapsed_ms: start.elapsed().as_millis() as u64,
232    })?;
233
234    Ok(())
235}
236
237#[cfg(test)]
238mod testes {
239    use crate::output::{RecallItem, RecallResponse};
240
241    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
242        RecallItem {
243            memory_id: 1,
244            name: name.to_string(),
245            namespace: "global".to_string(),
246            memory_type: "fact".to_string(),
247            description: "desc".to_string(),
248            snippet: "snippet".to_string(),
249            distance,
250            source: source.to_string(),
251            graph_depth: if source == "graph" { Some(0) } else { None },
252        }
253    }
254
255    #[test]
256    fn recall_response_serializa_campos_obrigatorios() {
257        let resp = RecallResponse {
258            query: "rust memory".to_string(),
259            k: 5,
260            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
261            graph_matches: vec![],
262            results: vec![make_item("mem-a", 0.12, "direct")],
263            elapsed_ms: 42,
264        };
265
266        let json = serde_json::to_value(&resp).expect("serialização falhou");
267        assert_eq!(json["query"], "rust memory");
268        assert_eq!(json["k"], 5);
269        assert_eq!(json["elapsed_ms"], 42u64);
270        assert!(json["direct_matches"].is_array());
271        assert!(json["graph_matches"].is_array());
272        assert!(json["results"].is_array());
273    }
274
275    #[test]
276    fn recall_item_serializa_type_renomeado() {
277        let item = make_item("mem-teste", 0.25, "direct");
278        let json = serde_json::to_value(&item).expect("serialização falhou");
279
280        // O campo memory_type é renomeado para "type" no JSON
281        assert_eq!(json["type"], "fact");
282        assert_eq!(json["distance"], 0.25f32);
283        assert_eq!(json["source"], "direct");
284    }
285
286    #[test]
287    fn recall_response_results_contem_direct_e_graph() {
288        let direct = make_item("d-mem", 0.10, "direct");
289        let graph = make_item("g-mem", 0.0, "graph");
290
291        let resp = RecallResponse {
292            query: "query".to_string(),
293            k: 10,
294            direct_matches: vec![direct.clone()],
295            graph_matches: vec![graph.clone()],
296            results: vec![direct, graph],
297            elapsed_ms: 10,
298        };
299
300        let json = serde_json::to_value(&resp).expect("serialização falhou");
301        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
302        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
303        assert_eq!(json["results"].as_array().unwrap().len(), 2);
304        assert_eq!(json["results"][0]["source"], "direct");
305        assert_eq!(json["results"][1]["source"], "graph");
306    }
307
308    #[test]
309    fn recall_response_vazio_serializa_arrays_vazios() {
310        let resp = RecallResponse {
311            query: "nada".to_string(),
312            k: 3,
313            direct_matches: vec![],
314            graph_matches: vec![],
315            results: vec![],
316            elapsed_ms: 1,
317        };
318
319        let json = serde_json::to_value(&resp).expect("serialização falhou");
320        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
321        assert_eq!(json["results"].as_array().unwrap().len(), 0);
322    }
323}