1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Basic hybrid search combining FTS5 + vector via RRF\n \
23 sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n \
24 # Tune RRF weights to favor keyword matches over semantic similarity\n \
25 sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n \
26 # Add graph traversal matches (entities connected to top results)\n \
27 sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n \
28 # Graph traversal with custom depth and minimum edge weight\n \
29 sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n \
30NOTES:\n \
31 --with-graph enables entity graph traversal seeded by the top RRF results.\n \
32 Graph matches appear in the `graph_matches` array (separate from `results`).\n \
33 Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35 #[arg(
36 allow_hyphen_values = true,
37 help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)"
38 )]
39 pub query: String,
40 #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
45 pub k: usize,
46 #[arg(long, default_value = "60")]
47 pub rrf_k: u32,
48 #[arg(long, default_value = "1.0")]
49 pub weight_vec: f32,
50 #[arg(long, default_value = "1.0")]
51 pub weight_fts: f32,
52 #[arg(long, value_enum)]
56 pub r#type: Option<MemoryType>,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long)]
60 pub with_graph: bool,
61 #[arg(long, help = "Skip live query embedding; serve FTS5 BM25 only")]
64 pub fallback_fts_only: bool,
65 #[arg(long)]
67 pub max_hops: Option<u32>,
68 #[arg(long)]
70 pub min_weight: Option<f64>,
71 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
72 pub format: JsonOutputFormat,
73 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
74 pub db: Option<String>,
75 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
77 pub json: bool,
78}
79
80#[derive(serde::Serialize)]
81pub struct HybridSearchItem {
82 pub memory_id: i64,
83 pub name: String,
84 pub namespace: String,
85 #[serde(rename = "type")]
86 pub memory_type: String,
87 pub description: String,
88 pub body: String,
89 pub snippet: String,
90 pub combined_score: f64,
91 pub score: f64,
93 pub source: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub vec_rank: Option<usize>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub fts_rank: Option<usize>,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub rrf_score: Option<f64>,
102 pub normalized_score: f64,
104 #[serde(skip_serializing_if = "Option::is_none")]
109 pub vec_distance: Option<f64>,
110 #[serde(skip_serializing_if = "Option::is_none")]
113 pub fts_bm25: Option<f64>,
114}
115
116#[derive(serde::Serialize)]
118pub struct Weights {
119 pub vec: f32,
120 pub fts: f32,
121}
122
123#[derive(serde::Serialize)]
124pub struct HybridSearchResponse {
125 pub query: String,
126 pub k: usize,
127 pub rrf_k: u32,
129 pub weights: Weights,
131 pub results: Vec<HybridSearchItem>,
132 pub graph_matches: Vec<RecallItem>,
133 #[serde(skip_serializing_if = "std::ops::Not::not")]
137 pub fts_degraded: bool,
138 #[serde(skip_serializing_if = "Option::is_none")]
142 pub fts_error: Option<String>,
143 #[serde(skip_serializing_if = "std::ops::Not::not")]
147 pub fts_auto_rebuilt: bool,
148 #[serde(skip_serializing_if = "std::ops::Not::not", default)]
152 pub vec_degraded: bool,
153 #[serde(skip_serializing_if = "Option::is_none")]
157 pub vec_error: Option<String>,
158 #[serde(skip_serializing_if = "Option::is_none")]
162 pub warning: Option<String>,
163 #[serde(skip_serializing_if = "Option::is_none")]
167 pub backend_invoked: Option<&'static str>,
168 #[serde(skip_serializing_if = "Option::is_none")]
172 pub vec_degraded_reason: Option<String>,
173 pub elapsed_ms: u64,
175}
176
177#[tracing::instrument(skip_all, level = "debug", name = "hybrid_search")]
178pub fn run(
179 args: HybridSearchArgs,
180 llm_backend: crate::cli::LlmBackendChoice,
181) -> Result<(), AppError> {
182 let start = std::time::Instant::now();
183 let _ = args.format;
184 tracing::debug!(target: "hybrid_search", query = %args.query, k = args.k, "fusing results");
185
186 if !args.with_graph {
190 if args.max_hops.is_some() {
191 return Err(AppError::Validation(
192 "--max-hops requires --with-graph to be active".to_string(),
193 ));
194 }
195 if args.min_weight.is_some() {
196 return Err(AppError::Validation(
197 "--min-weight requires --with-graph to be active".to_string(),
198 ));
199 }
200 }
201
202 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
203 let paths = AppPaths::resolve(args.db.as_deref())?;
204 crate::storage::connection::ensure_db_ready(&paths)?;
205
206 output::emit_progress_i18n(
207 "Computing query embedding...",
208 "Calculando embedding da consulta...",
209 );
210 let conn = open_ro(&paths.db)?;
211 let (embedding, vec_degraded, vec_error, backend_invoked) = if args.fallback_fts_only {
219 (
220 None,
221 true,
222 Some("fallback_fts_only requested".to_string()),
223 None,
224 )
225 } else {
226 match crate::embedder::try_embed_query_with_deterministic_fallback(
233 &paths.models,
234 &args.query,
235 Some(llm_backend),
236 ) {
237 Ok((v, backend)) => (Some(v), false, None, Some(backend.as_str())),
238 Err(reason) => {
239 let msg = reason.to_string();
240 tracing::warn!(target: "hybrid_search", fallback_reason = %msg, reason_code = %reason.reason_code(), "live embedding failed; falling back to FTS5");
241 (None, true, Some(msg), None)
242 }
243 }
244 };
245
246 let memory_type_str = args.r#type.map(|t| t.as_str());
247
248 let vec_results: Vec<(i64, f32)> = if let Some(emb) = embedding.as_ref() {
249 memories::knn_search(
250 &conn,
251 emb,
252 std::slice::from_ref(&namespace),
253 memory_type_str,
254 args.k * 2,
255 )?
256 } else {
257 Vec::new()
258 };
259
260 let vec_rank_map: HashMap<i64, usize> = vec_results
262 .iter()
263 .enumerate()
264 .map(|(pos, (id, _))| (*id, pos + 1))
265 .collect();
266
267 let vec_distance_map: HashMap<i64, f64> = vec_results
269 .iter()
270 .map(|(id, dist)| (*id, *dist as f64))
271 .collect();
272
273 let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
274 (vec![], false, None, false)
275 } else {
276 match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
277 Ok(r) => (r, false, None, false),
278 Err(e) => {
279 let err_msg = e.to_string();
280 let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
281 if is_malformed {
282 tracing::warn!(target: "hybrid_search", "FTS5 index corrupted, attempting auto-rebuild");
283 if conn
284 .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
285 .is_ok()
286 {
287 match memories::fts_search(
288 &conn,
289 &args.query,
290 &namespace,
291 memory_type_str,
292 args.k * 2,
293 ) {
294 Ok(r) => (r, false, None, true),
295 Err(e2) => {
296 tracing::error!(target: "hybrid_search", error = %e2, "FTS5 auto-rebuild failed to recover");
297 (vec![], true, Some(e2.to_string()), true)
298 }
299 }
300 } else {
301 (vec![], true, Some(err_msg), false)
302 }
303 } else {
304 tracing::warn!(target: "hybrid_search", error = %e, "FTS5 query failed, falling back to vec-only");
305 (vec![], true, Some(err_msg), false)
306 }
307 }
308 }
309 };
310
311 let fts_rank_map: HashMap<i64, usize> = fts_results
313 .iter()
314 .enumerate()
315 .map(|(pos, row)| (row.id, pos + 1))
316 .collect();
317
318 let rrf_k = args.rrf_k as f64;
319
320 let mut combined_scores: crate::hash::AHashMap<i64, f64> =
322 crate::hash::AHashMap::with_capacity_and_hasher(
323 vec_results.len() + fts_results.len(),
324 Default::default(),
325 );
326
327 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
328 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
329 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
330 }
331
332 for (rank, row) in fts_results.iter().enumerate() {
333 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
334 *combined_scores.entry(row.id).or_insert(0.0) += score;
335 }
336
337 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
339 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
340 ranked.truncate(args.k);
341
342 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
344
345 let mut memory_data: crate::hash::AHashMap<i64, memories::MemoryRow> =
347 crate::hash::AHashMap::with_capacity_and_hasher(ranked.len(), Default::default());
348 for id in &top_ids {
349 if let Some(row) = memories::read_full(&conn, *id)? {
350 memory_data.insert(*id, row);
351 }
352 }
353
354 let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
355 + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
356
357 let results: Vec<HybridSearchItem> = ranked
359 .into_iter()
360 .filter_map(|(memory_id, combined_score)| {
361 let normalized_score = if max_possible > 0.0 {
362 combined_score / max_possible
363 } else {
364 0.0
365 };
366 memory_data.remove(&memory_id).map(|row| {
367 let snippet: String = row.body.chars().take(300).collect();
368 HybridSearchItem {
369 memory_id: row.id,
370 name: row.name,
371 namespace: row.namespace,
372 memory_type: row.memory_type,
373 description: row.description,
374 body: row.body,
375 snippet,
376 combined_score,
377 score: combined_score,
378 source: "hybrid".to_string(),
379 vec_rank: vec_rank_map.get(&memory_id).copied(),
380 fts_rank: fts_rank_map.get(&memory_id).copied(),
381 rrf_score: Some(combined_score),
382 normalized_score,
383 vec_distance: vec_distance_map.get(&memory_id).copied(),
384 fts_bm25: None,
385 }
386 })
387 })
388 .collect();
389
390 let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
392 if let Some(emb) = args
393 .with_graph
394 .then_some(())
395 .filter(|_| !results.is_empty())
396 .and(embedding.as_ref())
397 {
398 let namespace_for_graph = namespace.clone();
399 let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
400
401 let entity_knn = entities::knn_search(&conn, emb, &namespace_for_graph, 5)?;
402 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
403
404 let all_seed_ids: Vec<i64> = memory_ids
405 .iter()
406 .chain(entity_ids.iter())
407 .copied()
408 .collect();
409
410 if !all_seed_ids.is_empty() {
411 let graph_memory_ids = traverse_from_memories_with_hops(
412 &conn,
413 &all_seed_ids,
414 &namespace_for_graph,
415 args.min_weight.unwrap_or(0.3),
416 args.max_hops.unwrap_or(2),
417 )?;
418
419 let already_in_results: std::collections::HashSet<i64> =
420 results.iter().map(|r| r.memory_id).collect();
421
422 for (graph_mem_id, hop) in graph_memory_ids {
423 if already_in_results.contains(&graph_mem_id) {
424 continue;
425 }
426 if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
427 let snippet: String = row.body.chars().take(300).collect();
428 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
429 graph_matches.push(RecallItem {
430 memory_id: row.id,
431 name: row.name,
432 namespace: row.namespace,
433 memory_type: row.memory_type,
434 description: row.description,
435 snippet,
436 distance: graph_distance,
437 score: RecallItem::score_from_distance(graph_distance),
438 source: "graph".to_string(),
439 graph_depth: Some(hop),
440 });
441 }
442 }
443 }
444 }
445
446 output::emit_json(&HybridSearchResponse {
447 query: args.query,
448 k: args.k,
449 rrf_k: args.rrf_k,
450 weights: Weights {
451 vec: args.weight_vec,
452 fts: args.weight_fts,
453 },
454 results,
455 graph_matches,
456 fts_degraded,
457 fts_error,
458 fts_auto_rebuilt,
459 vec_degraded,
460 vec_error: vec_error.clone(),
461 warning: if vec_degraded {
462 Some(
463 "live query embedding unavailable; results are FTS5 BM25 only (semantic relevance reduced)"
464 .to_string(),
465 )
466 } else {
467 None
468 },
469 backend_invoked,
470 vec_degraded_reason: if vec_degraded { vec_error } else { None },
471 elapsed_ms: start.elapsed().as_millis() as u64,
472 })?;
473
474 Ok(())
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[derive(clap::Parser)]
482 struct TestCli {
483 #[command(flatten)]
484 args: HybridSearchArgs,
485 }
486
487 #[test]
488 fn graph_flags_parse_as_none_when_absent() {
489 use clap::Parser;
493 let cli = TestCli::try_parse_from(["hybrid-search", "q"]).expect("bare query parses");
494 assert!(cli.args.max_hops.is_none());
495 assert!(cli.args.min_weight.is_none());
496 let cli = TestCli::try_parse_from(["hybrid-search", "q", "--max-hops", "2"])
497 .expect("explicit flag parses");
498 assert_eq!(cli.args.max_hops, Some(2));
499 }
500
501 fn empty_response(
502 k: usize,
503 rrf_k: u32,
504 weight_vec: f32,
505 weight_fts: f32,
506 ) -> HybridSearchResponse {
507 HybridSearchResponse {
508 query: "test query".to_string(),
509 k,
510 rrf_k,
511 weights: Weights {
512 vec: weight_vec,
513 fts: weight_fts,
514 },
515 results: vec![],
516 graph_matches: vec![],
517 fts_degraded: false,
518 fts_error: None,
519 fts_auto_rebuilt: false,
520 vec_degraded: false,
521 vec_error: None,
522 warning: None,
523 backend_invoked: None,
524 vec_degraded_reason: None,
525 elapsed_ms: 0,
526 }
527 }
528
529 #[test]
530 fn hybrid_search_response_empty_serializes_correct_fields() {
531 let resp = empty_response(10, 60, 1.0, 1.0);
532 let json = serde_json::to_string(&resp).unwrap();
533 assert!(json.contains("\"results\""), "must contain results field");
534 assert!(json.contains("\"query\""), "must contain query field");
535 assert!(json.contains("\"k\""), "must contain k field");
536 assert!(
537 json.contains("\"graph_matches\""),
538 "must contain graph_matches field"
539 );
540 assert!(
541 !json.contains("\"combined_rank\""),
542 "must not contain combined_rank"
543 );
544 assert!(
545 !json.contains("\"vec_rank_list\""),
546 "must not contain vec_rank_list"
547 );
548 assert!(
549 !json.contains("\"fts_rank_list\""),
550 "must not contain fts_rank_list"
551 );
552 }
553
554 #[test]
555 fn hybrid_search_response_serializes_rrf_k_and_weights() {
556 let resp = empty_response(5, 60, 0.7, 0.3);
557 let json = serde_json::to_string(&resp).unwrap();
558 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
559 assert!(json.contains("\"weights\""), "must contain weights field");
560 assert!(json.contains("\"vec\""), "must contain weights.vec field");
561 assert!(json.contains("\"fts\""), "must contain weights.fts field");
562 }
563
564 #[test]
565 fn hybrid_search_response_serializes_elapsed_ms() {
566 let mut resp = empty_response(5, 60, 1.0, 1.0);
567 resp.elapsed_ms = 123;
568 let json = serde_json::to_string(&resp).unwrap();
569 assert!(
570 json.contains("\"elapsed_ms\""),
571 "must contain elapsed_ms field"
572 );
573 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
574 }
575
576 #[test]
577 fn weights_struct_serializes_correctly() {
578 let w = Weights { vec: 0.6, fts: 0.4 };
579 let json = serde_json::to_string(&w).unwrap();
580 assert!(json.contains("\"vec\""));
581 assert!(json.contains("\"fts\""));
582 }
583
584 #[test]
585 fn hybrid_search_item_omits_fts_rank_when_none() {
586 let item = HybridSearchItem {
587 memory_id: 1,
588 name: "mem".to_string(),
589 namespace: "default".to_string(),
590 memory_type: "user".to_string(),
591 description: "desc".to_string(),
592 body: "content".to_string(),
593 snippet: "content".to_string(),
594 combined_score: 0.0328,
595 score: 0.0328,
596 source: "hybrid".to_string(),
597 vec_rank: Some(1),
598 fts_rank: None,
599 rrf_score: Some(0.0328),
600 normalized_score: 1.0,
601 vec_distance: Some(0.12),
602 fts_bm25: None,
603 };
604 let json = serde_json::to_string(&item).unwrap();
605 assert!(
606 json.contains("\"vec_rank\""),
607 "must contain vec_rank when Some"
608 );
609 assert!(
610 !json.contains("\"fts_rank\""),
611 "must not contain fts_rank when None"
612 );
613 }
614
615 #[test]
616 fn hybrid_search_item_omits_vec_rank_when_none() {
617 let item = HybridSearchItem {
618 memory_id: 2,
619 name: "mem2".to_string(),
620 namespace: "default".to_string(),
621 memory_type: "fact".to_string(),
622 description: "desc2".to_string(),
623 body: "corpo2".to_string(),
624 snippet: "corpo2".to_string(),
625 combined_score: 0.016,
626 score: 0.016,
627 source: "hybrid".to_string(),
628 vec_rank: None,
629 fts_rank: Some(2),
630 rrf_score: Some(0.016),
631 normalized_score: 0.5,
632 vec_distance: None,
633 fts_bm25: None,
634 };
635 let json = serde_json::to_string(&item).unwrap();
636 assert!(
637 !json.contains("\"vec_rank\""),
638 "must not contain vec_rank when None"
639 );
640 assert!(
641 json.contains("\"fts_rank\""),
642 "must contain fts_rank when Some"
643 );
644 }
645
646 #[test]
647 fn hybrid_search_item_serializes_both_ranks_when_some() {
648 let item = HybridSearchItem {
649 memory_id: 3,
650 name: "mem3".to_string(),
651 namespace: "ns".to_string(),
652 memory_type: "entity".to_string(),
653 description: "desc3".to_string(),
654 body: "corpo3".to_string(),
655 snippet: "corpo3".to_string(),
656 combined_score: 0.05,
657 score: 0.05,
658 source: "hybrid".to_string(),
659 vec_rank: Some(3),
660 fts_rank: Some(1),
661 rrf_score: Some(0.05),
662 normalized_score: 0.8,
663 vec_distance: Some(0.25),
664 fts_bm25: None,
665 };
666 let json = serde_json::to_string(&item).unwrap();
667 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
668 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
669 assert!(json.contains("\"type\""), "deve serializar type renomeado");
670 assert!(!json.contains("memory_type"), "must not expose memory_type");
671 }
672
673 #[test]
674 fn hybrid_search_response_serializes_k_correctly() {
675 let resp = empty_response(5, 60, 1.0, 1.0);
676 let json = serde_json::to_string(&resp).unwrap();
677 assert!(json.contains("\"k\":5"), "deve serializar k=5");
678 }
679
680 #[test]
681 fn hybrid_search_response_with_graph_matches() {
682 use crate::output::RecallItem;
683 let resp = HybridSearchResponse {
684 query: "test".to_string(),
685 k: 5,
686 rrf_k: 60,
687 weights: Weights { vec: 1.0, fts: 1.0 },
688 results: vec![],
689 graph_matches: vec![RecallItem {
690 memory_id: 1,
691 name: "graph-hit".to_string(),
692 namespace: "global".to_string(),
693 memory_type: "document".to_string(),
694 description: "found via graph".to_string(),
695 snippet: "graph content".to_string(),
696 distance: 0.1,
697 score: 0.9,
698 source: "graph".to_string(),
699 graph_depth: Some(1),
700 }],
701 fts_degraded: false,
702 fts_error: None,
703 fts_auto_rebuilt: false,
704 vec_degraded: false,
705 vec_error: None,
706 warning: None,
707 backend_invoked: None,
708 vec_degraded_reason: None,
709 elapsed_ms: 42,
710 };
711 let json = serde_json::to_value(&resp).unwrap();
712 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
713 assert_eq!(json["graph_matches"][0]["source"], "graph");
714 assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
715 }
716
717 #[test]
718 fn fts_degraded_omitted_on_success_present_on_failure() {
719 let ok_resp = empty_response(5, 60, 1.0, 1.0);
721 let ok_json = serde_json::to_string(&ok_resp).unwrap();
722 assert!(
723 !ok_json.contains("\"fts_degraded\""),
724 "fts_degraded must be absent when false"
725 );
726 assert!(
727 !ok_json.contains("\"fts_error\""),
728 "fts_error must be absent when None"
729 );
730
731 let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
733 degraded_resp.fts_degraded = true;
734 degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
735 let degraded_json = serde_json::to_string(°raded_resp).unwrap();
736 assert!(
737 degraded_json.contains("\"fts_degraded\":true"),
738 "fts_degraded must be present and true when degraded"
739 );
740 assert!(
741 degraded_json.contains("\"fts_error\""),
742 "fts_error must be present when Some"
743 );
744 assert!(
745 degraded_json.contains("FTS5 table corrupted"),
746 "fts_error must contain the error message"
747 );
748 }
749}