1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::output::{self, JsonOutputFormat, RecallItem};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::memories;
9
10use std::collections::HashMap;
11
12#[derive(clap::Args)]
19pub struct HybridSearchArgs {
20 #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
21 pub query: String,
22 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
27 pub k: usize,
28 #[arg(long, default_value = "60")]
29 pub rrf_k: u32,
30 #[arg(long, default_value = "1.0")]
31 pub weight_vec: f32,
32 #[arg(long, default_value = "1.0")]
33 pub weight_fts: f32,
34 #[arg(long, value_enum)]
38 pub r#type: Option<MemoryType>,
39 #[arg(long)]
40 pub namespace: Option<String>,
41 #[arg(long)]
42 pub with_graph: bool,
43 #[arg(long, default_value = "2")]
44 pub max_hops: u32,
45 #[arg(long, default_value = "0.3")]
46 pub min_weight: f64,
47 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
48 pub format: JsonOutputFormat,
49 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
50 pub db: Option<String>,
51 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
53 pub json: bool,
54}
55
56#[derive(serde::Serialize)]
57pub struct HybridSearchItem {
58 pub memory_id: i64,
59 pub name: String,
60 pub namespace: String,
61 #[serde(rename = "type")]
62 pub memory_type: String,
63 pub description: String,
64 pub body: String,
65 pub combined_score: f64,
66 pub score: f64,
68 pub source: String,
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub vec_rank: Option<usize>,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub fts_rank: Option<usize>,
74 #[serde(skip_serializing_if = "Option::is_none")]
76 pub rrf_score: Option<f64>,
77}
78
79#[derive(serde::Serialize)]
81pub struct Weights {
82 pub vec: f32,
83 pub fts: f32,
84}
85
86#[derive(serde::Serialize)]
87pub struct HybridSearchResponse {
88 pub query: String,
89 pub k: usize,
90 pub rrf_k: u32,
92 pub weights: Weights,
94 pub results: Vec<HybridSearchItem>,
95 pub graph_matches: Vec<RecallItem>,
96 pub elapsed_ms: u64,
98}
99
100pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
101 let start = std::time::Instant::now();
102 let _ = args.format;
103
104 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
105 let paths = AppPaths::resolve(args.db.as_deref())?;
106 crate::storage::connection::ensure_db_ready(&paths)?;
107
108 output::emit_progress_i18n(
109 "Computing query embedding...",
110 "Calculando embedding da consulta...",
111 );
112 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
113
114 let conn = open_ro(&paths.db)?;
115
116 let memory_type_str = args.r#type.map(|t| t.as_str());
117
118 let vec_results = memories::knn_search(
119 &conn,
120 &embedding,
121 &[namespace.clone()],
122 memory_type_str,
123 args.k * 2,
124 )?;
125
126 let vec_rank_map: HashMap<i64, usize> = vec_results
128 .iter()
129 .enumerate()
130 .map(|(pos, (id, _))| (*id, pos + 1))
131 .collect();
132
133 let fts_results =
134 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
135
136 let fts_rank_map: HashMap<i64, usize> = fts_results
138 .iter()
139 .enumerate()
140 .map(|(pos, row)| (row.id, pos + 1))
141 .collect();
142
143 let rrf_k = args.rrf_k as f64;
144
145 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
147
148 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
149 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
150 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
151 }
152
153 for (rank, row) in fts_results.iter().enumerate() {
154 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
155 *combined_scores.entry(row.id).or_insert(0.0) += score;
156 }
157
158 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
160 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
161 ranked.truncate(args.k);
162
163 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
165
166 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
168 for id in &top_ids {
169 if let Some(row) = memories::read_full(&conn, *id)? {
170 memory_data.insert(*id, row);
171 }
172 }
173
174 let results: Vec<HybridSearchItem> = ranked
176 .into_iter()
177 .filter_map(|(memory_id, combined_score)| {
178 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
179 memory_id: row.id,
180 name: row.name,
181 namespace: row.namespace,
182 memory_type: row.memory_type,
183 description: row.description,
184 body: row.body,
185 combined_score,
186 score: combined_score,
187 source: "hybrid".to_string(),
188 vec_rank: vec_rank_map.get(&memory_id).copied(),
189 fts_rank: fts_rank_map.get(&memory_id).copied(),
190 rrf_score: Some(combined_score),
191 })
192 })
193 .collect();
194
195 output::emit_json(&HybridSearchResponse {
196 query: args.query,
197 k: args.k,
198 rrf_k: args.rrf_k,
199 weights: Weights {
200 vec: args.weight_vec,
201 fts: args.weight_fts,
202 },
203 results,
204 graph_matches: vec![],
205 elapsed_ms: start.elapsed().as_millis() as u64,
206 })?;
207
208 Ok(())
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 fn empty_response(
216 k: usize,
217 rrf_k: u32,
218 weight_vec: f32,
219 weight_fts: f32,
220 ) -> HybridSearchResponse {
221 HybridSearchResponse {
222 query: "busca teste".to_string(),
223 k,
224 rrf_k,
225 weights: Weights {
226 vec: weight_vec,
227 fts: weight_fts,
228 },
229 results: vec![],
230 graph_matches: vec![],
231 elapsed_ms: 0,
232 }
233 }
234
235 #[test]
236 fn hybrid_search_response_empty_serializes_correct_fields() {
237 let resp = empty_response(10, 60, 1.0, 1.0);
238 let json = serde_json::to_string(&resp).unwrap();
239 assert!(json.contains("\"results\""), "must contain results field");
240 assert!(json.contains("\"query\""), "must contain query field");
241 assert!(json.contains("\"k\""), "must contain k field");
242 assert!(
243 json.contains("\"graph_matches\""),
244 "must contain graph_matches field"
245 );
246 assert!(
247 !json.contains("\"combined_rank\""),
248 "must not contain combined_rank"
249 );
250 assert!(
251 !json.contains("\"vec_rank_list\""),
252 "must not contain vec_rank_list"
253 );
254 assert!(
255 !json.contains("\"fts_rank_list\""),
256 "must not contain fts_rank_list"
257 );
258 }
259
260 #[test]
261 fn hybrid_search_response_serializes_rrf_k_and_weights() {
262 let resp = empty_response(5, 60, 0.7, 0.3);
263 let json = serde_json::to_string(&resp).unwrap();
264 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
265 assert!(json.contains("\"weights\""), "must contain weights field");
266 assert!(json.contains("\"vec\""), "must contain weights.vec field");
267 assert!(json.contains("\"fts\""), "must contain weights.fts field");
268 }
269
270 #[test]
271 fn hybrid_search_response_serializes_elapsed_ms() {
272 let mut resp = empty_response(5, 60, 1.0, 1.0);
273 resp.elapsed_ms = 123;
274 let json = serde_json::to_string(&resp).unwrap();
275 assert!(
276 json.contains("\"elapsed_ms\""),
277 "must contain elapsed_ms field"
278 );
279 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
280 }
281
282 #[test]
283 fn weights_struct_serializes_correctly() {
284 let w = Weights { vec: 0.6, fts: 0.4 };
285 let json = serde_json::to_string(&w).unwrap();
286 assert!(json.contains("\"vec\""));
287 assert!(json.contains("\"fts\""));
288 }
289
290 #[test]
291 fn hybrid_search_item_omits_fts_rank_when_none() {
292 let item = HybridSearchItem {
293 memory_id: 1,
294 name: "mem".to_string(),
295 namespace: "default".to_string(),
296 memory_type: "user".to_string(),
297 description: "desc".to_string(),
298 body: "content".to_string(),
299 combined_score: 0.0328,
300 score: 0.0328,
301 source: "hybrid".to_string(),
302 vec_rank: Some(1),
303 fts_rank: None,
304 rrf_score: Some(0.0328),
305 };
306 let json = serde_json::to_string(&item).unwrap();
307 assert!(
308 json.contains("\"vec_rank\""),
309 "must contain vec_rank when Some"
310 );
311 assert!(
312 !json.contains("\"fts_rank\""),
313 "must not contain fts_rank when None"
314 );
315 }
316
317 #[test]
318 fn hybrid_search_item_omits_vec_rank_when_none() {
319 let item = HybridSearchItem {
320 memory_id: 2,
321 name: "mem2".to_string(),
322 namespace: "default".to_string(),
323 memory_type: "fact".to_string(),
324 description: "desc2".to_string(),
325 body: "corpo2".to_string(),
326 combined_score: 0.016,
327 score: 0.016,
328 source: "hybrid".to_string(),
329 vec_rank: None,
330 fts_rank: Some(2),
331 rrf_score: Some(0.016),
332 };
333 let json = serde_json::to_string(&item).unwrap();
334 assert!(
335 !json.contains("\"vec_rank\""),
336 "must not contain vec_rank when None"
337 );
338 assert!(
339 json.contains("\"fts_rank\""),
340 "must contain fts_rank when Some"
341 );
342 }
343
344 #[test]
345 fn hybrid_search_item_serializes_both_ranks_when_some() {
346 let item = HybridSearchItem {
347 memory_id: 3,
348 name: "mem3".to_string(),
349 namespace: "ns".to_string(),
350 memory_type: "entity".to_string(),
351 description: "desc3".to_string(),
352 body: "corpo3".to_string(),
353 combined_score: 0.05,
354 score: 0.05,
355 source: "hybrid".to_string(),
356 vec_rank: Some(3),
357 fts_rank: Some(1),
358 rrf_score: Some(0.05),
359 };
360 let json = serde_json::to_string(&item).unwrap();
361 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
362 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
363 assert!(json.contains("\"type\""), "deve serializar type renomeado");
364 assert!(!json.contains("memory_type"), "must not expose memory_type");
365 }
366
367 #[test]
368 fn hybrid_search_response_serializes_k_correctly() {
369 let resp = empty_response(5, 60, 1.0, 1.0);
370 let json = serde_json::to_string(&resp).unwrap();
371 assert!(json.contains("\"k\":5"), "deve serializar k=5");
372 }
373}