Skip to main content

sqlite_graphrag/commands/
recall.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::graph::traverse_from_memories;
4use crate::i18n::erros;
5use crate::output::{self, JsonOutputFormat, RecallItem, RecallResponse};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::entities;
9use crate::storage::memories;
10
11#[derive(clap::Args)]
12pub struct RecallArgs {
13    pub query: String,
14    #[arg(short = 'k', long, default_value = "10")]
15    pub k: usize,
16    #[arg(long, value_enum)]
17    pub r#type: Option<MemoryType>,
18    #[arg(long)]
19    pub namespace: Option<String>,
20    #[arg(long)]
21    pub no_graph: bool,
22    #[arg(long)]
23    pub precise: bool,
24    #[arg(long, default_value = "2")]
25    pub max_hops: u32,
26    #[arg(long, default_value = "0.3")]
27    pub min_weight: f64,
28    /// Filter results by maximum distance. Results with distance greater than this value
29    /// are excluded. If all matches exceed this threshold, the command exits with code 4
30    /// (`not found`) per the documented public contract.
31    /// Default `1.0` disables the filter and preserves the top-k behavior.
32    #[arg(long, alias = "min-distance", default_value = "1.0")]
33    pub max_distance: f32,
34    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
35    pub format: JsonOutputFormat,
36    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
37    pub db: Option<String>,
38    /// Accept `--json` as a no-op because output is already JSON by default.
39    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
40    pub json: bool,
41}
42
43pub fn run(args: RecallArgs) -> Result<(), AppError> {
44    let start = std::time::Instant::now();
45    let _ = args.format;
46    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
47    let paths = AppPaths::resolve(args.db.as_deref())?;
48
49    if !paths.db.exists() {
50        return Err(AppError::NotFound(erros::banco_nao_encontrado(
51            &paths.db.display().to_string(),
52        )));
53    }
54
55    output::emit_progress_i18n(
56        "Computing query embedding...",
57        "Calculando embedding da consulta...",
58    );
59    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
60
61    let conn = open_ro(&paths.db)?;
62
63    let memory_type_str = args.r#type.map(|t| t.as_str());
64    let knn_results = memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k)?;
65
66    let mut direct_matches = Vec::new();
67    let mut memory_ids: Vec<i64> = Vec::new();
68    for (memory_id, distance) in knn_results {
69        let row = {
70            let mut stmt = conn.prepare_cached(
71                "SELECT id, namespace, name, type, description, body, body_hash,
72                        session_id, source, metadata, created_at, updated_at
73                 FROM memories WHERE id=?1 AND deleted_at IS NULL",
74            )?;
75            stmt.query_row(rusqlite::params![memory_id], |r| {
76                Ok(memories::MemoryRow {
77                    id: r.get(0)?,
78                    namespace: r.get(1)?,
79                    name: r.get(2)?,
80                    memory_type: r.get(3)?,
81                    description: r.get(4)?,
82                    body: r.get(5)?,
83                    body_hash: r.get(6)?,
84                    session_id: r.get(7)?,
85                    source: r.get(8)?,
86                    metadata: r.get(9)?,
87                    created_at: r.get(10)?,
88                    updated_at: r.get(11)?,
89                })
90            })
91            .ok()
92        };
93        if let Some(row) = row {
94            let snippet: String = row.body.chars().take(300).collect();
95            direct_matches.push(RecallItem {
96                memory_id: row.id,
97                name: row.name,
98                namespace: row.namespace,
99                memory_type: row.memory_type,
100                description: row.description,
101                snippet,
102                distance,
103                source: "direct".to_string(),
104            });
105            memory_ids.push(memory_id);
106        }
107    }
108
109    let mut graph_matches = Vec::new();
110    if !args.no_graph {
111        let entity_knn = entities::knn_search(&conn, &embedding, &namespace, 5)?;
112        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
113
114        let all_seed_ids: Vec<i64> = memory_ids
115            .iter()
116            .chain(entity_ids.iter())
117            .copied()
118            .collect();
119
120        if !all_seed_ids.is_empty() {
121            let graph_memory_ids = traverse_from_memories(
122                &conn,
123                &all_seed_ids,
124                &namespace,
125                args.min_weight,
126                args.max_hops,
127            )?;
128
129            for graph_mem_id in graph_memory_ids {
130                let row = {
131                    let mut stmt = conn.prepare_cached(
132                        "SELECT id, namespace, name, type, description, body, body_hash,
133                                session_id, source, metadata, created_at, updated_at
134                         FROM memories WHERE id=?1 AND deleted_at IS NULL",
135                    )?;
136                    stmt.query_row(rusqlite::params![graph_mem_id], |r| {
137                        Ok(memories::MemoryRow {
138                            id: r.get(0)?,
139                            namespace: r.get(1)?,
140                            name: r.get(2)?,
141                            memory_type: r.get(3)?,
142                            description: r.get(4)?,
143                            body: r.get(5)?,
144                            body_hash: r.get(6)?,
145                            session_id: r.get(7)?,
146                            source: r.get(8)?,
147                            metadata: r.get(9)?,
148                            created_at: r.get(10)?,
149                            updated_at: r.get(11)?,
150                        })
151                    })
152                    .ok()
153                };
154                if let Some(row) = row {
155                    let snippet: String = row.body.chars().take(300).collect();
156                    graph_matches.push(RecallItem {
157                        memory_id: row.id,
158                        name: row.name,
159                        namespace: row.namespace,
160                        memory_type: row.memory_type,
161                        description: row.description,
162                        snippet,
163                        distance: 0.0,
164                        source: "graph".to_string(),
165                    });
166                }
167            }
168        }
169    }
170
171    // Filtrar por max_distance se < 1.0 (ativado). Se nenhum hit dentro do threshold, exit 4.
172    if args.max_distance < 1.0 {
173        let has_relevant = direct_matches
174            .iter()
175            .any(|item| item.distance <= args.max_distance);
176        if !has_relevant {
177            return Err(AppError::NotFound(erros::sem_resultados_recall(
178                args.max_distance,
179                &args.query,
180                &namespace,
181            )));
182        }
183    }
184
185    let results: Vec<RecallItem> = direct_matches
186        .iter()
187        .cloned()
188        .chain(graph_matches.iter().cloned())
189        .collect();
190
191    output::emit_json(&RecallResponse {
192        query: args.query,
193        k: args.k,
194        direct_matches,
195        graph_matches,
196        results,
197        elapsed_ms: start.elapsed().as_millis() as u64,
198    })?;
199
200    Ok(())
201}
202
203#[cfg(test)]
204mod testes {
205    use crate::output::{RecallItem, RecallResponse};
206
207    fn make_item(name: &str, distance: f32, source: &str) -> RecallItem {
208        RecallItem {
209            memory_id: 1,
210            name: name.to_string(),
211            namespace: "global".to_string(),
212            memory_type: "fact".to_string(),
213            description: "desc".to_string(),
214            snippet: "snippet".to_string(),
215            distance,
216            source: source.to_string(),
217        }
218    }
219
220    #[test]
221    fn recall_response_serializa_campos_obrigatorios() {
222        let resp = RecallResponse {
223            query: "rust memory".to_string(),
224            k: 5,
225            direct_matches: vec![make_item("mem-a", 0.12, "direct")],
226            graph_matches: vec![],
227            results: vec![make_item("mem-a", 0.12, "direct")],
228            elapsed_ms: 42,
229        };
230
231        let json = serde_json::to_value(&resp).expect("serialização falhou");
232        assert_eq!(json["query"], "rust memory");
233        assert_eq!(json["k"], 5);
234        assert_eq!(json["elapsed_ms"], 42u64);
235        assert!(json["direct_matches"].is_array());
236        assert!(json["graph_matches"].is_array());
237        assert!(json["results"].is_array());
238    }
239
240    #[test]
241    fn recall_item_serializa_type_renomeado() {
242        let item = make_item("mem-teste", 0.25, "direct");
243        let json = serde_json::to_value(&item).expect("serialização falhou");
244
245        // O campo memory_type é renomeado para "type" no JSON
246        assert_eq!(json["type"], "fact");
247        assert_eq!(json["distance"], 0.25f32);
248        assert_eq!(json["source"], "direct");
249    }
250
251    #[test]
252    fn recall_response_results_contem_direct_e_graph() {
253        let direct = make_item("d-mem", 0.10, "direct");
254        let graph = make_item("g-mem", 0.0, "graph");
255
256        let resp = RecallResponse {
257            query: "query".to_string(),
258            k: 10,
259            direct_matches: vec![direct.clone()],
260            graph_matches: vec![graph.clone()],
261            results: vec![direct, graph],
262            elapsed_ms: 10,
263        };
264
265        let json = serde_json::to_value(&resp).expect("serialização falhou");
266        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 1);
267        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
268        assert_eq!(json["results"].as_array().unwrap().len(), 2);
269        assert_eq!(json["results"][0]["source"], "direct");
270        assert_eq!(json["results"][1]["source"], "graph");
271    }
272
273    #[test]
274    fn recall_response_vazio_serializa_arrays_vazios() {
275        let resp = RecallResponse {
276            query: "nada".to_string(),
277            k: 3,
278            direct_matches: vec![],
279            graph_matches: vec![],
280            results: vec![],
281            elapsed_ms: 1,
282        };
283
284        let json = serde_json::to_value(&resp).expect("serialização falhou");
285        assert_eq!(json["direct_matches"].as_array().unwrap().len(), 0);
286        assert_eq!(json["results"].as_array().unwrap().len(), 0);
287    }
288}