Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1//! Handler for the `hybrid-search` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::output::{self, JsonOutputFormat, RecallItem};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::memories;
9
10use std::collections::HashMap;
11
12/// Arguments for the `hybrid-search` subcommand.
13///
14/// When `--namespace` is omitted the search runs against the `global` namespace,
15/// which is the default namespace used by `remember` when no `--namespace` flag
16/// is provided. Pass an explicit `--namespace` value to search a different
17/// isolated namespace.
18#[derive(clap::Args)]
19pub struct HybridSearchArgs {
20    pub query: String,
21    #[arg(short = 'k', long, default_value = "10")]
22    pub k: usize,
23    #[arg(long, default_value = "60")]
24    pub rrf_k: u32,
25    #[arg(long, default_value = "1.0")]
26    pub weight_vec: f32,
27    #[arg(long, default_value = "1.0")]
28    pub weight_fts: f32,
29    /// Filter by memory.type. Note: distinct from graph entity_type
30    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker)
31    /// used in --entities-file.
32    #[arg(long, value_enum)]
33    pub r#type: Option<MemoryType>,
34    #[arg(long)]
35    pub namespace: Option<String>,
36    #[arg(long)]
37    pub with_graph: bool,
38    #[arg(long, default_value = "2")]
39    pub max_hops: u32,
40    #[arg(long, default_value = "0.3")]
41    pub min_weight: f64,
42    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
43    pub format: JsonOutputFormat,
44    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
45    pub db: Option<String>,
46    /// Accept `--json` as a no-op because output is already JSON by default.
47    #[arg(long, help = "No-op; JSON is always emitted on stdout")]
48    pub json: bool,
49}
50
51#[derive(serde::Serialize)]
52pub struct HybridSearchItem {
53    pub memory_id: i64,
54    pub name: String,
55    pub namespace: String,
56    #[serde(rename = "type")]
57    pub memory_type: String,
58    pub description: String,
59    pub body: String,
60    pub combined_score: f64,
61    /// Alias de `combined_score` para contrato documentado em SKILL.md.
62    pub score: f64,
63    /// Fonte do match: sempre "hybrid" (RRF de vec + fts). Adicionado em v2.0.1.
64    pub source: String,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub vec_rank: Option<usize>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub fts_rank: Option<usize>,
69    /// Combined RRF score — explicit alias of `combined_score` for integration contracts.
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub rrf_score: Option<f64>,
72}
73
74/// RRF weights used in hybrid search: vec (vector) and fts (text).
75#[derive(serde::Serialize)]
76pub struct Weights {
77    pub vec: f32,
78    pub fts: f32,
79}
80
81#[derive(serde::Serialize)]
82pub struct HybridSearchResponse {
83    pub query: String,
84    pub k: usize,
85    /// RRF k parameter used in the combined ranking.
86    pub rrf_k: u32,
87    /// Pesos aplicados às fontes vec e fts no RRF.
88    pub weights: Weights,
89    pub results: Vec<HybridSearchItem>,
90    pub graph_matches: Vec<RecallItem>,
91    /// Total execution time in milliseconds from handler start to serialisation.
92    pub elapsed_ms: u64,
93}
94
95pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
96    let start = std::time::Instant::now();
97    let _ = args.format;
98
99    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
100    let paths = AppPaths::resolve(args.db.as_deref())?;
101    if !paths.db.exists() {
102        return Err(AppError::NotFound(
103            crate::i18n::errors_msg::database_not_found(&paths.db.display().to_string()),
104        ));
105    }
106
107    output::emit_progress_i18n(
108        "Computing query embedding...",
109        "Calculando embedding da consulta...",
110    );
111    let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
112
113    let conn = open_ro(&paths.db)?;
114
115    let memory_type_str = args.r#type.map(|t| t.as_str());
116
117    let vec_results = memories::knn_search(
118        &conn,
119        &embedding,
120        &[namespace.clone()],
121        memory_type_str,
122        args.k * 2,
123    )?;
124
125    // Mapear posição de ranking vetorial por memory_id (1-indexed conforme schema)
126    let vec_rank_map: HashMap<i64, usize> = vec_results
127        .iter()
128        .enumerate()
129        .map(|(pos, (id, _))| (*id, pos + 1))
130        .collect();
131
132    let fts_results =
133        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
134
135    // Mapear posição de ranking FTS por memory_id (1-indexed conforme schema)
136    let fts_rank_map: HashMap<i64, usize> = fts_results
137        .iter()
138        .enumerate()
139        .map(|(pos, row)| (row.id, pos + 1))
140        .collect();
141
142    let rrf_k = args.rrf_k as f64;
143
144    // Acumular scores RRF combinados
145    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
146
147    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
148        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
149        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
150    }
151
152    for (rank, row) in fts_results.iter().enumerate() {
153        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
154        *combined_scores.entry(row.id).or_insert(0.0) += score;
155    }
156
157    // Ordenar por score descendente e tomar os top-k
158    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
159    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160    ranked.truncate(args.k);
161
162    // Coletar todos os IDs para busca em batch (evitar N+1)
163    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
164
165    // Buscar dados completos das memórias top
166    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
167    for id in &top_ids {
168        if let Some(row) = memories::read_full(&conn, *id)? {
169            memory_data.insert(*id, row);
170        }
171    }
172
173    // Construir resultados finais na ordem de ranking
174    let results: Vec<HybridSearchItem> = ranked
175        .into_iter()
176        .filter_map(|(memory_id, combined_score)| {
177            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
178                memory_id: row.id,
179                name: row.name,
180                namespace: row.namespace,
181                memory_type: row.memory_type,
182                description: row.description,
183                body: row.body,
184                combined_score,
185                score: combined_score,
186                source: "hybrid".to_string(),
187                vec_rank: vec_rank_map.get(&memory_id).copied(),
188                fts_rank: fts_rank_map.get(&memory_id).copied(),
189                rrf_score: Some(combined_score),
190            })
191        })
192        .collect();
193
194    output::emit_json(&HybridSearchResponse {
195        query: args.query,
196        k: args.k,
197        rrf_k: args.rrf_k,
198        weights: Weights {
199            vec: args.weight_vec,
200            fts: args.weight_fts,
201        },
202        results,
203        graph_matches: vec![],
204        elapsed_ms: start.elapsed().as_millis() as u64,
205    })?;
206
207    Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    fn resposta_vazia(
215        k: usize,
216        rrf_k: u32,
217        weight_vec: f32,
218        weight_fts: f32,
219    ) -> HybridSearchResponse {
220        HybridSearchResponse {
221            query: "busca teste".to_string(),
222            k,
223            rrf_k,
224            weights: Weights {
225                vec: weight_vec,
226                fts: weight_fts,
227            },
228            results: vec![],
229            graph_matches: vec![],
230            elapsed_ms: 0,
231        }
232    }
233
234    #[test]
235    fn hybrid_search_response_vazia_serializa_campos_corretos() {
236        let resp = resposta_vazia(10, 60, 1.0, 1.0);
237        let json = serde_json::to_string(&resp).unwrap();
238        assert!(json.contains("\"results\""), "deve conter campo results");
239        assert!(json.contains("\"query\""), "deve conter campo query");
240        assert!(json.contains("\"k\""), "deve conter campo k");
241        assert!(
242            json.contains("\"graph_matches\""),
243            "deve conter campo graph_matches"
244        );
245        assert!(
246            !json.contains("\"combined_rank\""),
247            "NÃO deve conter combined_rank"
248        );
249        assert!(
250            !json.contains("\"vec_rank_list\""),
251            "NÃO deve conter vec_rank_list"
252        );
253        assert!(
254            !json.contains("\"fts_rank_list\""),
255            "NÃO deve conter fts_rank_list"
256        );
257    }
258
259    #[test]
260    fn hybrid_search_response_serializa_rrf_k_e_weights() {
261        let resp = resposta_vazia(5, 60, 0.7, 0.3);
262        let json = serde_json::to_string(&resp).unwrap();
263        assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
264        assert!(json.contains("\"weights\""), "deve conter campo weights");
265        assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
266        assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
267    }
268
269    #[test]
270    fn hybrid_search_response_serializa_elapsed_ms() {
271        let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
272        resp.elapsed_ms = 123;
273        let json = serde_json::to_string(&resp).unwrap();
274        assert!(
275            json.contains("\"elapsed_ms\""),
276            "deve conter campo elapsed_ms"
277        );
278        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
279    }
280
281    #[test]
282    fn weights_struct_serializa_corretamente() {
283        let w = Weights { vec: 0.6, fts: 0.4 };
284        let json = serde_json::to_string(&w).unwrap();
285        assert!(json.contains("\"vec\""));
286        assert!(json.contains("\"fts\""));
287    }
288
289    #[test]
290    fn hybrid_search_item_omite_fts_rank_quando_none() {
291        let item = HybridSearchItem {
292            memory_id: 1,
293            name: "mem".to_string(),
294            namespace: "default".to_string(),
295            memory_type: "user".to_string(),
296            description: "desc".to_string(),
297            body: "conteúdo".to_string(),
298            combined_score: 0.0328,
299            score: 0.0328,
300            source: "hybrid".to_string(),
301            vec_rank: Some(1),
302            fts_rank: None,
303            rrf_score: Some(0.0328),
304        };
305        let json = serde_json::to_string(&item).unwrap();
306        assert!(
307            json.contains("\"vec_rank\""),
308            "deve conter vec_rank quando Some"
309        );
310        assert!(
311            !json.contains("\"fts_rank\""),
312            "NÃO deve conter fts_rank quando None"
313        );
314    }
315
316    #[test]
317    fn hybrid_search_item_omite_vec_rank_quando_none() {
318        let item = HybridSearchItem {
319            memory_id: 2,
320            name: "mem2".to_string(),
321            namespace: "default".to_string(),
322            memory_type: "fact".to_string(),
323            description: "desc2".to_string(),
324            body: "corpo2".to_string(),
325            combined_score: 0.016,
326            score: 0.016,
327            source: "hybrid".to_string(),
328            vec_rank: None,
329            fts_rank: Some(2),
330            rrf_score: Some(0.016),
331        };
332        let json = serde_json::to_string(&item).unwrap();
333        assert!(
334            !json.contains("\"vec_rank\""),
335            "NÃO deve conter vec_rank quando None"
336        );
337        assert!(
338            json.contains("\"fts_rank\""),
339            "deve conter fts_rank quando Some"
340        );
341    }
342
343    #[test]
344    fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
345        let item = HybridSearchItem {
346            memory_id: 3,
347            name: "mem3".to_string(),
348            namespace: "ns".to_string(),
349            memory_type: "entity".to_string(),
350            description: "desc3".to_string(),
351            body: "corpo3".to_string(),
352            combined_score: 0.05,
353            score: 0.05,
354            source: "hybrid".to_string(),
355            vec_rank: Some(3),
356            fts_rank: Some(1),
357            rrf_score: Some(0.05),
358        };
359        let json = serde_json::to_string(&item).unwrap();
360        assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
361        assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
362        assert!(json.contains("\"type\""), "deve serializar type renomeado");
363        assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
364    }
365
366    #[test]
367    fn hybrid_search_response_serializa_k_corretamente() {
368        let resp = resposta_vazia(5, 60, 1.0, 1.0);
369        let json = serde_json::to_string(&resp).unwrap();
370        assert!(json.contains("\"k\":5"), "deve serializar k=5");
371    }
372}