1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::graph::traverse_from_memories_with_hops;
6use crate::output::{self, JsonOutputFormat, RecallItem};
7use crate::paths::AppPaths;
8use crate::storage::connection::open_ro;
9use crate::storage::entities;
10use crate::storage::memories;
11
12use std::collections::HashMap;
13
14#[derive(clap::Args)]
21#[command(after_long_help = "EXAMPLES:\n \
22 # Basic hybrid search combining FTS5 + vector via RRF\n \
23 sqlite-graphrag hybrid-search \"postgres migration deadlock\" --k 10\n\n \
24 # Tune RRF weights to favor keyword matches over semantic similarity\n \
25 sqlite-graphrag hybrid-search \"jwt auth\" --weight-fts 1.5 --weight-vec 0.5 --k 5\n\n \
26 # Add graph traversal matches (entities connected to top results)\n \
27 sqlite-graphrag hybrid-search \"frontend architecture\" --with-graph --k 10\n\n \
28 # Graph traversal with custom depth and minimum edge weight\n \
29 sqlite-graphrag hybrid-search \"auth design\" --with-graph --max-hops 3 --min-weight 0.5 --k 10\n\n \
30NOTES:\n \
31 --with-graph enables entity graph traversal seeded by the top RRF results.\n \
32 Graph matches appear in the `graph_matches` array (separate from `results`).\n \
33 Without --with-graph, `graph_matches` is always empty.")]
34pub struct HybridSearchArgs {
35 #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
36 pub query: String,
37 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
42 pub k: usize,
43 #[arg(long, default_value = "60")]
44 pub rrf_k: u32,
45 #[arg(long, default_value = "1.0")]
46 pub weight_vec: f32,
47 #[arg(long, default_value = "1.0")]
48 pub weight_fts: f32,
49 #[arg(long, value_enum)]
53 pub r#type: Option<MemoryType>,
54 #[arg(long)]
55 pub namespace: Option<String>,
56 #[arg(long)]
57 pub with_graph: bool,
58 #[arg(long, default_value = "2")]
59 pub max_hops: u32,
60 #[arg(long, default_value = "0.3")]
61 pub min_weight: f64,
62 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
63 pub format: JsonOutputFormat,
64 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
65 pub db: Option<String>,
66 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
68 pub json: bool,
69 #[command(flatten)]
70 pub daemon: crate::cli::DaemonOpts,
71}
72
73#[derive(serde::Serialize)]
74pub struct HybridSearchItem {
75 pub memory_id: i64,
76 pub name: String,
77 pub namespace: String,
78 #[serde(rename = "type")]
79 pub memory_type: String,
80 pub description: String,
81 pub body: String,
82 pub combined_score: f64,
83 pub score: f64,
85 pub source: String,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub vec_rank: Option<usize>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub fts_rank: Option<usize>,
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub rrf_score: Option<f64>,
94}
95
96#[derive(serde::Serialize)]
98pub struct Weights {
99 pub vec: f32,
100 pub fts: f32,
101}
102
103#[derive(serde::Serialize)]
104pub struct HybridSearchResponse {
105 pub query: String,
106 pub k: usize,
107 pub rrf_k: u32,
109 pub weights: Weights,
111 pub results: Vec<HybridSearchItem>,
112 pub graph_matches: Vec<RecallItem>,
113 pub elapsed_ms: u64,
115}
116
117pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
118 let start = std::time::Instant::now();
119 let _ = args.format;
120
121 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
122 let paths = AppPaths::resolve(args.db.as_deref())?;
123 crate::storage::connection::ensure_db_ready(&paths)?;
124
125 output::emit_progress_i18n(
126 "Computing query embedding...",
127 "Calculando embedding da consulta...",
128 );
129 let embedding = crate::daemon::embed_query_or_local(
130 &paths.models,
131 &args.query,
132 args.daemon.autostart_daemon,
133 )?;
134
135 let conn = open_ro(&paths.db)?;
136
137 let memory_type_str = args.r#type.map(|t| t.as_str());
138
139 let vec_results = memories::knn_search(
140 &conn,
141 &embedding,
142 &[namespace.clone()],
143 memory_type_str,
144 args.k * 2,
145 )?;
146
147 let vec_rank_map: HashMap<i64, usize> = vec_results
149 .iter()
150 .enumerate()
151 .map(|(pos, (id, _))| (*id, pos + 1))
152 .collect();
153
154 let fts_results =
155 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
156
157 let fts_rank_map: HashMap<i64, usize> = fts_results
159 .iter()
160 .enumerate()
161 .map(|(pos, row)| (row.id, pos + 1))
162 .collect();
163
164 let rrf_k = args.rrf_k as f64;
165
166 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
168
169 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
170 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
171 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
172 }
173
174 for (rank, row) in fts_results.iter().enumerate() {
175 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
176 *combined_scores.entry(row.id).or_insert(0.0) += score;
177 }
178
179 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
181 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
182 ranked.truncate(args.k);
183
184 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
186
187 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
189 for id in &top_ids {
190 if let Some(row) = memories::read_full(&conn, *id)? {
191 memory_data.insert(*id, row);
192 }
193 }
194
195 let results: Vec<HybridSearchItem> = ranked
197 .into_iter()
198 .filter_map(|(memory_id, combined_score)| {
199 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
200 memory_id: row.id,
201 name: row.name,
202 namespace: row.namespace,
203 memory_type: row.memory_type,
204 description: row.description,
205 body: row.body,
206 combined_score,
207 score: combined_score,
208 source: "hybrid".to_string(),
209 vec_rank: vec_rank_map.get(&memory_id).copied(),
210 fts_rank: fts_rank_map.get(&memory_id).copied(),
211 rrf_score: Some(combined_score),
212 })
213 })
214 .collect();
215
216 let mut graph_matches: Vec<RecallItem> = Vec::new();
218 if args.with_graph && !results.is_empty() {
219 let namespace_for_graph = namespace.clone();
220 let memory_ids: Vec<i64> = results.iter().map(|r| r.memory_id).collect();
221
222 let entity_knn = entities::knn_search(&conn, &embedding, &namespace_for_graph, 5)?;
223 let entity_ids: Vec<i64> = entity_knn.iter().map(|(id, _)| *id).collect();
224
225 let all_seed_ids: Vec<i64> = memory_ids
226 .iter()
227 .chain(entity_ids.iter())
228 .copied()
229 .collect();
230
231 if !all_seed_ids.is_empty() {
232 let graph_memory_ids = traverse_from_memories_with_hops(
233 &conn,
234 &all_seed_ids,
235 &namespace_for_graph,
236 args.min_weight,
237 args.max_hops,
238 )?;
239
240 let already_in_results: std::collections::HashSet<i64> =
241 results.iter().map(|r| r.memory_id).collect();
242
243 for (graph_mem_id, hop) in graph_memory_ids {
244 if already_in_results.contains(&graph_mem_id) {
245 continue;
246 }
247 if let Some(row) = memories::read_full(&conn, graph_mem_id)? {
248 let snippet: String = row.body.chars().take(300).collect();
249 let graph_distance = 1.0 - 1.0 / (hop as f32 + 1.0);
250 graph_matches.push(RecallItem {
251 memory_id: row.id,
252 name: row.name,
253 namespace: row.namespace,
254 memory_type: row.memory_type,
255 description: row.description,
256 snippet,
257 distance: graph_distance,
258 score: RecallItem::score_from_distance(graph_distance),
259 source: "graph".to_string(),
260 graph_depth: Some(hop),
261 });
262 }
263 }
264 }
265 }
266
267 output::emit_json(&HybridSearchResponse {
268 query: args.query,
269 k: args.k,
270 rrf_k: args.rrf_k,
271 weights: Weights {
272 vec: args.weight_vec,
273 fts: args.weight_fts,
274 },
275 results,
276 graph_matches,
277 elapsed_ms: start.elapsed().as_millis() as u64,
278 })?;
279
280 Ok(())
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 fn empty_response(
288 k: usize,
289 rrf_k: u32,
290 weight_vec: f32,
291 weight_fts: f32,
292 ) -> HybridSearchResponse {
293 HybridSearchResponse {
294 query: "busca teste".to_string(),
295 k,
296 rrf_k,
297 weights: Weights {
298 vec: weight_vec,
299 fts: weight_fts,
300 },
301 results: vec![],
302 graph_matches: vec![],
303 elapsed_ms: 0,
304 }
305 }
306
307 #[test]
308 fn hybrid_search_response_empty_serializes_correct_fields() {
309 let resp = empty_response(10, 60, 1.0, 1.0);
310 let json = serde_json::to_string(&resp).unwrap();
311 assert!(json.contains("\"results\""), "must contain results field");
312 assert!(json.contains("\"query\""), "must contain query field");
313 assert!(json.contains("\"k\""), "must contain k field");
314 assert!(
315 json.contains("\"graph_matches\""),
316 "must contain graph_matches field"
317 );
318 assert!(
319 !json.contains("\"combined_rank\""),
320 "must not contain combined_rank"
321 );
322 assert!(
323 !json.contains("\"vec_rank_list\""),
324 "must not contain vec_rank_list"
325 );
326 assert!(
327 !json.contains("\"fts_rank_list\""),
328 "must not contain fts_rank_list"
329 );
330 }
331
332 #[test]
333 fn hybrid_search_response_serializes_rrf_k_and_weights() {
334 let resp = empty_response(5, 60, 0.7, 0.3);
335 let json = serde_json::to_string(&resp).unwrap();
336 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
337 assert!(json.contains("\"weights\""), "must contain weights field");
338 assert!(json.contains("\"vec\""), "must contain weights.vec field");
339 assert!(json.contains("\"fts\""), "must contain weights.fts field");
340 }
341
342 #[test]
343 fn hybrid_search_response_serializes_elapsed_ms() {
344 let mut resp = empty_response(5, 60, 1.0, 1.0);
345 resp.elapsed_ms = 123;
346 let json = serde_json::to_string(&resp).unwrap();
347 assert!(
348 json.contains("\"elapsed_ms\""),
349 "must contain elapsed_ms field"
350 );
351 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
352 }
353
354 #[test]
355 fn weights_struct_serializes_correctly() {
356 let w = Weights { vec: 0.6, fts: 0.4 };
357 let json = serde_json::to_string(&w).unwrap();
358 assert!(json.contains("\"vec\""));
359 assert!(json.contains("\"fts\""));
360 }
361
362 #[test]
363 fn hybrid_search_item_omits_fts_rank_when_none() {
364 let item = HybridSearchItem {
365 memory_id: 1,
366 name: "mem".to_string(),
367 namespace: "default".to_string(),
368 memory_type: "user".to_string(),
369 description: "desc".to_string(),
370 body: "content".to_string(),
371 combined_score: 0.0328,
372 score: 0.0328,
373 source: "hybrid".to_string(),
374 vec_rank: Some(1),
375 fts_rank: None,
376 rrf_score: Some(0.0328),
377 };
378 let json = serde_json::to_string(&item).unwrap();
379 assert!(
380 json.contains("\"vec_rank\""),
381 "must contain vec_rank when Some"
382 );
383 assert!(
384 !json.contains("\"fts_rank\""),
385 "must not contain fts_rank when None"
386 );
387 }
388
389 #[test]
390 fn hybrid_search_item_omits_vec_rank_when_none() {
391 let item = HybridSearchItem {
392 memory_id: 2,
393 name: "mem2".to_string(),
394 namespace: "default".to_string(),
395 memory_type: "fact".to_string(),
396 description: "desc2".to_string(),
397 body: "corpo2".to_string(),
398 combined_score: 0.016,
399 score: 0.016,
400 source: "hybrid".to_string(),
401 vec_rank: None,
402 fts_rank: Some(2),
403 rrf_score: Some(0.016),
404 };
405 let json = serde_json::to_string(&item).unwrap();
406 assert!(
407 !json.contains("\"vec_rank\""),
408 "must not contain vec_rank when None"
409 );
410 assert!(
411 json.contains("\"fts_rank\""),
412 "must contain fts_rank when Some"
413 );
414 }
415
416 #[test]
417 fn hybrid_search_item_serializes_both_ranks_when_some() {
418 let item = HybridSearchItem {
419 memory_id: 3,
420 name: "mem3".to_string(),
421 namespace: "ns".to_string(),
422 memory_type: "entity".to_string(),
423 description: "desc3".to_string(),
424 body: "corpo3".to_string(),
425 combined_score: 0.05,
426 score: 0.05,
427 source: "hybrid".to_string(),
428 vec_rank: Some(3),
429 fts_rank: Some(1),
430 rrf_score: Some(0.05),
431 };
432 let json = serde_json::to_string(&item).unwrap();
433 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
434 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
435 assert!(json.contains("\"type\""), "deve serializar type renomeado");
436 assert!(!json.contains("memory_type"), "must not expose memory_type");
437 }
438
439 #[test]
440 fn hybrid_search_response_serializes_k_correctly() {
441 let resp = empty_response(5, 60, 1.0, 1.0);
442 let json = serde_json::to_string(&resp).unwrap();
443 assert!(json.contains("\"k\":5"), "deve serializar k=5");
444 }
445
446 #[test]
447 fn hybrid_search_response_with_graph_matches() {
448 use crate::output::RecallItem;
449 let resp = HybridSearchResponse {
450 query: "test".to_string(),
451 k: 5,
452 rrf_k: 60,
453 weights: Weights { vec: 1.0, fts: 1.0 },
454 results: vec![],
455 graph_matches: vec![RecallItem {
456 memory_id: 1,
457 name: "graph-hit".to_string(),
458 namespace: "global".to_string(),
459 memory_type: "document".to_string(),
460 description: "found via graph".to_string(),
461 snippet: "graph content".to_string(),
462 distance: 0.1,
463 score: 0.9,
464 source: "graph".to_string(),
465 graph_depth: Some(1),
466 }],
467 elapsed_ms: 42,
468 };
469 let json = serde_json::to_value(&resp).unwrap();
470 assert_eq!(json["graph_matches"].as_array().unwrap().len(), 1);
471 assert_eq!(json["graph_matches"][0]["source"], "graph");
472 assert_eq!(json["graph_matches"][0]["graph_depth"], 1);
473 }
474}