1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Basic hybrid search combining FTS5 + vector via RRF\n \
23 sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n \
24 # Tune RRF weights to favor keyword matches over semantic similarity\n \
25 sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n \
26 # Add graph traversal matches (entities connected to top results)\n \
27 sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n \
28 # Graph traversal with custom depth and minimum edge weight\n \
29 sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n \
30NOTES:\n \
31 --with-graph enables entity graph traversal seeded by the top RRF results.\n \
32 Graph matches appear in the `graph_matches` array (separate from `results`).\n \
33 Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35 #[arg(
36 allow_hyphen_values = true,
37 help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)"
38 )]
39 pub query: String,
40 #[arg(short = 'k', long, aliases = ["limit", "top-k"], default_value = "10", value_parser = crate::parsers::parse_k_range)]
45 pub k: usize,
46 #[arg(long, default_value = "60")]
47 pub rrf_k: u32,
48 #[arg(long, default_value = "1.0")]
49 pub weight_vec: f32,
50 #[arg(long, default_value = "1.0")]
51 pub weight_fts: f32,
52 #[arg(long, value_enum)]
56 pub r#type: Option<MemoryType>,
57 #[arg(long)]
58 pub namespace: Option<String>,
59 #[arg(long)]
60 pub with_graph: bool,
61 #[arg(long, default_value = "2")]
62 pub max_hops: u32,
63 #[arg(long, default_value = "0.3")]
64 pub min_weight: f64,
65 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
66 pub format: JsonOutputFormat,
67 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
68 pub db: Option<String>,
69 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
71 pub json: bool,
72}
73
74#[derive(serde::Serialize)]
75pub struct HybridSearchItem {
76 pub memory_id: i64,
77 pub name: String,
78 pub namespace: String,
79 #[serde(rename = "type")]
80 pub memory_type: String,
81 pub description: String,
82 pub body: String,
83 pub snippet: String,
84 pub combined_score: f64,
85 pub score: f64,
87 pub source: String,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub vec_rank: Option<usize>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub fts_rank: Option<usize>,
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub rrf_score: Option<f64>,
96 pub normalized_score: f64,
98 #[serde(skip_serializing_if = "Option::is_none")]
103 pub vec_distance: Option<f64>,
104 #[serde(skip_serializing_if = "Option::is_none")]
107 pub fts_bm25: Option<f64>,
108}
109
110#[derive(serde::Serialize)]
112pub struct Weights {
113 pub vec: f32,
114 pub fts: f32,
115}
116
117#[derive(serde::Serialize)]
118pub struct HybridSearchResponse {
119 pub query: String,
120 pub k: usize,
121 pub rrf_k: u32,
123 pub weights: Weights,
125 pub results: Vec<HybridSearchItem>,
126 pub graph_matches: Vec<RecallItem>,
127 #[serde(skip_serializing_if = "std::ops::Not::not")]
131 pub fts_degraded: bool,
132 #[serde(skip_serializing_if = "Option::is_none")]
136 pub fts_error: Option<String>,
137 #[serde(skip_serializing_if = "std::ops::Not::not")]
141 pub fts_auto_rebuilt: bool,
142 pub elapsed_ms: u64,
144}
145
146#[tracing::instrument(skip_all, level = "debug", name = "hybrid_search")]
147pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
148 let start = std::time::Instant::now();
149 let _ = args.format;
150 tracing::debug!(target: "hybrid_search", query = %args.query, k = args.k, "fusing results");
151
152 if !args.with_graph {
154 if args.max_hops != 2 {
155 return Err(AppError::Validation(
156 "--max-hops requires --with-graph to be active".to_string(),
157 ));
158 }
159 if (args.min_weight - 0.3).abs() > f64::EPSILON {
160 return Err(AppError::Validation(
161 "--min-weight requires --with-graph to be active".to_string(),
162 ));
163 }
164 }
165
166 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
167 let paths = AppPaths::resolve(args.db.as_deref())?;
168 crate::storage::connection::ensure_db_ready(&paths)?;
169
170 output::emit_progress_i18n(
171 "Computing query embedding...",
172 "Calculando embedding da consulta...",
173 );
174 let embedding = crate::embedder::embed_query_local(&paths.models, &args.query)?;
175
176 let conn = open_ro(&paths.db)?;
177
178 let memory_type_str = args.r#type.map(|t| t.as_str());
179
180 let vec_results = memories::knn_search(
181 &conn,
182 &embedding,
183 &[namespace.clone()],
184 memory_type_str,
185 args.k * 2,
186 )?;
187
188 let vec_rank_map: HashMap<i64, usize> = vec_results
190 .iter()
191 .enumerate()
192 .map(|(pos, (id, _))| (*id, pos + 1))
193 .collect();
194
195 let vec_distance_map: HashMap<i64, f64> = vec_results
197 .iter()
198 .map(|(id, dist)| (*id, *dist as f64))
199 .collect();
200
201 let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
202 (vec![], false, None, false)
203 } else {
204 match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
205 Ok(r) => (r, false, None, false),
206 Err(e) => {
207 let err_msg = e.to_string();
208 let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
209 if is_malformed {
210 tracing::warn!(target: "hybrid_search", "FTS5 index corrupted, attempting auto-rebuild");
211 if conn
212 .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
213 .is_ok()
214 {
215 match memories::fts_search(
216 &conn,
217 &args.query,
218 &namespace,
219 memory_type_str,
220 args.k * 2,
221 ) {
222 Ok(r) => (r, false, None, true),
223 Err(e2) => {
224 tracing::error!(target: "hybrid_search", error = %e2, "FTS5 auto-rebuild failed to recover");
225 (vec![], true, Some(e2.to_string()), true)
226 }
227 }
228 } else {
229 (vec![], true, Some(err_msg), false)
230 }
231 } else {
232 tracing::warn!(target: "hybrid_search", error = %e, "FTS5 query failed, falling back to vec-only");
233 (vec![], true, Some(err_msg), false)
234 }
235 }
236 }
237 };
238
239 let fts_rank_map: HashMap<i64, usize> = fts_results
241 .iter()
242 .enumerate()
243 .map(|(pos, row)| (row.id, pos + 1))
244 .collect();
245
246 let rrf_k = args.rrf_k as f64;
247
248 let mut combined_scores: crate::hash::AHashMap<i64, f64> =
250 crate::hash::AHashMap::with_capacity_and_hasher(
251 vec_results.len() + fts_results.len(),
252 Default::default(),
253 );
254
255 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
256 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
257 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
258 }
259
260 for (rank, row) in fts_results.iter().enumerate() {
261 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
262 *combined_scores.entry(row.id).or_insert(0.0) += score;
263 }
264
265 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
267 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
268 ranked.truncate(args.k);
269
270 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
272
273 let mut memory_data: crate::hash::AHashMap<i64, memories::MemoryRow> =
275 crate::hash::AHashMap::with_capacity_and_hasher(ranked.len(), Default::default());
276 for id in &top_ids {
277 if let Some(row) = memories::read_full(&conn, *id)? {
278 memory_data.insert(*id, row);
279 }
280 }
281
282 let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
283 + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
284
285 let results: Vec<HybridSearchItem> = ranked
287 .into_iter()
288 .filter_map(|(memory_id, combined_score)| {
289 let normalized_score = if max_possible > 0.0 {
290 combined_score / max_possible
291 } else {
292 0.0
293 };
294 memory_data.remove(&memory_id).map(|row| {
295 let snippet: String = row.body.chars().take(300).collect();
296 HybridSearchItem {
297 memory_id: row.id,
298 name: row.name,
299 namespace: row.namespace,
300 memory_type: row.memory_type,
301 description: row.description,
302 body: row.body,
303 snippet,
304 combined_score,
305 score: combined_score,
306 source: "hybrid".to_string(),
307 vec_rank: vec_rank_map.get(&memory_id).copied(),
308 fts_rank: fts_rank_map.get(&memory_id).copied(),
309 rrf_score: Some(combined_score),
310 normalized_score,
311 vec_distance: vec_distance_map.get(&memory_id).copied(),
312 fts_bm25: None,
313 }
314 })
315 })
316 .collect();
317
318 let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
320 if args.with_graph && !results.is_empty() {
321 let namespace_for_graph = namespace.clone();
322 let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
323
324 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
325 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
326
327 let all_seed_ids: Vec<i64> = memory_ids
328 .iter()
329 .chain(entity_ids.iter())
330 .copied()
331 .collect();
332
333 if !all_seed_ids.is_empty() {
334 let graph_memory_ids = traverse_from_memories_with_hops(
335 &conn,
336 &all_seed_ids,
337 &namespace_for_graph,
338 args.min_weight,
339 args.max_hops,
340 )?;
341
342 let already_in_results: std::collections::HashSet<i64> =
343 results.iter().map(|r| r.memory_id).collect();
344
345 for (graph_mem_id, hop) in graph_memory_ids {
346 if already_in_results.contains(&graph_mem_id) {
347 continue;
348 }
349 if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
350 let snippet: String = row.body.chars().take(300).collect();
351 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
352 graph_matches.push(RecallItem {
353 memory_id: row.id,
354 name: row.name,
355 namespace: row.namespace,
356 memory_type: row.memory_type,
357 description: row.description,
358 snippet,
359 distance: graph_distance,
360 score: RecallItem::score_from_distance(graph_distance),
361 source: "graph".to_string(),
362 graph_depth: Some(hop),
363 });
364 }
365 }
366 }
367 }
368
369 output::emit_json(&HybridSearchResponse {
370 query: args.query,
371 k: args.k,
372 rrf_k: args.rrf_k,
373 weights: Weights {
374 vec: args.weight_vec,
375 fts: args.weight_fts,
376 },
377 results,
378 graph_matches,
379 fts_degraded,
380 fts_error,
381 fts_auto_rebuilt,
382 elapsed_ms: start.elapsed().as_millis() as u64,
383 })?;
384
385 Ok(())
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391
392 fn empty_response(
393 k: usize,
394 rrf_k: u32,
395 weight_vec: f32,
396 weight_fts: f32,
397 ) -> HybridSearchResponse {
398 HybridSearchResponse {
399 query: "test query".to_string(),
400 k,
401 rrf_k,
402 weights: Weights {
403 vec: weight_vec,
404 fts: weight_fts,
405 },
406 results: vec![],
407 graph_matches: vec![],
408 fts_degraded: false,
409 fts_error: None,
410 fts_auto_rebuilt: false,
411 elapsed_ms: 0,
412 }
413 }
414
415 #[test]
416 fn hybrid_search_response_empty_serializes_correct_fields() {
417 let resp = empty_response(10, 60, 1.0, 1.0);
418 let json = serde_json::to_string(&resp).unwrap();
419 assert!(json.contains("\"results\""), "must contain results field");
420 assert!(json.contains("\"query\""), "must contain query field");
421 assert!(json.contains("\"k\""), "must contain k field");
422 assert!(
423 json.contains("\"graph_matches\""),
424 "must contain graph_matches field"
425 );
426 assert!(
427 !json.contains("\"combined_rank\""),
428 "must not contain combined_rank"
429 );
430 assert!(
431 !json.contains("\"vec_rank_list\""),
432 "must not contain vec_rank_list"
433 );
434 assert!(
435 !json.contains("\"fts_rank_list\""),
436 "must not contain fts_rank_list"
437 );
438 }
439
440 #[test]
441 fn hybrid_search_response_serializes_rrf_k_and_weights() {
442 let resp = empty_response(5, 60, 0.7, 0.3);
443 let json = serde_json::to_string(&resp).unwrap();
444 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
445 assert!(json.contains("\"weights\""), "must contain weights field");
446 assert!(json.contains("\"vec\""), "must contain weights.vec field");
447 assert!(json.contains("\"fts\""), "must contain weights.fts field");
448 }
449
450 #[test]
451 fn hybrid_search_response_serializes_elapsed_ms() {
452 let mut resp = empty_response(5, 60, 1.0, 1.0);
453 resp.elapsed_ms = 123;
454 let json = serde_json::to_string(&resp).unwrap();
455 assert!(
456 json.contains("\"elapsed_ms\""),
457 "must contain elapsed_ms field"
458 );
459 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
460 }
461
462 #[test]
463 fn weights_struct_serializes_correctly() {
464 let w = Weights { vec: 0.6, fts: 0.4 };
465 let json = serde_json::to_string(&w).unwrap();
466 assert!(json.contains("\"vec\""));
467 assert!(json.contains("\"fts\""));
468 }
469
470 #[test]
471 fn hybrid_search_item_omits_fts_rank_when_none() {
472 let item = HybridSearchItem {
473 memory_id: 1,
474 name: "mem".to_string(),
475 namespace: "default".to_string(),
476 memory_type: "user".to_string(),
477 description: "desc".to_string(),
478 body: "content".to_string(),
479 snippet: "content".to_string(),
480 combined_score: 0.0328,
481 score: 0.0328,
482 source: "hybrid".to_string(),
483 vec_rank: Some(1),
484 fts_rank: None,
485 rrf_score: Some(0.0328),
486 normalized_score: 1.0,
487 vec_distance: Some(0.12),
488 fts_bm25: None,
489 };
490 let json = serde_json::to_string(&item).unwrap();
491 assert!(
492 json.contains("\"vec_rank\""),
493 "must contain vec_rank when Some"
494 );
495 assert!(
496 !json.contains("\"fts_rank\""),
497 "must not contain fts_rank when None"
498 );
499 }
500
501 #[test]
502 fn hybrid_search_item_omits_vec_rank_when_none() {
503 let item = HybridSearchItem {
504 memory_id: 2,
505 name: "mem2".to_string(),
506 namespace: "default".to_string(),
507 memory_type: "fact".to_string(),
508 description: "desc2".to_string(),
509 body: "corpo2".to_string(),
510 snippet: "corpo2".to_string(),
511 combined_score: 0.016,
512 score: 0.016,
513 source: "hybrid".to_string(),
514 vec_rank: None,
515 fts_rank: Some(2),
516 rrf_score: Some(0.016),
517 normalized_score: 0.5,
518 vec_distance: None,
519 fts_bm25: None,
520 };
521 let json = serde_json::to_string(&item).unwrap();
522 assert!(
523 !json.contains("\"vec_rank\""),
524 "must not contain vec_rank when None"
525 );
526 assert!(
527 json.contains("\"fts_rank\""),
528 "must contain fts_rank when Some"
529 );
530 }
531
532 #[test]
533 fn hybrid_search_item_serializes_both_ranks_when_some() {
534 let item = HybridSearchItem {
535 memory_id: 3,
536 name: "mem3".to_string(),
537 namespace: "ns".to_string(),
538 memory_type: "entity".to_string(),
539 description: "desc3".to_string(),
540 body: "corpo3".to_string(),
541 snippet: "corpo3".to_string(),
542 combined_score: 0.05,
543 score: 0.05,
544 source: "hybrid".to_string(),
545 vec_rank: Some(3),
546 fts_rank: Some(1),
547 rrf_score: Some(0.05),
548 normalized_score: 0.8,
549 vec_distance: Some(0.25),
550 fts_bm25: None,
551 };
552 let json = serde_json::to_string(&item).unwrap();
553 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
554 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
555 assert!(json.contains("\"type\""), "deve serializar type renomeado");
556 assert!(!json.contains("memory_type"), "must not expose memory_type");
557 }
558
559 #[test]
560 fn hybrid_search_response_serializes_k_correctly() {
561 let resp = empty_response(5, 60, 1.0, 1.0);
562 let json = serde_json::to_string(&resp).unwrap();
563 assert!(json.contains("\"k\":5"), "deve serializar k=5");
564 }
565
566 #[test]
567 fn hybrid_search_response_with_graph_matches() {
568 use crate::output::RecallItem;
569 let resp = HybridSearchResponse {
570 query: "test".to_string(),
571 k: 5,
572 rrf_k: 60,
573 weights: Weights { vec: 1.0, fts: 1.0 },
574 results: vec![],
575 graph_matches: vec![RecallItem {
576 memory_id: 1,
577 name: "graph-hit".to_string(),
578 namespace: "global".to_string(),
579 memory_type: "document".to_string(),
580 description: "found via graph".to_string(),
581 snippet: "graph content".to_string(),
582 distance: 0.1,
583 score: 0.9,
584 source: "graph".to_string(),
585 graph_depth: Some(1),
586 }],
587 fts_degraded: false,
588 fts_error: None,
589 fts_auto_rebuilt: false,
590 elapsed_ms: 42,
591 };
592 let json = serde_json::to_value(&resp).unwrap();
593 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
594 assert_eq!(json["graph_matches"][0]["source"], "graph");
595 assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
596 }
597
598 #[test]
599 fn fts_degraded_omitted_on_success_present_on_failure() {
600 let ok_resp = empty_response(5, 60, 1.0, 1.0);
602 let ok_json = serde_json::to_string(&ok_resp).unwrap();
603 assert!(
604 !ok_json.contains("\"fts_degraded\""),
605 "fts_degraded must be absent when false"
606 );
607 assert!(
608 !ok_json.contains("\"fts_error\""),
609 "fts_error must be absent when None"
610 );
611
612 let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
614 degraded_resp.fts_degraded = true;
615 degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
616 let degraded_json = serde_json::to_string(°raded_resp).unwrap();
617 assert!(
618 degraded_json.contains("\"fts_degraded\":true"),
619 "fts_degraded must be present and true when degraded"
620 );
621 assert!(
622 degraded_json.contains("\"fts_error\""),
623 "fts_error must be present when Some"
624 );
625 assert!(
626 degraded_json.contains("FTS5 table corrupted"),
627 "fts_error must contain the error message"
628 );
629 }
630}