Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1//! Handler for the `hybrid-search` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14/// Arguments for the `hybrid-search` subcommand.
15///
16/// When `--namespace` is omitted the search runs against the `global` namespace,
17/// which is the default namespace used by `remember` when no `--namespace` flag
18/// is provided. Pass an explicit `--namespace` value to search a different
19/// isolated namespace.
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # Basic hybrid search combining FTS5 + vector via RRF\n  \
23    sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n  \
24    # Tune RRF weights to favor keyword matches over semantic similarity\n  \
25    sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n  \
26    # Add graph traversal matches (entities connected to top results)\n  \
27    sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n  \
28    # Graph traversal with custom depth and minimum edge weight\n  \
29    sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n  \
30NOTES:\n  \
31    --with-graph enables entity graph traversal seeded by the top RRF results.\n  \
32    Graph matches appear in the `graph_matches` array (separate from `results`).\n  \
33    Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35    #[arg(
36        allow_hyphen_values = true,
37        help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)"
38    )]
39    pub query: String,
40    /// Maximum number of fused results to return after RRF combines vector + FTS5 candidates.
41    ///
42    /// Validated to the inclusive range `1..=4096` (the upper bound matches `sqlite-vec`'s knn
43    /// limit). Each underlying search fetches `k * 2` candidates before fusion.
44    #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
45    pub k: usize,
46    #[arg(long, default_value = "60")]
47    pub rrf_k: u32,
48    #[arg(long, default_value = "1.0")]
49    pub weight_vec: f32,
50    #[arg(long, default_value = "1.0")]
51    pub weight_fts: f32,
52    /// Filter by memory.type. Note: distinct from graph entity_type
53    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
54    /// used in --entities-file.
55    #[arg(long, value_enum)]
56    pub r#type: Option<MemoryType>,
57    #[arg(long)]
58    pub namespace: Option<String>,
59    #[arg(long)]
60    pub with_graph: bool,
61    #[arg(long, default_value = "2")]
62    pub max_hops: u32,
63    #[arg(long, default_value = "0.3")]
64    pub min_weight: f64,
65    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66    pub format: JsonOutputFormat,
67    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68    pub db: Option<String>,
69    /// Accept `--json` as a no-op because output is already JSON by default.
70    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71    pub json: bool,
72    #[command(flatten)]
73    pub daemon: crate::cli::DaemonOpts,
74}
75
76#[derive(serde::Serialize)]
77pub struct HybridSearchItem {
78    pub memory_id: i64,
79    pub name: String,
80    pub namespace: String,
81    #[serde(rename = "type")]
82    pub memory_type: String,
83    pub description: String,
84    pub body: String,
85    pub snippet: String,
86    pub combined_score: f64,
87    /// Alias of `combined_score` for the documented contract in SKILL.md.
88    pub score: f64,
89    /// Source of the match: always "hybrid" (RRF of vec + fts). Added in v2.0.1.
90    pub source: String,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub vec_rank: Option<usize>,
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub fts_rank: Option<usize>,
95    /// Combined RRF score — explicit alias of `combined_score` for integration contracts.
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub rrf_score: Option<f64>,
98    /// RRF score normalized to [0.0, 1.0] for cross-method comparability.
99    pub normalized_score: f64,
100    /// Raw KNN distance from the vector index (lower = more similar).
101    ///
102    /// Present when the result came from the vector search path; `None` when the
103    /// result appeared only in the FTS5 results and was not ranked by the KNN index.
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub vec_distance: Option<f64>,
106    /// Raw BM25 score from the FTS5 index. Currently always `None`; reserved for
107    /// a future release when the FTS5 BM25 score is exposed by the storage layer.
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub fts_bm25: Option<f64>,
110}
111
112/// RRF weights used in hybrid search: vec (vector) and fts (text).
113#[derive(serde::Serialize)]
114pub struct Weights {
115    pub vec: f32,
116    pub fts: f32,
117}
118
119#[derive(serde::Serialize)]
120pub struct HybridSearchResponse {
121    pub query: String,
122    pub k: usize,
123    /// RRF k parameter used in the combined ranking.
124    pub rrf_k: u32,
125    /// Weights applied to vec and fts sources in the RRF fusion.
126    pub weights: Weights,
127    pub results: Vec<HybridSearchItem>,
128    pub graph_matches: Vec<RecallItem>,
129    /// True when FTS5 failed and the response is vec-only.
130    ///
131    /// Omitted from JSON when `false` to keep the happy-path envelope clean.
132    #[serde(skip_serializing_if = "std::ops::Not::not")]
133    pub fts_degraded: bool,
134    /// Human-readable description of the FTS5 failure when `fts_degraded` is true.
135    ///
136    /// Omitted from JSON when `None`.
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub fts_error: Option<String>,
139    /// True when the FTS5 index was corrupted and successfully auto-rebuilt during this request.
140    ///
141    /// Omitted from JSON when `false` to keep the happy-path envelope clean.
142    #[serde(skip_serializing_if = "std::ops::Not::not")]
143    pub fts_auto_rebuilt: bool,
144    /// Total execution time in milliseconds from handler start to serialisation.
145    pub elapsed_ms: u64,
146}
147
148#[tracing::instrument(skip_all, level = "debug", name = "hybrid_search")]
149pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
150    let start = std::time::Instant::now();
151    let _ = args.format;
152    tracing::debug!(target: "hybrid_search", query = %args.query, k = args.k, "fusing results");
153
154    // G20: reject graph-specific flags when --with-graph is not active
155    if !args.with_graph {
156        if args.max_hops != 2 {
157            return Err(AppError::Validation(
158                "--max-hops requires --with-graph to be active".to_string(),
159            ));
160        }
161        if (args.min_weight - 0.3).abs() > f64::EPSILON {
162            return Err(AppError::Validation(
163                "--min-weight requires --with-graph to be active".to_string(),
164            ));
165        }
166    }
167
168    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
169    let paths = AppPaths::resolve(args.db.as_deref())?;
170    crate::storage::connection::ensure_db_ready(&paths)?;
171
172    output::emit_progress_i18n(
173        "Computing query embedding...",
174        "Calculando embedding da consulta...",
175    );
176    let embedding = crate::daemon::embed_query_or_local(
177        &paths.models,
178        &args.query,
179        args.daemon.autostart_daemon,
180    )?;
181
182    let conn = open_ro(&paths.db)?;
183
184    let memory_type_str = args.r#type.map(|t| t.as_str());
185
186    let vec_results = memories::knn_search(
187        &conn,
188        &embedding,
189        &[namespace.clone()],
190        memory_type_str,
191        args.k * 2,
192    )?;
193
194    // Map vector ranking position by memory_id (1-indexed per schema)
195    let vec_rank_map: HashMap<i64, usize> = vec_results
196        .iter()
197        .enumerate()
198        .map(|(pos, (id, _))| (*id, pos + 1))
199        .collect();
200
201    // Map raw KNN distance by memory_id for GAP-30: vec_distance field.
202    let vec_distance_map: HashMap<i64, f64> = vec_results
203        .iter()
204        .map(|(id, dist)| (*id, *dist as f64))
205        .collect();
206
207    let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
208        (vec![], false, None, false)
209    } else {
210        match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
211            Ok(r) => (r, false, None, false),
212            Err(e) => {
213                let err_msg = e.to_string();
214                let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
215                if is_malformed {
216                    tracing::warn!(target: "hybrid_search", "FTS5 index corrupted, attempting auto-rebuild");
217                    if conn
218                        .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
219                        .is_ok()
220                    {
221                        match memories::fts_search(
222                            &conn,
223                            &args.query,
224                            &namespace,
225                            memory_type_str,
226                            args.k * 2,
227                        ) {
228                            Ok(r) => (r, false, None, true),
229                            Err(e2) => {
230                                tracing::error!(target: "hybrid_search", error = %e2, "FTS5 auto-rebuild failed to recover");
231                                (vec![], true, Some(e2.to_string()), true)
232                            }
233                        }
234                    } else {
235                        (vec![], true, Some(err_msg), false)
236                    }
237                } else {
238                    tracing::warn!(target: "hybrid_search", error = %e, "FTS5 query failed, falling back to vec-only");
239                    (vec![], true, Some(err_msg), false)
240                }
241            }
242        }
243    };
244
245    // Map FTS ranking position by memory_id (1-indexed per schema)
246    let fts_rank_map: HashMap<i64, usize> = fts_results
247        .iter()
248        .enumerate()
249        .map(|(pos, row)| (row.id, pos + 1))
250        .collect();
251
252    let rrf_k = args.rrf_k as f64;
253
254    // Accumulate combined RRF scores
255    let mut combined_scores: crate::hash::AHashMap<i64, f64> =
256        crate::hash::AHashMap::with_capacity_and_hasher(
257            vec_results.len() + fts_results.len(),
258            Default::default(),
259        );
260
261    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
262        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
263        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
264    }
265
266    for (rank, row) in fts_results.iter().enumerate() {
267        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
268        *combined_scores.entry(row.id).or_insert(0.0) += score;
269    }
270
271    // Sort by score descending and take the top-k
272    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
273    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
274    ranked.truncate(args.k);
275
276    // Collect all IDs for batch fetch (avoiding N+1)
277    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
278
279    // Fetch full data for the top memories
280    let mut memory_data: crate::hash::AHashMap<i64, memories::MemoryRow> =
281        crate::hash::AHashMap::with_capacity_and_hasher(ranked.len(), Default::default());
282    for id in &top_ids {
283        if let Some(row) = memories::read_full(&conn, *id)? {
284            memory_data.insert(*id, row);
285        }
286    }
287
288    let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
289        + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
290
291    // Build final results in ranking order
292    let results: Vec<HybridSearchItem> = ranked
293        .into_iter()
294        .filter_map(|(memory_id, combined_score)| {
295            let normalized_score = if max_possible > 0.0 {
296                combined_score / max_possible
297            } else {
298                0.0
299            };
300            memory_data.remove(&memory_id).map(|row| {
301                let snippet: String = row.body.chars().take(300).collect();
302                HybridSearchItem {
303                    memory_id: row.id,
304                    name: row.name,
305                    namespace: row.namespace,
306                    memory_type: row.memory_type,
307                    description: row.description,
308                    body: row.body,
309                    snippet,
310                    combined_score,
311                    score: combined_score,
312                    source: "hybrid".to_string(),
313                    vec_rank: vec_rank_map.get(&memory_id).copied(),
314                    fts_rank: fts_rank_map.get(&memory_id).copied(),
315                    rrf_score: Some(combined_score),
316                    normalized_score,
317                    vec_distance: vec_distance_map.get(&memory_id).copied(),
318                    fts_bm25: None,
319                }
320            })
321        })
322        .collect();
323
324    // --- Graph traversal (activated by --with-graph) ---
325    let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
326    if args.with_graph && !results.is_empty() {
327        let namespace_for_graph = namespace.clone();
328        let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
329
330        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
331        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
332
333        let all_seed_ids: Vec<i64> = memory_ids
334            .iter()
335            .chain(entity_ids.iter())
336            .copied()
337            .collect();
338
339        if !all_seed_ids.is_empty() {
340            let graph_memory_ids = traverse_from_memories_with_hops(
341                &conn,
342                &all_seed_ids,
343                &namespace_for_graph,
344                args.min_weight,
345                args.max_hops,
346            )?;
347
348            let already_in_results: std::collections::HashSet<i64> =
349                results.iter().map(|r| r.memory_id).collect();
350
351            for (graph_mem_id, hop) in graph_memory_ids {
352                if already_in_results.contains(&graph_mem_id) {
353                    continue;
354                }
355                if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
356                    let snippet: String = row.body.chars().take(300).collect();
357                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
358                    graph_matches.push(RecallItem {
359                        memory_id: row.id,
360                        name: row.name,
361                        namespace: row.namespace,
362                        memory_type: row.memory_type,
363                        description: row.description,
364                        snippet,
365                        distance: graph_distance,
366                        score: RecallItem::score_from_distance(graph_distance),
367                        source: "graph".to_string(),
368                        graph_depth: Some(hop),
369                    });
370                }
371            }
372        }
373    }
374
375    output::emit_json(&HybridSearchResponse {
376        query: args.query,
377        k: args.k,
378        rrf_k: args.rrf_k,
379        weights: Weights {
380            vec: args.weight_vec,
381            fts: args.weight_fts,
382        },
383        results,
384        graph_matches,
385        fts_degraded,
386        fts_error,
387        fts_auto_rebuilt,
388        elapsed_ms: start.elapsed().as_millis() as u64,
389    })?;
390
391    Ok(())
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    fn empty_response(
399        k: usize,
400        rrf_k: u32,
401        weight_vec: f32,
402        weight_fts: f32,
403    ) -> HybridSearchResponse {
404        HybridSearchResponse {
405            query: "test query".to_string(),
406            k,
407            rrf_k,
408            weights: Weights {
409                vec: weight_vec,
410                fts: weight_fts,
411            },
412            results: vec![],
413            graph_matches: vec![],
414            fts_degraded: false,
415            fts_error: None,
416            fts_auto_rebuilt: false,
417            elapsed_ms: 0,
418        }
419    }
420
421    #[test]
422    fn hybrid_search_response_empty_serializes_correct_fields() {
423        let resp = empty_response(10, 60, 1.0, 1.0);
424        let json = serde_json::to_string(&resp).unwrap();
425        assert!(json.contains("\"results\""), "must contain results field");
426        assert!(json.contains("\"query\""), "must contain query field");
427        assert!(json.contains("\"k\""), "must contain k field");
428        assert!(
429            json.contains("\"graph_matches\""),
430            "must contain graph_matches field"
431        );
432        assert!(
433            !json.contains("\"combined_rank\""),
434            "must not contain combined_rank"
435        );
436        assert!(
437            !json.contains("\"vec_rank_list\""),
438            "must not contain vec_rank_list"
439        );
440        assert!(
441            !json.contains("\"fts_rank_list\""),
442            "must not contain fts_rank_list"
443        );
444    }
445
446    #[test]
447    fn hybrid_search_response_serializes_rrf_k_and_weights() {
448        let resp = empty_response(5, 60, 0.7, 0.3);
449        let json = serde_json::to_string(&resp).unwrap();
450        assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
451        assert!(json.contains("\"weights\""), "must contain weights field");
452        assert!(json.contains("\"vec\""), "must contain weights.vec field");
453        assert!(json.contains("\"fts\""), "must contain weights.fts field");
454    }
455
456    #[test]
457    fn hybrid_search_response_serializes_elapsed_ms() {
458        let mut resp = empty_response(5, 60, 1.0, 1.0);
459        resp.elapsed_ms = 123;
460        let json = serde_json::to_string(&resp).unwrap();
461        assert!(
462            json.contains("\"elapsed_ms\""),
463            "must contain elapsed_ms field"
464        );
465        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
466    }
467
468    #[test]
469    fn weights_struct_serializes_correctly() {
470        let w = Weights { vec: 0.6, fts: 0.4 };
471        let json = serde_json::to_string(&w).unwrap();
472        assert!(json.contains("\"vec\""));
473        assert!(json.contains("\"fts\""));
474    }
475
476    #[test]
477    fn hybrid_search_item_omits_fts_rank_when_none() {
478        let item = HybridSearchItem {
479            memory_id: 1,
480            name: "mem".to_string(),
481            namespace: "default".to_string(),
482            memory_type: "user".to_string(),
483            description: "desc".to_string(),
484            body: "content".to_string(),
485            snippet: "content".to_string(),
486            combined_score: 0.0328,
487            score: 0.0328,
488            source: "hybrid".to_string(),
489            vec_rank: Some(1),
490            fts_rank: None,
491            rrf_score: Some(0.0328),
492            normalized_score: 1.0,
493            vec_distance: Some(0.12),
494            fts_bm25: None,
495        };
496        let json = serde_json::to_string(&item).unwrap();
497        assert!(
498            json.contains("\"vec_rank\""),
499            "must contain vec_rank when Some"
500        );
501        assert!(
502            !json.contains("\"fts_rank\""),
503            "must not contain fts_rank when None"
504        );
505    }
506
507    #[test]
508    fn hybrid_search_item_omits_vec_rank_when_none() {
509        let item = HybridSearchItem {
510            memory_id: 2,
511            name: "mem2".to_string(),
512            namespace: "default".to_string(),
513            memory_type: "fact".to_string(),
514            description: "desc2".to_string(),
515            body: "corpo2".to_string(),
516            snippet: "corpo2".to_string(),
517            combined_score: 0.016,
518            score: 0.016,
519            source: "hybrid".to_string(),
520            vec_rank: None,
521            fts_rank: Some(2),
522            rrf_score: Some(0.016),
523            normalized_score: 0.5,
524            vec_distance: None,
525            fts_bm25: None,
526        };
527        let json = serde_json::to_string(&item).unwrap();
528        assert!(
529            !json.contains("\"vec_rank\""),
530            "must not contain vec_rank when None"
531        );
532        assert!(
533            json.contains("\"fts_rank\""),
534            "must contain fts_rank when Some"
535        );
536    }
537
538    #[test]
539    fn hybrid_search_item_serializes_both_ranks_when_some() {
540        let item = HybridSearchItem {
541            memory_id: 3,
542            name: "mem3".to_string(),
543            namespace: "ns".to_string(),
544            memory_type: "entity".to_string(),
545            description: "desc3".to_string(),
546            body: "corpo3".to_string(),
547            snippet: "corpo3".to_string(),
548            combined_score: 0.05,
549            score: 0.05,
550            source: "hybrid".to_string(),
551            vec_rank: Some(3),
552            fts_rank: Some(1),
553            rrf_score: Some(0.05),
554            normalized_score: 0.8,
555            vec_distance: Some(0.25),
556            fts_bm25: None,
557        };
558        let json = serde_json::to_string(&item).unwrap();
559        assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
560        assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
561        assert!(json.contains("\"type\""), "deve serializar type renomeado");
562        assert!(!json.contains("memory_type"), "must not expose memory_type");
563    }
564
565    #[test]
566    fn hybrid_search_response_serializes_k_correctly() {
567        let resp = empty_response(5, 60, 1.0, 1.0);
568        let json = serde_json::to_string(&resp).unwrap();
569        assert!(json.contains("\"k\":5"), "deve serializar k=5");
570    }
571
572    #[test]
573    fn hybrid_search_response_with_graph_matches() {
574        use crate::output::RecallItem;
575        let resp = HybridSearchResponse {
576            query: "test".to_string(),
577            k: 5,
578            rrf_k: 60,
579            weights: Weights { vec: 1.0, fts: 1.0 },
580            results: vec![],
581            graph_matches: vec![RecallItem {
582                memory_id: 1,
583                name: "graph-hit".to_string(),
584                namespace: "global".to_string(),
585                memory_type: "document".to_string(),
586                description: "found via graph".to_string(),
587                snippet: "graph content".to_string(),
588                distance: 0.1,
589                score: 0.9,
590                source: "graph".to_string(),
591                graph_depth: Some(1),
592            }],
593            fts_degraded: false,
594            fts_error: None,
595            fts_auto_rebuilt: false,
596            elapsed_ms: 42,
597        };
598        let json = serde_json::to_value(&resp).unwrap();
599        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
600        assert_eq!(json["graph_matches"][0]["source"], "graph");
601        assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
602    }
603
604    #[test]
605    fn fts_degraded_omitted_on_success_present_on_failure() {
606        // Happy path: fts_degraded=false must be absent from JSON (skip_serializing_if).
607        let ok_resp = empty_response(5, 60, 1.0, 1.0);
608        let ok_json = serde_json::to_string(&ok_resp).unwrap();
609        assert!(
610            !ok_json.contains("\"fts_degraded\""),
611            "fts_degraded must be absent when false"
612        );
613        assert!(
614            !ok_json.contains("\"fts_error\""),
615            "fts_error must be absent when None"
616        );
617
618        // Degraded path: fts_degraded=true and fts_error=Some must appear in JSON.
619        let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
620        degraded_resp.fts_degraded = true;
621        degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
622        let degraded_json = serde_json::to_string(&degraded_resp).unwrap();
623        assert!(
624            degraded_json.contains("\"fts_degraded\":true"),
625            "fts_degraded must be present and true when degraded"
626        );
627        assert!(
628            degraded_json.contains("\"fts_error\""),
629            "fts_error must be present when Some"
630        );
631        assert!(
632            degraded_json.contains("FTS5 table corrupted"),
633            "fts_error must contain the error message"
634        );
635    }
636}