1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::output::{self, JsonOutputFormat, RecallItem};
4use crate::paths::AppPaths;
5use crate::storage::connection::open_ro;
6use crate::storage::memories;
7
8use std::collections::HashMap;
9
10#[derive(clap::Args)]
17pub struct HybridSearchArgs {
18 pub query: String,
19 #[arg(short = 'k', long, default_value = "10")]
20 pub k: usize,
21 #[arg(long, default_value = "60")]
22 pub rrf_k: u32,
23 #[arg(long, default_value = "1.0")]
24 pub weight_vec: f32,
25 #[arg(long, default_value = "1.0")]
26 pub weight_fts: f32,
27 #[arg(long, value_enum)]
31 pub r#type: Option<MemoryType>,
32 #[arg(long)]
33 pub namespace: Option<String>,
34 #[arg(long)]
35 pub with_graph: bool,
36 #[arg(long, default_value = "2")]
37 pub max_hops: u32,
38 #[arg(long, default_value = "0.3")]
39 pub min_weight: f64,
40 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
41 pub format: JsonOutputFormat,
42 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
43 pub db: Option<String>,
44 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
46 pub json: bool,
47}
48
49#[derive(serde::Serialize)]
50pub struct HybridSearchItem {
51 pub memory_id: i64,
52 pub name: String,
53 pub namespace: String,
54 #[serde(rename = "type")]
55 pub memory_type: String,
56 pub description: String,
57 pub body: String,
58 pub combined_score: f64,
59 pub score: f64,
61 pub source: String,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub vec_rank: Option<usize>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub fts_rank: Option<usize>,
67 #[serde(skip_serializing_if = "Option::is_none")]
69 pub rrf_score: Option<f64>,
70}
71
72#[derive(serde::Serialize)]
74pub struct Weights {
75 pub vec: f32,
76 pub fts: f32,
77}
78
79#[derive(serde::Serialize)]
80pub struct HybridSearchResponse {
81 pub query: String,
82 pub k: usize,
83 pub rrf_k: u32,
85 pub weights: Weights,
87 pub results: Vec<HybridSearchItem>,
88 pub graph_matches: Vec<RecallItem>,
89 pub elapsed_ms: u64,
91}
92
93pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
94 let start = std::time::Instant::now();
95 let _ = args.format;
96
97 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
98 let paths = AppPaths::resolve(args.db.as_deref())?;
99 if !paths.db.exists() {
100 return Err(AppError::NotFound(
101 crate::i18n::erros::banco_nao_encontrado(&paths.db.display().to_string()),
102 ));
103 }
104
105 output::emit_progress_i18n(
106 "Computing query embedding...",
107 "Calculando embedding da consulta...",
108 );
109 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
110
111 let conn = open_ro(&paths.db)?;
112
113 let memory_type_str = args.r#type.map(|t| t.as_str());
114
115 let vec_results = memories::knn_search(
116 &conn,
117 &embedding,
118 &[namespace.clone()],
119 memory_type_str,
120 args.k * 2,
121 )?;
122
123 let vec_rank_map: HashMap<i64, usize> = vec_results
125 .iter()
126 .enumerate()
127 .map(|(pos, (id, _))| (*id, pos + 1))
128 .collect();
129
130 let fts_results =
131 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
132
133 let fts_rank_map: HashMap<i64, usize> = fts_results
135 .iter()
136 .enumerate()
137 .map(|(pos, row)| (row.id, pos + 1))
138 .collect();
139
140 let rrf_k = args.rrf_k as f64;
141
142 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
144
145 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
146 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
147 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
148 }
149
150 for (rank, row) in fts_results.iter().enumerate() {
151 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
152 *combined_scores.entry(row.id).or_insert(0.0) += score;
153 }
154
155 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
157 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
158 ranked.truncate(args.k);
159
160 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
162
163 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
165 for id in &top_ids {
166 if let Some(row) = memories::read_full(&conn, *id)? {
167 memory_data.insert(*id, row);
168 }
169 }
170
171 let results: Vec<HybridSearchItem> = ranked
173 .into_iter()
174 .filter_map(|(memory_id, combined_score)| {
175 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
176 memory_id: row.id,
177 name: row.name,
178 namespace: row.namespace,
179 memory_type: row.memory_type,
180 description: row.description,
181 body: row.body,
182 combined_score,
183 score: combined_score,
184 source: "hybrid".to_string(),
185 vec_rank: vec_rank_map.get(&memory_id).copied(),
186 fts_rank: fts_rank_map.get(&memory_id).copied(),
187 rrf_score: Some(combined_score),
188 })
189 })
190 .collect();
191
192 output::emit_json(&HybridSearchResponse {
193 query: args.query,
194 k: args.k,
195 rrf_k: args.rrf_k,
196 weights: Weights {
197 vec: args.weight_vec,
198 fts: args.weight_fts,
199 },
200 results,
201 graph_matches: vec![],
202 elapsed_ms: start.elapsed().as_millis() as u64,
203 })?;
204
205 Ok(())
206}
207
208#[cfg(test)]
209mod testes {
210 use super::*;
211
212 fn resposta_vazia(
213 k: usize,
214 rrf_k: u32,
215 weight_vec: f32,
216 weight_fts: f32,
217 ) -> HybridSearchResponse {
218 HybridSearchResponse {
219 query: "busca teste".to_string(),
220 k,
221 rrf_k,
222 weights: Weights {
223 vec: weight_vec,
224 fts: weight_fts,
225 },
226 results: vec![],
227 graph_matches: vec![],
228 elapsed_ms: 0,
229 }
230 }
231
232 #[test]
233 fn hybrid_search_response_vazia_serializa_campos_corretos() {
234 let resp = resposta_vazia(10, 60, 1.0, 1.0);
235 let json = serde_json::to_string(&resp).unwrap();
236 assert!(json.contains("\"results\""), "deve conter campo results");
237 assert!(json.contains("\"query\""), "deve conter campo query");
238 assert!(json.contains("\"k\""), "deve conter campo k");
239 assert!(
240 json.contains("\"graph_matches\""),
241 "deve conter campo graph_matches"
242 );
243 assert!(
244 !json.contains("\"combined_rank\""),
245 "NÃO deve conter combined_rank"
246 );
247 assert!(
248 !json.contains("\"vec_rank_list\""),
249 "NÃO deve conter vec_rank_list"
250 );
251 assert!(
252 !json.contains("\"fts_rank_list\""),
253 "NÃO deve conter fts_rank_list"
254 );
255 }
256
257 #[test]
258 fn hybrid_search_response_serializa_rrf_k_e_weights() {
259 let resp = resposta_vazia(5, 60, 0.7, 0.3);
260 let json = serde_json::to_string(&resp).unwrap();
261 assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
262 assert!(json.contains("\"weights\""), "deve conter campo weights");
263 assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
264 assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
265 }
266
267 #[test]
268 fn hybrid_search_response_serializa_elapsed_ms() {
269 let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
270 resp.elapsed_ms = 123;
271 let json = serde_json::to_string(&resp).unwrap();
272 assert!(
273 json.contains("\"elapsed_ms\""),
274 "deve conter campo elapsed_ms"
275 );
276 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
277 }
278
279 #[test]
280 fn weights_struct_serializa_corretamente() {
281 let w = Weights { vec: 0.6, fts: 0.4 };
282 let json = serde_json::to_string(&w).unwrap();
283 assert!(json.contains("\"vec\""));
284 assert!(json.contains("\"fts\""));
285 }
286
287 #[test]
288 fn hybrid_search_item_omite_fts_rank_quando_none() {
289 let item = HybridSearchItem {
290 memory_id: 1,
291 name: "mem".to_string(),
292 namespace: "default".to_string(),
293 memory_type: "user".to_string(),
294 description: "desc".to_string(),
295 body: "conteúdo".to_string(),
296 combined_score: 0.0328,
297 score: 0.0328,
298 source: "hybrid".to_string(),
299 vec_rank: Some(1),
300 fts_rank: None,
301 rrf_score: Some(0.0328),
302 };
303 let json = serde_json::to_string(&item).unwrap();
304 assert!(
305 json.contains("\"vec_rank\""),
306 "deve conter vec_rank quando Some"
307 );
308 assert!(
309 !json.contains("\"fts_rank\""),
310 "NÃO deve conter fts_rank quando None"
311 );
312 }
313
314 #[test]
315 fn hybrid_search_item_omite_vec_rank_quando_none() {
316 let item = HybridSearchItem {
317 memory_id: 2,
318 name: "mem2".to_string(),
319 namespace: "default".to_string(),
320 memory_type: "fact".to_string(),
321 description: "desc2".to_string(),
322 body: "corpo2".to_string(),
323 combined_score: 0.016,
324 score: 0.016,
325 source: "hybrid".to_string(),
326 vec_rank: None,
327 fts_rank: Some(2),
328 rrf_score: Some(0.016),
329 };
330 let json = serde_json::to_string(&item).unwrap();
331 assert!(
332 !json.contains("\"vec_rank\""),
333 "NÃO deve conter vec_rank quando None"
334 );
335 assert!(
336 json.contains("\"fts_rank\""),
337 "deve conter fts_rank quando Some"
338 );
339 }
340
341 #[test]
342 fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
343 let item = HybridSearchItem {
344 memory_id: 3,
345 name: "mem3".to_string(),
346 namespace: "ns".to_string(),
347 memory_type: "entity".to_string(),
348 description: "desc3".to_string(),
349 body: "corpo3".to_string(),
350 combined_score: 0.05,
351 score: 0.05,
352 source: "hybrid".to_string(),
353 vec_rank: Some(3),
354 fts_rank: Some(1),
355 rrf_score: Some(0.05),
356 };
357 let json = serde_json::to_string(&item).unwrap();
358 assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
359 assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
360 assert!(json.contains("\"type\""), "deve serializar type renomeado");
361 assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
362 }
363
364 #[test]
365 fn hybrid_search_response_serializa_k_corretamente() {
366 let resp = resposta_vazia(5, 60, 1.0, 1.0);
367 let json = serde_json::to_string(&resp).unwrap();
368 assert!(json.contains("\"k\":5"), "deve serializar k=5");
369 }
370}