1use crate::cli::MemoryType;
4use crate::errors::AppError;
5use crate::output::{self, JsonOutputFormat, RecallItem};
6use crate::paths::AppPaths;
7use crate::storage::connection::open_ro;
8use crate::storage::memories;
9
10use std::collections::HashMap;
11
12#[derive(clap::Args)]
19pub struct HybridSearchArgs {
20 pub query: String,
21 #[arg(short = 'k', long, default_value = "10")]
22 pub k: usize,
23 #[arg(long, default_value = "60")]
24 pub rrf_k: u32,
25 #[arg(long, default_value = "1.0")]
26 pub weight_vec: f32,
27 #[arg(long, default_value = "1.0")]
28 pub weight_fts: f32,
29 #[arg(long, value_enum)]
33 pub r#type: Option<MemoryType>,
34 #[arg(long)]
35 pub namespace: Option<String>,
36 #[arg(long)]
37 pub with_graph: bool,
38 #[arg(long, default_value = "2")]
39 pub max_hops: u32,
40 #[arg(long, default_value = "0.3")]
41 pub min_weight: f64,
42 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
43 pub format: JsonOutputFormat,
44 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
45 pub db: Option<String>,
46 #[arg(long, hide = true, help = "No-op; JSON is always emitted on stdout")]
48 pub json: bool,
49}
50
51#[derive(serde::Serialize)]
52pub struct HybridSearchItem {
53 pub memory_id: i64,
54 pub name: String,
55 pub namespace: String,
56 #[serde(rename = "type")]
57 pub memory_type: String,
58 pub description: String,
59 pub body: String,
60 pub combined_score: f64,
61 pub score: f64,
63 pub source: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub vec_rank: Option<usize>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub fts_rank: Option<usize>,
69 #[serde(skip_serializing_if = "Option::is_none")]
71 pub rrf_score: Option<f64>,
72}
73
74#[derive(serde::Serialize)]
76pub struct Weights {
77 pub vec: f32,
78 pub fts: f32,
79}
80
81#[derive(serde::Serialize)]
82pub struct HybridSearchResponse {
83 pub query: String,
84 pub k: usize,
85 pub rrf_k: u32,
87 pub weights: Weights,
89 pub results: Vec<HybridSearchItem>,
90 pub graph_matches: Vec<RecallItem>,
91 pub elapsed_ms: u64,
93}
94
95pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
96 let start = std::time::Instant::now();
97 let _ = args.format;
98
99 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
100 let paths = AppPaths::resolve(args.db.as_deref())?;
101 if !paths.db.exists() {
102 return Err(AppError::NotFound(
103 crate::i18n::errors_msg::database_not_found(&paths.db.display().to_string()),
104 ));
105 }
106
107 output::emit_progress_i18n(
108 "Computing query embedding...",
109 "Calculando embedding da consulta...",
110 );
111 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
112
113 let conn = open_ro(&paths.db)?;
114
115 let memory_type_str = args.r#type.map(|t| t.as_str());
116
117 let vec_results = memories::knn_search(
118 &conn,
119 &embedding,
120 &[namespace.clone()],
121 memory_type_str,
122 args.k * 2,
123 )?;
124
125 let vec_rank_map: HashMap<i64, usize> = vec_results
127 .iter()
128 .enumerate()
129 .map(|(pos, (id, _))| (*id, pos + 1))
130 .collect();
131
132 let fts_results =
133 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
134
135 let fts_rank_map: HashMap<i64, usize> = fts_results
137 .iter()
138 .enumerate()
139 .map(|(pos, row)| (row.id, pos + 1))
140 .collect();
141
142 let rrf_k = args.rrf_k as f64;
143
144 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
146
147 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
148 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
149 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
150 }
151
152 for (rank, row) in fts_results.iter().enumerate() {
153 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
154 *combined_scores.entry(row.id).or_insert(0.0) += score;
155 }
156
157 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
159 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160 ranked.truncate(args.k);
161
162 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
164
165 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
167 for id in &top_ids {
168 if let Some(row) = memories::read_full(&conn, *id)? {
169 memory_data.insert(*id, row);
170 }
171 }
172
173 let results: Vec<HybridSearchItem> = ranked
175 .into_iter()
176 .filter_map(|(memory_id, combined_score)| {
177 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
178 memory_id: row.id,
179 name: row.name,
180 namespace: row.namespace,
181 memory_type: row.memory_type,
182 description: row.description,
183 body: row.body,
184 combined_score,
185 score: combined_score,
186 source: "hybrid".to_string(),
187 vec_rank: vec_rank_map.get(&memory_id).copied(),
188 fts_rank: fts_rank_map.get(&memory_id).copied(),
189 rrf_score: Some(combined_score),
190 })
191 })
192 .collect();
193
194 output::emit_json(&HybridSearchResponse {
195 query: args.query,
196 k: args.k,
197 rrf_k: args.rrf_k,
198 weights: Weights {
199 vec: args.weight_vec,
200 fts: args.weight_fts,
201 },
202 results,
203 graph_matches: vec![],
204 elapsed_ms: start.elapsed().as_millis() as u64,
205 })?;
206
207 Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 fn resposta_vazia(
215 k: usize,
216 rrf_k: u32,
217 weight_vec: f32,
218 weight_fts: f32,
219 ) -> HybridSearchResponse {
220 HybridSearchResponse {
221 query: "busca teste".to_string(),
222 k,
223 rrf_k,
224 weights: Weights {
225 vec: weight_vec,
226 fts: weight_fts,
227 },
228 results: vec![],
229 graph_matches: vec![],
230 elapsed_ms: 0,
231 }
232 }
233
234 #[test]
235 fn hybrid_search_response_vazia_serializa_campos_corretos() {
236 let resp = resposta_vazia(10, 60, 1.0, 1.0);
237 let json = serde_json::to_string(&resp).unwrap();
238 assert!(json.contains("\"results\""), "deve conter campo results");
239 assert!(json.contains("\"query\""), "deve conter campo query");
240 assert!(json.contains("\"k\""), "deve conter campo k");
241 assert!(
242 json.contains("\"graph_matches\""),
243 "deve conter campo graph_matches"
244 );
245 assert!(
246 !json.contains("\"combined_rank\""),
247 "NÃO deve conter combined_rank"
248 );
249 assert!(
250 !json.contains("\"vec_rank_list\""),
251 "NÃO deve conter vec_rank_list"
252 );
253 assert!(
254 !json.contains("\"fts_rank_list\""),
255 "NÃO deve conter fts_rank_list"
256 );
257 }
258
259 #[test]
260 fn hybrid_search_response_serializa_rrf_k_e_weights() {
261 let resp = resposta_vazia(5, 60, 0.7, 0.3);
262 let json = serde_json::to_string(&resp).unwrap();
263 assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
264 assert!(json.contains("\"weights\""), "deve conter campo weights");
265 assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
266 assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
267 }
268
269 #[test]
270 fn hybrid_search_response_serializa_elapsed_ms() {
271 let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
272 resp.elapsed_ms = 123;
273 let json = serde_json::to_string(&resp).unwrap();
274 assert!(
275 json.contains("\"elapsed_ms\""),
276 "deve conter campo elapsed_ms"
277 );
278 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
279 }
280
281 #[test]
282 fn weights_struct_serializa_corretamente() {
283 let w = Weights { vec: 0.6, fts: 0.4 };
284 let json = serde_json::to_string(&w).unwrap();
285 assert!(json.contains("\"vec\""));
286 assert!(json.contains("\"fts\""));
287 }
288
289 #[test]
290 fn hybrid_search_item_omite_fts_rank_quando_none() {
291 let item = HybridSearchItem {
292 memory_id: 1,
293 name: "mem".to_string(),
294 namespace: "default".to_string(),
295 memory_type: "user".to_string(),
296 description: "desc".to_string(),
297 body: "conteúdo".to_string(),
298 combined_score: 0.0328,
299 score: 0.0328,
300 source: "hybrid".to_string(),
301 vec_rank: Some(1),
302 fts_rank: None,
303 rrf_score: Some(0.0328),
304 };
305 let json = serde_json::to_string(&item).unwrap();
306 assert!(
307 json.contains("\"vec_rank\""),
308 "deve conter vec_rank quando Some"
309 );
310 assert!(
311 !json.contains("\"fts_rank\""),
312 "NÃO deve conter fts_rank quando None"
313 );
314 }
315
316 #[test]
317 fn hybrid_search_item_omite_vec_rank_quando_none() {
318 let item = HybridSearchItem {
319 memory_id: 2,
320 name: "mem2".to_string(),
321 namespace: "default".to_string(),
322 memory_type: "fact".to_string(),
323 description: "desc2".to_string(),
324 body: "corpo2".to_string(),
325 combined_score: 0.016,
326 score: 0.016,
327 source: "hybrid".to_string(),
328 vec_rank: None,
329 fts_rank: Some(2),
330 rrf_score: Some(0.016),
331 };
332 let json = serde_json::to_string(&item).unwrap();
333 assert!(
334 !json.contains("\"vec_rank\""),
335 "NÃO deve conter vec_rank quando None"
336 );
337 assert!(
338 json.contains("\"fts_rank\""),
339 "deve conter fts_rank quando Some"
340 );
341 }
342
343 #[test]
344 fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
345 let item = HybridSearchItem {
346 memory_id: 3,
347 name: "mem3".to_string(),
348 namespace: "ns".to_string(),
349 memory_type: "entity".to_string(),
350 description: "desc3".to_string(),
351 body: "corpo3".to_string(),
352 combined_score: 0.05,
353 score: 0.05,
354 source: "hybrid".to_string(),
355 vec_rank: Some(3),
356 fts_rank: Some(1),
357 rrf_score: Some(0.05),
358 };
359 let json = serde_json::to_string(&item).unwrap();
360 assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
361 assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
362 assert!(json.contains("\"type\""), "deve serializar type renomeado");
363 assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
364 }
365
366 #[test]
367 fn hybrid_search_response_serializa_k_corretamente() {
368 let resp = resposta_vazia(5, 60, 1.0, 1.0);
369 let json = serde_json::to_string(&resp).unwrap();
370 assert!(json.contains("\"k\":5"), "deve serializar k=5");
371 }
372}