1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::output::{self, JsonOutputFormat, RecallItem};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::memories;
9
10use std::collections::HashMap;
11
12#[derive(clap::Args)]
19pub struct HybridSearchArgs {
20 #[arg(help = "Hybrid search query (vector KNN + FTS5 BM25 fused via RRF)")]
21 pub query: String,
22 #[arg(short = 'k', long, alias = "limit", default_value = "10", value_parser = crate::parsers::parse_k_range)]
27 pub k: usize,
28 #[arg(long, default_value = "60")]
29 pub rrf_k: u32,
30 #[arg(long, default_value = "1.0")]
31 pub weight_vec: f32,
32 #[arg(long, default_value = "1.0")]
33 pub weight_fts: f32,
34 #[arg(long, value_enum)]
38 pub r#type: Option<MemoryType>,
39 #[arg(long)]
40 pub namespace: Option<String>,
41 #[arg(long)]
42 pub with_graph: bool,
43 #[arg(long, default_value = "2")]
44 pub max_hops: u32,
45 #[arg(long, default_value = "0.3")]
46 pub min_weight: f64,
47 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
48 pub format: JsonOutputFormat,
49 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
50 pub db: Option<String>,
51 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
53 pub json: bool,
54 #[command(flatten)]
55 pub daemon: crate::cli::DaemonOpts,
56}
57
58#[derive(serde::Serialize)]
59pub struct HybridSearchItem {
60 pub memory_id: i64,
61 pub name: String,
62 pub namespace: String,
63 #[serde(rename = "type")]
64 pub memory_type: String,
65 pub description: String,
66 pub body: String,
67 pub combined_score: f64,
68 pub score: f64,
70 pub source: String,
72 #[serde(skip_serializing_if = "Option::is_none")]
73 pub vec_rank: Option<usize>,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub fts_rank: Option<usize>,
76 #[serde(skip_serializing_if = "Option::is_none")]
78 pub rrf_score: Option<f64>,
79}
80
81#[derive(serde::Serialize)]
83pub struct Weights {
84 pub vec: f32,
85 pub fts: f32,
86}
87
88#[derive(serde::Serialize)]
89pub struct HybridSearchResponse {
90 pub query: String,
91 pub k: usize,
92 pub rrf_k: u32,
94 pub weights: Weights,
96 pub results: Vec<HybridSearchItem>,
97 pub graph_matches: Vec<RecallItem>,
98 pub elapsed_ms: u64,
100}
101
102pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
103 let start = std::time::Instant::now();
104 let _ = args.format;
105
106 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
107 let paths = AppPaths::resolve(args.db.as_deref())?;
108 crate::storage::connection::ensure_db_ready(&paths)?;
109
110 output::emit_progress_i18n(
111 "Computing query embedding...",
112 "Calculando embedding da consulta...",
113 );
114 let embedding = crate::daemon::embed_query_or_local(
115 &paths.models,
116 &args.query,
117 args.daemon.autostart_daemon,
118 )?;
119
120 let conn = open_ro(&paths.db)?;
121
122 let memory_type_str = args.r#type.map(|t| t.as_str());
123
124 let vec_results = memories::knn_search(
125 &conn,
126 &embedding,
127 &[namespace.clone()],
128 memory_type_str,
129 args.k * 2,
130 )?;
131
132 let vec_rank_map: HashMap<i64, usize> = vec_results
134 .iter()
135 .enumerate()
136 .map(|(pos, (id, _))| (*id, pos + 1))
137 .collect();
138
139 let fts_results =
140 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
141
142 let fts_rank_map: HashMap<i64, usize> = fts_results
144 .iter()
145 .enumerate()
146 .map(|(pos, row)| (row.id, pos + 1))
147 .collect();
148
149 let rrf_k = args.rrf_k as f64;
150
151 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
153
154 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
155 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
156 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
157 }
158
159 for (rank, row) in fts_results.iter().enumerate() {
160 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
161 *combined_scores.entry(row.id).or_insert(0.0) += score;
162 }
163
164 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
166 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
167 ranked.truncate(args.k);
168
169 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
171
172 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
174 for id in &top_ids {
175 if let Some(row) = memories::read_full(&conn, *id)? {
176 memory_data.insert(*id, row);
177 }
178 }
179
180 let results: Vec<HybridSearchItem> = ranked
182 .into_iter()
183 .filter_map(|(memory_id, combined_score)| {
184 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
185 memory_id: row.id,
186 name: row.name,
187 namespace: row.namespace,
188 memory_type: row.memory_type,
189 description: row.description,
190 body: row.body,
191 combined_score,
192 score: combined_score,
193 source: "hybrid".to_string(),
194 vec_rank: vec_rank_map.get(&memory_id).copied(),
195 fts_rank: fts_rank_map.get(&memory_id).copied(),
196 rrf_score: Some(combined_score),
197 })
198 })
199 .collect();
200
201 output::emit_json(&HybridSearchResponse {
202 query: args.query,
203 k: args.k,
204 rrf_k: args.rrf_k,
205 weights: Weights {
206 vec: args.weight_vec,
207 fts: args.weight_fts,
208 },
209 results,
210 graph_matches: vec![],
211 elapsed_ms: start.elapsed().as_millis() as u64,
212 })?;
213
214 Ok(())
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 fn empty_response(
222 k: usize,
223 rrf_k: u32,
224 weight_vec: f32,
225 weight_fts: f32,
226 ) -> HybridSearchResponse {
227 HybridSearchResponse {
228 query: "busca teste".to_string(),
229 k,
230 rrf_k,
231 weights: Weights {
232 vec: weight_vec,
233 fts: weight_fts,
234 },
235 results: vec![],
236 graph_matches: vec![],
237 elapsed_ms: 0,
238 }
239 }
240
241 #[test]
242 fn hybrid_search_response_empty_serializes_correct_fields() {
243 let resp = empty_response(10, 60, 1.0, 1.0);
244 let json = serde_json::to_string(&resp).unwrap();
245 assert!(json.contains("\"results\""), "must contain results field");
246 assert!(json.contains("\"query\""), "must contain query field");
247 assert!(json.contains("\"k\""), "must contain k field");
248 assert!(
249 json.contains("\"graph_matches\""),
250 "must contain graph_matches field"
251 );
252 assert!(
253 !json.contains("\"combined_rank\""),
254 "must not contain combined_rank"
255 );
256 assert!(
257 !json.contains("\"vec_rank_list\""),
258 "must not contain vec_rank_list"
259 );
260 assert!(
261 !json.contains("\"fts_rank_list\""),
262 "must not contain fts_rank_list"
263 );
264 }
265
266 #[test]
267 fn hybrid_search_response_serializes_rrf_k_and_weights() {
268 let resp = empty_response(5, 60, 0.7, 0.3);
269 let json = serde_json::to_string(&resp).unwrap();
270 assert!(json.contains("\"rrf_k\""), "must contain rrf_k field");
271 assert!(json.contains("\"weights\""), "must contain weights field");
272 assert!(json.contains("\"vec\""), "must contain weights.vec field");
273 assert!(json.contains("\"fts\""), "must contain weights.fts field");
274 }
275
276 #[test]
277 fn hybrid_search_response_serializes_elapsed_ms() {
278 let mut resp = empty_response(5, 60, 1.0, 1.0);
279 resp.elapsed_ms = 123;
280 let json = serde_json::to_string(&resp).unwrap();
281 assert!(
282 json.contains("\"elapsed_ms\""),
283 "must contain elapsed_ms field"
284 );
285 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
286 }
287
288 #[test]
289 fn weights_struct_serializes_correctly() {
290 let w = Weights { vec: 0.6, fts: 0.4 };
291 let json = serde_json::to_string(&w).unwrap();
292 assert!(json.contains("\"vec\""));
293 assert!(json.contains("\"fts\""));
294 }
295
296 #[test]
297 fn hybrid_search_item_omits_fts_rank_when_none() {
298 let item = HybridSearchItem {
299 memory_id: 1,
300 name: "mem".to_string(),
301 namespace: "default".to_string(),
302 memory_type: "user".to_string(),
303 description: "desc".to_string(),
304 body: "content".to_string(),
305 combined_score: 0.0328,
306 score: 0.0328,
307 source: "hybrid".to_string(),
308 vec_rank: Some(1),
309 fts_rank: None,
310 rrf_score: Some(0.0328),
311 };
312 let json = serde_json::to_string(&item).unwrap();
313 assert!(
314 json.contains("\"vec_rank\""),
315 "must contain vec_rank when Some"
316 );
317 assert!(
318 !json.contains("\"fts_rank\""),
319 "must not contain fts_rank when None"
320 );
321 }
322
323 #[test]
324 fn hybrid_search_item_omits_vec_rank_when_none() {
325 let item = HybridSearchItem {
326 memory_id: 2,
327 name: "mem2".to_string(),
328 namespace: "default".to_string(),
329 memory_type: "fact".to_string(),
330 description: "desc2".to_string(),
331 body: "corpo2".to_string(),
332 combined_score: 0.016,
333 score: 0.016,
334 source: "hybrid".to_string(),
335 vec_rank: None,
336 fts_rank: Some(2),
337 rrf_score: Some(0.016),
338 };
339 let json = serde_json::to_string(&item).unwrap();
340 assert!(
341 !json.contains("\"vec_rank\""),
342 "must not contain vec_rank when None"
343 );
344 assert!(
345 json.contains("\"fts_rank\""),
346 "must contain fts_rank when Some"
347 );
348 }
349
350 #[test]
351 fn hybrid_search_item_serializes_both_ranks_when_some() {
352 let item = HybridSearchItem {
353 memory_id: 3,
354 name: "mem3".to_string(),
355 namespace: "ns".to_string(),
356 memory_type: "entity".to_string(),
357 description: "desc3".to_string(),
358 body: "corpo3".to_string(),
359 combined_score: 0.05,
360 score: 0.05,
361 source: "hybrid".to_string(),
362 vec_rank: Some(3),
363 fts_rank: Some(1),
364 rrf_score: Some(0.05),
365 };
366 let json = serde_json::to_string(&item).unwrap();
367 assert!(json.contains("\"vec_rank\""), "must contain vec_rank");
368 assert!(json.contains("\"fts_rank\""), "must contain fts_rank");
369 assert!(json.contains("\"type\""), "deve serializar type renomeado");
370 assert!(!json.contains("memory_type"), "must not expose memory_type");
371 }
372
373 #[test]
374 fn hybrid_search_response_serializes_k_correctly() {
375 let resp = empty_response(5, 60, 1.0, 1.0);
376 let json = serde_json::to_string(&resp).unwrap();
377 assert!(json.contains("\"k\":5"), "deve serializar k=5");
378 }
379}