Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::output::{self, JsonOutputFormat, RecallItem};
4use crate::paths::AppPaths;
5use crate::storage::connection::open_ro;
6use crate::storage::memories;
7
8use std::collections::HashMap;
9
10/// Arguments for the `hybrid-search` subcommand.
11///
12/// When `--namespace` is omitted the search runs against the `global` namespace,
13/// which is the default namespace used by `remember` when no `--namespace` flag
14/// is provided. Pass an explicit `--namespace` value to search a different
15/// isolated namespace.
16#[derive(clap::Args)]
17pub struct HybridSearchArgs {
18    pub query: String,
19    #[arg(short = 'k', long, default_value = "10")]
20    pub k: usize,
21    #[arg(long, default_value = "60")]
22    pub rrf_k: u32,
23    #[arg(long, default_value = "1.0")]
24    pub weight_vec: f32,
25    #[arg(long, default_value = "1.0")]
26    pub weight_fts: f32,
27    /// Filter by memory.type. Note: distinct from graph entity_type
28    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker)
29    /// used in --entities-file.
30    #[arg(long, value_enum)]
31    pub r#type: Option<MemoryType>,
32    #[arg(long)]
33    pub namespace: Option<String>,
34    #[arg(long)]
35    pub with_graph: bool,
36    #[arg(long, default_value = "2")]
37    pub max_hops: u32,
38    #[arg(long, default_value = "0.3")]
39    pub min_weight: f64,
40    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
41    pub format: JsonOutputFormat,
42    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
43    pub db: Option<String>,
44    /// Accept `--json` as a no-op because output is already JSON by default.
45    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
46    pub json: bool,
47}
48
49#[derive(serde::Serialize)]
50pub struct HybridSearchItem {
51    pub memory_id: i64,
52    pub name: String,
53    pub namespace: String,
54    #[serde(rename = "type")]
55    pub memory_type: String,
56    pub description: String,
57    pub body: String,
58    pub combined_score: f64,
59    /// Alias de `combined_score` para contrato documentado em SKILL.md.
60    pub score: f64,
61    /// Fonte do match: sempre "hybrid" (RRF de vec + fts). Adicionado em v2.0.1.
62    pub source: String,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub vec_rank: Option<usize>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub fts_rank: Option<usize>,
67    /// Score RRF combinado — alias explícito de `combined_score` para contratos de integração.
68    #[serde(skip_serializing_if = "Option::is_none")]
69    pub rrf_score: Option<f64>,
70}
71
72/// Pesos RRF usados na busca híbrida: vec (vetorial) e fts (texto).
73#[derive(serde::Serialize)]
74pub struct Weights {
75    pub vec: f32,
76    pub fts: f32,
77}
78
79#[derive(serde::Serialize)]
80pub struct HybridSearchResponse {
81    pub query: String,
82    pub k: usize,
83    /// Parâmetro k do RRF usado no ranking combinado.
84    pub rrf_k: u32,
85    /// Pesos aplicados às fontes vec e fts no RRF.
86    pub weights: Weights,
87    pub results: Vec<HybridSearchItem>,
88    pub graph_matches: Vec<RecallItem>,
89    /// Tempo total de execução em milissegundos desde início do handler até serialização.
90    pub elapsed_ms: u64,
91}
92
93pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
94    let start = std::time::Instant::now();
95    let _ = args.format;
96
97    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
98    let paths = AppPaths::resolve(args.db.as_deref())?;
99    if !paths.db.exists() {
100        return Err(AppError::NotFound(
101            crate::i18n::erros::banco_nao_encontrado(&paths.db.display().to_string()),
102        ));
103    }
104
105    output::emit_progress_i18n(
106        "Computing query embedding...",
107        "Calculando embedding da consulta...",
108    );
109    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
110
111    let conn = open_ro(&paths.db)?;
112
113    let memory_type_str = args.r#type.map(|t| t.as_str());
114
115    let vec_results = memories::knn_search(
116        &conn,
117        &embedding,
118        &[namespace.clone()],
119        memory_type_str,
120        args.k * 2,
121    )?;
122
123    // Mapear posição de ranking vetorial por memory_id (1-indexed conforme schema)
124    let vec_rank_map: HashMap<i64, usize> = vec_results
125        .iter()
126        .enumerate()
127        .map(|(pos, (id, _))| (*id, pos + 1))
128        .collect();
129
130    let fts_results =
131        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
132
133    // Mapear posição de ranking FTS por memory_id (1-indexed conforme schema)
134    let fts_rank_map: HashMap<i64, usize> = fts_results
135        .iter()
136        .enumerate()
137        .map(|(pos, row)| (row.id, pos + 1))
138        .collect();
139
140    let rrf_k = args.rrf_k as f64;
141
142    // Acumular scores RRF combinados
143    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
144
145    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
146        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
147        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
148    }
149
150    for (rank, row) in fts_results.iter().enumerate() {
151        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
152        *combined_scores.entry(row.id).or_insert(0.0) += score;
153    }
154
155    // Ordenar por score descendente e tomar os top-k
156    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
157    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
158    ranked.truncate(args.k);
159
160    // Coletar todos os IDs para busca em batch (evitar N+1)
161    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
162
163    // Buscar dados completos das memórias top
164    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
165    for id in &top_ids {
166        if let Some(row) = memories::read_full(&conn, *id)? {
167            memory_data.insert(*id, row);
168        }
169    }
170
171    // Construir resultados finais na ordem de ranking
172    let results: Vec<HybridSearchItem> = ranked
173        .into_iter()
174        .filter_map(|(memory_id, combined_score)| {
175            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
176                memory_id: row.id,
177                name: row.name,
178                namespace: row.namespace,
179                memory_type: row.memory_type,
180                description: row.description,
181                body: row.body,
182                combined_score,
183                score: combined_score,
184                source: "hybrid".to_string(),
185                vec_rank: vec_rank_map.get(&memory_id).copied(),
186                fts_rank: fts_rank_map.get(&memory_id).copied(),
187                rrf_score: Some(combined_score),
188            })
189        })
190        .collect();
191
192    output::emit_json(&HybridSearchResponse {
193        query: args.query,
194        k: args.k,
195        rrf_k: args.rrf_k,
196        weights: Weights {
197            vec: args.weight_vec,
198            fts: args.weight_fts,
199        },
200        results,
201        graph_matches: vec![],
202        elapsed_ms: start.elapsed().as_millis() as u64,
203    })?;
204
205    Ok(())
206}
207
208#[cfg(test)]
209mod testes {
210    use super::*;
211
212    fn resposta_vazia(
213        k: usize,
214        rrf_k: u32,
215        weight_vec: f32,
216        weight_fts: f32,
217    ) -> HybridSearchResponse {
218        HybridSearchResponse {
219            query: "busca teste".to_string(),
220            k,
221            rrf_k,
222            weights: Weights {
223                vec: weight_vec,
224                fts: weight_fts,
225            },
226            results: vec![],
227            graph_matches: vec![],
228            elapsed_ms: 0,
229        }
230    }
231
232    #[test]
233    fn hybrid_search_response_vazia_serializa_campos_corretos() {
234        let resp = resposta_vazia(10, 60, 1.0, 1.0);
235        let json = serde_json::to_string(&resp).unwrap();
236        assert!(json.contains("\"results\""), "deve conter campo results");
237        assert!(json.contains("\"query\""), "deve conter campo query");
238        assert!(json.contains("\"k\""), "deve conter campo k");
239        assert!(
240            json.contains("\"graph_matches\""),
241            "deve conter campo graph_matches"
242        );
243        assert!(
244            !json.contains("\"combined_rank\""),
245            "NÃO deve conter combined_rank"
246        );
247        assert!(
248            !json.contains("\"vec_rank_list\""),
249            "NÃO deve conter vec_rank_list"
250        );
251        assert!(
252            !json.contains("\"fts_rank_list\""),
253            "NÃO deve conter fts_rank_list"
254        );
255    }
256
257    #[test]
258    fn hybrid_search_response_serializa_rrf_k_e_weights() {
259        let resp = resposta_vazia(5, 60, 0.7, 0.3);
260        let json = serde_json::to_string(&resp).unwrap();
261        assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
262        assert!(json.contains("\"weights\""), "deve conter campo weights");
263        assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
264        assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
265    }
266
267    #[test]
268    fn hybrid_search_response_serializa_elapsed_ms() {
269        let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
270        resp.elapsed_ms = 123;
271        let json = serde_json::to_string(&resp).unwrap();
272        assert!(
273            json.contains("\"elapsed_ms\""),
274            "deve conter campo elapsed_ms"
275        );
276        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
277    }
278
279    #[test]
280    fn weights_struct_serializa_corretamente() {
281        let w = Weights { vec: 0.6, fts: 0.4 };
282        let json = serde_json::to_string(&w).unwrap();
283        assert!(json.contains("\"vec\""));
284        assert!(json.contains("\"fts\""));
285    }
286
287    #[test]
288    fn hybrid_search_item_omite_fts_rank_quando_none() {
289        let item = HybridSearchItem {
290            memory_id: 1,
291            name: "mem".to_string(),
292            namespace: "default".to_string(),
293            memory_type: "user".to_string(),
294            description: "desc".to_string(),
295            body: "conteúdo".to_string(),
296            combined_score: 0.0328,
297            score: 0.0328,
298            source: "hybrid".to_string(),
299            vec_rank: Some(1),
300            fts_rank: None,
301            rrf_score: Some(0.0328),
302        };
303        let json = serde_json::to_string(&item).unwrap();
304        assert!(
305            json.contains("\"vec_rank\""),
306            "deve conter vec_rank quando Some"
307        );
308        assert!(
309            !json.contains("\"fts_rank\""),
310            "NÃO deve conter fts_rank quando None"
311        );
312    }
313
314    #[test]
315    fn hybrid_search_item_omite_vec_rank_quando_none() {
316        let item = HybridSearchItem {
317            memory_id: 2,
318            name: "mem2".to_string(),
319            namespace: "default".to_string(),
320            memory_type: "fact".to_string(),
321            description: "desc2".to_string(),
322            body: "corpo2".to_string(),
323            combined_score: 0.016,
324            score: 0.016,
325            source: "hybrid".to_string(),
326            vec_rank: None,
327            fts_rank: Some(2),
328            rrf_score: Some(0.016),
329        };
330        let json = serde_json::to_string(&item).unwrap();
331        assert!(
332            !json.contains("\"vec_rank\""),
333            "NÃO deve conter vec_rank quando None"
334        );
335        assert!(
336            json.contains("\"fts_rank\""),
337            "deve conter fts_rank quando Some"
338        );
339    }
340
341    #[test]
342    fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
343        let item = HybridSearchItem {
344            memory_id: 3,
345            name: "mem3".to_string(),
346            namespace: "ns".to_string(),
347            memory_type: "entity".to_string(),
348            description: "desc3".to_string(),
349            body: "corpo3".to_string(),
350            combined_score: 0.05,
351            score: 0.05,
352            source: "hybrid".to_string(),
353            vec_rank: Some(3),
354            fts_rank: Some(1),
355            rrf_score: Some(0.05),
356        };
357        let json = serde_json::to_string(&item).unwrap();
358        assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
359        assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
360        assert!(json.contains("\"type\""), "deve serializar type renomeado");
361        assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
362    }
363
364    #[test]
365    fn hybrid_search_response_serializa_k_corretamente() {
366        let resp = resposta_vazia(5, 60, 1.0, 1.0);
367        let json = serde_json::to_string(&resp).unwrap();
368        assert!(json.contains("\"k\":5"), "deve serializar k=5");
369    }
370}