1use std::collections::{HashMap, HashSet};
9use uni_common::Vid;
10
11pub fn fuse_rrf_multi(ranked_lists: &[&[(Vid, f32)]], k: usize) -> Vec<(Vid, f32)> {
18 let mut scores: HashMap<Vid, f32> = HashMap::new();
19
20 for ranked_list in ranked_lists {
21 for (rank, (vid, _)) in ranked_list.iter().enumerate() {
22 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
23 *scores.entry(*vid).or_default() += rrf_score;
24 }
25 }
26
27 let mut results: Vec<_> = scores.into_iter().collect();
28 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
29 results
30}
31
32pub fn fuse_rrf(
37 vec_results: &[(Vid, f32)],
38 fts_results: &[(Vid, f32)],
39 k: usize,
40) -> Vec<(Vid, f32)> {
41 fuse_rrf_multi(&[vec_results, fts_results], k)
42}
43
44pub fn fuse_weighted(
50 vec_results: &[(Vid, f32)],
51 fts_results: &[(Vid, f32)],
52 alpha: f32,
53) -> Vec<(Vid, f32)> {
54 let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
56 let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
57 let vec_range = if vec_max > vec_min {
58 vec_max - vec_min
59 } else {
60 1.0
61 };
62
63 let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
64
65 let vec_scores: HashMap<Vid, f32> = vec_results
66 .iter()
67 .map(|(vid, dist)| {
68 let norm = 1.0 - (dist - vec_min) / vec_range;
69 (*vid, norm)
70 })
71 .collect();
72
73 let fts_scores: HashMap<Vid, f32> = fts_results
74 .iter()
75 .map(|(vid, score)| {
76 let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
77 (*vid, norm)
78 })
79 .collect();
80
81 let all_vids: HashSet<Vid> = vec_scores
82 .keys()
83 .chain(fts_scores.keys())
84 .cloned()
85 .collect();
86
87 let mut results: Vec<(Vid, f32)> = all_vids
88 .into_iter()
89 .map(|vid| {
90 let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
91 let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
92 let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
93 (vid, fused)
94 })
95 .collect();
96
97 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
98 results
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum NormKind {
107 DistanceToSim,
109 ScoreByMax,
111}
112
113pub type WeightedSource<'a> = (&'a [(Vid, f32)], f32, NormKind);
116
117pub fn fuse_weighted_sources(sources: &[WeightedSource<'_>]) -> Vec<(Vid, f32)> {
128 let mut fused: HashMap<Vid, f32> = HashMap::new();
129
130 for (results, weight, norm) in sources {
131 let normalized: HashMap<Vid, f32> = match norm {
132 NormKind::DistanceToSim => {
133 let max = results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
134 let min = results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
135 let range = if max > min { max - min } else { 1.0 };
136 results
137 .iter()
138 .map(|(vid, dist)| (*vid, 1.0 - (dist - min) / range))
139 .collect()
140 }
141 NormKind::ScoreByMax => {
142 let max = results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
143 results
144 .iter()
145 .map(|(vid, score)| {
146 let norm = if max > 0.0 { score / max } else { 0.0 };
147 (*vid, norm)
148 })
149 .collect()
150 }
151 };
152
153 for (vid, norm_score) in normalized {
154 *fused.entry(vid).or_default() += weight * norm_score;
155 }
156 }
157
158 let mut results: Vec<(Vid, f32)> = fused.into_iter().collect();
159 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160 results
161}
162
163pub fn fuse_weighted_multi(scores: &[f32], weights: &[f32]) -> f32 {
168 debug_assert_eq!(scores.len(), weights.len());
169 scores.iter().zip(weights.iter()).map(|(s, w)| s * w).sum()
170}
171
172pub fn fuse_rrf_point(scores: &[f32]) -> (f32, bool) {
180 if scores.is_empty() {
181 return (0.0, false);
182 }
183 let weight = 1.0 / scores.len() as f32;
184 let fused: f32 = scores.iter().map(|s| s * weight).sum();
185 (fused, true)
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn test_fuse_weighted_multi() {
194 let scores = vec![0.8, 0.6];
195 let weights = vec![0.7, 0.3];
196 let result = fuse_weighted_multi(&scores, &weights);
197 assert!((result - 0.74).abs() < 1e-6);
198 }
199
200 #[test]
201 fn test_fuse_weighted_multi_equal() {
202 let scores = vec![0.5, 0.5, 0.5];
203 let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
204 let result = fuse_weighted_multi(&scores, &weights);
205 assert!((result - 0.5).abs() < 1e-6);
206 }
207
208 #[test]
209 fn test_fuse_rrf_point_fallback() {
210 let scores = vec![0.8, 0.6];
211 let (result, used_fallback) = fuse_rrf_point(&scores);
212 assert!(used_fallback);
213 assert!((result - 0.7).abs() < 1e-6);
214 }
215
216 #[test]
217 fn test_fuse_rrf_point_empty() {
218 let (result, used_fallback) = fuse_rrf_point(&[]);
219 assert!(!used_fallback);
220 assert!((result - 0.0).abs() < 1e-6);
221 }
222
223 #[test]
224 fn test_fuse_rrf_disjoint_lists() {
225 let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
226 let fts_results = vec![(Vid::from(3u64), 0.8), (Vid::from(4u64), 0.6)];
227 let fused = fuse_rrf(&vec_results, &fts_results, 60);
228
229 assert_eq!(fused.len(), 4);
231 let vids: HashSet<Vid> = fused.iter().map(|(v, _)| *v).collect();
232 assert!(vids.contains(&Vid::from(1u64)));
233 assert!(vids.contains(&Vid::from(2u64)));
234 assert!(vids.contains(&Vid::from(3u64)));
235 assert!(vids.contains(&Vid::from(4u64)));
236 }
237
238 #[test]
239 fn test_fuse_rrf_overlapping_lists() {
240 let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
241 let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
242 let fused = fuse_rrf(&vec_results, &fts_results, 60);
243
244 assert_eq!(fused.len(), 3);
246 assert_eq!(
247 fused[0].0,
248 Vid::from(1u64),
249 "Overlapping VID should rank first"
250 );
251 }
252
253 #[test]
254 fn test_fuse_rrf_empty_lists() {
255 let fused = fuse_rrf(&[], &[], 60);
256 assert!(fused.is_empty());
257 }
258
259 #[test]
260 fn test_fuse_rrf_multi_three_sources_overlap_wins() {
261 let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
262 let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
263 let sparse_results = vec![(Vid::from(1u64), 5.0), (Vid::from(4u64), 1.0)];
264 let fused = fuse_rrf_multi(&[&vec_results, &fts_results, &sparse_results], 60);
265
266 assert_eq!(fused.len(), 4);
268 assert_eq!(fused[0].0, Vid::from(1u64));
269 }
270
271 #[test]
272 fn test_fuse_rrf_multi_empty_third_source_is_noop() {
273 let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
274 let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
275
276 let two_way: HashMap<Vid, f32> = fuse_rrf(&vec_results, &fts_results, 60)
280 .into_iter()
281 .collect();
282 let three_way: HashMap<Vid, f32> = fuse_rrf_multi(&[&vec_results, &fts_results, &[]], 60)
283 .into_iter()
284 .collect();
285
286 assert_eq!(two_way, three_way, "absent sparse source must be a no-op");
287 }
288
289 #[test]
290 fn test_fuse_weighted_sources_normalizes_per_source() {
291 let vec_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 1.0)];
293 let sparse_results = vec![(Vid::from(1u64), 2.0), (Vid::from(2u64), 4.0)];
294 let fused = fuse_weighted_sources(&[
295 (&vec_results, 0.5, NormKind::DistanceToSim),
296 (&sparse_results, 0.5, NormKind::ScoreByMax),
297 ]);
298
299 let v1 = fused.iter().find(|(v, _)| *v == Vid::from(1u64)).unwrap().1;
302 let v2 = fused.iter().find(|(v, _)| *v == Vid::from(2u64)).unwrap().1;
303 assert!((v1 - 0.75).abs() < 1e-6);
304 assert!((v2 - 0.50).abs() < 1e-6);
305 assert_eq!(fused[0].0, Vid::from(1u64));
306 }
307
308 #[test]
309 fn test_fuse_weighted_sources_zero_max_sparse() {
310 let vec_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 1.0)];
312 let sparse_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 0.0)];
313 let fused = fuse_weighted_sources(&[
314 (&vec_results, 0.5, NormKind::DistanceToSim),
315 (&sparse_results, 0.5, NormKind::ScoreByMax),
316 ]);
317
318 let v1 = fused.iter().find(|(v, _)| *v == Vid::from(1u64)).unwrap().1;
319 assert!(
320 (v1 - 0.5).abs() < 1e-6,
321 "sparse contributes 0 when all zero"
322 );
323 }
324}