1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Basic hybrid search combining FTS5 + vector via RRF\n \
23 sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n \
24 # Tune RRF weights to favor keyword matches over semantic similarity\n \
25 sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n \
26 # Add graph traversal matches (entities connected to top results)\n \
27 sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n \
28 # Graph traversal with custom depth and minimum edge weight\n \
29 sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n \
30NOTES:\n \
31 --with-graph enables entity graph traversal seeded by the top RRF results.\n \
32 Graph matches appear in the `graph_matches` array (separate from `results`).\n \
33 Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35 #[arg(
36 allow_hyphen_values = true,
37 help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)"
38 )]
39 pub query: String,
40 #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
45 pub k: usize,
46 #[arg(long, default_value = "60")]
47 pub rrf_k: u32,
48 #[arg(long, default_value = "1.0")]
49 pub weight_vec: f32,
50 #[arg(long, default_value = "1.0")]
51 pub weight_fts: f32,
52 #[arg(long, value_enum)]
56 pub r#type: Option<MemoryType>,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long)]
60 pub with_graph: bool,
61 #[arg(long, help = "Skip live query embedding; serve FTS5 BM25 only")]
64 pub fallback_fts_only: bool,
65 #[arg(long)]
67 pub max_hops: Option<u32>,
68 #[arg(long)]
70 pub min_weight: Option<f64>,
71 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
72 pub format: JsonOutputFormat,
73 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
74 pub db: Option<String>,
75 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
77 pub json: bool,
78}
79
80#[derive(serde::Serialize)]
81pub struct HybridSearchItem {
82 pub memory_id: i64,
83 pub name: String,
84 pub namespace: String,
85 #[serde(rename = "type")]
86 pub memory_type: String,
87 pub description: String,
88 pub body: String,
89 pub snippet: String,
90 pub combined_score: f64,
91 pub score: f64,
93 pub source: String,
95 #[serde(skip_serializing_if = "Option::is_none")]
96 pub vec_rank: Option<usize>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub fts_rank: Option<usize>,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub rrf_score: Option<f64>,
102 pub normalized_score: f64,
104 #[serde(skip_serializing_if = "Option::is_none")]
109 pub vec_distance: Option<f64>,
110 #[serde(skip_serializing_if = "Option::is_none")]
113 pub fts_bm25: Option<f64>,
114}
115
116#[derive(serde::Serialize)]
118pub struct Weights {
119 pub vec: f32,
120 pub fts: f32,
121}
122
123#[derive(serde::Serialize)]
124pub struct HybridSearchResponse {
125 pub query: String,
126 pub k: usize,
127 pub rrf_k: u32,
129 pub weights: Weights,
131 pub results: Vec<HybridSearchItem>,
132 pub graph_matches: Vec<RecallItem>,
133 #[serde(skip_serializing_if = "std::ops::Not::not")]
137 pub fts_degraded: bool,
138 #[serde(skip_serializing_if = "Option::is_none")]
142 pub fts_error: Option<String>,
143 #[serde(skip_serializing_if = "std::ops::Not::not")]
147 pub fts_auto_rebuilt: bool,
148 #[serde(skip_serializing_if = "std::ops::Not::not", default)]
152 pub vec_degraded: bool,
153 #[serde(skip_serializing_if = "Option::is_none")]
157 pub vec_error: Option<String>,
158 #[serde(skip_serializing_if = "Option::is_none")]
162 pub warning: Option<String>,
163 #[serde(skip_serializing_if = "Option::is_none")]
167 pub backend_invoked: Option<&'static str>,
168 #[serde(skip_serializing_if = "Option::is_none")]
172 pub vec_degraded_reason: Option<String>,
173 pub elapsed_ms: u64,
175}
176
177#[tracing::instrument(skip_all, level = "debug", name = "hybrid_search")]
178pub fn run(
179 args: HybridSearchArgs,
180 llm_backend: crate::cli::LlmBackendChoice,
181 embedding_backend: crate::cli::EmbeddingBackendChoice,
182) -> Result<(), AppError> {
183 let start = std::time::Instant::now();
184 let _ = args.format;
185 tracing::debug!(target: "hybrid_search", query = %args.query, k = args.k, "fusing results");
186
187 if !args.with_graph {
191 if args.max_hops.is_some() {
192 return Err(AppError::Validation(
193 "--max-hops requires --with-graph to be active".to_string(),
194 ));
195 }
196 if args.min_weight.is_some() {
197 return Err(AppError::Validation(
198 "--min-weight requires --with-graph to be active".to_string(),
199 ));
200 }
201 }
202
203 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
204 let paths = AppPaths::resolve(args.db.as_deref())?;
205 crate::storage::connection::ensure_db_ready(&paths)?;
206
207 output::emit_progress_i18n(
208 "Computing query embedding...",
209 "Calculando embedding da consulta...",
210 );
211 let conn = open_ro(&paths.db)?;
212 let (embedding, vec_degraded, vec_error, backend_invoked) = if args.fallback_fts_only {
220 (
221 None,
222 true,
223 Some("fallback_fts_only requested".to_string()),
224 None,
225 )
226 } else {
227 match crate::embedder::try_embed_query_with_embedding_choice(
234 &paths.models,
235 &args.query,
236 embedding_backend,
237 llm_backend,
238 ) {
239 Ok((v, backend)) => (Some(v), false, None, Some(backend.as_str())),
240 Err(reason) => {
241 let msg = reason.to_string();
242 tracing::warn!(target: "hybrid_search", fallback_reason = %msg, reason_code = %reason.reason_code(), "live embedding failed; falling back to FTS5");
243 (None, true, Some(msg), None)
244 }
245 }
246 };
247
248 let memory_type_str = args.r#type.map(|t| t.as_str());
249
250 let vec_results: Vec<(i64, f32)> = if let Some(emb) = embedding.as_ref() {
251 memories::knn_search(
252 &conn,
253 emb,
254 std::slice::from_ref(&namespace),
255 memory_type_str,
256 args.k * 2,
257 )?
258 } else {
259 Vec::new()
260 };
261
262 let vec_rank_map: HashMap<i64, usize> = vec_results
264 .iter()
265 .enumerate()
266 .map(|(pos, (id, _))| (*id, pos + 1))
267 .collect();
268
269 let vec_distance_map: HashMap<i64, f64> = vec_results
271 .iter()
272 .map(|(id, dist)| (*id, *dist as f64))
273 .collect();
274
275 let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
276 (vec![], false, None, false)
277 } else {
278 match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
279 Ok(r) => (r, false, None, false),
280 Err(e) => {
281 let err_msg = e.to_string();
282 let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
283 if is_malformed {
284 tracing::warn!(target: "hybrid_search", "FTS5 index corrupted, attempting auto-rebuild");
285 if conn
286 .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
287 .is_ok()
288 {
289 match memories::fts_search(
290 &conn,
291 &args.query,
292 &namespace,
293 memory_type_str,
294 args.k * 2,
295 ) {
296 Ok(r) => (r, false, None, true),
297 Err(e2) => {
298 tracing::error!(target: "hybrid_search", error = %e2, "FTS5 auto-rebuild failed to recover");
299 (vec![], true, Some(e2.to_string()), true)
300 }
301 }
302 } else {
303 (vec![], true, Some(err_msg), false)
304 }
305 } else {
306 tracing::warn!(target: "hybrid_search", error = %e, "FTS5 query failed, falling back to vec-only");
307 (vec![], true, Some(err_msg), false)
308 }
309 }
310 }
311 };
312
313 let fts_rank_map: HashMap<i64, usize> = fts_results
315 .iter()
316 .enumerate()
317 .map(|(pos, row)| (row.id, pos + 1))
318 .collect();
319
320 let rrf_k = args.rrf_k as f64;
321
322 let mut combined_scores: crate::hash::AHashMap<i64, f64> =
324 crate::hash::AHashMap::with_capacity_and_hasher(
325 vec_results.len() + fts_results.len(),
326 Default::default(),
327 );
328
329 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
330 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
331 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
332 }
333
334 for (rank, row) in fts_results.iter().enumerate() {
335 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
336 *combined_scores.entry(row.id).or_insert(0.0) += score;
337 }
338
339 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
341 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
342 ranked.truncate(args.k);
343
344 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
346
347 let mut memory_data: crate::hash::AHashMap<i64, memories::MemoryRow> =
349 crate::hash::AHashMap::with_capacity_and_hasher(ranked.len(), Default::default());
350 for id in &top_ids {
351 if let Some(row) = memories::read_full(&conn, *id)? {
352 memory_data.insert(*id, row);
353 }
354 }
355
356 let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
357 + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
358
359 let results: Vec<HybridSearchItem> = ranked
361 .into_iter()
362 .filter_map(|(memory_id, combined_score)| {
363 let normalized_score = if max_possible > 0.0 {
364 combined_score / max_possible
365 } else {
366 0.0
367 };
368 memory_data.remove(&memory_id).map(|row| {
369 let snippet: String = row.body.chars().take(300).collect();
370 HybridSearchItem {
371 memory_id: row.id,
372 name: row.name,
373 namespace: row.namespace,
374 memory_type: row.memory_type,
375 description: row.description,
376 body: row.body,
377 snippet,
378 combined_score,
379 score: combined_score,
380 source: "hybrid".to_string(),
381 vec_rank: vec_rank_map.get(&memory_id).copied(),
382 fts_rank: fts_rank_map.get(&memory_id).copied(),
383 rrf_score: Some(combined_score),
384 normalized_score,
385 vec_distance: vec_distance_map.get(&memory_id).copied(),
386 fts_bm25: None,
387 }
388 })
389 })
390 .collect();
391
392 let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
394 if let Some(emb) = args
395 .with_graph
396 .then_some(())
397 .filter(|_| !results.is_empty())
398 .and(embedding.as_ref())
399 {
400 let namespace_for_graph = namespace.clone();
401 let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
402
403 let entity_knn = entities::knn_search(&conn, emb, &namespace_for_graph, 5)?;
404 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
405
406 let all_seed_ids: Vec<i64> = memory_ids
407 .iter()
408 .chain(entity_ids.iter())
409 .copied()
410 .collect();
411
412 if !all_seed_ids.is_empty() {
413 let graph_memory_ids = traverse_from_memories_with_hops(
414 &conn,
415 &all_seed_ids,
416 &namespace_for_graph,
417 args.min_weight.unwrap_or(0.3),
418 args.max_hops.unwrap_or(2),
419 )?;
420
421 let already_in_results: std::collections::HashSet<i64> =
422 results.iter().map(|r| r.memory_id).collect();
423
424 for (graph_mem_id, hop) in graph_memory_ids {
425 if already_in_results.contains(&graph_mem_id) {
426 continue;
427 }
428 if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
429 let snippet: String = row.body.chars().take(300).collect();
430 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
431 graph_matches.push(RecallItem {
432 memory_id: row.id,
433 name: row.name,
434 namespace: row.namespace,
435 memory_type: row.memory_type,
436 description: row.description,
437 snippet,
438 distance: graph_distance,
439 score: RecallItem::score_from_distance(graph_distance),
440 source: "graph".to_string(),
441 graph_depth: Some(hop),
442 });
443 }
444 }
445 }
446 }
447
448 output::emit_json(&HybridSearchResponse {
449 query: args.query,
450 k: args.k,
451 rrf_k: args.rrf_k,
452 weights: Weights {
453 vec: args.weight_vec,
454 fts: args.weight_fts,
455 },
456 results,
457 graph_matches,
458 fts_degraded,
459 fts_error,
460 fts_auto_rebuilt,
461 vec_degraded,
462 vec_error: vec_error.clone(),
463 warning: if vec_degraded {
464 Some(
465 "live query embedding unavailable; results are FTS5 BM25 only (semantic relevance reduced)"
466 .to_string(),
467 )
468 } else {
469 None
470 },
471 backend_invoked,
472 vec_degraded_reason: if vec_degraded { vec_error } else { None },
473 elapsed_ms: start.elapsed().as_millis() as u64,
474 })?;
475
476 Ok(())
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[derive(clap::Parser)]
484 struct TestCli {
485 #[command(flatten)]
486 args: HybridSearchArgs,
487 }
488
489 #[test]
490 fn graph_flags_parse_as_none_when_absent() {
491 use clap::Parser;
495 let cli = TestCli::try_parse_from(["hybrid-search", "q"]).expect("bare query parses");
496 assert!(cli.args.max_hops.is_none());
497 assert!(cli.args.min_weight.is_none());
498 let cli = TestCli::try_parse_from(["hybrid-search", "q", "--max-hops", "2"])
499 .expect("explicit flag parses");
500 assert_eq!(cli.args.max_hops, Some(2));
501 }
502
503 fn empty_response(
504 k: usize,
505 rrf_k: u32,
506 weight_vec: f32,
507 weight_fts: f32,
508 ) -> HybridSearchResponse {
509 HybridSearchResponse {
510 query: "test query".to_string(),
511 k,
512 rrf_k,
513 weights: Weights {
514 vec: weight_vec,
515 fts: weight_fts,
516 },
517 results: vec![],
518 graph_matches: vec![],
519 fts_degraded: false,
520 fts_error: None,
521 fts_auto_rebuilt: false,
522 vec_degraded: false,
523 vec_error: None,
524 warning: None,
525 backend_invoked: None,
526 vec_degraded_reason: None,
527 elapsed_ms: 0,
528 }
529 }
530
531 #[test]
532 fn hybrid_search_response_empty_serializes_correct_fields() {
533 let resp = empty_response(10, 60, 1.0, 1.0);
534 let json = serde_json::to_string(&resp).unwrap();
535 assert!(json.contains("\"results\""), "must contain results field");
536 assert!(json.contains("\"query\""), "must contain query field");
537 assert!(json.contains("\"k\""), "must contain k field");
538 assert!(
539 json.contains("\"graph_matches\""),
540 "must contain graph_matches field"
541 );
542 assert!(
543 !json.contains("\"combined_rank\""),
544 "must not contain combined_rank"
545 );
546 assert!(
547 !json.contains("\"vec_rank_list\""),
548 "must not contain vec_rank_list"
549 );
550 assert!(
551 !json.contains("\"fts_rank_list\""),
552 "must not contain fts_rank_list"
553 );
554 }
555
556 #[test]
557 fn hybrid_search_response_serializes_rrf_k_and_weights() {
558 let resp = empty_response(5, 60, 0.7, 0.3);
559 let json = serde_json::to_string(&resp).unwrap();
560 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
561 assert!(json.contains("\"weights\""), "must contain weights field");
562 assert!(json.contains("\"vec\""), "must contain weights.vec field");
563 assert!(json.contains("\"fts\""), "must contain weights.fts field");
564 }
565
566 #[test]
567 fn hybrid_search_response_serializes_elapsed_ms() {
568 let mut resp = empty_response(5, 60, 1.0, 1.0);
569 resp.elapsed_ms = 123;
570 let json = serde_json::to_string(&resp).unwrap();
571 assert!(
572 json.contains("\"elapsed_ms\""),
573 "must contain elapsed_ms field"
574 );
575 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
576 }
577
578 #[test]
579 fn weights_struct_serializes_correctly() {
580 let w = Weights { vec: 0.6, fts: 0.4 };
581 let json = serde_json::to_string(&w).unwrap();
582 assert!(json.contains("\"vec\""));
583 assert!(json.contains("\"fts\""));
584 }
585
586 #[test]
587 fn hybrid_search_item_omits_fts_rank_when_none() {
588 let item = HybridSearchItem {
589 memory_id: 1,
590 name: "mem".to_string(),
591 namespace: "default".to_string(),
592 memory_type: "user".to_string(),
593 description: "desc".to_string(),
594 body: "content".to_string(),
595 snippet: "content".to_string(),
596 combined_score: 0.0328,
597 score: 0.0328,
598 source: "hybrid".to_string(),
599 vec_rank: Some(1),
600 fts_rank: None,
601 rrf_score: Some(0.0328),
602 normalized_score: 1.0,
603 vec_distance: Some(0.12),
604 fts_bm25: None,
605 };
606 let json = serde_json::to_string(&item).unwrap();
607 assert!(
608 json.contains("\"vec_rank\""),
609 "must contain vec_rank when Some"
610 );
611 assert!(
612 !json.contains("\"fts_rank\""),
613 "must not contain fts_rank when None"
614 );
615 }
616
617 #[test]
618 fn hybrid_search_item_omits_vec_rank_when_none() {
619 let item = HybridSearchItem {
620 memory_id: 2,
621 name: "mem2".to_string(),
622 namespace: "default".to_string(),
623 memory_type: "fact".to_string(),
624 description: "desc2".to_string(),
625 body: "corpo2".to_string(),
626 snippet: "corpo2".to_string(),
627 combined_score: 0.016,
628 score: 0.016,
629 source: "hybrid".to_string(),
630 vec_rank: None,
631 fts_rank: Some(2),
632 rrf_score: Some(0.016),
633 normalized_score: 0.5,
634 vec_distance: None,
635 fts_bm25: None,
636 };
637 let json = serde_json::to_string(&item).unwrap();
638 assert!(
639 !json.contains("\"vec_rank\""),
640 "must not contain vec_rank when None"
641 );
642 assert!(
643 json.contains("\"fts_rank\""),
644 "must contain fts_rank when Some"
645 );
646 }
647
648 #[test]
649 fn hybrid_search_item_serializes_both_ranks_when_some() {
650 let item = HybridSearchItem {
651 memory_id: 3,
652 name: "mem3".to_string(),
653 namespace: "ns".to_string(),
654 memory_type: "entity".to_string(),
655 description: "desc3".to_string(),
656 body: "corpo3".to_string(),
657 snippet: "corpo3".to_string(),
658 combined_score: 0.05,
659 score: 0.05,
660 source: "hybrid".to_string(),
661 vec_rank: Some(3),
662 fts_rank: Some(1),
663 rrf_score: Some(0.05),
664 normalized_score: 0.8,
665 vec_distance: Some(0.25),
666 fts_bm25: None,
667 };
668 let json = serde_json::to_string(&item).unwrap();
669 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
670 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
671 assert!(json.contains("\"type\""), "deve serializar type renomeado");
672 assert!(!json.contains("memory_type"), "must not expose memory_type");
673 }
674
675 #[test]
676 fn hybrid_search_response_serializes_k_correctly() {
677 let resp = empty_response(5, 60, 1.0, 1.0);
678 let json = serde_json::to_string(&resp).unwrap();
679 assert!(json.contains("\"k\":5"), "deve serializar k=5");
680 }
681
682 #[test]
683 fn hybrid_search_response_with_graph_matches() {
684 use crate::output::RecallItem;
685 let resp = HybridSearchResponse {
686 query: "test".to_string(),
687 k: 5,
688 rrf_k: 60,
689 weights: Weights { vec: 1.0, fts: 1.0 },
690 results: vec![],
691 graph_matches: vec![RecallItem {
692 memory_id: 1,
693 name: "graph-hit".to_string(),
694 namespace: "global".to_string(),
695 memory_type: "document".to_string(),
696 description: "found via graph".to_string(),
697 snippet: "graph content".to_string(),
698 distance: 0.1,
699 score: 0.9,
700 source: "graph".to_string(),
701 graph_depth: Some(1),
702 }],
703 fts_degraded: false,
704 fts_error: None,
705 fts_auto_rebuilt: false,
706 vec_degraded: false,
707 vec_error: None,
708 warning: None,
709 backend_invoked: None,
710 vec_degraded_reason: None,
711 elapsed_ms: 42,
712 };
713 let json = serde_json::to_value(&resp).unwrap();
714 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
715 assert_eq!(json["graph_matches"][0]["source"], "graph");
716 assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
717 }
718
719 #[test]
720 fn fts_degraded_omitted_on_success_present_on_failure() {
721 let ok_resp = empty_response(5, 60, 1.0, 1.0);
723 let ok_json = serde_json::to_string(&ok_resp).unwrap();
724 assert!(
725 !ok_json.contains("\"fts_degraded\""),
726 "fts_degraded must be absent when false"
727 );
728 assert!(
729 !ok_json.contains("\"fts_error\""),
730 "fts_error must be absent when None"
731 );
732
733 let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
735 degraded_resp.fts_degraded = true;
736 degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
737 let degraded_json = serde_json::to_string(°raded_resp).unwrap();
738 assert!(
739 degraded_json.contains("\"fts_degraded\":true"),
740 "fts_degraded must be present and true when degraded"
741 );
742 assert!(
743 degraded_json.contains("\"fts_error\""),
744 "fts_error must be present when Some"
745 );
746 assert!(
747 degraded_json.contains("FTS5 table corrupted"),
748 "fts_error must contain the error message"
749 );
750 }
751}