1use crate::ChunkId;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum FusionStrategy {
10 RRF {
12 k: f32,
14 },
15 Linear {
17 dense_weight: f32,
19 },
20 Convex {
22 alpha: f32,
24 },
25 DBSF,
27 Union,
29 Intersection,
31 #[cfg(feature = "multivector")]
33 ThreeWay {
34 dense_weight: f32,
36 sparse_weight: f32,
38 multivector_weight: f32,
40 },
41}
42
43impl Default for FusionStrategy {
44 fn default() -> Self {
45 Self::RRF { k: 60.0 }
46 }
47}
48
49impl FusionStrategy {
50 #[must_use]
52 pub fn fuse(
53 &self,
54 dense_results: &[(ChunkId, f32)],
55 sparse_results: &[(ChunkId, f32)],
56 ) -> Vec<(ChunkId, f32)> {
57 match self {
58 FusionStrategy::RRF { k } => {
59 Self::reciprocal_rank_fusion(dense_results, sparse_results, *k)
60 }
61 FusionStrategy::Linear { dense_weight } => {
62 Self::linear_fusion(dense_results, sparse_results, *dense_weight)
63 }
64 FusionStrategy::Convex { alpha } => {
65 Self::convex_fusion(dense_results, sparse_results, *alpha)
66 }
67 FusionStrategy::DBSF => Self::dbsf_fusion(dense_results, sparse_results),
68 FusionStrategy::Union => Self::union_fusion(dense_results, sparse_results),
69 FusionStrategy::Intersection => {
70 Self::intersection_fusion(dense_results, sparse_results)
71 }
72 #[cfg(feature = "multivector")]
74 FusionStrategy::ThreeWay { .. } => {
75 Self::reciprocal_rank_fusion(dense_results, sparse_results, 60.0)
76 }
77 }
78 }
79
80 fn reciprocal_rank_fusion(
84 dense: &[(ChunkId, f32)],
85 sparse: &[(ChunkId, f32)],
86 k: f32,
87 ) -> Vec<(ChunkId, f32)> {
88 let mut scores: HashMap<ChunkId, f32> = HashMap::new();
89
90 for (rank, (id, _)) in dense.iter().enumerate() {
91 *scores.entry(*id).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
92 }
93
94 for (rank, (id, _)) in sparse.iter().enumerate() {
95 *scores.entry(*id).or_insert(0.0) += 1.0 / (k + rank as f32 + 1.0);
96 }
97
98 Self::sort_by_score(scores)
99 }
100
101 fn linear_fusion(
103 dense: &[(ChunkId, f32)],
104 sparse: &[(ChunkId, f32)],
105 dense_weight: f32,
106 ) -> Vec<(ChunkId, f32)> {
107 let sparse_weight = 1.0 - dense_weight;
108
109 let dense_normalized = Self::min_max_normalize(dense);
111 let sparse_normalized = Self::min_max_normalize(sparse);
112
113 let mut scores: HashMap<ChunkId, f32> = HashMap::new();
114
115 for (id, score) in dense_normalized {
116 *scores.entry(id).or_insert(0.0) += dense_weight * score;
117 }
118
119 for (id, score) in sparse_normalized {
120 *scores.entry(id).or_insert(0.0) += sparse_weight * score;
121 }
122
123 Self::sort_by_score(scores)
124 }
125
126 fn convex_fusion(
128 dense: &[(ChunkId, f32)],
129 sparse: &[(ChunkId, f32)],
130 alpha: f32,
131 ) -> Vec<(ChunkId, f32)> {
132 Self::linear_fusion(dense, sparse, alpha)
134 }
135
136 fn dbsf_fusion(dense: &[(ChunkId, f32)], sparse: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
138 let dense_normalized = Self::z_score_normalize(dense);
140 let sparse_normalized = Self::z_score_normalize(sparse);
141
142 let mut scores: HashMap<ChunkId, f32> = HashMap::new();
143
144 for (id, score) in dense_normalized {
145 *scores.entry(id).or_insert(0.0) += score;
146 }
147
148 for (id, score) in sparse_normalized {
149 *scores.entry(id).or_insert(0.0) += score;
150 }
151
152 Self::sort_by_score(scores)
153 }
154
155 fn union_fusion(dense: &[(ChunkId, f32)], sparse: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
157 let mut scores: HashMap<ChunkId, (f32, usize)> = HashMap::new();
158
159 for (rank, (id, score)) in dense.iter().enumerate() {
161 scores.insert(*id, (*score, rank));
162 }
163
164 for (rank, (id, score)) in sparse.iter().enumerate() {
166 scores.entry(*id).or_insert((*score, dense.len() + rank));
167 }
168
169 let mut results: Vec<_> = scores.into_iter().collect();
170 results.sort_by(|a, b| a.1 .1.cmp(&b.1 .1)); results.into_iter().map(|(id, (score, _))| (id, score)).collect()
172 }
173
174 fn intersection_fusion(
176 dense: &[(ChunkId, f32)],
177 sparse: &[(ChunkId, f32)],
178 ) -> Vec<(ChunkId, f32)> {
179 let dense_ids: HashMap<ChunkId, f32> = dense.iter().copied().collect();
180 let sparse_ids: HashMap<ChunkId, f32> = sparse.iter().copied().collect();
181
182 let mut scores: HashMap<ChunkId, f32> = HashMap::new();
183
184 for (id, dense_score) in &dense_ids {
185 if let Some(sparse_score) = sparse_ids.get(id) {
186 scores.insert(*id, (dense_score + sparse_score) / 2.0);
188 }
189 }
190
191 Self::sort_by_score(scores)
192 }
193
194 fn min_max_normalize(results: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
196 if results.is_empty() {
197 return Vec::new();
198 }
199
200 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
201 let min = scores.iter().cloned().fold(f32::INFINITY, f32::min);
202 let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
203 let range = max - min;
204
205 if range.abs() < f32::EPSILON {
206 return results.iter().map(|(id, _)| (*id, 1.0)).collect();
208 }
209
210 results.iter().map(|(id, score)| (*id, (score - min) / range)).collect()
211 }
212
213 fn z_score_normalize(results: &[(ChunkId, f32)]) -> Vec<(ChunkId, f32)> {
215 if results.is_empty() {
216 return Vec::new();
217 }
218
219 let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
220 let n = scores.len() as f32;
221 let mean: f32 = scores.iter().sum::<f32>() / n;
222 let variance: f32 = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / n;
223 let std_dev = variance.sqrt();
224
225 if std_dev.abs() < f32::EPSILON {
226 return results.iter().map(|(id, _)| (*id, 0.0)).collect();
227 }
228
229 results.iter().map(|(id, score)| (*id, (score - mean) / std_dev)).collect()
230 }
231
232 fn sort_by_score(scores: HashMap<ChunkId, f32>) -> Vec<(ChunkId, f32)> {
234 let mut results: Vec<_> = scores.into_iter().collect();
235 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
236 results
237 }
238
239 #[cfg(feature = "multivector")]
254 #[must_use]
255 pub fn fuse_three(
256 &self,
257 dense: &[(ChunkId, f32)],
258 sparse: &[(ChunkId, f32)],
259 multivector: &[(ChunkId, f32)],
260 ) -> Vec<(ChunkId, f32)> {
261 if let FusionStrategy::ThreeWay { dense_weight, sparse_weight, multivector_weight } = self {
262 Self::three_way_linear(
263 dense,
264 sparse,
265 multivector,
266 *dense_weight,
267 *sparse_weight,
268 *multivector_weight,
269 )
270 } else {
271 let dense_sparse = self.fuse(dense, sparse);
274 self.fuse(&dense_sparse, multivector)
276 }
277 }
278
279 #[cfg(feature = "multivector")]
281 fn three_way_linear(
282 dense: &[(ChunkId, f32)],
283 sparse: &[(ChunkId, f32)],
284 multivector: &[(ChunkId, f32)],
285 w_dense: f32,
286 w_sparse: f32,
287 w_multi: f32,
288 ) -> Vec<(ChunkId, f32)> {
289 let mut scores: HashMap<ChunkId, f32> = HashMap::new();
290
291 let dense_norm = Self::min_max_normalize(dense);
293 let sparse_norm = Self::min_max_normalize(sparse);
294 let multi_norm = Self::min_max_normalize(multivector);
295
296 for (id, score) in dense_norm {
297 *scores.entry(id).or_insert(0.0) += w_dense * score;
298 }
299 for (id, score) in sparse_norm {
300 *scores.entry(id).or_insert(0.0) += w_sparse * score;
301 }
302 for (id, score) in multi_norm {
303 *scores.entry(id).or_insert(0.0) += w_multi * score;
304 }
305
306 Self::sort_by_score(scores)
307 }
308
309 #[cfg(feature = "multivector")]
313 #[must_use]
314 pub fn three_way(dense_weight: f32, sparse_weight: f32, multivector_weight: f32) -> Self {
315 Self::ThreeWay { dense_weight, sparse_weight, multivector_weight }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 fn chunk_id(n: u128) -> ChunkId {
324 ChunkId(uuid::Uuid::from_u128(n))
325 }
326
327 #[test]
330 fn test_fusion_strategy_default() {
331 let strategy = FusionStrategy::default();
332 match strategy {
333 FusionStrategy::RRF { k } => assert!((k - 60.0).abs() < 0.01),
334 _ => panic!("Expected RRF"),
335 }
336 }
337
338 #[test]
339 fn test_fusion_strategy_serialization() {
340 let strategy = FusionStrategy::Linear { dense_weight: 0.7 };
341 let json = serde_json::to_string(&strategy).unwrap();
342 let deserialized: FusionStrategy = serde_json::from_str(&json).unwrap();
343
344 match deserialized {
345 FusionStrategy::Linear { dense_weight } => {
346 assert!((dense_weight - 0.7).abs() < 0.01);
347 }
348 _ => panic!("Wrong strategy type"),
349 }
350 }
351
352 #[test]
355 fn test_rrf_empty() {
356 let strategy = FusionStrategy::RRF { k: 60.0 };
357 let results = strategy.fuse(&[], &[]);
358 assert!(results.is_empty());
359 }
360
361 #[test]
362 fn test_rrf_dense_only() {
363 let strategy = FusionStrategy::RRF { k: 60.0 };
364 let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
365 let results = strategy.fuse(&dense, &[]);
366
367 assert_eq!(results.len(), 2);
368 assert_eq!(results[0].0, chunk_id(1));
369 assert_eq!(results[1].0, chunk_id(2));
370 }
371
372 #[test]
373 fn test_rrf_sparse_only() {
374 let strategy = FusionStrategy::RRF { k: 60.0 };
375 let sparse = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
376 let results = strategy.fuse(&[], &sparse);
377
378 assert_eq!(results.len(), 2);
379 }
380
381 #[test]
382 fn test_rrf_combines_ranks() {
383 let strategy = FusionStrategy::RRF { k: 60.0 };
384
385 let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
389 let sparse = vec![(chunk_id(1), 0.9), (chunk_id(3), 0.8)];
390
391 let results = strategy.fuse(&dense, &sparse);
392
393 assert_eq!(results.len(), 3);
394 assert_eq!(results[0].0, chunk_id(1));
396 }
397
398 #[test]
399 fn test_rrf_score_calculation() {
400 let strategy = FusionStrategy::RRF { k: 60.0 };
401
402 let dense = vec![(chunk_id(1), 1.0)]; let sparse = vec![(chunk_id(1), 1.0)]; let results = strategy.fuse(&dense, &sparse);
406
407 let expected = 2.0 / 61.0;
409 assert!((results[0].1 - expected).abs() < 0.001);
410 }
411
412 #[test]
415 fn test_linear_empty() {
416 let strategy = FusionStrategy::Linear { dense_weight: 0.5 };
417 let results = strategy.fuse(&[], &[]);
418 assert!(results.is_empty());
419 }
420
421 #[test]
422 fn test_linear_dense_only() {
423 let strategy = FusionStrategy::Linear { dense_weight: 0.7 };
424 let dense = vec![(chunk_id(1), 1.0), (chunk_id(2), 0.5)];
425 let results = strategy.fuse(&dense, &[]);
426
427 assert!(!results.is_empty());
429 }
430
431 #[test]
432 fn test_linear_equal_weights() {
433 let strategy = FusionStrategy::Linear { dense_weight: 0.5 };
434
435 let dense = vec![(chunk_id(1), 1.0)];
436 let sparse = vec![(chunk_id(1), 1.0)];
437
438 let results = strategy.fuse(&dense, &sparse);
439
440 assert!((results[0].1 - 1.0).abs() < 0.01);
442 }
443
444 #[test]
445 fn test_linear_weight_preference() {
446 let strategy = FusionStrategy::Linear { dense_weight: 0.9 };
447
448 let dense = vec![(chunk_id(1), 1.0), (chunk_id(2), 0.0)];
451 let sparse = vec![(chunk_id(2), 1.0), (chunk_id(1), 0.0)];
452
453 let results = strategy.fuse(&dense, &sparse);
454
455 assert_eq!(results[0].0, chunk_id(1));
457 }
458
459 #[test]
462 fn test_convex_same_as_linear() {
463 let linear = FusionStrategy::Linear { dense_weight: 0.6 };
464 let convex = FusionStrategy::Convex { alpha: 0.6 };
465
466 let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.5)];
467 let sparse = vec![(chunk_id(2), 0.8), (chunk_id(3), 0.4)];
468
469 let linear_results = linear.fuse(&dense, &sparse);
470 let convex_results = convex.fuse(&dense, &sparse);
471
472 assert_eq!(linear_results.len(), convex_results.len());
473 }
474
475 #[test]
478 fn test_dbsf_empty() {
479 let strategy = FusionStrategy::DBSF;
480 let results = strategy.fuse(&[], &[]);
481 assert!(results.is_empty());
482 }
483
484 #[test]
485 fn test_dbsf_z_score() {
486 let strategy = FusionStrategy::DBSF;
487
488 let dense = vec![(chunk_id(1), 10.0), (chunk_id(2), 5.0), (chunk_id(3), 0.0)];
489 let sparse = vec![(chunk_id(1), 100.0), (chunk_id(2), 50.0), (chunk_id(3), 0.0)];
490
491 let results = strategy.fuse(&dense, &sparse);
492
493 assert_eq!(results[0].0, chunk_id(1));
495 }
496
497 #[test]
500 fn test_union_combines_all() {
501 let strategy = FusionStrategy::Union;
502
503 let dense = vec![(chunk_id(1), 0.9)];
504 let sparse = vec![(chunk_id(2), 0.8)];
505
506 let results = strategy.fuse(&dense, &sparse);
507
508 assert_eq!(results.len(), 2);
509 }
510
511 #[test]
512 fn test_union_deduplicates() {
513 let strategy = FusionStrategy::Union;
514
515 let dense = vec![(chunk_id(1), 0.9), (chunk_id(2), 0.8)];
516 let sparse = vec![(chunk_id(1), 0.7), (chunk_id(3), 0.6)];
517
518 let results = strategy.fuse(&dense, &sparse);
519
520 assert_eq!(results.len(), 3);
522 }
523
524 #[test]
525 fn test_union_prefers_dense_rank() {
526 let strategy = FusionStrategy::Union;
527
528 let dense = vec![(chunk_id(1), 0.9)];
529 let sparse = vec![(chunk_id(1), 0.5)]; let results = strategy.fuse(&dense, &sparse);
532
533 assert!((results[0].1 - 0.9).abs() < f32::EPSILON);
535 }
536
537 #[test]
540 fn test_intersection_empty_no_overlap() {
541 let strategy = FusionStrategy::Intersection;
542
543 let dense = vec![(chunk_id(1), 0.9)];
544 let sparse = vec![(chunk_id(2), 0.8)];
545
546 let results = strategy.fuse(&dense, &sparse);
547
548 assert!(results.is_empty());
549 }
550
551 #[test]
552 fn test_intersection_keeps_overlap() {
553 let strategy = FusionStrategy::Intersection;
554
555 let dense = vec![(chunk_id(1), 0.8), (chunk_id(2), 0.6)];
556 let sparse = vec![(chunk_id(2), 0.9), (chunk_id(3), 0.5)];
557
558 let results = strategy.fuse(&dense, &sparse);
559
560 assert_eq!(results.len(), 1);
562 assert_eq!(results[0].0, chunk_id(2));
563 }
564
565 #[test]
566 fn test_intersection_averages_scores() {
567 let strategy = FusionStrategy::Intersection;
568
569 let dense = vec![(chunk_id(1), 0.8)];
570 let sparse = vec![(chunk_id(1), 0.4)];
571
572 let results = strategy.fuse(&dense, &sparse);
573
574 assert!((results[0].1 - 0.6).abs() < 0.001);
576 }
577
578 #[test]
581 fn test_min_max_normalize_empty() {
582 let normalized = FusionStrategy::min_max_normalize(&[]);
583 assert!(normalized.is_empty());
584 }
585
586 #[test]
587 fn test_min_max_normalize_single() {
588 let results = vec![(chunk_id(1), 5.0)];
589 let normalized = FusionStrategy::min_max_normalize(&results);
590
591 assert_eq!(normalized.len(), 1);
592 assert!((normalized[0].1 - 1.0).abs() < 0.001);
593 }
594
595 #[test]
596 fn test_min_max_normalize_range() {
597 let results = vec![(chunk_id(1), 10.0), (chunk_id(2), 5.0), (chunk_id(3), 0.0)];
598 let normalized = FusionStrategy::min_max_normalize(&results);
599
600 assert!((normalized[0].1 - 1.0).abs() < 0.001);
602 assert!((normalized[2].1 - 0.0).abs() < 0.001);
603 assert!((normalized[1].1 - 0.5).abs() < 0.001);
604 }
605
606 #[test]
607 fn test_z_score_normalize_empty() {
608 let normalized = FusionStrategy::z_score_normalize(&[]);
609 assert!(normalized.is_empty());
610 }
611
612 #[test]
613 fn test_z_score_normalize_same_values() {
614 let results = vec![(chunk_id(1), 5.0), (chunk_id(2), 5.0), (chunk_id(3), 5.0)];
615 let normalized = FusionStrategy::z_score_normalize(&results);
616
617 for (_, score) in normalized {
619 assert!(score.abs() < 0.001);
620 }
621 }
622
623 use proptest::prelude::*;
626
627 proptest! {
628 #[test]
629 fn prop_rrf_scores_positive(
630 n_dense in 1usize..10,
631 n_sparse in 1usize..10
632 ) {
633 let dense: Vec<_> = (0..n_dense)
634 .map(|i| (chunk_id(i as u128), 1.0 - i as f32 * 0.1))
635 .collect();
636 let sparse: Vec<_> = (100..100 + n_sparse)
637 .map(|i| (chunk_id(i as u128), 1.0 - (i - 100) as f32 * 0.1))
638 .collect();
639
640 let strategy = FusionStrategy::RRF { k: 60.0 };
641 let results = strategy.fuse(&dense, &sparse);
642
643 for (_, score) in results {
644 prop_assert!(score > 0.0);
645 }
646 }
647
648 #[test]
649 fn prop_linear_weights_sum_to_one(dense_weight in 0.0f32..1.0) {
650 let dense = vec![(chunk_id(1), 1.0)];
651 let sparse = vec![(chunk_id(1), 1.0)];
652
653 let strategy = FusionStrategy::Linear { dense_weight };
654 let results = strategy.fuse(&dense, &sparse);
655
656 prop_assert!((results[0].1 - 1.0).abs() < 0.01);
658 }
659
660 #[test]
661 fn prop_intersection_subset_of_inputs(
662 dense_ids in prop::collection::vec(0u128..100, 1..10),
663 sparse_ids in prop::collection::vec(0u128..100, 1..10)
664 ) {
665 let dense: Vec<_> = dense_ids.iter().map(|&i| (chunk_id(i), 1.0)).collect();
666 let sparse: Vec<_> = sparse_ids.iter().map(|&i| (chunk_id(i), 1.0)).collect();
667
668 let strategy = FusionStrategy::Intersection;
669 let results = strategy.fuse(&dense, &sparse);
670
671 let dense_set: std::collections::HashSet<_> = dense_ids.iter().copied().collect();
672 let sparse_set: std::collections::HashSet<_> = sparse_ids.iter().copied().collect();
673
674 for (id, _) in results {
675 let id_val = id.0.as_u128();
676 prop_assert!(dense_set.contains(&id_val) && sparse_set.contains(&id_val));
677 }
678 }
679
680 #[test]
681 fn prop_fusion_deterministic(
682 n in 1usize..5
683 ) {
684 let dense: Vec<_> = (0..n).map(|i| (chunk_id(i as u128), 1.0)).collect();
685 let sparse: Vec<_> = (0..n).map(|i| (chunk_id(i as u128), 0.5)).collect();
686
687 let strategy = FusionStrategy::RRF { k: 60.0 };
688 let results1 = strategy.fuse(&dense, &sparse);
689 let results2 = strategy.fuse(&dense, &sparse);
690
691 prop_assert_eq!(results1.len(), results2.len());
692 for ((id1, s1), (id2, s2)) in results1.iter().zip(results2.iter()) {
693 prop_assert_eq!(id1, id2);
694 prop_assert!((s1 - s2).abs() < 0.0001);
695 }
696 }
697 }
698}