1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::output::{self, JsonOutputFormat, RecallItem};
4use crate::paths::AppPaths;
5use crate::storage::connection::open_ro;
6use crate::storage::memories;
7
8use std::collections::HashMap;
9
10#[derive(clap::Args)]
11pub struct HybridSearchArgs {
12 pub query: String,
13 #[arg(short = 'k', long, default_value = "10")]
14 pub k: usize,
15 #[arg(long, default_value = "60")]
16 pub rrf_k: u32,
17 #[arg(long, default_value = "1.0")]
18 pub weight_vec: f32,
19 #[arg(long, default_value = "1.0")]
20 pub weight_fts: f32,
21 #[arg(long, value_enum)]
22 pub r#type: Option<MemoryType>,
23 #[arg(long)]
24 pub namespace: Option<String>,
25 #[arg(long)]
26 pub with_graph: bool,
27 #[arg(long, default_value = "2")]
28 pub max_hops: u32,
29 #[arg(long, default_value = "0.3")]
30 pub min_weight: f64,
31 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
32 pub format: JsonOutputFormat,
33 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
34 pub db: Option<String>,
35 #[arg(long, help = "No-op; JSON is always emitted on stdout")]
37 pub json: bool,
38}
39
40#[derive(serde::Serialize)]
41pub struct HybridSearchItem {
42 pub memory_id: i64,
43 pub name: String,
44 pub namespace: String,
45 #[serde(rename = "type")]
46 pub memory_type: String,
47 pub description: String,
48 pub body: String,
49 pub combined_score: f64,
50 pub score: f64,
52 pub source: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub vec_rank: Option<usize>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub fts_rank: Option<usize>,
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub rrf_score: Option<f64>,
61}
62
63#[derive(serde::Serialize)]
65pub struct Weights {
66 pub vec: f32,
67 pub fts: f32,
68}
69
70#[derive(serde::Serialize)]
71pub struct HybridSearchResponse {
72 pub query: String,
73 pub k: usize,
74 pub rrf_k: u32,
76 pub weights: Weights,
78 pub results: Vec<HybridSearchItem>,
79 pub graph_matches: Vec<RecallItem>,
80 pub elapsed_ms: u64,
82}
83
84pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
85 let start = std::time::Instant::now();
86 let _ = args.format;
87
88 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
89 let paths = AppPaths::resolve(args.db.as_deref())?;
90 if !paths.db.exists() {
91 return Err(AppError::NotFound(
92 crate::i18n::erros::banco_nao_encontrado(&paths.db.display().to_string()),
93 ));
94 }
95
96 output::emit_progress_i18n(
97 "Computing query embedding...",
98 "Calculando embedding da consulta...",
99 );
100 let embedding = crate::daemon::embed_query_or_local(&paths.models, &args.query)?;
101
102 let conn = open_ro(&paths.db)?;
103
104 let memory_type_str = args.r#type.map(|t| t.as_str());
105
106 let vec_results =
107 memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k * 2)?;
108
109 let vec_rank_map: HashMap<i64, usize> = vec_results
111 .iter()
112 .enumerate()
113 .map(|(pos, (id, _))| (*id, pos + 1))
114 .collect();
115
116 let fts_results =
117 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
118
119 let fts_rank_map: HashMap<i64, usize> = fts_results
121 .iter()
122 .enumerate()
123 .map(|(pos, row)| (row.id, pos + 1))
124 .collect();
125
126 let rrf_k = args.rrf_k as f64;
127
128 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
130
131 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
132 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
133 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
134 }
135
136 for (rank, row) in fts_results.iter().enumerate() {
137 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
138 *combined_scores.entry(row.id).or_insert(0.0) += score;
139 }
140
141 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
143 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
144 ranked.truncate(args.k);
145
146 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
148
149 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
151 for id in &top_ids {
152 if let Some(row) = memories::read_full(&conn, *id)? {
153 memory_data.insert(*id, row);
154 }
155 }
156
157 let results: Vec<HybridSearchItem> = ranked
159 .into_iter()
160 .filter_map(|(memory_id, combined_score)| {
161 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
162 memory_id: row.id,
163 name: row.name,
164 namespace: row.namespace,
165 memory_type: row.memory_type,
166 description: row.description,
167 body: row.body,
168 combined_score,
169 score: combined_score,
170 source: "hybrid".to_string(),
171 vec_rank: vec_rank_map.get(&memory_id).copied(),
172 fts_rank: fts_rank_map.get(&memory_id).copied(),
173 rrf_score: Some(combined_score),
174 })
175 })
176 .collect();
177
178 output::emit_json(&HybridSearchResponse {
179 query: args.query,
180 k: args.k,
181 rrf_k: args.rrf_k,
182 weights: Weights {
183 vec: args.weight_vec,
184 fts: args.weight_fts,
185 },
186 results,
187 graph_matches: vec![],
188 elapsed_ms: start.elapsed().as_millis() as u64,
189 })?;
190
191 Ok(())
192}
193
194#[cfg(test)]
195mod testes {
196 use super::*;
197
198 fn resposta_vazia(
199 k: usize,
200 rrf_k: u32,
201 weight_vec: f32,
202 weight_fts: f32,
203 ) -> HybridSearchResponse {
204 HybridSearchResponse {
205 query: "busca teste".to_string(),
206 k,
207 rrf_k,
208 weights: Weights {
209 vec: weight_vec,
210 fts: weight_fts,
211 },
212 results: vec![],
213 graph_matches: vec![],
214 elapsed_ms: 0,
215 }
216 }
217
218 #[test]
219 fn hybrid_search_response_vazia_serializa_campos_corretos() {
220 let resp = resposta_vazia(10, 60, 1.0, 1.0);
221 let json = serde_json::to_string(&resp).unwrap();
222 assert!(json.contains("\"results\""), "deve conter campo results");
223 assert!(json.contains("\"query\""), "deve conter campo query");
224 assert!(json.contains("\"k\""), "deve conter campo k");
225 assert!(
226 json.contains("\"graph_matches\""),
227 "deve conter campo graph_matches"
228 );
229 assert!(
230 !json.contains("\"combined_rank\""),
231 "NÃO deve conter combined_rank"
232 );
233 assert!(
234 !json.contains("\"vec_rank_list\""),
235 "NÃO deve conter vec_rank_list"
236 );
237 assert!(
238 !json.contains("\"fts_rank_list\""),
239 "NÃO deve conter fts_rank_list"
240 );
241 }
242
243 #[test]
244 fn hybrid_search_response_serializa_rrf_k_e_weights() {
245 let resp = resposta_vazia(5, 60, 0.7, 0.3);
246 let json = serde_json::to_string(&resp).unwrap();
247 assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
248 assert!(json.contains("\"weights\""), "deve conter campo weights");
249 assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
250 assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
251 }
252
253 #[test]
254 fn hybrid_search_response_serializa_elapsed_ms() {
255 let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
256 resp.elapsed_ms = 123;
257 let json = serde_json::to_string(&resp).unwrap();
258 assert!(
259 json.contains("\"elapsed_ms\""),
260 "deve conter campo elapsed_ms"
261 );
262 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
263 }
264
265 #[test]
266 fn weights_struct_serializa_corretamente() {
267 let w = Weights { vec: 0.6, fts: 0.4 };
268 let json = serde_json::to_string(&w).unwrap();
269 assert!(json.contains("\"vec\""));
270 assert!(json.contains("\"fts\""));
271 }
272
273 #[test]
274 fn hybrid_search_item_omite_fts_rank_quando_none() {
275 let item = HybridSearchItem {
276 memory_id: 1,
277 name: "mem".to_string(),
278 namespace: "default".to_string(),
279 memory_type: "user".to_string(),
280 description: "desc".to_string(),
281 body: "conteúdo".to_string(),
282 combined_score: 0.0328,
283 score: 0.0328,
284 source: "hybrid".to_string(),
285 vec_rank: Some(1),
286 fts_rank: None,
287 rrf_score: Some(0.0328),
288 };
289 let json = serde_json::to_string(&item).unwrap();
290 assert!(
291 json.contains("\"vec_rank\""),
292 "deve conter vec_rank quando Some"
293 );
294 assert!(
295 !json.contains("\"fts_rank\""),
296 "NÃO deve conter fts_rank quando None"
297 );
298 }
299
300 #[test]
301 fn hybrid_search_item_omite_vec_rank_quando_none() {
302 let item = HybridSearchItem {
303 memory_id: 2,
304 name: "mem2".to_string(),
305 namespace: "default".to_string(),
306 memory_type: "fact".to_string(),
307 description: "desc2".to_string(),
308 body: "corpo2".to_string(),
309 combined_score: 0.016,
310 score: 0.016,
311 source: "hybrid".to_string(),
312 vec_rank: None,
313 fts_rank: Some(2),
314 rrf_score: Some(0.016),
315 };
316 let json = serde_json::to_string(&item).unwrap();
317 assert!(
318 !json.contains("\"vec_rank\""),
319 "NÃO deve conter vec_rank quando None"
320 );
321 assert!(
322 json.contains("\"fts_rank\""),
323 "deve conter fts_rank quando Some"
324 );
325 }
326
327 #[test]
328 fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
329 let item = HybridSearchItem {
330 memory_id: 3,
331 name: "mem3".to_string(),
332 namespace: "ns".to_string(),
333 memory_type: "entity".to_string(),
334 description: "desc3".to_string(),
335 body: "corpo3".to_string(),
336 combined_score: 0.05,
337 score: 0.05,
338 source: "hybrid".to_string(),
339 vec_rank: Some(3),
340 fts_rank: Some(1),
341 rrf_score: Some(0.05),
342 };
343 let json = serde_json::to_string(&item).unwrap();
344 assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
345 assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
346 assert!(json.contains("\"type\""), "deve serializar type renomeado");
347 assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
348 }
349
350 #[test]
351 fn hybrid_search_response_serializa_k_corretamente() {
352 let resp = resposta_vazia(5, 60, 1.0, 1.0);
353 let json = serde_json::to_string(&resp).unwrap();
354 assert!(json.contains("\"k\":5"), "deve serializar k=5");
355 }
356}