1use std::collections::HashMap;
70
71pub type DocId = u64;
77
78#[derive(Debug, Clone)]
80pub struct SearchResult {
81 pub doc_id: DocId,
83
84 pub score: f32,
86
87 pub component_scores: Option<ComponentScores>,
89}
90
91#[derive(Debug, Clone)]
93pub struct ComponentScores {
94 pub vector_score: Option<f32>,
96
97 pub vector_rank: Option<usize>,
99
100 pub lexical_score: Option<f32>,
102
103 pub lexical_rank: Option<usize>,
105}
106
107#[derive(Debug, Clone, Copy)]
113pub struct RRFConfig {
114 pub k: f32,
117
118 pub vector_weight: f32,
120
121 pub lexical_weight: f32,
123}
124
125impl Default for RRFConfig {
126 fn default() -> Self {
127 Self {
128 k: 60.0,
129 vector_weight: 1.0,
130 lexical_weight: 1.0,
131 }
132 }
133}
134
135impl RRFConfig {
136 pub fn with_weights(vector_weight: f32, lexical_weight: f32) -> Self {
138 Self {
139 k: 60.0,
140 vector_weight,
141 lexical_weight,
142 }
143 }
144
145 pub fn semantic_focused() -> Self {
147 Self {
148 k: 60.0,
149 vector_weight: 0.7,
150 lexical_weight: 0.3,
151 }
152 }
153
154 pub fn keyword_focused() -> Self {
156 Self {
157 k: 60.0,
158 vector_weight: 0.3,
159 lexical_weight: 0.7,
160 }
161 }
162
163 pub fn balanced() -> Self {
165 Self::default()
166 }
167}
168
169pub struct RRFFusion {
175 config: RRFConfig,
176}
177
178impl RRFFusion {
179 pub fn new(config: RRFConfig) -> Self {
181 Self { config }
182 }
183
184 pub fn fuse(
195 &self,
196 vector_results: &[(DocId, f32)],
197 lexical_results: &[(DocId, f32)],
198 limit: usize,
199 keep_details: bool,
200 ) -> Vec<SearchResult> {
201 let k = self.config.k;
202
203 let mut doc_scores: HashMap<DocId, FusionState> = HashMap::new();
205
206 for (rank, &(doc_id, score)) in vector_results.iter().enumerate() {
208 let rrf_score = self.config.vector_weight / (k + (rank + 1) as f32);
209
210 let state = doc_scores.entry(doc_id).or_default();
211 state.rrf_score += rrf_score;
212 state.vector_score = Some(score);
213 state.vector_rank = Some(rank + 1);
214 }
215
216 for (rank, &(doc_id, score)) in lexical_results.iter().enumerate() {
218 let rrf_score = self.config.lexical_weight / (k + (rank + 1) as f32);
219
220 let state = doc_scores.entry(doc_id).or_default();
221 state.rrf_score += rrf_score;
222 state.lexical_score = Some(score);
223 state.lexical_rank = Some(rank + 1);
224 }
225
226 let mut results: Vec<SearchResult> = doc_scores
228 .into_iter()
229 .map(|(doc_id, state)| SearchResult {
230 doc_id,
231 score: state.rrf_score,
232 component_scores: if keep_details {
233 Some(ComponentScores {
234 vector_score: state.vector_score,
235 vector_rank: state.vector_rank,
236 lexical_score: state.lexical_score,
237 lexical_rank: state.lexical_rank,
238 })
239 } else {
240 None
241 },
242 })
243 .collect();
244
245 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
247
248 results.truncate(limit);
250
251 results
252 }
253
254 pub fn fuse_multi(
264 &self,
265 result_lists: &[(&[(DocId, f32)], f32)], limit: usize,
267 ) -> Vec<SearchResult> {
268 let k = self.config.k;
269 let mut doc_scores: HashMap<DocId, f32> = HashMap::new();
270
271 for (results, weight) in result_lists {
272 for (rank, &(doc_id, _score)) in results.iter().enumerate() {
273 let rrf_score = *weight / (k + (rank + 1) as f32);
274 *doc_scores.entry(doc_id).or_default() += rrf_score;
275 }
276 }
277
278 let mut results: Vec<SearchResult> = doc_scores
279 .into_iter()
280 .map(|(doc_id, score)| SearchResult {
281 doc_id,
282 score,
283 component_scores: None,
284 })
285 .collect();
286
287 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
288 results.truncate(limit);
289
290 results
291 }
292}
293
294#[derive(Default)]
296struct FusionState {
297 rrf_score: f32,
298 vector_score: Option<f32>,
299 vector_rank: Option<usize>,
300 lexical_score: Option<f32>,
301 lexical_rank: Option<usize>,
302}
303
304impl Default for RRFFusion {
305 fn default() -> Self {
306 Self::new(RRFConfig::default())
307 }
308}
309
310pub struct HybridSearchEngine<V, L> {
316 vector_search: V,
318
319 lexical_search: L,
321
322 fusion_config: RRFConfig,
324
325 overfetch_factor: f32,
327}
328
329pub trait VectorSearchBackend {
331 fn search(&self, query: &[f32], k: usize) -> Vec<(DocId, f32)>;
333}
334
335pub trait LexicalSearchBackend {
337 fn search(&self, query: &str, k: usize) -> Vec<(DocId, f32)>;
339}
340
341impl<V, L> HybridSearchEngine<V, L>
342where
343 V: VectorSearchBackend,
344 L: LexicalSearchBackend,
345{
346 pub fn new(vector_search: V, lexical_search: L) -> Self {
348 Self {
349 vector_search,
350 lexical_search,
351 fusion_config: RRFConfig::default(),
352 overfetch_factor: 2.0,
353 }
354 }
355
356 pub fn with_fusion_config(mut self, config: RRFConfig) -> Self {
358 self.fusion_config = config;
359 self
360 }
361
362 pub fn with_overfetch(mut self, factor: f32) -> Self {
364 self.overfetch_factor = factor.max(1.0);
365 self
366 }
367
368 pub fn search(
370 &self,
371 vector_query: Option<&[f32]>,
372 text_query: Option<&str>,
373 limit: usize,
374 ) -> Vec<SearchResult> {
375 let fetch_k = (limit as f32 * self.overfetch_factor) as usize;
376
377 let vector_results = match vector_query {
379 Some(q) => self.vector_search.search(q, fetch_k),
380 None => Vec::new(),
381 };
382
383 let lexical_results = match text_query {
385 Some(q) => self.lexical_search.search(q, fetch_k),
386 None => Vec::new(),
387 };
388
389 if vector_results.is_empty() {
391 return lexical_results
392 .into_iter()
393 .take(limit)
394 .map(|(doc_id, score)| SearchResult {
395 doc_id,
396 score,
397 component_scores: None,
398 })
399 .collect();
400 }
401
402 if lexical_results.is_empty() {
403 return vector_results
404 .into_iter()
405 .take(limit)
406 .map(|(doc_id, score)| SearchResult {
407 doc_id,
408 score,
409 component_scores: None,
410 })
411 .collect();
412 }
413
414 let fusion = RRFFusion::new(self.fusion_config);
416 fusion.fuse(&vector_results, &lexical_results, limit, false)
417 }
418
419 pub fn search_detailed(
421 &self,
422 vector_query: Option<&[f32]>,
423 text_query: Option<&str>,
424 limit: usize,
425 ) -> Vec<SearchResult> {
426 let fetch_k = (limit as f32 * self.overfetch_factor) as usize;
427
428 let vector_results = vector_query
429 .map(|q| self.vector_search.search(q, fetch_k))
430 .unwrap_or_default();
431
432 let lexical_results = text_query
433 .map(|q| self.lexical_search.search(q, fetch_k))
434 .unwrap_or_default();
435
436 let fusion = RRFFusion::new(self.fusion_config);
437 fusion.fuse(&vector_results, &lexical_results, limit, true)
438 }
439}
440
441pub fn filter_results<F>(
447 results: Vec<SearchResult>,
448 predicate: F,
449 limit: usize,
450) -> Vec<SearchResult>
451where
452 F: Fn(DocId) -> bool,
453{
454 results
455 .into_iter()
456 .filter(|r| predicate(r.doc_id))
457 .take(limit)
458 .collect()
459}
460
461#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_rrf_fusion_basic() {
471 let fusion = RRFFusion::default();
472
473 let vector_results = vec![(1, 0.95), (2, 0.90), (3, 0.85)];
474
475 let lexical_results = vec![
476 (2, 5.0), (4, 4.5),
478 (3, 4.0), ];
480
481 let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
482
483 assert!(!results.is_empty());
485
486 for r in &results {
488 assert!(r.score > 0.0);
489 }
490 }
491
492 #[test]
493 fn test_rrf_fusion_with_details() {
494 let fusion = RRFFusion::default();
495
496 let vector_results = vec![(1, 0.9), (2, 0.8)];
497 let lexical_results = vec![(2, 5.0), (3, 4.0)];
498
499 let results = fusion.fuse(&vector_results, &lexical_results, 10, true);
500
501 let doc2 = results.iter().find(|r| r.doc_id == 2).unwrap();
503 let scores = doc2.component_scores.as_ref().unwrap();
504
505 assert_eq!(scores.vector_rank, Some(2)); assert_eq!(scores.lexical_rank, Some(1)); assert_eq!(scores.vector_score, Some(0.8));
508 assert_eq!(scores.lexical_score, Some(5.0));
509 }
510
511 #[test]
512 fn test_rrf_ranking() {
513 let fusion = RRFFusion::default();
514
515 let vector_results = vec![(1, 0.95), (2, 0.90)];
519 let lexical_results = vec![(2, 5.0)];
520
521 let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
522
523 assert_eq!(results[0].doc_id, 2); }
525
526 #[test]
527 fn test_rrf_weights() {
528 let config = RRFConfig::keyword_focused();
530 let fusion = RRFFusion::new(config);
531
532 let vector_results = vec![(1, 0.95)];
535 let lexical_results = vec![(2, 5.0)];
536
537 let results = fusion.fuse(&vector_results, &lexical_results, 10, false);
538
539 assert_eq!(results[0].doc_id, 2);
541 }
542
543 #[test]
544 fn test_fuse_multi() {
545 let fusion = RRFFusion::default();
546
547 let list1: Vec<(DocId, f32)> = vec![(1, 0.9), (2, 0.8)];
548 let list2: Vec<(DocId, f32)> = vec![(2, 0.9), (3, 0.8)];
549 let list3: Vec<(DocId, f32)> = vec![(3, 0.9), (1, 0.8)];
550
551 let results = fusion.fuse_multi(&[(&list1, 1.0), (&list2, 1.0), (&list3, 1.0)], 10);
552
553 let doc_ids: Vec<_> = results.iter().map(|r| r.doc_id).collect();
555 assert!(doc_ids.contains(&1));
556 assert!(doc_ids.contains(&2));
557 assert!(doc_ids.contains(&3));
558 }
559
560 #[test]
561 fn test_fuse_multi_rrf_formula_golden() {
562 let k = 60.0_f32;
568 let fusion = RRFFusion::new(RRFConfig {
569 k,
570 vector_weight: 1.0,
571 lexical_weight: 1.0,
572 });
573
574 let docs: Vec<(DocId, f32)> = vec![(7, 0.9), (8, 0.5)];
576 let single = fusion.fuse_multi(&[(&docs, 2.0)], 10);
577 let s7 = single.iter().find(|r| r.doc_id == 7).unwrap().score;
578 let s8 = single.iter().find(|r| r.doc_id == 8).unwrap().score;
579 assert!(
580 (s7 - 2.0 / (k + 1.0)).abs() < 1e-6,
581 "rank-1 must be 1-indexed weighted"
582 );
583 assert!(
584 (s8 - 2.0 / (k + 2.0)).abs() < 1e-6,
585 "rank-2 must be 1-indexed weighted"
586 );
587 assert!(s7 > s8, "earlier rank must score higher");
588
589 let la: Vec<(DocId, f32)> = vec![(1, 0.0)];
591 let lb: Vec<(DocId, f32)> = vec![(1, 0.0)];
592 let merged = fusion.fuse_multi(&[(&la, 1.0), (&lb, 3.0)], 10);
593 let s1 = merged.iter().find(|r| r.doc_id == 1).unwrap().score;
594 let expected = 1.0 / (k + 1.0) + 3.0 / (k + 1.0);
595 assert!(
596 (s1 - expected).abs() < 1e-6,
597 "weights must sum across lists"
598 );
599 }
600
601 #[test]
602 fn test_filter_results() {
603 let results = vec![
604 SearchResult {
605 doc_id: 1,
606 score: 0.9,
607 component_scores: None,
608 },
609 SearchResult {
610 doc_id: 2,
611 score: 0.8,
612 component_scores: None,
613 },
614 SearchResult {
615 doc_id: 3,
616 score: 0.7,
617 component_scores: None,
618 },
619 SearchResult {
620 doc_id: 4,
621 score: 0.6,
622 component_scores: None,
623 },
624 ];
625
626 let filtered = filter_results(results, |id| id % 2 == 0, 10);
628
629 assert_eq!(filtered.len(), 2);
630 assert_eq!(filtered[0].doc_id, 2);
631 assert_eq!(filtered[1].doc_id, 4);
632 }
633}