Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1//! Handler for the `hybrid-search` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14/// Arguments for the `hybrid-search` subcommand.
15///
16/// When `--namespace` is omitted the search runs against the `global` namespace,
17/// which is the default namespace used by `remember` when no `--namespace` flag
18/// is provided. Pass an explicit `--namespace` value to search a different
19/// isolated namespace.
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # Basic hybrid search combining FTS5 + vector via RRF\n  \
23    sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n  \
24    # Tune RRF weights to favor keyword matches over semantic similarity\n  \
25    sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n  \
26    # Add graph traversal matches (entities connected to top results)\n  \
27    sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n  \
28    # Graph traversal with custom depth and minimum edge weight\n  \
29    sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n  \
30NOTES:\n  \
31    --with-graph enables entity graph traversal seeded by the top RRF results.\n  \
32    Graph matches appear in the `graph_matches` array (separate from `results`).\n  \
33    Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35    #[arg(
36        allow_hyphen_values = true,
37        help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)"
38    )]
39    pub query: String,
40    /// Maximum number of fused results to return after RRF combines vector + FTS5 candidates.
41    ///
42    /// Validated to the inclusive range `1..=4096` (the upper bound matches `sqlite-vec`'s knn
43    /// limit). Each underlying search fetches `k * 2` candidates before fusion.
44    #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
45    pub k: usize,
46    #[arg(long, default_value = "60")]
47    pub rrf_k: u32,
48    #[arg(long, default_value = "1.0")]
49    pub weight_vec: f32,
50    #[arg(long, default_value = "1.0")]
51    pub weight_fts: f32,
52    /// Filter by memory.type. Note: distinct from graph entity_type
53    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
54    /// used in --entities-file.
55    #[arg(long, value_enum)]
56    pub r#type: Option<MemoryType>,
57    #[arg(long)]
58    pub namespace: Option<String>,
59    #[arg(long)]
60    pub with_graph: bool,
61    #[arg(long, default_value = "2")]
62    pub max_hops: u32,
63    #[arg(long, default_value = "0.3")]
64    pub min_weight: f64,
65    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66    pub format: JsonOutputFormat,
67    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68    pub db: Option<String>,
69    /// Accept `--json` as a no-op because output is already JSON by default.
70    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71    pub json: bool,
72}
73
74#[derive(serde::Serialize)]
75pub struct HybridSearchItem {
76    pub memory_id: i64,
77    pub name: String,
78    pub namespace: String,
79    #[serde(rename = "type")]
80    pub memory_type: String,
81    pub description: String,
82    pub body: String,
83    pub snippet: String,
84    pub combined_score: f64,
85    /// Alias of `combined_score` for the documented contract in SKILL.md.
86    pub score: f64,
87    /// Source of the match: always "hybrid" (RRF of vec + fts). Added in v2.0.1.
88    pub source: String,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub vec_rank: Option<usize>,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub fts_rank: Option<usize>,
93    /// Combined RRF score — explicit alias of `combined_score` for integration contracts.
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub rrf_score: Option<f64>,
96    /// RRF score normalized to [0.0, 1.0] for cross-method comparability.
97    pub normalized_score: f64,
98    /// Raw KNN distance from the vector index (lower = more similar).
99    ///
100    /// Present when the result came from the vector search path; `None` when the
101    /// result appeared only in the FTS5 results and was not ranked by the KNN index.
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub vec_distance: Option<f64>,
104    /// Raw BM25 score from the FTS5 index. Currently always `None`; reserved for
105    /// a future release when the FTS5 BM25 score is exposed by the storage layer.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub fts_bm25: Option<f64>,
108}
109
110/// RRF weights used in hybrid search: vec (vector) and fts (text).
111#[derive(serde::Serialize)]
112pub struct Weights {
113    pub vec: f32,
114    pub fts: f32,
115}
116
117#[derive(serde::Serialize)]
118pub struct HybridSearchResponse {
119    pub query: String,
120    pub k: usize,
121    /// RRF k parameter used in the combined ranking.
122    pub rrf_k: u32,
123    /// Weights applied to vec and fts sources in the RRF fusion.
124    pub weights: Weights,
125    pub results: Vec<HybridSearchItem>,
126    pub graph_matches: Vec<RecallItem>,
127    /// True when FTS5 failed and the response is vec-only.
128    ///
129    /// Omitted from JSON when `false` to keep the happy-path envelope clean.
130    #[serde(skip_serializing_if = "std::ops::Not::not")]
131    pub fts_degraded: bool,
132    /// Human-readable description of the FTS5 failure when `fts_degraded` is true.
133    ///
134    /// Omitted from JSON when `None`.
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub fts_error: Option<String>,
137    /// True when the FTS5 index was corrupted and successfully auto-rebuilt during this request.
138    ///
139    /// Omitted from JSON when `false` to keep the happy-path envelope clean.
140    #[serde(skip_serializing_if = "std::ops::Not::not")]
141    pub fts_auto_rebuilt: bool,
142    /// Total execution time in milliseconds from handler start to serialisation.
143    pub elapsed_ms: u64,
144}
145
146#[tracing::instrument(skip_all, level = "debug", name = "hybrid_search")]
147pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
148    let start = std::time::Instant::now();
149    let _ = args.format;
150    tracing::debug!(target: "hybrid_search", query = %args.query, k = args.k, "fusing results");
151
152    // G20: reject graph-specific flags when --with-graph is not active
153    if !args.with_graph {
154        if args.max_hops != 2 {
155            return Err(AppError::Validation(
156                "--max-hops requires --with-graph to be active".to_string(),
157            ));
158        }
159        if (args.min_weight - 0.3).abs() > f64::EPSILON {
160            return Err(AppError::Validation(
161                "--min-weight requires --with-graph to be active".to_string(),
162            ));
163        }
164    }
165
166    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
167    let paths = AppPaths::resolve(args.db.as_deref())?;
168    crate::storage::connection::ensure_db_ready(&paths)?;
169
170    output::emit_progress_i18n(
171        "Computing query embedding...",
172        "Calculando embedding da consulta...",
173    );
174    let embedding = crate::embedder::embed_query_local(&paths.models, &args.query)?;
175
176    let conn = open_ro(&paths.db)?;
177
178    let memory_type_str = args.r#type.map(|t| t.as_str());
179
180    let vec_results = memories::knn_search(
181        &conn,
182        &embedding,
183        &[namespace.clone()],
184        memory_type_str,
185        args.k * 2,
186    )?;
187
188    // Map vector ranking position by memory_id (1-indexed per schema)
189    let vec_rank_map: HashMap<i64, usize> = vec_results
190        .iter()
191        .enumerate()
192        .map(|(pos, (id, _))| (*id, pos + 1))
193        .collect();
194
195    // Map raw KNN distance by memory_id for GAP-30: vec_distance field.
196    let vec_distance_map: HashMap<i64, f64> = vec_results
197        .iter()
198        .map(|(id, dist)| (*id, *dist as f64))
199        .collect();
200
201    let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
202        (vec![], false, None, false)
203    } else {
204        match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
205            Ok(r) => (r, false, None, false),
206            Err(e) => {
207                let err_msg = e.to_string();
208                let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
209                if is_malformed {
210                    tracing::warn!(target: "hybrid_search", "FTS5 index corrupted, attempting auto-rebuild");
211                    if conn
212                        .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
213                        .is_ok()
214                    {
215                        match memories::fts_search(
216                            &conn,
217                            &args.query,
218                            &namespace,
219                            memory_type_str,
220                            args.k * 2,
221                        ) {
222                            Ok(r) => (r, false, None, true),
223                            Err(e2) => {
224                                tracing::error!(target: "hybrid_search", error = %e2, "FTS5 auto-rebuild failed to recover");
225                                (vec![], true, Some(e2.to_string()), true)
226                            }
227                        }
228                    } else {
229                        (vec![], true, Some(err_msg), false)
230                    }
231                } else {
232                    tracing::warn!(target: "hybrid_search", error = %e, "FTS5 query failed, falling back to vec-only");
233                    (vec![], true, Some(err_msg), false)
234                }
235            }
236        }
237    };
238
239    // Map FTS ranking position by memory_id (1-indexed per schema)
240    let fts_rank_map: HashMap<i64, usize> = fts_results
241        .iter()
242        .enumerate()
243        .map(|(pos, row)| (row.id, pos + 1))
244        .collect();
245
246    let rrf_k = args.rrf_k as f64;
247
248    // Accumulate combined RRF scores
249    let mut combined_scores: crate::hash::AHashMap<i64, f64> =
250        crate::hash::AHashMap::with_capacity_and_hasher(
251            vec_results.len() + fts_results.len(),
252            Default::default(),
253        );
254
255    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
256        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
257        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
258    }
259
260    for (rank, row) in fts_results.iter().enumerate() {
261        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
262        *combined_scores.entry(row.id).or_insert(0.0) += score;
263    }
264
265    // Sort by score descending and take the top-k
266    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
267    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
268    ranked.truncate(args.k);
269
270    // Collect all IDs for batch fetch (avoiding N+1)
271    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
272
273    // Fetch full data for the top memories
274    let mut memory_data: crate::hash::AHashMap<i64, memories::MemoryRow> =
275        crate::hash::AHashMap::with_capacity_and_hasher(ranked.len(), Default::default());
276    for id in &top_ids {
277        if let Some(row) = memories::read_full(&conn, *id)? {
278            memory_data.insert(*id, row);
279        }
280    }
281
282    let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
283        + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
284
285    // Build final results in ranking order
286    let results: Vec<HybridSearchItem> = ranked
287        .into_iter()
288        .filter_map(|(memory_id, combined_score)| {
289            let normalized_score = if max_possible > 0.0 {
290                combined_score / max_possible
291            } else {
292                0.0
293            };
294            memory_data.remove(&memory_id).map(|row| {
295                let snippet: String = row.body.chars().take(300).collect();
296                HybridSearchItem {
297                    memory_id: row.id,
298                    name: row.name,
299                    namespace: row.namespace,
300                    memory_type: row.memory_type,
301                    description: row.description,
302                    body: row.body,
303                    snippet,
304                    combined_score,
305                    score: combined_score,
306                    source: "hybrid".to_string(),
307                    vec_rank: vec_rank_map.get(&memory_id).copied(),
308                    fts_rank: fts_rank_map.get(&memory_id).copied(),
309                    rrf_score: Some(combined_score),
310                    normalized_score,
311                    vec_distance: vec_distance_map.get(&memory_id).copied(),
312                    fts_bm25: None,
313                }
314            })
315        })
316        .collect();
317
318    // --- Graph traversal (activated by --with-graph) ---
319    let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
320    if args.with_graph && !results.is_empty() {
321        let namespace_for_graph = namespace.clone();
322        let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
323
324        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
325        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
326
327        let all_seed_ids: Vec<i64> = memory_ids
328            .iter()
329            .chain(entity_ids.iter())
330            .copied()
331            .collect();
332
333        if !all_seed_ids.is_empty() {
334            let graph_memory_ids = traverse_from_memories_with_hops(
335                &conn,
336                &all_seed_ids,
337                &namespace_for_graph,
338                args.min_weight,
339                args.max_hops,
340            )?;
341
342            let already_in_results: std::collections::HashSet<i64> =
343                results.iter().map(|r| r.memory_id).collect();
344
345            for (graph_mem_id, hop) in graph_memory_ids {
346                if already_in_results.contains(&graph_mem_id) {
347                    continue;
348                }
349                if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
350                    let snippet: String = row.body.chars().take(300).collect();
351                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
352                    graph_matches.push(RecallItem {
353                        memory_id: row.id,
354                        name: row.name,
355                        namespace: row.namespace,
356                        memory_type: row.memory_type,
357                        description: row.description,
358                        snippet,
359                        distance: graph_distance,
360                        score: RecallItem::score_from_distance(graph_distance),
361                        source: "graph".to_string(),
362                        graph_depth: Some(hop),
363                    });
364                }
365            }
366        }
367    }
368
369    output::emit_json(&HybridSearchResponse {
370        query: args.query,
371        k: args.k,
372        rrf_k: args.rrf_k,
373        weights: Weights {
374            vec: args.weight_vec,
375            fts: args.weight_fts,
376        },
377        results,
378        graph_matches,
379        fts_degraded,
380        fts_error,
381        fts_auto_rebuilt,
382        elapsed_ms: start.elapsed().as_millis() as u64,
383    })?;
384
385    Ok(())
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391
392    fn empty_response(
393        k: usize,
394        rrf_k: u32,
395        weight_vec: f32,
396        weight_fts: f32,
397    ) -> HybridSearchResponse {
398        HybridSearchResponse {
399            query: "test query".to_string(),
400            k,
401            rrf_k,
402            weights: Weights {
403                vec: weight_vec,
404                fts: weight_fts,
405            },
406            results: vec![],
407            graph_matches: vec![],
408            fts_degraded: false,
409            fts_error: None,
410            fts_auto_rebuilt: false,
411            elapsed_ms: 0,
412        }
413    }
414
415    #[test]
416    fn hybrid_search_response_empty_serializes_correct_fields() {
417        let resp = empty_response(10, 60, 1.0, 1.0);
418        let json = serde_json::to_string(&resp).unwrap();
419        assert!(json.contains("\"results\""), "must contain results field");
420        assert!(json.contains("\"query\""), "must contain query field");
421        assert!(json.contains("\"k\""), "must contain k field");
422        assert!(
423            json.contains("\"graph_matches\""),
424            "must contain graph_matches field"
425        );
426        assert!(
427            !json.contains("\"combined_rank\""),
428            "must not contain combined_rank"
429        );
430        assert!(
431            !json.contains("\"vec_rank_list\""),
432            "must not contain vec_rank_list"
433        );
434        assert!(
435            !json.contains("\"fts_rank_list\""),
436            "must not contain fts_rank_list"
437        );
438    }
439
440    #[test]
441    fn hybrid_search_response_serializes_rrf_k_and_weights() {
442        let resp = empty_response(5, 60, 0.7, 0.3);
443        let json = serde_json::to_string(&resp).unwrap();
444        assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
445        assert!(json.contains("\"weights\""), "must contain weights field");
446        assert!(json.contains("\"vec\""), "must contain weights.vec field");
447        assert!(json.contains("\"fts\""), "must contain weights.fts field");
448    }
449
450    #[test]
451    fn hybrid_search_response_serializes_elapsed_ms() {
452        let mut resp = empty_response(5, 60, 1.0, 1.0);
453        resp.elapsed_ms = 123;
454        let json = serde_json::to_string(&resp).unwrap();
455        assert!(
456            json.contains("\"elapsed_ms\""),
457            "must contain elapsed_ms field"
458        );
459        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
460    }
461
462    #[test]
463    fn weights_struct_serializes_correctly() {
464        let w = Weights { vec: 0.6, fts: 0.4 };
465        let json = serde_json::to_string(&w).unwrap();
466        assert!(json.contains("\"vec\""));
467        assert!(json.contains("\"fts\""));
468    }
469
470    #[test]
471    fn hybrid_search_item_omits_fts_rank_when_none() {
472        let item = HybridSearchItem {
473            memory_id: 1,
474            name: "mem".to_string(),
475            namespace: "default".to_string(),
476            memory_type: "user".to_string(),
477            description: "desc".to_string(),
478            body: "content".to_string(),
479            snippet: "content".to_string(),
480            combined_score: 0.0328,
481            score: 0.0328,
482            source: "hybrid".to_string(),
483            vec_rank: Some(1),
484            fts_rank: None,
485            rrf_score: Some(0.0328),
486            normalized_score: 1.0,
487            vec_distance: Some(0.12),
488            fts_bm25: None,
489        };
490        let json = serde_json::to_string(&item).unwrap();
491        assert!(
492            json.contains("\"vec_rank\""),
493            "must contain vec_rank when Some"
494        );
495        assert!(
496            !json.contains("\"fts_rank\""),
497            "must not contain fts_rank when None"
498        );
499    }
500
501    #[test]
502    fn hybrid_search_item_omits_vec_rank_when_none() {
503        let item = HybridSearchItem {
504            memory_id: 2,
505            name: "mem2".to_string(),
506            namespace: "default".to_string(),
507            memory_type: "fact".to_string(),
508            description: "desc2".to_string(),
509            body: "corpo2".to_string(),
510            snippet: "corpo2".to_string(),
511            combined_score: 0.016,
512            score: 0.016,
513            source: "hybrid".to_string(),
514            vec_rank: None,
515            fts_rank: Some(2),
516            rrf_score: Some(0.016),
517            normalized_score: 0.5,
518            vec_distance: None,
519            fts_bm25: None,
520        };
521        let json = serde_json::to_string(&item).unwrap();
522        assert!(
523            !json.contains("\"vec_rank\""),
524            "must not contain vec_rank when None"
525        );
526        assert!(
527            json.contains("\"fts_rank\""),
528            "must contain fts_rank when Some"
529        );
530    }
531
532    #[test]
533    fn hybrid_search_item_serializes_both_ranks_when_some() {
534        let item = HybridSearchItem {
535            memory_id: 3,
536            name: "mem3".to_string(),
537            namespace: "ns".to_string(),
538            memory_type: "entity".to_string(),
539            description: "desc3".to_string(),
540            body: "corpo3".to_string(),
541            snippet: "corpo3".to_string(),
542            combined_score: 0.05,
543            score: 0.05,
544            source: "hybrid".to_string(),
545            vec_rank: Some(3),
546            fts_rank: Some(1),
547            rrf_score: Some(0.05),
548            normalized_score: 0.8,
549            vec_distance: Some(0.25),
550            fts_bm25: None,
551        };
552        let json = serde_json::to_string(&item).unwrap();
553        assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
554        assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
555        assert!(json.contains("\"type\""), "deve serializar type renomeado");
556        assert!(!json.contains("memory_type"), "must not expose memory_type");
557    }
558
559    #[test]
560    fn hybrid_search_response_serializes_k_correctly() {
561        let resp = empty_response(5, 60, 1.0, 1.0);
562        let json = serde_json::to_string(&resp).unwrap();
563        assert!(json.contains("\"k\":5"), "deve serializar k=5");
564    }
565
566    #[test]
567    fn hybrid_search_response_with_graph_matches() {
568        use crate::output::RecallItem;
569        let resp = HybridSearchResponse {
570            query: "test".to_string(),
571            k: 5,
572            rrf_k: 60,
573            weights: Weights { vec: 1.0, fts: 1.0 },
574            results: vec![],
575            graph_matches: vec![RecallItem {
576                memory_id: 1,
577                name: "graph-hit".to_string(),
578                namespace: "global".to_string(),
579                memory_type: "document".to_string(),
580                description: "found via graph".to_string(),
581                snippet: "graph content".to_string(),
582                distance: 0.1,
583                score: 0.9,
584                source: "graph".to_string(),
585                graph_depth: Some(1),
586            }],
587            fts_degraded: false,
588            fts_error: None,
589            fts_auto_rebuilt: false,
590            elapsed_ms: 42,
591        };
592        let json = serde_json::to_value(&resp).unwrap();
593        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
594        assert_eq!(json["graph_matches"][0]["source"], "graph");
595        assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
596    }
597
598    #[test]
599    fn fts_degraded_omitted_on_success_present_on_failure() {
600        // Happy path: fts_degraded=false must be absent from JSON (skip_serializing_if).
601        let ok_resp = empty_response(5, 60, 1.0, 1.0);
602        let ok_json = serde_json::to_string(&ok_resp).unwrap();
603        assert!(
604            !ok_json.contains("\"fts_degraded\""),
605            "fts_degraded must be absent when false"
606        );
607        assert!(
608            !ok_json.contains("\"fts_error\""),
609            "fts_error must be absent when None"
610        );
611
612        // Degraded path: fts_degraded=true and fts_error=Some must appear in JSON.
613        let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
614        degraded_resp.fts_degraded = true;
615        degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
616        let degraded_json = serde_json::to_string(&degraded_resp).unwrap();
617        assert!(
618            degraded_json.contains("\"fts_degraded\":true"),
619            "fts_degraded must be present and true when degraded"
620        );
621        assert!(
622            degraded_json.contains("\"fts_error\""),
623            "fts_error must be present when Some"
624        );
625        assert!(
626            degraded_json.contains("FTS5 table corrupted"),
627            "fts_error must contain the error message"
628        );
629    }
630}