Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1//! Handler for the `hybrid-search` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::output::{self, JsonOutputFormat, RecallItem};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::memories;
9
10use std::collections::HashMap;
11
12/// Arguments for the `hybrid-search` subcommand.
13///
14/// When `--namespace` is omitted the search runs against the `global` namespace,
15/// which is the default namespace used by `remember` when no `--namespace` flag
16/// is provided. Pass an explicit `--namespace` value to search a different
17/// isolated namespace.
18#[derive(clap::Args)]
19pub struct HybridSearchArgs {
20    #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
21    pub query: String,
22    /// Maximum number of fused results to return after RRF combines vector + FTS5 candidates.
23    ///
24    /// Validated to the inclusive range `1..=4096` (the upper bound matches `sqlite-vec`'s knn
25    /// limit). Each underlying search fetches `k * 2` candidates before fusion.
26    #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
27    pub k: usize,
28    #[arg(long, default_value = "60")]
29    pub rrf_k: u32,
30    #[arg(long, default_value = "1.0")]
31    pub weight_vec: f32,
32    #[arg(long, default_value = "1.0")]
33    pub weight_fts: f32,
34    /// Filter by memory.type. Note: distinct from graph entity_type
35    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
36    /// used in --entities-file.
37    #[arg(long, value_enum)]
38    pub r#type: Option<MemoryType>,
39    #[arg(long)]
40    pub namespace: Option<String>,
41    #[arg(long)]
42    pub with_graph: bool,
43    #[arg(long, default_value = "2")]
44    pub max_hops: u32,
45    #[arg(long, default_value = "0.3")]
46    pub min_weight: f64,
47    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
48    pub format: JsonOutputFormat,
49    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
50    pub db: Option<String>,
51    /// Accept `--json` as a no-op because output is already JSON by default.
52    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
53    pub json: bool,
54    #[command(flatten)]
55    pub daemon: crate::cli::DaemonOpts,
56}
57
58#[derive(serde::Serialize)]
59pub struct HybridSearchItem {
60    pub memory_id: i64,
61    pub name: String,
62    pub namespace: String,
63    #[serde(rename = "type")]
64    pub memory_type: String,
65    pub description: String,
66    pub body: String,
67    pub combined_score: f64,
68    /// Alias de `combined_score` para contrato documentado em SKILL.md.
69    pub score: f64,
70    /// Fonte do match: sempre "hybrid" (RRF de vec + fts). Adicionado em v2.0.1.
71    pub source: String,
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub vec_rank: Option<usize>,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub fts_rank: Option<usize>,
76    /// Combined RRF score — explicit alias of `combined_score` for integration contracts.
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub rrf_score: Option<f64>,
79}
80
81/// RRF weights used in hybrid search: vec (vector) and fts (text).
82#[derive(serde::Serialize)]
83pub struct Weights {
84    pub vec: f32,
85    pub fts: f32,
86}
87
88#[derive(serde::Serialize)]
89pub struct HybridSearchResponse {
90    pub query: String,
91    pub k: usize,
92    /// RRF k parameter used in the combined ranking.
93    pub rrf_k: u32,
94    /// Pesos aplicados às fontes vec e fts no RRF.
95    pub weights: Weights,
96    pub results: Vec<HybridSearchItem>,
97    pub graph_matches: Vec<RecallItem>,
98    /// Total execution time in milliseconds from handler start to serialisation.
99    pub elapsed_ms: u64,
100}
101
102pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
103    let start = std::time::Instant::now();
104    let _ = args.format;
105
106    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
107    let paths = AppPaths::resolve(args.db.as_deref())?;
108    crate::storage::connection::ensure_db_ready(&paths)?;
109
110    output::emit_progress_i18n(
111        "Computing query embedding...",
112        "Calculando embedding da consulta...",
113    );
114    let embedding = crate::daemon::embed_query_or_local(
115        &paths.models,
116        &args.query,
117        args.daemon.autostart_daemon,
118    )?;
119
120    let conn = open_ro(&paths.db)?;
121
122    let memory_type_str = args.r#type.map(|t| t.as_str());
123
124    let vec_results = memories::knn_search(
125        &conn,
126        &embedding,
127        &[namespace.clone()],
128        memory_type_str,
129        args.k * 2,
130    )?;
131
132    // Map vector ranking position by memory_id (1-indexed per schema)
133    let vec_rank_map: HashMap<i64, usize> = vec_results
134        .iter()
135        .enumerate()
136        .map(|(pos, (id, _))| (*id, pos + 1))
137        .collect();
138
139    let fts_results =
140        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
141
142    // Map FTS ranking position by memory_id (1-indexed per schema)
143    let fts_rank_map: HashMap<i64, usize> = fts_results
144        .iter()
145        .enumerate()
146        .map(|(pos, row)| (row.id, pos + 1))
147        .collect();
148
149    let rrf_k = args.rrf_k as f64;
150
151    // Accumulate combined RRF scores
152    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
153
154    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
155        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
156        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
157    }
158
159    for (rank, row) in fts_results.iter().enumerate() {
160        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
161        *combined_scores.entry(row.id).or_insert(0.0) += score;
162    }
163
164    // Sort by score descending and take the top-k
165    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
166    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
167    ranked.truncate(args.k);
168
169    // Collect all IDs for batch fetch (avoiding N+1)
170    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
171
172    // Fetch full data for the top memories
173    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
174    for id in &top_ids {
175        if let Some(row) = memories::read_full(&conn, *id)? {
176            memory_data.insert(*id, row);
177        }
178    }
179
180    // Construir resultados finais na ordem de ranking
181    let results: Vec<HybridSearchItem> = ranked
182        .into_iter()
183        .filter_map(|(memory_id, combined_score)| {
184            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
185                memory_id: row.id,
186                name: row.name,
187                namespace: row.namespace,
188                memory_type: row.memory_type,
189                description: row.description,
190                body: row.body,
191                combined_score,
192                score: combined_score,
193                source: "hybrid".to_string(),
194                vec_rank: vec_rank_map.get(&memory_id).copied(),
195                fts_rank: fts_rank_map.get(&memory_id).copied(),
196                rrf_score: Some(combined_score),
197            })
198        })
199        .collect();
200
201    output::emit_json(&HybridSearchResponse {
202        query: args.query,
203        k: args.k,
204        rrf_k: args.rrf_k,
205        weights: Weights {
206            vec: args.weight_vec,
207            fts: args.weight_fts,
208        },
209        results,
210        graph_matches: vec![],
211        elapsed_ms: start.elapsed().as_millis() as u64,
212    })?;
213
214    Ok(())
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    fn empty_response(
222        k: usize,
223        rrf_k: u32,
224        weight_vec: f32,
225        weight_fts: f32,
226    ) -> HybridSearchResponse {
227        HybridSearchResponse {
228            query: "busca teste".to_string(),
229            k,
230            rrf_k,
231            weights: Weights {
232                vec: weight_vec,
233                fts: weight_fts,
234            },
235            results: vec![],
236            graph_matches: vec![],
237            elapsed_ms: 0,
238        }
239    }
240
241    #[test]
242    fn hybrid_search_response_empty_serializes_correct_fields() {
243        let resp = empty_response(10, 60, 1.0, 1.0);
244        let json = serde_json::to_string(&resp).unwrap();
245        assert!(json.contains("\"results\""), "must contain results field");
246        assert!(json.contains("\"query\""), "must contain query field");
247        assert!(json.contains("\"k\""), "must contain k field");
248        assert!(
249            json.contains("\"graph_matches\""),
250            "must contain graph_matches field"
251        );
252        assert!(
253            !json.contains("\"combined_rank\""),
254            "must not contain combined_rank"
255        );
256        assert!(
257            !json.contains("\"vec_rank_list\""),
258            "must not contain vec_rank_list"
259        );
260        assert!(
261            !json.contains("\"fts_rank_list\""),
262            "must not contain fts_rank_list"
263        );
264    }
265
266    #[test]
267    fn hybrid_search_response_serializes_rrf_k_and_weights() {
268        let resp = empty_response(5, 60, 0.7, 0.3);
269        let json = serde_json::to_string(&resp).unwrap();
270        assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
271        assert!(json.contains("\"weights\""), "must contain weights field");
272        assert!(json.contains("\"vec\""), "must contain weights.vec field");
273        assert!(json.contains("\"fts\""), "must contain weights.fts field");
274    }
275
276    #[test]
277    fn hybrid_search_response_serializes_elapsed_ms() {
278        let mut resp = empty_response(5, 60, 1.0, 1.0);
279        resp.elapsed_ms = 123;
280        let json = serde_json::to_string(&resp).unwrap();
281        assert!(
282            json.contains("\"elapsed_ms\""),
283            "must contain elapsed_ms field"
284        );
285        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
286    }
287
288    #[test]
289    fn weights_struct_serializes_correctly() {
290        let w = Weights { vec: 0.6, fts: 0.4 };
291        let json = serde_json::to_string(&w).unwrap();
292        assert!(json.contains("\"vec\""));
293        assert!(json.contains("\"fts\""));
294    }
295
296    #[test]
297    fn hybrid_search_item_omits_fts_rank_when_none() {
298        let item = HybridSearchItem {
299            memory_id: 1,
300            name: "mem".to_string(),
301            namespace: "default".to_string(),
302            memory_type: "user".to_string(),
303            description: "desc".to_string(),
304            body: "content".to_string(),
305            combined_score: 0.0328,
306            score: 0.0328,
307            source: "hybrid".to_string(),
308            vec_rank: Some(1),
309            fts_rank: None,
310            rrf_score: Some(0.0328),
311        };
312        let json = serde_json::to_string(&item).unwrap();
313        assert!(
314            json.contains("\"vec_rank\""),
315            "must contain vec_rank when Some"
316        );
317        assert!(
318            !json.contains("\"fts_rank\""),
319            "must not contain fts_rank when None"
320        );
321    }
322
323    #[test]
324    fn hybrid_search_item_omits_vec_rank_when_none() {
325        let item = HybridSearchItem {
326            memory_id: 2,
327            name: "mem2".to_string(),
328            namespace: "default".to_string(),
329            memory_type: "fact".to_string(),
330            description: "desc2".to_string(),
331            body: "corpo2".to_string(),
332            combined_score: 0.016,
333            score: 0.016,
334            source: "hybrid".to_string(),
335            vec_rank: None,
336            fts_rank: Some(2),
337            rrf_score: Some(0.016),
338        };
339        let json = serde_json::to_string(&item).unwrap();
340        assert!(
341            !json.contains("\"vec_rank\""),
342            "must not contain vec_rank when None"
343        );
344        assert!(
345            json.contains("\"fts_rank\""),
346            "must contain fts_rank when Some"
347        );
348    }
349
350    #[test]
351    fn hybrid_search_item_serializes_both_ranks_when_some() {
352        let item = HybridSearchItem {
353            memory_id: 3,
354            name: "mem3".to_string(),
355            namespace: "ns".to_string(),
356            memory_type: "entity".to_string(),
357            description: "desc3".to_string(),
358            body: "corpo3".to_string(),
359            combined_score: 0.05,
360            score: 0.05,
361            source: "hybrid".to_string(),
362            vec_rank: Some(3),
363            fts_rank: Some(1),
364            rrf_score: Some(0.05),
365        };
366        let json = serde_json::to_string(&item).unwrap();
367        assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
368        assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
369        assert!(json.contains("\"type\""), "deve serializar type renomeado");
370        assert!(!json.contains("memory_type"), "must not expose memory_type");
371    }
372
373    #[test]
374    fn hybrid_search_response_serializes_k_correctly() {
375        let resp = empty_response(5, 60, 1.0, 1.0);
376        let json = serde_json::to_string(&resp).unwrap();
377        assert!(json.contains("\"k\":5"), "deve serializar k=5");
378    }
379}