1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Basic hybrid search combining FTS5 + vector via RRF\n \
23 sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n \
24 # Tune RRF weights to favor keyword matches over semantic similarity\n \
25 sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n \
26 # Add graph traversal matches (entities connected to top results)\n \
27 sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n \
28 # Graph traversal with custom depth and minimum edge weight\n \
29 sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n \
30NOTES:\n \
31 --with-graph enables entity graph traversal seeded by the top RRF results.\n \
32 Graph matches appear in the `graph_matches` array (separate from `results`).\n \
33 Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35 #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
36 pub query: String,
37 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
42 pub k: usize,
43 #[arg(long, default_value = "60")]
44 pub rrf_k: u32,
45 #[arg(long, default_value = "1.0")]
46 pub weight_vec: f32,
47 #[arg(long, default_value = "1.0")]
48 pub weight_fts: f32,
49 #[arg(long, value_enum)]
53 pub r#type: Option<MemoryType>,
54 #[arg(long)]
55 pub namespace: Option<String>,
56 #[arg(long)]
57 pub with_graph: bool,
58 #[arg(long, default_value = "2")]
59 pub max_hops: u32,
60 #[arg(long, default_value = "0.3")]
61 pub min_weight: f64,
62 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
63 pub format: JsonOutputFormat,
64 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
65 pub db: Option<String>,
66 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
68 pub json: bool,
69 #[command(flatten)]
70 pub daemon: crate::cli::DaemonOpts,
71}
72
73#[derive(serde::Serialize)]
74pub struct HybridSearchItem {
75 pub memory_id: i64,
76 pub name: String,
77 pub namespace: String,
78 #[serde(rename = "type")]
79 pub memory_type: String,
80 pub description: String,
81 pub body: String,
82 pub combined_score: f64,
83 pub score: f64,
85 pub source: String,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub vec_rank: Option<usize>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub fts_rank: Option<usize>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub rrf_score: Option<f64>,
94 pub normalized_score: f64,
96 #[serde(skip_serializing_if = "Option::is_none")]
101 pub vec_distance: Option<f64>,
102 #[serde(skip_serializing_if = "Option::is_none")]
105 pub fts_bm25: Option<f64>,
106}
107
108#[derive(serde::Serialize)]
110pub struct Weights {
111 pub vec: f32,
112 pub fts: f32,
113}
114
115#[derive(serde::Serialize)]
116pub struct HybridSearchResponse {
117 pub query: String,
118 pub k: usize,
119 pub rrf_k: u32,
121 pub weights: Weights,
123 pub results: Vec<HybridSearchItem>,
124 pub graph_matches: Vec<RecallItem>,
125 #[serde(skip_serializing_if = "std::ops::Not::not")]
129 pub fts_degraded: bool,
130 #[serde(skip_serializing_if = "Option::is_none")]
134 pub fts_error: Option<String>,
135 #[serde(skip_serializing_if = "std::ops::Not::not")]
139 pub fts_auto_rebuilt: bool,
140 pub elapsed_ms: u64,
142}
143
144pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
145 let start = std::time::Instant::now();
146 let _ = args.format;
147
148 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
149 let paths = AppPaths::resolve(args.db.as_deref())?;
150 crate::storage::connection::ensure_db_ready(&paths)?;
151
152 output::emit_progress_i18n(
153 "Computing query embedding...",
154 "Calculando embedding da consulta...",
155 );
156 let embedding = crate::daemon::embed_query_or_local(
157 &paths.models,
158 &args.query,
159 args.daemon.autostart_daemon,
160 )?;
161
162 let conn = open_ro(&paths.db)?;
163
164 let memory_type_str = args.r#type.map(|t| t.as_str());
165
166 let vec_results = memories::knn_search(
167 &conn,
168 &embedding,
169 &[namespace.clone()],
170 memory_type_str,
171 args.k * 2,
172 )?;
173
174 let vec_rank_map: HashMap<i64, usize> = vec_results
176 .iter()
177 .enumerate()
178 .map(|(pos, (id, _))| (*id, pos + 1))
179 .collect();
180
181 let vec_distance_map: HashMap<i64, f64> = vec_results
183 .iter()
184 .map(|(id, dist)| (*id, *dist as f64))
185 .collect();
186
187 let (fts_results, fts_degraded, fts_error, fts_auto_rebuilt) = if args.weight_fts == 0.0 {
188 (vec![], false, None, false)
189 } else {
190 match memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2) {
191 Ok(r) => (r, false, None, false),
192 Err(e) => {
193 let err_msg = e.to_string();
194 let is_malformed = err_msg.contains("malformed") || err_msg.contains("corrupt");
195 if is_malformed {
196 tracing::warn!("FTS5 index corrupted, attempting auto-rebuild");
197 if conn
198 .execute_batch("INSERT INTO fts_memories(fts_memories) VALUES('rebuild');")
199 .is_ok()
200 {
201 match memories::fts_search(
202 &conn,
203 &args.query,
204 &namespace,
205 memory_type_str,
206 args.k * 2,
207 ) {
208 Ok(r) => (r, false, None, true),
209 Err(e2) => {
210 tracing::error!("FTS5 auto-rebuild failed to recover: {e2}");
211 (vec![], true, Some(e2.to_string()), true)
212 }
213 }
214 } else {
215 (vec![], true, Some(err_msg), false)
216 }
217 } else {
218 tracing::warn!("FTS5 query failed, falling back to vec-only: {e}");
219 (vec![], true, Some(err_msg), false)
220 }
221 }
222 }
223 };
224
225 let fts_rank_map: HashMap<i64, usize> = fts_results
227 .iter()
228 .enumerate()
229 .map(|(pos, row)| (row.id, pos + 1))
230 .collect();
231
232 let rrf_k = args.rrf_k as f64;
233
234 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
236
237 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
238 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
239 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
240 }
241
242 for (rank, row) in fts_results.iter().enumerate() {
243 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
244 *combined_scores.entry(row.id).or_insert(0.0) += score;
245 }
246
247 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
249 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
250 ranked.truncate(args.k);
251
252 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
254
255 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
257 for id in &top_ids {
258 if let Some(row) = memories::read_full(&conn, *id)? {
259 memory_data.insert(*id, row);
260 }
261 }
262
263 let max_possible = args.weight_vec as f64 * (1.0 / (rrf_k + 1.0))
264 + args.weight_fts as f64 * (1.0 / (rrf_k + 1.0));
265
266 let results: Vec<HybridSearchItem> = ranked
268 .into_iter()
269 .filter_map(|(memory_id, combined_score)| {
270 let normalized_score = if max_possible > 0.0 {
271 combined_score / max_possible
272 } else {
273 0.0
274 };
275 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
276 memory_id: row.id,
277 name: row.name,
278 namespace: row.namespace,
279 memory_type: row.memory_type,
280 description: row.description,
281 body: row.body,
282 combined_score,
283 score: combined_score,
284 source: "hybrid".to_string(),
285 vec_rank: vec_rank_map.get(&memory_id).copied(),
286 fts_rank: fts_rank_map.get(&memory_id).copied(),
287 rrf_score: Some(combined_score),
288 normalized_score,
289 vec_distance: vec_distance_map.get(&memory_id).copied(),
290 fts_bm25: None,
291 })
292 })
293 .collect();
294
295 let mut graph_matches: Vec<RecallItem> = Vec::with_capacity(8);
297 if args.with_graph && !results.is_empty() {
298 let namespace_for_graph = namespace.clone();
299 let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
300
301 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
302 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
303
304 let all_seed_ids: Vec<i64> = memory_ids
305 .iter()
306 .chain(entity_ids.iter())
307 .copied()
308 .collect();
309
310 if !all_seed_ids.is_empty() {
311 let graph_memory_ids = traverse_from_memories_with_hops(
312 &conn,
313 &all_seed_ids,
314 &namespace_for_graph,
315 args.min_weight,
316 args.max_hops,
317 )?;
318
319 let already_in_results: std::collections::HashSet<i64> =
320 results.iter().map(|r| r.memory_id).collect();
321
322 for (graph_mem_id, hop) in graph_memory_ids {
323 if already_in_results.contains(&graph_mem_id) {
324 continue;
325 }
326 if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
327 let snippet: String = row.body.chars().take(300).collect();
328 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
329 graph_matches.push(RecallItem {
330 memory_id: row.id,
331 name: row.name,
332 namespace: row.namespace,
333 memory_type: row.memory_type,
334 description: row.description,
335 snippet,
336 distance: graph_distance,
337 score: RecallItem::score_from_distance(graph_distance),
338 source: "graph".to_string(),
339 graph_depth: Some(hop),
340 });
341 }
342 }
343 }
344 }
345
346 output::emit_json(&HybridSearchResponse {
347 query: args.query,
348 k: args.k,
349 rrf_k: args.rrf_k,
350 weights: Weights {
351 vec: args.weight_vec,
352 fts: args.weight_fts,
353 },
354 results,
355 graph_matches,
356 fts_degraded,
357 fts_error,
358 fts_auto_rebuilt,
359 elapsed_ms: start.elapsed().as_millis() as u64,
360 })?;
361
362 Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 fn empty_response(
370 k: usize,
371 rrf_k: u32,
372 weight_vec: f32,
373 weight_fts: f32,
374 ) -> HybridSearchResponse {
375 HybridSearchResponse {
376 query: "test query".to_string(),
377 k,
378 rrf_k,
379 weights: Weights {
380 vec: weight_vec,
381 fts: weight_fts,
382 },
383 results: vec![],
384 graph_matches: vec![],
385 fts_degraded: false,
386 fts_error: None,
387 fts_auto_rebuilt: false,
388 elapsed_ms: 0,
389 }
390 }
391
392 #[test]
393 fn hybrid_search_response_empty_serializes_correct_fields() {
394 let resp = empty_response(10, 60, 1.0, 1.0);
395 let json = serde_json::to_string(&resp).unwrap();
396 assert!(json.contains("\"results\""), "must contain results field");
397 assert!(json.contains("\"query\""), "must contain query field");
398 assert!(json.contains("\"k\""), "must contain k field");
399 assert!(
400 json.contains("\"graph_matches\""),
401 "must contain graph_matches field"
402 );
403 assert!(
404 !json.contains("\"combined_rank\""),
405 "must not contain combined_rank"
406 );
407 assert!(
408 !json.contains("\"vec_rank_list\""),
409 "must not contain vec_rank_list"
410 );
411 assert!(
412 !json.contains("\"fts_rank_list\""),
413 "must not contain fts_rank_list"
414 );
415 }
416
417 #[test]
418 fn hybrid_search_response_serializes_rrf_k_and_weights() {
419 let resp = empty_response(5, 60, 0.7, 0.3);
420 let json = serde_json::to_string(&resp).unwrap();
421 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
422 assert!(json.contains("\"weights\""), "must contain weights field");
423 assert!(json.contains("\"vec\""), "must contain weights.vec field");
424 assert!(json.contains("\"fts\""), "must contain weights.fts field");
425 }
426
427 #[test]
428 fn hybrid_search_response_serializes_elapsed_ms() {
429 let mut resp = empty_response(5, 60, 1.0, 1.0);
430 resp.elapsed_ms = 123;
431 let json = serde_json::to_string(&resp).unwrap();
432 assert!(
433 json.contains("\"elapsed_ms\""),
434 "must contain elapsed_ms field"
435 );
436 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
437 }
438
439 #[test]
440 fn weights_struct_serializes_correctly() {
441 let w = Weights { vec: 0.6, fts: 0.4 };
442 let json = serde_json::to_string(&w).unwrap();
443 assert!(json.contains("\"vec\""));
444 assert!(json.contains("\"fts\""));
445 }
446
447 #[test]
448 fn hybrid_search_item_omits_fts_rank_when_none() {
449 let item = HybridSearchItem {
450 memory_id: 1,
451 name: "mem".to_string(),
452 namespace: "default".to_string(),
453 memory_type: "user".to_string(),
454 description: "desc".to_string(),
455 body: "content".to_string(),
456 combined_score: 0.0328,
457 score: 0.0328,
458 source: "hybrid".to_string(),
459 vec_rank: Some(1),
460 fts_rank: None,
461 rrf_score: Some(0.0328),
462 normalized_score: 1.0,
463 vec_distance: Some(0.12),
464 fts_bm25: None,
465 };
466 let json = serde_json::to_string(&item).unwrap();
467 assert!(
468 json.contains("\"vec_rank\""),
469 "must contain vec_rank when Some"
470 );
471 assert!(
472 !json.contains("\"fts_rank\""),
473 "must not contain fts_rank when None"
474 );
475 }
476
477 #[test]
478 fn hybrid_search_item_omits_vec_rank_when_none() {
479 let item = HybridSearchItem {
480 memory_id: 2,
481 name: "mem2".to_string(),
482 namespace: "default".to_string(),
483 memory_type: "fact".to_string(),
484 description: "desc2".to_string(),
485 body: "corpo2".to_string(),
486 combined_score: 0.016,
487 score: 0.016,
488 source: "hybrid".to_string(),
489 vec_rank: None,
490 fts_rank: Some(2),
491 rrf_score: Some(0.016),
492 normalized_score: 0.5,
493 vec_distance: None,
494 fts_bm25: None,
495 };
496 let json = serde_json::to_string(&item).unwrap();
497 assert!(
498 !json.contains("\"vec_rank\""),
499 "must not contain vec_rank when None"
500 );
501 assert!(
502 json.contains("\"fts_rank\""),
503 "must contain fts_rank when Some"
504 );
505 }
506
507 #[test]
508 fn hybrid_search_item_serializes_both_ranks_when_some() {
509 let item = HybridSearchItem {
510 memory_id: 3,
511 name: "mem3".to_string(),
512 namespace: "ns".to_string(),
513 memory_type: "entity".to_string(),
514 description: "desc3".to_string(),
515 body: "corpo3".to_string(),
516 combined_score: 0.05,
517 score: 0.05,
518 source: "hybrid".to_string(),
519 vec_rank: Some(3),
520 fts_rank: Some(1),
521 rrf_score: Some(0.05),
522 normalized_score: 0.8,
523 vec_distance: Some(0.25),
524 fts_bm25: None,
525 };
526 let json = serde_json::to_string(&item).unwrap();
527 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
528 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
529 assert!(json.contains("\"type\""), "deve serializar type renomeado");
530 assert!(!json.contains("memory_type"), "must not expose memory_type");
531 }
532
533 #[test]
534 fn hybrid_search_response_serializes_k_correctly() {
535 let resp = empty_response(5, 60, 1.0, 1.0);
536 let json = serde_json::to_string(&resp).unwrap();
537 assert!(json.contains("\"k\":5"), "deve serializar k=5");
538 }
539
540 #[test]
541 fn hybrid_search_response_with_graph_matches() {
542 use crate::output::RecallItem;
543 let resp = HybridSearchResponse {
544 query: "test".to_string(),
545 k: 5,
546 rrf_k: 60,
547 weights: Weights { vec: 1.0, fts: 1.0 },
548 results: vec![],
549 graph_matches: vec![RecallItem {
550 memory_id: 1,
551 name: "graph-hit".to_string(),
552 namespace: "global".to_string(),
553 memory_type: "document".to_string(),
554 description: "found via graph".to_string(),
555 snippet: "graph content".to_string(),
556 distance: 0.1,
557 score: 0.9,
558 source: "graph".to_string(),
559 graph_depth: Some(1),
560 }],
561 fts_degraded: false,
562 fts_error: None,
563 fts_auto_rebuilt: false,
564 elapsed_ms: 42,
565 };
566 let json = serde_json::to_value(&resp).unwrap();
567 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
568 assert_eq!(json["graph_matches"][0]["source"], "graph");
569 assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
570 }
571
572 #[test]
573 fn fts_degraded_omitted_on_success_present_on_failure() {
574 let ok_resp = empty_response(5, 60, 1.0, 1.0);
576 let ok_json = serde_json::to_string(&ok_resp).unwrap();
577 assert!(
578 !ok_json.contains("\"fts_degraded\""),
579 "fts_degraded must be absent when false"
580 );
581 assert!(
582 !ok_json.contains("\"fts_error\""),
583 "fts_error must be absent when None"
584 );
585
586 let mut degraded_resp = empty_response(5, 60, 1.0, 1.0);
588 degraded_resp.fts_degraded = true;
589 degraded_resp.fts_error = Some("FTS5 table corrupted".to_string());
590 let degraded_json = serde_json::to_string(°raded_resp).unwrap();
591 assert!(
592 degraded_json.contains("\"fts_degraded\":true"),
593 "fts_degraded must be present and true when degraded"
594 );
595 assert!(
596 degraded_json.contains("\"fts_error\""),
597 "fts_error must be present when Some"
598 );
599 assert!(
600 degraded_json.contains("FTS5 table corrupted"),
601 "fts_error must contain the error message"
602 );
603 }
604}