Skip to main content

sqlite_graphrag/commands/
hybrid_search.rs

1//! Handler for the `hybrid-search` CLI subcommand.
2
3use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14/// Arguments for the `hybrid-search` subcommand.
15///
16/// When `--namespace` is omitted the search runs against the `global` namespace,
17/// which is the default namespace used by `remember` when no `--namespace` flag
18/// is provided. Pass an explicit `--namespace` value to search a different
19/// isolated namespace.
20#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n  \
22    # Basic hybrid search combining FTS5 + vector via RRF\n  \
23    sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n  \
24    # Tune RRF weights to favor keyword matches over semantic similarity\n  \
25    sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n  \
26    # Add graph traversal matches (entities connected to top results)\n  \
27    sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n  \
28    # Graph traversal with custom depth and minimum edge weight\n  \
29    sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n  \
30NOTES:\n  \
31    --with-graph enables entity graph traversal seeded by the top RRF results.\n  \
32    Graph matches appear in the `graph_matches` array (separate from `results`).\n  \
33    Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35    #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
36    pub query: String,
37    /// Maximum number of fused results to return after RRF combines vector + FTS5 candidates.
38    ///
39    /// Validated to the inclusive range `1..=4096` (the upper bound matches `sqlite-vec`'s knn
40    /// limit). Each underlying search fetches `k * 2` candidates before fusion.
41    #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
42    pub k: usize,
43    #[arg(long, default_value = "60")]
44    pub rrf_k: u32,
45    #[arg(long, default_value = "1.0")]
46    pub weight_vec: f32,
47    #[arg(long, default_value = "1.0")]
48    pub weight_fts: f32,
49    /// Filter by memory.type. Note: distinct from graph entity_type
50    /// (project/tool/person/file/concept/incident/decision/memory/dashboard/issue_tracker/organization/location/date)
51    /// used in --entities-file.
52    #[arg(long, value_enum)]
53    pub r#type: Option<MemoryType>,
54    #[arg(long)]
55    pub namespace: Option<String>,
56    #[arg(long)]
57    pub with_graph: bool,
58    #[arg(long, default_value = "2")]
59    pub max_hops: u32,
60    #[arg(long, default_value = "0.3")]
61    pub min_weight: f64,
62    #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
63    pub format: JsonOutputFormat,
64    #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
65    pub db: Option<String>,
66    /// Accept `--json` as a no-op because output is already JSON by default.
67    #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
68    pub json: bool,
69    #[command(flatten)]
70    pub daemon: crate::cli::DaemonOpts,
71}
72
73#[derive(serde::Serialize)]
74pub struct HybridSearchItem {
75    pub memory_id: i64,
76    pub name: String,
77    pub namespace: String,
78    #[serde(rename = "type")]
79    pub memory_type: String,
80    pub description: String,
81    pub body: String,
82    pub combined_score: f64,
83    /// Alias of `combined_score` for the documented contract in SKILL.md.
84    pub score: f64,
85    /// Source of the match: always "hybrid" (RRF of vec + fts). Added in v2.0.1.
86    pub source: String,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub vec_rank: Option<usize>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub fts_rank: Option<usize>,
91    /// Combined RRF score — explicit alias of `combined_score` for integration contracts.
92    #[serde(skip_serializing_if = "Option::is_none")]
93    pub rrf_score: Option<f64>,
94}
95
96/// RRF weights used in hybrid search: vec (vector) and fts (text).
97#[derive(serde::Serialize)]
98pub struct Weights {
99    pub vec: f32,
100    pub fts: f32,
101}
102
103#[derive(serde::Serialize)]
104pub struct HybridSearchResponse {
105    pub query: String,
106    pub k: usize,
107    /// RRF k parameter used in the combined ranking.
108    pub rrf_k: u32,
109    /// Pesos aplicados às fontes vec e fts no RRF.
110    pub weights: Weights,
111    pub results: Vec<HybridSearchItem>,
112    pub graph_matches: Vec<RecallItem>,
113    /// Total execution time in milliseconds from handler start to serialisation.
114    pub elapsed_ms: u64,
115}
116
117pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
118    let start = std::time::Instant::now();
119    let _ = args.format;
120
121    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
122    let paths = AppPaths::resolve(args.db.as_deref())?;
123    crate::storage::connection::ensure_db_ready(&paths)?;
124
125    output::emit_progress_i18n(
126        "Computing query embedding...",
127        "Calculando embedding da consulta...",
128    );
129    let embedding = crate::daemon::embed_query_or_local(
130        &paths.models,
131        &args.query,
132        args.daemon.autostart_daemon,
133    )?;
134
135    let conn = open_ro(&paths.db)?;
136
137    let memory_type_str = args.r#type.map(|t| t.as_str());
138
139    let vec_results = memories::knn_search(
140        &conn,
141        &embedding,
142        &[namespace.clone()],
143        memory_type_str,
144        args.k * 2,
145    )?;
146
147    // Map vector ranking position by memory_id (1-indexed per schema)
148    let vec_rank_map: HashMap<i64, usize> = vec_results
149        .iter()
150        .enumerate()
151        .map(|(pos, (id, _))| (*id, pos + 1))
152        .collect();
153
154    let fts_results =
155        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
156
157    // Map FTS ranking position by memory_id (1-indexed per schema)
158    let fts_rank_map: HashMap<i64, usize> = fts_results
159        .iter()
160        .enumerate()
161        .map(|(pos, row)| (row.id, pos + 1))
162        .collect();
163
164    let rrf_k = args.rrf_k as f64;
165
166    // Accumulate combined RRF scores
167    let mut combined_scores: HashMap<i64, f64> = HashMap::new();
168
169    for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
170        let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
171        *combined_scores.entry(*memory_id).or_insert(0.0) += score;
172    }
173
174    for (rank, row) in fts_results.iter().enumerate() {
175        let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
176        *combined_scores.entry(row.id).or_insert(0.0) += score;
177    }
178
179    // Sort by score descending and take the top-k
180    let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
181    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
182    ranked.truncate(args.k);
183
184    // Collect all IDs for batch fetch (avoiding N+1)
185    let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
186
187    // Fetch full data for the top memories
188    let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
189    for id in &top_ids {
190        if let Some(row) = memories::read_full(&conn, *id)? {
191            memory_data.insert(*id, row);
192        }
193    }
194
195    // Construir resultados finais na ordem de ranking
196    let results: Vec<HybridSearchItem> = ranked
197        .into_iter()
198        .filter_map(|(memory_id, combined_score)| {
199            memory_data.remove(&memory_id).map(|row| HybridSearchItem {
200                memory_id: row.id,
201                name: row.name,
202                namespace: row.namespace,
203                memory_type: row.memory_type,
204                description: row.description,
205                body: row.body,
206                combined_score,
207                score: combined_score,
208                source: "hybrid".to_string(),
209                vec_rank: vec_rank_map.get(&memory_id).copied(),
210                fts_rank: fts_rank_map.get(&memory_id).copied(),
211                rrf_score: Some(combined_score),
212            })
213        })
214        .collect();
215
216    // --- Graph traversal (activated by --with-graph) ---
217    let mut graph_matches: Vec<RecallItem> = Vec::new();
218    if args.with_graph && !results.is_empty() {
219        let namespace_for_graph = namespace.clone();
220        let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
221
222        let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
223        let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
224
225        let all_seed_ids: Vec<i64> = memory_ids
226            .iter()
227            .chain(entity_ids.iter())
228            .copied()
229            .collect();
230
231        if !all_seed_ids.is_empty() {
232            let graph_memory_ids = traverse_from_memories_with_hops(
233                &conn,
234                &all_seed_ids,
235                &namespace_for_graph,
236                args.min_weight,
237                args.max_hops,
238            )?;
239
240            let already_in_results: std::collections::HashSet<i64> =
241                results.iter().map(|r| r.memory_id).collect();
242
243            for (graph_mem_id, hop) in graph_memory_ids {
244                if already_in_results.contains(&graph_mem_id) {
245                    continue;
246                }
247                if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
248                    let snippet: String = row.body.chars().take(300).collect();
249                    let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
250                    graph_matches.push(RecallItem {
251                        memory_id: row.id,
252                        name: row.name,
253                        namespace: row.namespace,
254                        memory_type: row.memory_type,
255                        description: row.description,
256                        snippet,
257                        distance: graph_distance,
258                        score: RecallItem::score_from_distance(graph_distance),
259                        source: "graph".to_string(),
260                        graph_depth: Some(hop),
261                    });
262                }
263            }
264        }
265    }
266
267    output::emit_json(&HybridSearchResponse {
268        query: args.query,
269        k: args.k,
270        rrf_k: args.rrf_k,
271        weights: Weights {
272            vec: args.weight_vec,
273            fts: args.weight_fts,
274        },
275        results,
276        graph_matches,
277        elapsed_ms: start.elapsed().as_millis() as u64,
278    })?;
279
280    Ok(())
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286
287    fn empty_response(
288        k: usize,
289        rrf_k: u32,
290        weight_vec: f32,
291        weight_fts: f32,
292    ) -> HybridSearchResponse {
293        HybridSearchResponse {
294            query: "busca teste".to_string(),
295            k,
296            rrf_k,
297            weights: Weights {
298                vec: weight_vec,
299                fts: weight_fts,
300            },
301            results: vec![],
302            graph_matches: vec![],
303            elapsed_ms: 0,
304        }
305    }
306
307    #[test]
308    fn hybrid_search_response_empty_serializes_correct_fields() {
309        let resp = empty_response(10, 60, 1.0, 1.0);
310        let json = serde_json::to_string(&resp).unwrap();
311        assert!(json.contains("\"results\""), "must contain results field");
312        assert!(json.contains("\"query\""), "must contain query field");
313        assert!(json.contains("\"k\""), "must contain k field");
314        assert!(
315            json.contains("\"graph_matches\""),
316            "must contain graph_matches field"
317        );
318        assert!(
319            !json.contains("\"combined_rank\""),
320            "must not contain combined_rank"
321        );
322        assert!(
323            !json.contains("\"vec_rank_list\""),
324            "must not contain vec_rank_list"
325        );
326        assert!(
327            !json.contains("\"fts_rank_list\""),
328            "must not contain fts_rank_list"
329        );
330    }
331
332    #[test]
333    fn hybrid_search_response_serializes_rrf_k_and_weights() {
334        let resp = empty_response(5, 60, 0.7, 0.3);
335        let json = serde_json::to_string(&resp).unwrap();
336        assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
337        assert!(json.contains("\"weights\""), "must contain weights field");
338        assert!(json.contains("\"vec\""), "must contain weights.vec field");
339        assert!(json.contains("\"fts\""), "must contain weights.fts field");
340    }
341
342    #[test]
343    fn hybrid_search_response_serializes_elapsed_ms() {
344        let mut resp = empty_response(5, 60, 1.0, 1.0);
345        resp.elapsed_ms = 123;
346        let json = serde_json::to_string(&resp).unwrap();
347        assert!(
348            json.contains("\"elapsed_ms\""),
349            "must contain elapsed_ms field"
350        );
351        assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
352    }
353
354    #[test]
355    fn weights_struct_serializes_correctly() {
356        let w = Weights { vec: 0.6, fts: 0.4 };
357        let json = serde_json::to_string(&w).unwrap();
358        assert!(json.contains("\"vec\""));
359        assert!(json.contains("\"fts\""));
360    }
361
362    #[test]
363    fn hybrid_search_item_omits_fts_rank_when_none() {
364        let item = HybridSearchItem {
365            memory_id: 1,
366            name: "mem".to_string(),
367            namespace: "default".to_string(),
368            memory_type: "user".to_string(),
369            description: "desc".to_string(),
370            body: "content".to_string(),
371            combined_score: 0.0328,
372            score: 0.0328,
373            source: "hybrid".to_string(),
374            vec_rank: Some(1),
375            fts_rank: None,
376            rrf_score: Some(0.0328),
377        };
378        let json = serde_json::to_string(&item).unwrap();
379        assert!(
380            json.contains("\"vec_rank\""),
381            "must contain vec_rank when Some"
382        );
383        assert!(
384            !json.contains("\"fts_rank\""),
385            "must not contain fts_rank when None"
386        );
387    }
388
389    #[test]
390    fn hybrid_search_item_omits_vec_rank_when_none() {
391        let item = HybridSearchItem {
392            memory_id: 2,
393            name: "mem2".to_string(),
394            namespace: "default".to_string(),
395            memory_type: "fact".to_string(),
396            description: "desc2".to_string(),
397            body: "corpo2".to_string(),
398            combined_score: 0.016,
399            score: 0.016,
400            source: "hybrid".to_string(),
401            vec_rank: None,
402            fts_rank: Some(2),
403            rrf_score: Some(0.016),
404        };
405        let json = serde_json::to_string(&item).unwrap();
406        assert!(
407            !json.contains("\"vec_rank\""),
408            "must not contain vec_rank when None"
409        );
410        assert!(
411            json.contains("\"fts_rank\""),
412            "must contain fts_rank when Some"
413        );
414    }
415
416    #[test]
417    fn hybrid_search_item_serializes_both_ranks_when_some() {
418        let item = HybridSearchItem {
419            memory_id: 3,
420            name: "mem3".to_string(),
421            namespace: "ns".to_string(),
422            memory_type: "entity".to_string(),
423            description: "desc3".to_string(),
424            body: "corpo3".to_string(),
425            combined_score: 0.05,
426            score: 0.05,
427            source: "hybrid".to_string(),
428            vec_rank: Some(3),
429            fts_rank: Some(1),
430            rrf_score: Some(0.05),
431        };
432        let json = serde_json::to_string(&item).unwrap();
433        assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
434        assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
435        assert!(json.contains("\"type\""), "deve serializar type renomeado");
436        assert!(!json.contains("memory_type"), "must not expose memory_type");
437    }
438
439    #[test]
440    fn hybrid_search_response_serializes_k_correctly() {
441        let resp = empty_response(5, 60, 1.0, 1.0);
442        let json = serde_json::to_string(&resp).unwrap();
443        assert!(json.contains("\"k\":5"), "deve serializar k=5");
444    }
445
446    #[test]
447    fn hybrid_search_response_with_graph_matches() {
448        use crate::output::RecallItem;
449        let resp = HybridSearchResponse {
450            query: "test".to_string(),
451            k: 5,
452            rrf_k: 60,
453            weights: Weights { vec: 1.0, fts: 1.0 },
454            results: vec![],
455            graph_matches: vec![RecallItem {
456                memory_id: 1,
457                name: "graph-hit".to_string(),
458                namespace: "global".to_string(),
459                memory_type: "document".to_string(),
460                description: "found via graph".to_string(),
461                snippet: "graph content".to_string(),
462                distance: 0.1,
463                score: 0.9,
464                source: "graph".to_string(),
465                graph_depth: Some(1),
466            }],
467            elapsed_ms: 42,
468        };
469        let json = serde_json::to_value(&resp).unwrap();
470        assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
471        assert_eq!(json["graph_matches"][0]["source"], "graph");
472        assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
473    }
474}