1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Basic hybrid search combining FTS5 + vector via RRF\n \
23 sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n \
24 # Tune RRF weights to favor keyword matches over semantic similarity\n \
25 sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n \
26 # Add graph traversal matches (entities connected to top results)\n \
27 sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n \
28 # Graph traversal with custom depth and minimum edge weight\n \
29 sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n \
30NOTES:\n \
31 --with-graph enables entity graph traversal seeded by the top RRF results.\n \
32 Graph matches appear in the `graph_matches` array (separate from `results`).\n \
33 Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35 #[arg(
36 allow_hyphen_values = true,
37 help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)"
38 )]
39 pub query: String,
40 #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
45 pub k: usize,
46 #[arg(long, default_value = "60")]
47 pub rrf_k: u32,
48 #[arg(long, default_value = "1.0")]
49 pub weight_vec: f32,
50 #[arg(long, default_value = "1.0")]
51 pub weight_fts: f32,
52 #[arg(long, value_enum)]
56 pub r#type: Option<MemoryType>,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long)]
60 pub with_graph: bool,
61 #[arg(long, default_value = "2")]
62 pub max_hops: u32,
63 #[arg(long, default_value = "0.3")]
64 pub min_weight: f64,
65 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66 pub format: JsonOutputFormat,
67 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68 pub db: Option<String>,
69 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71 pub json: bool,
72 #[command(flatten)]
73 pub daemon: crate::cli::DaemonOpts,
74}
75
76#[derive(serde::Serialize)]
77pub struct HybridSearchItem {
78 pub memory_id: i64,
79 pub name: String,
80 pub namespace: String,
81 #[serde(rename = "type")]
82 pub memory_type: String,
83 pub description: String,
84 pub body: String,
85 pub snippet: String,
86 pub combined_score: f64,
87 pub score: f64,
89 pub source: String,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub vec_rank: Option<usize>,
93 #[serde(skip_serializing_if = "Option::is_none")]
94 pub fts_rank: Option<usize>,
95 #[serde(skip_serializing_if = "Option::is_none")]
97 pub rrf_score: Option<f64>,
98 pub normalized_score: f64,
100 #[serde(skip_serializing_if = "Option::is_none")]
105 pub vec_distance: Option<f64>,
106 #[serde(skip_serializing_if = "Option::is_none")]
109 pub fts_bm25: Option<f64>,
110}
111
112#[derive(serde::Serialize)]
114pub struct Weights {
115 pub vec: f32,
116 pub fts: f32,
117}
118
119#[derive(serde::Serialize)]
120pub struct HybridSearchResponse {
121 pub query: String,
122 pub k: usize,
123 pub rrf_k: u32,
125 pub weights: Weights,
127 pub results: Vec<HybridSearchItem>,
128 pub graph_matches: Vec<RecallItem>,
129 #[serde(skip_serializing_if = "std::ops::Not::not")]
133 pub fts_degraded: bool,
134 #[serde(skip_serializing_if = "Option::is_none")]
138 pub fts_error: Option<String>,
139 #[serde(skip_serializing_if = "std::ops::Not::not")]
143 pub fts_auto_rebuilt: bool,
144 pub elapsed_ms: u64,
146}
147
148#[tracing::instrument(skip_all, level = "debug", name = "hybrid_search")]
149pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
150 let start = std::time::Instant::now();
151 let _ = args.format;
152 tracing::debug!(target: "hybrid_search", query = %args.query, k = args.k, "fusing results");
153
154 if !args.with_graph {
156 if args.max_hops != 2 {
157 return Err(AppError::Validation(
158 "--max-hops requires --with-graph to be active".to_string(),
159 ));
160 }
161 if (args.min_weight - 0.3).abs() > f64::EPSILON {
162 return Err(AppError::Validation(
163 "--min-weight requires --with-graph to be active".to_string(),
164 ));
165 }
166 }
167
168 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
169 let paths = AppPaths::resolve(args.db.as_deref())?;
170 crate::storage::connection::ensure_db_ready(&paths)?;
171
172 output::emit_progress_i18n(
173 "Computing query embedding...",
174 "Calculando embedding da consulta...",
175 );
176 let embedding = crate::daemon::embed_query_or_local(
177 &paths.models,
178 &args.query,
179 args.daemon.autostart_daemon,
180 )?;
181
182 let conn = open_ro(&paths.db)?;
183
184 let memory_type_str = args.r#type.map(|t| t.as_str());
185
186 let vec_results = memories::knn_search(
187 &conn,
188 &embedding,
189 &[namespace.clone()],
190 memory_type_str,
191 args.k * 2,
192 )?;
193
194 let vec_rank_map: HashMap<i64, usize> = vec_results
196 .iter()
197 .enumerate()
198 .map(|(pos, (id, _))| (*id, pos + 1))
199 .collect();
200
201 let vec_distance_map: HashMap<i64, f64> = vec_results
203 .iter()
204 .map(|(id, dist)| (*id, *dist as f64))
205 .collect();
206
207 let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
208 (vec![], false, None, false)
209 } else {
210 match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
211 Ok(r) => (r, false, None, false),
212 Err(e) => {
213 let err_msg = e.to_string();
214 let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
215 if is_malformed {
216 tracing::warn!(target: "hybrid_search", "FTS5 index corrupted, attempting auto-rebuild");
217 if conn
218 .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
219 .is_ok()
220 {
221 match memories::fts_search(
222 &conn,
223 &args.query,
224 &namespace,
225 memory_type_str,
226 args.k * 2,
227 ) {
228 Ok(r) => (r, false, None, true),
229 Err(e2) => {
230 tracing::error!(target: "hybrid_search", error = %e2, "FTS5 auto-rebuild failed to recover");
231 (vec![], true, Some(e2.to_string()), true)
232 }
233 }
234 } else {
235 (vec![], true, Some(err_msg), false)
236 }
237 } else {
238 tracing::warn!(target: "hybrid_search", error = %e, "FTS5 query failed, falling back to vec-only");
239 (vec![], true, Some(err_msg), false)
240 }
241 }
242 }
243 };
244
245 let fts_rank_map: HashMap<i64, usize> = fts_results
247 .iter()
248 .enumerate()
249 .map(|(pos, row)| (row.id, pos + 1))
250 .collect();
251
252 let rrf_k = args.rrf_k as f64;
253
254 let mut combined_scores: crate::hash::AHashMap<i64, f64> =
256 crate::hash::AHashMap::with_capacity_and_hasher(
257 vec_results.len() + fts_results.len(),
258 Default::default(),
259 );
260
261 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
262 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
263 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
264 }
265
266 for (rank, row) in fts_results.iter().enumerate() {
267 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
268 *combined_scores.entry(row.id).or_insert(0.0) += score;
269 }
270
271 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
273 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
274 ranked.truncate(args.k);
275
276 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
278
279 let mut memory_data: crate::hash::AHashMap<i64, memories::MemoryRow> =
281 crate::hash::AHashMap::with_capacity_and_hasher(ranked.len(), Default::default());
282 for id in &top_ids {
283 if let Some(row) = memories::read_full(&conn, *id)? {
284 memory_data.insert(*id, row);
285 }
286 }
287
288 let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
289 + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
290
291 let results: Vec<HybridSearchItem> = ranked
293 .into_iter()
294 .filter_map(|(memory_id, combined_score)| {
295 let normalized_score = if max_possible > 0.0 {
296 combined_score / max_possible
297 } else {
298 0.0
299 };
300 memory_data.remove(&memory_id).map(|row| {
301 let snippet: String = row.body.chars().take(300).collect();
302 HybridSearchItem {
303 memory_id: row.id,
304 name: row.name,
305 namespace: row.namespace,
306 memory_type: row.memory_type,
307 description: row.description,
308 body: row.body,
309 snippet,
310 combined_score,
311 score: combined_score,
312 source: "hybrid".to_string(),
313 vec_rank: vec_rank_map.get(&memory_id).copied(),
314 fts_rank: fts_rank_map.get(&memory_id).copied(),
315 rrf_score: Some(combined_score),
316 normalized_score,
317 vec_distance: vec_distance_map.get(&memory_id).copied(),
318 fts_bm25: None,
319 }
320 })
321 })
322 .collect();
323
324 let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
326 if args.with_graph && !results.is_empty() {
327 let namespace_for_graph = namespace.clone();
328 let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
329
330 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
331 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
332
333 let all_seed_ids: Vec<i64> = memory_ids
334 .iter()
335 .chain(entity_ids.iter())
336 .copied()
337 .collect();
338
339 if !all_seed_ids.is_empty() {
340 let graph_memory_ids = traverse_from_memories_with_hops(
341 &conn,
342 &all_seed_ids,
343 &namespace_for_graph,
344 args.min_weight,
345 args.max_hops,
346 )?;
347
348 let already_in_results: std::collections::HashSet<i64> =
349 results.iter().map(|r| r.memory_id).collect();
350
351 for (graph_mem_id, hop) in graph_memory_ids {
352 if already_in_results.contains(&graph_mem_id) {
353 continue;
354 }
355 if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
356 let snippet: String = row.body.chars().take(300).collect();
357 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
358 graph_matches.push(RecallItem {
359 memory_id: row.id,
360 name: row.name,
361 namespace: row.namespace,
362 memory_type: row.memory_type,
363 description: row.description,
364 snippet,
365 distance: graph_distance,
366 score: RecallItem::score_from_distance(graph_distance),
367 source: "graph".to_string(),
368 graph_depth: Some(hop),
369 });
370 }
371 }
372 }
373 }
374
375 output::emit_json(&HybridSearchResponse {
376 query: args.query,
377 k: args.k,
378 rrf_k: args.rrf_k,
379 weights: Weights {
380 vec: args.weight_vec,
381 fts: args.weight_fts,
382 },
383 results,
384 graph_matches,
385 fts_degraded,
386 fts_error,
387 fts_auto_rebuilt,
388 elapsed_ms: start.elapsed().as_millis() as u64,
389 })?;
390
391 Ok(())
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 fn empty_response(
399 k: usize,
400 rrf_k: u32,
401 weight_vec: f32,
402 weight_fts: f32,
403 ) -> HybridSearchResponse {
404 HybridSearchResponse {
405 query: "test query".to_string(),
406 k,
407 rrf_k,
408 weights: Weights {
409 vec: weight_vec,
410 fts: weight_fts,
411 },
412 results: vec![],
413 graph_matches: vec![],
414 fts_degraded: false,
415 fts_error: None,
416 fts_auto_rebuilt: false,
417 elapsed_ms: 0,
418 }
419 }
420
421 #[test]
422 fn hybrid_search_response_empty_serializes_correct_fields() {
423 let resp = empty_response(10, 60, 1.0, 1.0);
424 let json = serde_json::to_string(&resp).unwrap();
425 assert!(json.contains("\"results\""), "must contain results field");
426 assert!(json.contains("\"query\""), "must contain query field");
427 assert!(json.contains("\"k\""), "must contain k field");
428 assert!(
429 json.contains("\"graph_matches\""),
430 "must contain graph_matches field"
431 );
432 assert!(
433 !json.contains("\"combined_rank\""),
434 "must not contain combined_rank"
435 );
436 assert!(
437 !json.contains("\"vec_rank_list\""),
438 "must not contain vec_rank_list"
439 );
440 assert!(
441 !json.contains("\"fts_rank_list\""),
442 "must not contain fts_rank_list"
443 );
444 }
445
446 #[test]
447 fn hybrid_search_response_serializes_rrf_k_and_weights() {
448 let resp = empty_response(5, 60, 0.7, 0.3);
449 let json = serde_json::to_string(&resp).unwrap();
450 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
451 assert!(json.contains("\"weights\""), "must contain weights field");
452 assert!(json.contains("\"vec\""), "must contain weights.vec field");
453 assert!(json.contains("\"fts\""), "must contain weights.fts field");
454 }
455
456 #[test]
457 fn hybrid_search_response_serializes_elapsed_ms() {
458 let mut resp = empty_response(5, 60, 1.0, 1.0);
459 resp.elapsed_ms = 123;
460 let json = serde_json::to_string(&resp).unwrap();
461 assert!(
462 json.contains("\"elapsed_ms\""),
463 "must contain elapsed_ms field"
464 );
465 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
466 }
467
468 #[test]
469 fn weights_struct_serializes_correctly() {
470 let w = Weights { vec: 0.6, fts: 0.4 };
471 let json = serde_json::to_string(&w).unwrap();
472 assert!(json.contains("\"vec\""));
473 assert!(json.contains("\"fts\""));
474 }
475
476 #[test]
477 fn hybrid_search_item_omits_fts_rank_when_none() {
478 let item = HybridSearchItem {
479 memory_id: 1,
480 name: "mem".to_string(),
481 namespace: "default".to_string(),
482 memory_type: "user".to_string(),
483 description: "desc".to_string(),
484 body: "content".to_string(),
485 snippet: "content".to_string(),
486 combined_score: 0.0328,
487 score: 0.0328,
488 source: "hybrid".to_string(),
489 vec_rank: Some(1),
490 fts_rank: None,
491 rrf_score: Some(0.0328),
492 normalized_score: 1.0,
493 vec_distance: Some(0.12),
494 fts_bm25: None,
495 };
496 let json = serde_json::to_string(&item).unwrap();
497 assert!(
498 json.contains("\"vec_rank\""),
499 "must contain vec_rank when Some"
500 );
501 assert!(
502 !json.contains("\"fts_rank\""),
503 "must not contain fts_rank when None"
504 );
505 }
506
507 #[test]
508 fn hybrid_search_item_omits_vec_rank_when_none() {
509 let item = HybridSearchItem {
510 memory_id: 2,
511 name: "mem2".to_string(),
512 namespace: "default".to_string(),
513 memory_type: "fact".to_string(),
514 description: "desc2".to_string(),
515 body: "corpo2".to_string(),
516 snippet: "corpo2".to_string(),
517 combined_score: 0.016,
518 score: 0.016,
519 source: "hybrid".to_string(),
520 vec_rank: None,
521 fts_rank: Some(2),
522 rrf_score: Some(0.016),
523 normalized_score: 0.5,
524 vec_distance: None,
525 fts_bm25: None,
526 };
527 let json = serde_json::to_string(&item).unwrap();
528 assert!(
529 !json.contains("\"vec_rank\""),
530 "must not contain vec_rank when None"
531 );
532 assert!(
533 json.contains("\"fts_rank\""),
534 "must contain fts_rank when Some"
535 );
536 }
537
538 #[test]
539 fn hybrid_search_item_serializes_both_ranks_when_some() {
540 let item = HybridSearchItem {
541 memory_id: 3,
542 name: "mem3".to_string(),
543 namespace: "ns".to_string(),
544 memory_type: "entity".to_string(),
545 description: "desc3".to_string(),
546 body: "corpo3".to_string(),
547 snippet: "corpo3".to_string(),
548 combined_score: 0.05,
549 score: 0.05,
550 source: "hybrid".to_string(),
551 vec_rank: Some(3),
552 fts_rank: Some(1),
553 rrf_score: Some(0.05),
554 normalized_score: 0.8,
555 vec_distance: Some(0.25),
556 fts_bm25: None,
557 };
558 let json = serde_json::to_string(&item).unwrap();
559 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
560 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
561 assert!(json.contains("\"type\""), "deve serializar type renomeado");
562 assert!(!json.contains("memory_type"), "must not expose memory_type");
563 }
564
565 #[test]
566 fn hybrid_search_response_serializes_k_correctly() {
567 let resp = empty_response(5, 60, 1.0, 1.0);
568 let json = serde_json::to_string(&resp).unwrap();
569 assert!(json.contains("\"k\":5"), "deve serializar k=5");
570 }
571
572 #[test]
573 fn hybrid_search_response_with_graph_matches() {
574 use crate::output::RecallItem;
575 let resp = HybridSearchResponse {
576 query: "test".to_string(),
577 k: 5,
578 rrf_k: 60,
579 weights: Weights { vec: 1.0, fts: 1.0 },
580 results: vec![],
581 graph_matches: vec![RecallItem {
582 memory_id: 1,
583 name: "graph-hit".to_string(),
584 namespace: "global".to_string(),
585 memory_type: "document".to_string(),
586 description: "found via graph".to_string(),
587 snippet: "graph content".to_string(),
588 distance: 0.1,
589 score: 0.9,
590 source: "graph".to_string(),
591 graph_depth: Some(1),
592 }],
593 fts_degraded: false,
594 fts_error: None,
595 fts_auto_rebuilt: false,
596 elapsed_ms: 42,
597 };
598 let json = serde_json::to_value(&resp).unwrap();
599 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
600 assert_eq!(json["graph_matches"][0]["source"], "graph");
601 assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
602 }
603
604 #[test]
605 fn fts_degraded_omitted_on_success_present_on_failure() {
606 let ok_resp = empty_response(5, 60, 1.0, 1.0);
608 let ok_json = serde_json::to_string(&ok_resp).unwrap();
609 assert!(
610 !ok_json.contains("\"fts_degraded\""),
611 "fts_degraded must be absent when false"
612 );
613 assert!(
614 !ok_json.contains("\"fts_error\""),
615 "fts_error must be absent when None"
616 );
617
618 let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
620 degraded_resp.fts_degraded = true;
621 degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
622 let degraded_json = serde_json::to_string(°raded_resp).unwrap();
623 assert!(
624 degraded_json.contains("\"fts_degraded\":true"),
625 "fts_degraded must be present and true when degraded"
626 );
627 assert!(
628 degraded_json.contains("\"fts_error\""),
629 "fts_error must be present when Some"
630 );
631 assert!(
632 degraded_json.contains("FTS5 table corrupted"),
633 "fts_error must contain the error message"
634 );
635 }
636}