Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::output::{self, JsonOutputFormat, RecallItem};
4use crate::paths::AppPaths;
5use crate::storage::connection::open_ro;
6use crate::storage::memories;
7
8use std::collections::HashMap;
9
10#[derive(clap::Args)]
11pub struct HybridSearchArgs {
12    pub query: String,
13    #[arg(short = 'k', long, default_value = "10")]
14    pub k: usize,
15    #[arg(long, default_value = "60")]
16    pub rrf_k: u32,
17    #[arg(long, default_value = "1.0")]
18    pub weight_vec: f32,
19    #[arg(long, default_value = "1.0")]
20    pub weight_fts: f32,
21    #[arg(long, value_enum)]
22    pub r#type: Option<MemoryType>,
23    #[arg(long)]
24    pub namespace: Option<String>,
25    #[arg(long)]
26    pub with_graph: bool,
27    #[arg(long, default_value = "2")]
28    pub max_hops: u32,
29    #[arg(long, default_value = "0.3")]
30    pub min_weight: f64,
31    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
32    pub format: JsonOutputFormat,
33    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
34    pub db: Option<String>,
35    /// Aceita --json como no-op: output já é JSON por default.
36    #[arg(long, hide = true)]
37    pub json: bool,
38}
39
40#[derive(serde::Serialize)]
41pub struct HybridSearchItem {
42    pub memory_id: i64,
43    pub name: String,
44    pub namespace: String,
45    #[serde(rename = "type")]
46    pub memory_type: String,
47    pub description: String,
48    pub body: String,
49    pub combined_score: f64,
50    /// Alias de `combined_score` para contrato documentado em SKILL.md.
51    pub score: f64,
52    /// Fonte do match: sempre "hybrid" (RRF de vec + fts). Adicionado em v2.0.1.
53    pub source: String,
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub vec_rank: Option<usize>,
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub fts_rank: Option<usize>,
58}
59
60/// Pesos RRF usados na busca híbrida: vec (vetorial) e fts (texto).
61#[derive(serde::Serialize)]
62pub struct Weights {
63    pub vec: f32,
64    pub fts: f32,
65}
66
67#[derive(serde::Serialize)]
68pub struct HybridSearchResponse {
69    pub query: String,
70    pub k: usize,
71    /// Parâmetro k do RRF usado no ranking combinado.
72    pub rrf_k: u32,
73    /// Pesos aplicados às fontes vec e fts no RRF.
74    pub weights: Weights,
75    pub results: Vec<HybridSearchItem>,
76    pub graph_matches: Vec<RecallItem>,
77    /// Tempo total de execução em milissegundos desde início do handler até serialização.
78    pub elapsed_ms: u64,
79}
80
81pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
82    let start = std::time::Instant::now();
83    let _ = args.format;
84
85    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
86    let paths = AppPaths::resolve(args.db.as_deref())?;
87
88    output::emit_progress_i18n(
89        "Computing query embedding...",
90        "Calculando embedding da consulta...",
91    );
92    let embedder = crate::embedder::get_embedder(&paths.models)?;
93    let embedding = crate::embedder::embed_query(embedder, &args.query)?;
94
95    let conn = open_ro(&paths.db)?;
96
97    let memory_type_str = args.r#type.map(|t| t.as_str());
98
99    let vec_results =
100        memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k * 2)?;
101
102    // Mapear posição de ranking vetorial por memory_id (1-indexed conforme schema)
103    let vec_rank_map: HashMap<i64, usize> = vec_results
104        .iter()
105        .enumerate()
106        .map(|(pos, (id, _))| (*id, pos + 1))
107        .collect();
108
109    let fts_results =
110        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
111
112    // Mapear posição de ranking FTS por memory_id (1-indexed conforme schema)
113    let fts_rank_map: HashMap<i64, usize> = fts_results
114        .iter()
115        .enumerate()
116        .map(|(pos, row)| (row.id, pos + 1))
117        .collect();
118
119    let rrf_k = args.rrf_k as f64;
120
121    // Acumular scores RRF combinados
122    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
123
124    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
125        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
126        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
127    }
128
129    for (rank, row) in fts_results.iter().enumerate() {
130        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
131        *combined_scores.entry(row.id).or_insert(0.0) += score;
132    }
133
134    // Ordenar por score descendente e tomar os top-k
135    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
136    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
137    ranked.truncate(args.k);
138
139    // Coletar todos os IDs para busca em batch (evitar N+1)
140    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
141
142    // Buscar dados completos das memórias top
143    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
144    for id in &top_ids {
145        if let Some(row) = memories::read_full(&conn, *id)? {
146            memory_data.insert(*id, row);
147        }
148    }
149
150    // Construir resultados finais na ordem de ranking
151    let results: Vec<HybridSearchItem> = ranked
152        .into_iter()
153        .filter_map(|(memory_id, combined_score)| {
154            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
155                memory_id: row.id,
156                name: row.name,
157                namespace: row.namespace,
158                memory_type: row.memory_type,
159                description: row.description,
160                body: row.body,
161                combined_score,
162                score: combined_score,
163                source: "hybrid".to_string(),
164                vec_rank: vec_rank_map.get(&memory_id).copied(),
165                fts_rank: fts_rank_map.get(&memory_id).copied(),
166            })
167        })
168        .collect();
169
170    output::emit_json(&HybridSearchResponse {
171        query: args.query,
172        k: args.k,
173        rrf_k: args.rrf_k,
174        weights: Weights {
175            vec: args.weight_vec,
176            fts: args.weight_fts,
177        },
178        results,
179        graph_matches: vec![],
180        elapsed_ms: start.elapsed().as_millis() as u64,
181    })?;
182
183    Ok(())
184}
185
186#[cfg(test)]
187mod testes {
188    use super::*;
189
190    fn resposta_vazia(
191        k: usize,
192        rrf_k: u32,
193        weight_vec: f32,
194        weight_fts: f32,
195    ) -> HybridSearchResponse {
196        HybridSearchResponse {
197            query: "busca teste".to_string(),
198            k,
199            rrf_k,
200            weights: Weights {
201                vec: weight_vec,
202                fts: weight_fts,
203            },
204            results: vec![],
205            graph_matches: vec![],
206            elapsed_ms: 0,
207        }
208    }
209
210    #[test]
211    fn hybrid_search_response_vazia_serializa_campos_corretos() {
212        let resp = resposta_vazia(10, 60, 1.0, 1.0);
213        let json = serde_json::to_string(&resp).unwrap();
214        assert!(json.contains("\"results\""), "deve conter campo results");
215        assert!(json.contains("\"query\""), "deve conter campo query");
216        assert!(json.contains("\"k\""), "deve conter campo k");
217        assert!(
218            json.contains("\"graph_matches\""),
219            "deve conter campo graph_matches"
220        );
221        assert!(
222            !json.contains("\"combined_rank\""),
223            "NÃO deve conter combined_rank"
224        );
225        assert!(
226            !json.contains("\"vec_rank_list\""),
227            "NÃO deve conter vec_rank_list"
228        );
229        assert!(
230            !json.contains("\"fts_rank_list\""),
231            "NÃO deve conter fts_rank_list"
232        );
233    }
234
235    #[test]
236    fn hybrid_search_response_serializa_rrf_k_e_weights() {
237        let resp = resposta_vazia(5, 60, 0.7, 0.3);
238        let json = serde_json::to_string(&resp).unwrap();
239        assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
240        assert!(json.contains("\"weights\""), "deve conter campo weights");
241        assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
242        assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
243    }
244
245    #[test]
246    fn hybrid_search_response_serializa_elapsed_ms() {
247        let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
248        resp.elapsed_ms = 123;
249        let json = serde_json::to_string(&resp).unwrap();
250        assert!(
251            json.contains("\"elapsed_ms\""),
252            "deve conter campo elapsed_ms"
253        );
254        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
255    }
256
257    #[test]
258    fn weights_struct_serializa_corretamente() {
259        let w = Weights { vec: 0.6, fts: 0.4 };
260        let json = serde_json::to_string(&w).unwrap();
261        assert!(json.contains("\"vec\""));
262        assert!(json.contains("\"fts\""));
263    }
264
265    #[test]
266    fn hybrid_search_item_omite_fts_rank_quando_none() {
267        let item = HybridSearchItem {
268            memory_id: 1,
269            name: "mem".to_string(),
270            namespace: "default".to_string(),
271            memory_type: "user".to_string(),
272            description: "desc".to_string(),
273            body: "conteúdo".to_string(),
274            combined_score: 0.0328,
275            score: 0.0328,
276            source: "hybrid".to_string(),
277            vec_rank: Some(1),
278            fts_rank: None,
279        };
280        let json = serde_json::to_string(&item).unwrap();
281        assert!(
282            json.contains("\"vec_rank\""),
283            "deve conter vec_rank quando Some"
284        );
285        assert!(
286            !json.contains("\"fts_rank\""),
287            "NÃO deve conter fts_rank quando None"
288        );
289    }
290
291    #[test]
292    fn hybrid_search_item_omite_vec_rank_quando_none() {
293        let item = HybridSearchItem {
294            memory_id: 2,
295            name: "mem2".to_string(),
296            namespace: "default".to_string(),
297            memory_type: "fact".to_string(),
298            description: "desc2".to_string(),
299            body: "corpo2".to_string(),
300            combined_score: 0.016,
301            score: 0.016,
302            source: "hybrid".to_string(),
303            vec_rank: None,
304            fts_rank: Some(2),
305        };
306        let json = serde_json::to_string(&item).unwrap();
307        assert!(
308            !json.contains("\"vec_rank\""),
309            "NÃO deve conter vec_rank quando None"
310        );
311        assert!(
312            json.contains("\"fts_rank\""),
313            "deve conter fts_rank quando Some"
314        );
315    }
316
317    #[test]
318    fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
319        let item = HybridSearchItem {
320            memory_id: 3,
321            name: "mem3".to_string(),
322            namespace: "ns".to_string(),
323            memory_type: "entity".to_string(),
324            description: "desc3".to_string(),
325            body: "corpo3".to_string(),
326            combined_score: 0.05,
327            score: 0.05,
328            source: "hybrid".to_string(),
329            vec_rank: Some(3),
330            fts_rank: Some(1),
331        };
332        let json = serde_json::to_string(&item).unwrap();
333        assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
334        assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
335        assert!(json.contains("\"type\""), "deve serializar type renomeado");
336        assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
337    }
338
339    #[test]
340    fn hybrid_search_response_serializa_k_corretamente() {
341        let resp = resposta_vazia(5, 60, 1.0, 1.0);
342        let json = serde_json::to_string(&resp).unwrap();
343        assert!(json.contains("\"k\":5"), "deve serializar k=5");
344    }
345}