1use crate::cli::MemoryType;
2use crate::errors::AppError;
3use crate::output::{self, JsonOutputFormat, RecallItem};
4use crate::paths::AppPaths;
5use crate::storage::connection::open_ro;
6use crate::storage::memories;
7
8use std::collections::HashMap;
9
10#[derive(clap::Args)]
11pub struct HybridSearchArgs {
12 pub query: String,
13 #[arg(short = 'k', long, default_value = "10")]
14 pub k: usize,
15 #[arg(long, default_value = "60")]
16 pub rrf_k: u32,
17 #[arg(long, default_value = "1.0")]
18 pub weight_vec: f32,
19 #[arg(long, default_value = "1.0")]
20 pub weight_fts: f32,
21 #[arg(long, value_enum)]
22 pub r#type: Option<MemoryType>,
23 #[arg(long)]
24 pub namespace: Option<String>,
25 #[arg(long)]
26 pub with_graph: bool,
27 #[arg(long, default_value = "2")]
28 pub max_hops: u32,
29 #[arg(long, default_value = "0.3")]
30 pub min_weight: f64,
31 #[arg(long, value_enum, default_value_t = JsonOutputFormat::Json)]
32 pub format: JsonOutputFormat,
33 #[arg(long, env = "SQLITE_GRAPHRAG_DB_PATH")]
34 pub db: Option<String>,
35 #[arg(long, hide = true)]
37 pub json: bool,
38}
39
40#[derive(serde::Serialize)]
41pub struct HybridSearchItem {
42 pub memory_id: i64,
43 pub name: String,
44 pub namespace: String,
45 #[serde(rename = "type")]
46 pub memory_type: String,
47 pub description: String,
48 pub body: String,
49 pub combined_score: f64,
50 pub score: f64,
52 pub source: String,
54 #[serde(skip_serializing_if = "Option::is_none")]
55 pub vec_rank: Option<usize>,
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub fts_rank: Option<usize>,
58}
59
60#[derive(serde::Serialize)]
62pub struct Weights {
63 pub vec: f32,
64 pub fts: f32,
65}
66
67#[derive(serde::Serialize)]
68pub struct HybridSearchResponse {
69 pub query: String,
70 pub k: usize,
71 pub rrf_k: u32,
73 pub weights: Weights,
75 pub results: Vec<HybridSearchItem>,
76 pub graph_matches: Vec<RecallItem>,
77 pub elapsed_ms: u64,
79}
80
81pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
82 let start = std::time::Instant::now();
83 let _ = args.format;
84
85 let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
86 let paths = AppPaths::resolve(args.db.as_deref())?;
87
88 output::emit_progress_i18n(
89 "Computing query embedding...",
90 "Calculando embedding da consulta...",
91 );
92 let embedder = crate::embedder::get_embedder(&paths.models)?;
93 let embedding = crate::embedder::embed_query(embedder, &args.query)?;
94
95 let conn = open_ro(&paths.db)?;
96
97 let memory_type_str = args.r#type.map(|t| t.as_str());
98
99 let vec_results =
100 memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k * 2)?;
101
102 let vec_rank_map: HashMap<i64, usize> = vec_results
104 .iter()
105 .enumerate()
106 .map(|(pos, (id, _))| (*id, pos + 1))
107 .collect();
108
109 let fts_results =
110 memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;
111
112 let fts_rank_map: HashMap<i64, usize> = fts_results
114 .iter()
115 .enumerate()
116 .map(|(pos, row)| (row.id, pos + 1))
117 .collect();
118
119 let rrf_k = args.rrf_k as f64;
120
121 let mut combined_scores: HashMap<i64, f64> = HashMap::new();
123
124 for (rank, (memory_id, _)) in vec_results.iter().enumerate() {
125 let score = args.weight_vec as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
126 *combined_scores.entry(*memory_id).or_insert(0.0) += score;
127 }
128
129 for (rank, row) in fts_results.iter().enumerate() {
130 let score = args.weight_fts as f64 * (1.0 / (rrf_k + rank as f64 + 1.0));
131 *combined_scores.entry(row.id).or_insert(0.0) += score;
132 }
133
134 let mut ranked: Vec<(i64, f64)> = combined_scores.into_iter().collect();
136 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
137 ranked.truncate(args.k);
138
139 let top_ids: Vec<i64> = ranked.iter().map(|(id, _)| *id).collect();
141
142 let mut memory_data: HashMap<i64, memories::MemoryRow> = HashMap::new();
144 for id in &top_ids {
145 if let Some(row) = memories::read_full(&conn, *id)? {
146 memory_data.insert(*id, row);
147 }
148 }
149
150 let results: Vec<HybridSearchItem> = ranked
152 .into_iter()
153 .filter_map(|(memory_id, combined_score)| {
154 memory_data.remove(&memory_id).map(|row| HybridSearchItem {
155 memory_id: row.id,
156 name: row.name,
157 namespace: row.namespace,
158 memory_type: row.memory_type,
159 description: row.description,
160 body: row.body,
161 combined_score,
162 score: combined_score,
163 source: "hybrid".to_string(),
164 vec_rank: vec_rank_map.get(&memory_id).copied(),
165 fts_rank: fts_rank_map.get(&memory_id).copied(),
166 })
167 })
168 .collect();
169
170 output::emit_json(&HybridSearchResponse {
171 query: args.query,
172 k: args.k,
173 rrf_k: args.rrf_k,
174 weights: Weights {
175 vec: args.weight_vec,
176 fts: args.weight_fts,
177 },
178 results,
179 graph_matches: vec![],
180 elapsed_ms: start.elapsed().as_millis() as u64,
181 })?;
182
183 Ok(())
184}
185
186#[cfg(test)]
187mod testes {
188 use super::*;
189
190 fn resposta_vazia(
191 k: usize,
192 rrf_k: u32,
193 weight_vec: f32,
194 weight_fts: f32,
195 ) -> HybridSearchResponse {
196 HybridSearchResponse {
197 query: "busca teste".to_string(),
198 k,
199 rrf_k,
200 weights: Weights {
201 vec: weight_vec,
202 fts: weight_fts,
203 },
204 results: vec![],
205 graph_matches: vec![],
206 elapsed_ms: 0,
207 }
208 }
209
210 #[test]
211 fn hybrid_search_response_vazia_serializa_campos_corretos() {
212 let resp = resposta_vazia(10, 60, 1.0, 1.0);
213 let json = serde_json::to_string(&resp).unwrap();
214 assert!(json.contains("\"results\""), "deve conter campo results");
215 assert!(json.contains("\"query\""), "deve conter campo query");
216 assert!(json.contains("\"k\""), "deve conter campo k");
217 assert!(
218 json.contains("\"graph_matches\""),
219 "deve conter campo graph_matches"
220 );
221 assert!(
222 !json.contains("\"combined_rank\""),
223 "NÃO deve conter combined_rank"
224 );
225 assert!(
226 !json.contains("\"vec_rank_list\""),
227 "NÃO deve conter vec_rank_list"
228 );
229 assert!(
230 !json.contains("\"fts_rank_list\""),
231 "NÃO deve conter fts_rank_list"
232 );
233 }
234
235 #[test]
236 fn hybrid_search_response_serializa_rrf_k_e_weights() {
237 let resp = resposta_vazia(5, 60, 0.7, 0.3);
238 let json = serde_json::to_string(&resp).unwrap();
239 assert!(json.contains("\"rrf_k\""), "deve conter campo rrf_k");
240 assert!(json.contains("\"weights\""), "deve conter campo weights");
241 assert!(json.contains("\"vec\""), "deve conter campo weights.vec");
242 assert!(json.contains("\"fts\""), "deve conter campo weights.fts");
243 }
244
245 #[test]
246 fn hybrid_search_response_serializa_elapsed_ms() {
247 let mut resp = resposta_vazia(5, 60, 1.0, 1.0);
248 resp.elapsed_ms = 123;
249 let json = serde_json::to_string(&resp).unwrap();
250 assert!(
251 json.contains("\"elapsed_ms\""),
252 "deve conter campo elapsed_ms"
253 );
254 assert!(json.contains("123"), "deve serializar valor de elapsed_ms");
255 }
256
257 #[test]
258 fn weights_struct_serializa_corretamente() {
259 let w = Weights { vec: 0.6, fts: 0.4 };
260 let json = serde_json::to_string(&w).unwrap();
261 assert!(json.contains("\"vec\""));
262 assert!(json.contains("\"fts\""));
263 }
264
265 #[test]
266 fn hybrid_search_item_omite_fts_rank_quando_none() {
267 let item = HybridSearchItem {
268 memory_id: 1,
269 name: "mem".to_string(),
270 namespace: "default".to_string(),
271 memory_type: "user".to_string(),
272 description: "desc".to_string(),
273 body: "conteúdo".to_string(),
274 combined_score: 0.0328,
275 score: 0.0328,
276 source: "hybrid".to_string(),
277 vec_rank: Some(1),
278 fts_rank: None,
279 };
280 let json = serde_json::to_string(&item).unwrap();
281 assert!(
282 json.contains("\"vec_rank\""),
283 "deve conter vec_rank quando Some"
284 );
285 assert!(
286 !json.contains("\"fts_rank\""),
287 "NÃO deve conter fts_rank quando None"
288 );
289 }
290
291 #[test]
292 fn hybrid_search_item_omite_vec_rank_quando_none() {
293 let item = HybridSearchItem {
294 memory_id: 2,
295 name: "mem2".to_string(),
296 namespace: "default".to_string(),
297 memory_type: "fact".to_string(),
298 description: "desc2".to_string(),
299 body: "corpo2".to_string(),
300 combined_score: 0.016,
301 score: 0.016,
302 source: "hybrid".to_string(),
303 vec_rank: None,
304 fts_rank: Some(2),
305 };
306 let json = serde_json::to_string(&item).unwrap();
307 assert!(
308 !json.contains("\"vec_rank\""),
309 "NÃO deve conter vec_rank quando None"
310 );
311 assert!(
312 json.contains("\"fts_rank\""),
313 "deve conter fts_rank quando Some"
314 );
315 }
316
317 #[test]
318 fn hybrid_search_item_serializa_ambos_ranks_quando_some() {
319 let item = HybridSearchItem {
320 memory_id: 3,
321 name: "mem3".to_string(),
322 namespace: "ns".to_string(),
323 memory_type: "entity".to_string(),
324 description: "desc3".to_string(),
325 body: "corpo3".to_string(),
326 combined_score: 0.05,
327 score: 0.05,
328 source: "hybrid".to_string(),
329 vec_rank: Some(3),
330 fts_rank: Some(1),
331 };
332 let json = serde_json::to_string(&item).unwrap();
333 assert!(json.contains("\"vec_rank\""), "deve conter vec_rank");
334 assert!(json.contains("\"fts_rank\""), "deve conter fts_rank");
335 assert!(json.contains("\"type\""), "deve serializar type renomeado");
336 assert!(!json.contains("memory_type"), "NÃO deve expor memory_type");
337 }
338
339 #[test]
340 fn hybrid_search_response_serializa_k_corretamente() {
341 let resp = resposta_vazia(5, 60, 1.0, 1.0);
342 let json = serde_json::to_string(&resp).unwrap();
343 assert!(json.contains("\"k\":5"), "deve serializar k=5");
344 }
345}