Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::output::{self, JsonOutputFormat, RecallItem};
4use crate::paths::AppPaths;
5use crate::storage::connection::open_ro;
6use crate::storage::memories;
7
8use std::collections::HashMap;
9
10#[derive(clap::Args)]
11pub struct HybridSearchArgs {
12    pub query: String,
13    #[arg(short = 'k', long, default_value = "10")]
14    pub k: usize,
15    #[arg(long, default_value = "60")]
16    pub rrf_k: u32,
17    #[arg(long, default_value = "1.0")]
18    pub weight_vec: f32,
19    #[arg(long, default_value = "1.0")]
20    pub weight_fts: f32,
21    #[arg(long, value_enum)]
22    pub r#type: Option<MemoryType>,
23    #[arg(long)]
24    pub namespace: Option<String>,
25    #[arg(long)]
26    pub with_graph: bool,
27    #[arg(long, default_value = "2")]
28    pub max_hops: u32,
29    #[arg(long, default_value = "0.3")]
30    pub min_weight: f64,
31    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
32    pub format: JsonOutputFormat,
33    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
34    pub db: Option<String>,
35    /// Accept `--json` as a no-op because output is already JSON by default.
36    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
37    pub json: bool,
38}
39
40#[derive(serde::Serialize)]
41pub struct HybridSearchItem {
42    pub memory_id: i64,
43    pub name: String,
44    pub namespace: String,
45    #[serde(rename = "type")]
46    pub memory_type: String,
47    pub description: String,
48    pub body: String,
49    pub combined_score: f64,
50    /// Alias de `combined_score` para contrato documentado em SKILL.md.
51    pub score: f64,
52    /// Fonte do match: sempre "hybrid" (RRF de vec + fts). Adicionado em v2.0.1.
53    pub source: String,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub vec_rank: Option<usize>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub fts_rank: Option<usize>,
58    /// Score RRF combinado — alias explícito de `combined_score` para contratos de integração.
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub rrf_score: Option<f64>,
61}
62
63/// Pesos RRF usados na busca híbrida: vec (vetorial) e fts (texto).
64#[derive(serde::Serialize)]
65pub struct Weights {
66    pub vec: f32,
67    pub fts: f32,
68}
69
70#[derive(serde::Serialize)]
71pub struct HybridSearchResponse {
72    pub query: String,
73    pub k: usize,
74    /// Parâmetro k do RRF usado no ranking combinado.
75    pub rrf_k: u32,
76    /// Pesos aplicados às fontes vec e fts no RRF.
77    pub weights: Weights,
78    pub results: Vec<HybridSearchItem>,
79    pub graph_matches: Vec<RecallItem>,
80    /// Tempo total de execução em milissegundos desde início do handler até serialização.
81    pub elapsed_ms: u64,
82}
83
84pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
85    let start = std::time::Instant::now();
86    let _ = args.format;
87
88    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
89    let paths = AppPaths::resolve(args.db.as_deref())?;
90    if !paths.db.exists() {
91        return Err(AppError::NotFound(
92            crate::i18n::erros::banco_nao_encontrado(&paths.db.display().to_string()),
93        ));
94    }
95
96    output::emit_progress_i18n(
97        "Computing query embedding...",
98        "Calculando embedding da consulta...",
99    );
100    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
101
102    let conn = open_ro(&paths.db)?;
103
104    let memory_type_str = args.r#type.map(|t| t.as_str());
105
106    let vec_results =
107        memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k * 2)?;
108
109    // Mapear posição de ranking vetorial por memory_id (1-indexed conforme schema)
110    let vec_rank_map: HashMap<i64, usize> = vec_results
111        .iter()
112        .enumerate()
113        .map(|(pos, (id, _))| (*id, pos + 1))
114        .collect();
115
116    let fts_results =
117        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
118
119    // Mapear posição de ranking FTS por memory_id (1-indexed conforme schema)
120    let fts_rank_map: HashMap<i64, usize> = fts_results
121        .iter()
122        .enumerate()
123        .map(|(pos, row)| (row.id, pos + 1))
124        .collect();
125
126    let rrf_k = args.rrf_k as f64;
127
128    // Acumular scores RRF combinados
129    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
130
131    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
132        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
133        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
134    }
135
136    for (rank, row) in fts_results.iter().enumerate() {
137        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
138        *combined_scores.entry(row.id).or_insert(0.0) += score;
139    }
140
141    // Ordenar por score descendente e tomar os top-k
142    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
143    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
144    ranked.truncate(args.k);
145
146    // Coletar todos os IDs para busca em batch (evitar N+1)
147    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
148
149    // Buscar dados completos das memórias top
150    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
151    for id in &top_ids {
152        if let Some(row) = memories::read_full(&conn, *id)? {
153            memory_data.insert(*id, row);
154        }
155    }
156
157    // Construir resultados finais na ordem de ranking
158    let results: Vec<HybridSearchItem> = ranked
159        .into_iter()
160        .filter_map(|(memory_id, combined_score)| {
161            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
162                memory_id: row.id,
163                name: row.name,
164                namespace: row.namespace,
165                memory_type: row.memory_type,
166                description: row.description,
167                body: row.body,
168                combined_score,
169                score: combined_score,
170                source: "hybrid".to_string(),
171                vec_rank: vec_rank_map.get(&memory_id).copied(),
172                fts_rank: fts_rank_map.get(&memory_id).copied(),
173                rrf_score: Some(combined_score),
174            })
175        })
176        .collect();
177
178    output::emit_json(&HybridSearchResponse {
179        query: args.query,
180        k: args.k,
181        rrf_k: args.rrf_k,
182        weights: Weights {
183            vec: args.weight_vec,
184            fts: args.weight_fts,
185        },
186        results,
187        graph_matches: vec![],
188        elapsed_ms: start.elapsed().as_millis() as u64,
189    })?;
190
191    Ok(())
192}
193
194#[cfg(test)]
195mod testes {
196    use super::*;
197
198    fn resposta_vazia(
199        k: usize,
200        rrf_k: u32,
201        weight_vec: f32,
202        weight_fts: f32,
203    ) -> HybridSearchResponse {
204        HybridSearchResponse {
205            query: "busca teste".to_string(),
206            k,
207            rrf_k,
208            weights: Weights {
209                vec: weight_vec,
210                fts: weight_fts,
211            },
212            results: vec![],
213            graph_matches: vec![],
214            elapsed_ms: 0,
215        }
216    }
217
218    #[test]
219    fn hybrid_search_response_vazia_serializa_campos_corretos() {
220        let resp = resposta_vazia(10, 60, 1.0, 1.0);
221        let json = serde_json::to_string(&resp).unwrap();
222        assert!(json.contains("\"results\""), "deve conter campo results");
223        assert!(json.contains("\"query\""), "deve conter campo query");
224        assert!(json.contains("\"k\""), "deve conter campo k");
225        assert!(
226            json.contains("\"graph_matches\""),
227            "deve conter campo graph_matches"
228        );
229        assert!(
230            !json.contains("\"combined_rank\""),
231            "NÃO deve conter combined_rank"
232        );
233        assert!(
234            !json.contains("\"vec_rank_list\""),
235            "NÃO deve conter vec_rank_list"
236        );
237        assert!(
238            !json.contains("\"fts_rank_list\""),
239            "NÃO deve conter fts_rank_list"
240        );
241    }
242
243    #[test]
244    fn hybrid_search_response_serializa_rrf_k_e_weights() {
245        let resp = resposta_vazia(5, 60, 0.7, 0.3);
246        let json = serde_json::to_string(&resp).unwrap();
247        assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
248        assert!(json.contains("\"weights\""), "deve conter campo weights");
249        assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
250        assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
251    }
252
253    #[test]
254    fn hybrid_search_response_serializa_elapsed_ms() {
255        let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
256        resp.elapsed_ms = 123;
257        let json = serde_json::to_string(&resp).unwrap();
258        assert!(
259            json.contains("\"elapsed_ms\""),
260            "deve conter campo elapsed_ms"
261        );
262        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
263    }
264
265    #[test]
266    fn weights_struct_serializa_corretamente() {
267        let w = Weights { vec: 0.6, fts: 0.4 };
268        let json = serde_json::to_string(&w).unwrap();
269        assert!(json.contains("\"vec\""));
270        assert!(json.contains("\"fts\""));
271    }
272
273    #[test]
274    fn hybrid_search_item_omite_fts_rank_quando_none() {
275        let item = HybridSearchItem {
276            memory_id: 1,
277            name: "mem".to_string(),
278            namespace: "default".to_string(),
279            memory_type: "user".to_string(),
280            description: "desc".to_string(),
281            body: "conteúdo".to_string(),
282            combined_score: 0.0328,
283            score: 0.0328,
284            source: "hybrid".to_string(),
285            vec_rank: Some(1),
286            fts_rank: None,
287            rrf_score: Some(0.0328),
288        };
289        let json = serde_json::to_string(&item).unwrap();
290        assert!(
291            json.contains("\"vec_rank\""),
292            "deve conter vec_rank quando Some"
293        );
294        assert!(
295            !json.contains("\"fts_rank\""),
296            "NÃO deve conter fts_rank quando None"
297        );
298    }
299
300    #[test]
301    fn hybrid_search_item_omite_vec_rank_quando_none() {
302        let item = HybridSearchItem {
303            memory_id: 2,
304            name: "mem2".to_string(),
305            namespace: "default".to_string(),
306            memory_type: "fact".to_string(),
307            description: "desc2".to_string(),
308            body: "corpo2".to_string(),
309            combined_score: 0.016,
310            score: 0.016,
311            source: "hybrid".to_string(),
312            vec_rank: None,
313            fts_rank: Some(2),
314            rrf_score: Some(0.016),
315        };
316        let json = serde_json::to_string(&item).unwrap();
317        assert!(
318            !json.contains("\"vec_rank\""),
319            "NÃO deve conter vec_rank quando None"
320        );
321        assert!(
322            json.contains("\"fts_rank\""),
323            "deve conter fts_rank quando Some"
324        );
325    }
326
327    #[test]
328    fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
329        let item = HybridSearchItem {
330            memory_id: 3,
331            name: "mem3".to_string(),
332            namespace: "ns".to_string(),
333            memory_type: "entity".to_string(),
334            description: "desc3".to_string(),
335            body: "corpo3".to_string(),
336            combined_score: 0.05,
337            score: 0.05,
338            source: "hybrid".to_string(),
339            vec_rank: Some(3),
340            fts_rank: Some(1),
341            rrf_score: Some(0.05),
342        };
343        let json = serde_json::to_string(&item).unwrap();
344        assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
345        assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
346        assert!(json.contains("\"type\""), "deve serializar type renomeado");
347        assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
348    }
349
350    #[test]
351    fn hybrid_search_response_serializa_k_corretamente() {
352        let resp = resposta_vazia(5, 60, 1.0, 1.0);
353        let json = serde_json::to_string(&resp).unwrap();
354        assert!(json.contains("\"k\":5"), "deve serializar k=5");
355    }
356}