Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1//! Handler for the `hybrid-search` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14/// Arguments for the `hybrid-search` subcommand.
15///
16/// When `--namespace` is omitted the search runs against the `global` namespace,
17/// which is the default namespace used by `remember` when no `--namespace` flag
18/// is provided. Pass an explicit `--namespace` value to search a different
19/// isolated namespace.
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # Basic hybrid search combining FTS5 + vector via RRF\n  \
23    sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n  \
24    # Tune RRF weights to favor keyword matches over semantic similarity\n  \
25    sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n  \
26    # Add graph traversal matches (entities connected to top results)\n  \
27    sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n  \
28    # Graph traversal with custom depth and minimum edge weight\n  \
29    sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n  \
30NOTES:\n  \
31    --with-graph enables entity graph traversal seeded by the top RRF results.\n  \
32    Graph matches appear in the `graph_matches` array (separate from `results`).\n  \
33    Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35    #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
36    pub query: String,
37    /// Maximum number of fused results to return after RRF combines vector + FTS5 candidates.
38    ///
39    /// Validated to the inclusive range `1..=4096` (the upper bound matches `sqlite-vec`'s knn
40    /// limit). Each underlying search fetches `k * 2` candidates before fusion.
41    #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
42    pub k: usize,
43    #[arg(long, default_value = "60")]
44    pub rrf_k: u32,
45    #[arg(long, default_value = "1.0")]
46    pub weight_vec: f32,
47    #[arg(long, default_value = "1.0")]
48    pub weight_fts: f32,
49    /// Filter by memory.type. Note: distinct from graph entity_type
50    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
51    /// used in --entities-file.
52    #[arg(long, value_enum)]
53    pub r#type: Option<MemoryType>,
54    #[arg(long)]
55    pub namespace: Option<String>,
56    #[arg(long)]
57    pub with_graph: bool,
58    #[arg(long, default_value = "2")]
59    pub max_hops: u32,
60    #[arg(long, default_value = "0.3")]
61    pub min_weight: f64,
62    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
63    pub format: JsonOutputFormat,
64    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
65    pub db: Option<String>,
66    /// Accept `--json` as a no-op because output is already JSON by default.
67    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
68    pub json: bool,
69    #[command(flatten)]
70    pub daemon: crate::cli::DaemonOpts,
71}
72
73#[derive(serde::Serialize)]
74pub struct HybridSearchItem {
75    pub memory_id: i64,
76    pub name: String,
77    pub namespace: String,
78    #[serde(rename = "type")]
79    pub memory_type: String,
80    pub description: String,
81    pub body: String,
82    pub combined_score: f64,
83    /// Alias of `combined_score` for the documented contract in SKILL.md.
84    pub score: f64,
85    /// Source of the match: always "hybrid" (RRF of vec + fts). Added in v2.0.1.
86    pub source: String,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub vec_rank: Option<usize>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub fts_rank: Option<usize>,
91    /// Combined RRF score — explicit alias of `combined_score` for integration contracts.
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub rrf_score: Option<f64>,
94    /// RRF score normalized to [0.0, 1.0] for cross-method comparability.
95    pub normalized_score: f64,
96    /// Raw KNN distance from the vector index (lower = more similar).
97    ///
98    /// Present when the result came from the vector search path; `None` when the
99    /// result appeared only in the FTS5 results and was not ranked by the KNN index.
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub vec_distance: Option<f64>,
102    /// Raw BM25 score from the FTS5 index. Currently always `None`; reserved for
103    /// a future release when the FTS5 BM25 score is exposed by the storage layer.
104    #[serde(skip_serializing_if = "Option::is_none")]
105    pub fts_bm25: Option<f64>,
106}
107
108/// RRF weights used in hybrid search: vec (vector) and fts (text).
109#[derive(serde::Serialize)]
110pub struct Weights {
111    pub vec: f32,
112    pub fts: f32,
113}
114
115#[derive(serde::Serialize)]
116pub struct HybridSearchResponse {
117    pub query: String,
118    pub k: usize,
119    /// RRF k parameter used in the combined ranking.
120    pub rrf_k: u32,
121    /// Weights applied to vec and fts sources in the RRF fusion.
122    pub weights: Weights,
123    pub results: Vec<HybridSearchItem>,
124    pub graph_matches: Vec<RecallItem>,
125    /// True when FTS5 failed and the response is vec-only.
126    ///
127    /// Omitted from JSON when `false` to keep the happy-path envelope clean.
128    #[serde(skip_serializing_if = "std::ops::Not::not")]
129    pub fts_degraded: bool,
130    /// Human-readable description of the FTS5 failure when `fts_degraded` is true.
131    ///
132    /// Omitted from JSON when `None`.
133    #[serde(skip_serializing_if = "Option::is_none")]
134    pub fts_error: Option<String>,
135    /// True when the FTS5 index was corrupted and successfully auto-rebuilt during this request.
136    ///
137    /// Omitted from JSON when `false` to keep the happy-path envelope clean.
138    #[serde(skip_serializing_if = "std::ops::Not::not")]
139    pub fts_auto_rebuilt: bool,
140    /// Total execution time in milliseconds from handler start to serialisation.
141    pub elapsed_ms: u64,
142}
143
144pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
145    let start = std::time::Instant::now();
146    let _ = args.format;
147
148    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
149    let paths = AppPaths::resolve(args.db.as_deref())?;
150    crate::storage::connection::ensure_db_ready(&paths)?;
151
152    output::emit_progress_i18n(
153        "Computing query embedding...",
154        "Calculando embedding da consulta...",
155    );
156    let embedding = crate::daemon::embed_query_or_local(
157        &paths.models,
158        &args.query,
159        args.daemon.autostart_daemon,
160    )?;
161
162    let conn = open_ro(&paths.db)?;
163
164    let memory_type_str = args.r#type.map(|t| t.as_str());
165
166    let vec_results = memories::knn_search(
167        &conn,
168        &embedding,
169        &[namespace.clone()],
170        memory_type_str,
171        args.k * 2,
172    )?;
173
174    // Map vector ranking position by memory_id (1-indexed per schema)
175    let vec_rank_map: HashMap<i64, usize> = vec_results
176        .iter()
177        .enumerate()
178        .map(|(pos, (id, _))| (*id, pos + 1))
179        .collect();
180
181    // Map raw KNN distance by memory_id for GAP-30: vec_distance field.
182    let vec_distance_map: HashMap<i64, f64> = vec_results
183        .iter()
184        .map(|(id, dist)| (*id, *dist as f64))
185        .collect();
186
187    let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
188        (vec![], false, None, false)
189    } else {
190        match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
191            Ok(r) => (r, false, None, false),
192            Err(e) => {
193                let err_msg = e.to_string();
194                let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
195                if is_malformed {
196                    tracing::warn!("FTS5 index corrupted, attempting auto-rebuild");
197                    if conn
198                        .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
199                        .is_ok()
200                    {
201                        match memories::fts_search(
202                            &conn,
203                            &args.query,
204                            &namespace,
205                            memory_type_str,
206                            args.k * 2,
207                        ) {
208                            Ok(r) => (r, false, None, true),
209                            Err(e2) => {
210                                tracing::error!("FTS5 auto-rebuild failed to recover: {e2}");
211                                (vec![], true, Some(e2.to_string()), true)
212                            }
213                        }
214                    } else {
215                        (vec![], true, Some(err_msg), false)
216                    }
217                } else {
218                    tracing::warn!("FTS5 query failed, falling back to vec-only: {e}");
219                    (vec![], true, Some(err_msg), false)
220                }
221            }
222        }
223    };
224
225    // Map FTS ranking position by memory_id (1-indexed per schema)
226    let fts_rank_map: HashMap<i64, usize> = fts_results
227        .iter()
228        .enumerate()
229        .map(|(pos, row)| (row.id, pos + 1))
230        .collect();
231
232    let rrf_k = args.rrf_k as f64;
233
234    // Accumulate combined RRF scores
235    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
236
237    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
238        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
239        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
240    }
241
242    for (rank, row) in fts_results.iter().enumerate() {
243        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
244        *combined_scores.entry(row.id).or_insert(0.0) += score;
245    }
246
247    // Sort by score descending and take the top-k
248    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
249    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
250    ranked.truncate(args.k);
251
252    // Collect all IDs for batch fetch (avoiding N+1)
253    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
254
255    // Fetch full data for the top memories
256    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
257    for id in &top_ids {
258        if let Some(row) = memories::read_full(&conn, *id)? {
259            memory_data.insert(*id, row);
260        }
261    }
262
263    let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
264        + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
265
266    // Build final results in ranking order
267    let results: Vec<HybridSearchItem> = ranked
268        .into_iter()
269        .filter_map(|(memory_id, combined_score)| {
270            let normalized_score = if max_possible > 0.0 {
271                combined_score / max_possible
272            } else {
273                0.0
274            };
275            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
276                memory_id: row.id,
277                name: row.name,
278                namespace: row.namespace,
279                memory_type: row.memory_type,
280                description: row.description,
281                body: row.body,
282                combined_score,
283                score: combined_score,
284                source: "hybrid".to_string(),
285                vec_rank: vec_rank_map.get(&memory_id).copied(),
286                fts_rank: fts_rank_map.get(&memory_id).copied(),
287                rrf_score: Some(combined_score),
288                normalized_score,
289                vec_distance: vec_distance_map.get(&memory_id).copied(),
290                fts_bm25: None,
291            })
292        })
293        .collect();
294
295    // --- Graph traversal (activated by --with-graph) ---
296    let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
297    if args.with_graph && !results.is_empty() {
298        let namespace_for_graph = namespace.clone();
299        let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
300
301        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
302        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
303
304        let all_seed_ids: Vec<i64> = memory_ids
305            .iter()
306            .chain(entity_ids.iter())
307            .copied()
308            .collect();
309
310        if !all_seed_ids.is_empty() {
311            let graph_memory_ids = traverse_from_memories_with_hops(
312                &conn,
313                &all_seed_ids,
314                &namespace_for_graph,
315                args.min_weight,
316                args.max_hops,
317            )?;
318
319            let already_in_results: std::collections::HashSet<i64> =
320                results.iter().map(|r| r.memory_id).collect();
321
322            for (graph_mem_id, hop) in graph_memory_ids {
323                if already_in_results.contains(&graph_mem_id) {
324                    continue;
325                }
326                if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
327                    let snippet: String = row.body.chars().take(300).collect();
328                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
329                    graph_matches.push(RecallItem {
330                        memory_id: row.id,
331                        name: row.name,
332                        namespace: row.namespace,
333                        memory_type: row.memory_type,
334                        description: row.description,
335                        snippet,
336                        distance: graph_distance,
337                        score: RecallItem::score_from_distance(graph_distance),
338                        source: "graph".to_string(),
339                        graph_depth: Some(hop),
340                    });
341                }
342            }
343        }
344    }
345
346    output::emit_json(&HybridSearchResponse {
347        query: args.query,
348        k: args.k,
349        rrf_k: args.rrf_k,
350        weights: Weights {
351            vec: args.weight_vec,
352            fts: args.weight_fts,
353        },
354        results,
355        graph_matches,
356        fts_degraded,
357        fts_error,
358        fts_auto_rebuilt,
359        elapsed_ms: start.elapsed().as_millis() as u64,
360    })?;
361
362    Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    fn empty_response(
370        k: usize,
371        rrf_k: u32,
372        weight_vec: f32,
373        weight_fts: f32,
374    ) -> HybridSearchResponse {
375        HybridSearchResponse {
376            query: "test query".to_string(),
377            k,
378            rrf_k,
379            weights: Weights {
380                vec: weight_vec,
381                fts: weight_fts,
382            },
383            results: vec![],
384            graph_matches: vec![],
385            fts_degraded: false,
386            fts_error: None,
387            fts_auto_rebuilt: false,
388            elapsed_ms: 0,
389        }
390    }
391
392    #[test]
393    fn hybrid_search_response_empty_serializes_correct_fields() {
394        let resp = empty_response(10, 60, 1.0, 1.0);
395        let json = serde_json::to_string(&resp).unwrap();
396        assert!(json.contains("\"results\""), "must contain results field");
397        assert!(json.contains("\"query\""), "must contain query field");
398        assert!(json.contains("\"k\""), "must contain k field");
399        assert!(
400            json.contains("\"graph_matches\""),
401            "must contain graph_matches field"
402        );
403        assert!(
404            !json.contains("\"combined_rank\""),
405            "must not contain combined_rank"
406        );
407        assert!(
408            !json.contains("\"vec_rank_list\""),
409            "must not contain vec_rank_list"
410        );
411        assert!(
412            !json.contains("\"fts_rank_list\""),
413            "must not contain fts_rank_list"
414        );
415    }
416
417    #[test]
418    fn hybrid_search_response_serializes_rrf_k_and_weights() {
419        let resp = empty_response(5, 60, 0.7, 0.3);
420        let json = serde_json::to_string(&resp).unwrap();
421        assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
422        assert!(json.contains("\"weights\""), "must contain weights field");
423        assert!(json.contains("\"vec\""), "must contain weights.vec field");
424        assert!(json.contains("\"fts\""), "must contain weights.fts field");
425    }
426
427    #[test]
428    fn hybrid_search_response_serializes_elapsed_ms() {
429        let mut resp = empty_response(5, 60, 1.0, 1.0);
430        resp.elapsed_ms = 123;
431        let json = serde_json::to_string(&resp).unwrap();
432        assert!(
433            json.contains("\"elapsed_ms\""),
434            "must contain elapsed_ms field"
435        );
436        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
437    }
438
439    #[test]
440    fn weights_struct_serializes_correctly() {
441        let w = Weights { vec: 0.6, fts: 0.4 };
442        let json = serde_json::to_string(&w).unwrap();
443        assert!(json.contains("\"vec\""));
444        assert!(json.contains("\"fts\""));
445    }
446
447    #[test]
448    fn hybrid_search_item_omits_fts_rank_when_none() {
449        let item = HybridSearchItem {
450            memory_id: 1,
451            name: "mem".to_string(),
452            namespace: "default".to_string(),
453            memory_type: "user".to_string(),
454            description: "desc".to_string(),
455            body: "content".to_string(),
456            combined_score: 0.0328,
457            score: 0.0328,
458            source: "hybrid".to_string(),
459            vec_rank: Some(1),
460            fts_rank: None,
461            rrf_score: Some(0.0328),
462            normalized_score: 1.0,
463            vec_distance: Some(0.12),
464            fts_bm25: None,
465        };
466        let json = serde_json::to_string(&item).unwrap();
467        assert!(
468            json.contains("\"vec_rank\""),
469            "must contain vec_rank when Some"
470        );
471        assert!(
472            !json.contains("\"fts_rank\""),
473            "must not contain fts_rank when None"
474        );
475    }
476
477    #[test]
478    fn hybrid_search_item_omits_vec_rank_when_none() {
479        let item = HybridSearchItem {
480            memory_id: 2,
481            name: "mem2".to_string(),
482            namespace: "default".to_string(),
483            memory_type: "fact".to_string(),
484            description: "desc2".to_string(),
485            body: "corpo2".to_string(),
486            combined_score: 0.016,
487            score: 0.016,
488            source: "hybrid".to_string(),
489            vec_rank: None,
490            fts_rank: Some(2),
491            rrf_score: Some(0.016),
492            normalized_score: 0.5,
493            vec_distance: None,
494            fts_bm25: None,
495        };
496        let json = serde_json::to_string(&item).unwrap();
497        assert!(
498            !json.contains("\"vec_rank\""),
499            "must not contain vec_rank when None"
500        );
501        assert!(
502            json.contains("\"fts_rank\""),
503            "must contain fts_rank when Some"
504        );
505    }
506
507    #[test]
508    fn hybrid_search_item_serializes_both_ranks_when_some() {
509        let item = HybridSearchItem {
510            memory_id: 3,
511            name: "mem3".to_string(),
512            namespace: "ns".to_string(),
513            memory_type: "entity".to_string(),
514            description: "desc3".to_string(),
515            body: "corpo3".to_string(),
516            combined_score: 0.05,
517            score: 0.05,
518            source: "hybrid".to_string(),
519            vec_rank: Some(3),
520            fts_rank: Some(1),
521            rrf_score: Some(0.05),
522            normalized_score: 0.8,
523            vec_distance: Some(0.25),
524            fts_bm25: None,
525        };
526        let json = serde_json::to_string(&item).unwrap();
527        assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
528        assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
529        assert!(json.contains("\"type\""), "deve serializar type renomeado");
530        assert!(!json.contains("memory_type"), "must not expose memory_type");
531    }
532
533    #[test]
534    fn hybrid_search_response_serializes_k_correctly() {
535        let resp = empty_response(5, 60, 1.0, 1.0);
536        let json = serde_json::to_string(&resp).unwrap();
537        assert!(json.contains("\"k\":5"), "deve serializar k=5");
538    }
539
540    #[test]
541    fn hybrid_search_response_with_graph_matches() {
542        use crate::output::RecallItem;
543        let resp = HybridSearchResponse {
544            query: "test".to_string(),
545            k: 5,
546            rrf_k: 60,
547            weights: Weights { vec: 1.0, fts: 1.0 },
548            results: vec![],
549            graph_matches: vec![RecallItem {
550                memory_id: 1,
551                name: "graph-hit".to_string(),
552                namespace: "global".to_string(),
553                memory_type: "document".to_string(),
554                description: "found via graph".to_string(),
555                snippet: "graph content".to_string(),
556                distance: 0.1,
557                score: 0.9,
558                source: "graph".to_string(),
559                graph_depth: Some(1),
560            }],
561            fts_degraded: false,
562            fts_error: None,
563            fts_auto_rebuilt: false,
564            elapsed_ms: 42,
565        };
566        let json = serde_json::to_value(&resp).unwrap();
567        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
568        assert_eq!(json["graph_matches"][0]["source"], "graph");
569        assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
570    }
571
572    #[test]
573    fn fts_degraded_omitted_on_success_present_on_failure() {
574        // Happy path: fts_degraded=false must be absent from JSON (skip_serializing_if).
575        let ok_resp = empty_response(5, 60, 1.0, 1.0);
576        let ok_json = serde_json::to_string(&ok_resp).unwrap();
577        assert!(
578            !ok_json.contains("\"fts_degraded\""),
579            "fts_degraded must be absent when false"
580        );
581        assert!(
582            !ok_json.contains("\"fts_error\""),
583            "fts_error must be absent when None"
584        );
585
586        // Degraded path: fts_degraded=true and fts_error=Some must appear in JSON.
587        let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
588        degraded_resp.fts_degraded = true;
589        degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
590        let degraded_json = serde_json::to_string(&degraded_resp).unwrap();
591        assert!(
592            degraded_json.contains("\"fts_degraded\":true"),
593            "fts_degraded must be present and true when degraded"
594        );
595        assert!(
596            degraded_json.contains("\"fts_error\""),
597            "fts_error must be present when Some"
598        );
599        assert!(
600            degraded_json.contains("FTS5 table corrupted"),
601            "fts_error must contain the error message"
602        );
603    }
604}