1#![allow(clippy::unnecessary_wraps)]
4
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, PartialEq)]
9pub enum FusionError {
10 InvalidWeightSum {
12 sum: f32,
14 },
15 NegativeWeight {
17 weight: f32,
19 },
20}
21
22impl std::fmt::Display for FusionError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::InvalidWeightSum { sum } => {
26 write!(f, "Weights must sum to 1.0, got {sum:.4}")
27 }
28 Self::NegativeWeight { weight } => {
29 write!(f, "Weights must be non-negative, got {weight:.4}")
30 }
31 }
32 }
33}
34
35impl std::error::Error for FusionError {}
36
37#[derive(Debug, Clone, PartialEq)]
45pub enum FusionStrategy {
46 Average,
50
51 Maximum,
55
56 RRF {
62 k: u32,
64 },
65
66 Weighted {
71 avg_weight: f32,
73 max_weight: f32,
75 hit_weight: f32,
77 },
78
79 RelativeScore {
85 dense_weight: f32,
87 sparse_weight: f32,
89 },
90}
91
92impl FusionStrategy {
93 #[must_use]
95 pub fn rrf_default() -> Self {
96 Self::RRF { k: 60 }
97 }
98
99 pub fn relative_score(dense_weight: f32, sparse_weight: f32) -> Result<Self, FusionError> {
107 validate_non_negative(&[dense_weight, sparse_weight])?;
108 validate_weight_sum(dense_weight + sparse_weight)?;
109 Ok(Self::RelativeScore {
110 dense_weight,
111 sparse_weight,
112 })
113 }
114
115 pub fn weighted(
123 avg_weight: f32,
124 max_weight: f32,
125 hit_weight: f32,
126 ) -> Result<Self, FusionError> {
127 validate_non_negative(&[avg_weight, max_weight, hit_weight])?;
128 validate_weight_sum(avg_weight + max_weight + hit_weight)?;
129
130 Ok(Self::Weighted {
131 avg_weight,
132 max_weight,
133 hit_weight,
134 })
135 }
136
137 pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
152 if results.is_empty() {
153 return Ok(Vec::new());
154 }
155
156 let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
158 if non_empty_count == 0 {
159 return Ok(Vec::new());
160 }
161
162 let total_queries = results.len();
163
164 match self {
165 Self::Average => Self::fuse_average(results),
166 Self::Maximum => Self::fuse_maximum(results),
167 Self::RRF { k } => Self::fuse_rrf(results, *k),
168 Self::Weighted {
169 avg_weight,
170 max_weight,
171 hit_weight,
172 } => Ok(Self::fuse_weighted(
173 results,
174 *avg_weight,
175 *max_weight,
176 *hit_weight,
177 total_queries,
178 )),
179 Self::RelativeScore {
180 dense_weight,
181 sparse_weight,
182 } => Self::fuse_relative_score(&results, *dense_weight, *sparse_weight),
183 }
184 }
185
186 fn collect_doc_scores(results: Vec<Vec<(u64, f32)>>) -> HashMap<u64, Vec<f32>> {
191 let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
192
193 for query_results in results {
194 let mut query_best: HashMap<u64, f32> = HashMap::new();
195 for (id, score) in query_results {
196 query_best
197 .entry(id)
198 .and_modify(|s| *s = s.max(score))
199 .or_insert(score);
200 }
201
202 for (id, score) in query_best {
203 doc_scores.entry(id).or_default().push(score);
204 }
205 }
206
207 doc_scores
208 }
209
210 fn sort_descending(fused: &mut [(u64, f32)]) {
212 fused.sort_by(|a, b| b.1.total_cmp(&a.1));
213 }
214
215 #[allow(clippy::cast_precision_loss)]
217 fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
220 let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
221 .into_iter()
222 .map(|(id, scores)| {
223 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
224 (id, avg)
225 })
226 .collect();
227
228 Self::sort_descending(&mut fused);
229 Ok(fused)
230 }
231
232 fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
234 let mut doc_max: HashMap<u64, f32> = HashMap::new();
235
236 for query_results in results {
237 for (id, score) in query_results {
238 doc_max
239 .entry(id)
240 .and_modify(|s| *s = s.max(score))
241 .or_insert(score);
242 }
243 }
244
245 let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
246 Self::sort_descending(&mut fused);
247 Ok(fused)
248 }
249
250 #[allow(clippy::cast_precision_loss)]
252 fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
255 let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
256 let k_f32 = k as f32;
259
260 for query_results in results {
261 let mut seen: HashMap<u64, usize> = HashMap::new();
263 for (rank, (id, _score)) in query_results.into_iter().enumerate() {
264 seen.entry(id).or_insert(rank);
266 }
267
268 for (id, rank) in seen {
269 let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
270 *doc_rrf.entry(id).or_insert(0.0) += rrf_score;
271 }
272 }
273
274 let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
275 Self::sort_descending(&mut fused);
276 Ok(fused)
277 }
278
279 #[allow(clippy::cast_precision_loss)]
281 fn fuse_weighted(
284 results: Vec<Vec<(u64, f32)>>,
285 avg_weight: f32,
286 max_weight: f32,
287 hit_weight: f32,
288 total_queries: usize,
289 ) -> Vec<(u64, f32)> {
290 let total_q = total_queries as f32;
291
292 let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
293 .into_iter()
294 .map(|(id, scores)| {
295 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
296 let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
297 let hit_ratio = scores.len() as f32 / total_q;
298
299 let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
300 (id, combined)
301 })
302 .collect();
303
304 Self::sort_descending(&mut fused);
305 fused
306 }
307
308 fn fuse_relative_score(
315 results: &[Vec<(u64, f32)>],
316 dense_weight: f32,
317 sparse_weight: f32,
318 ) -> Result<Vec<(u64, f32)>, FusionError> {
319 if results.len() > 2 {
320 tracing::warn!(
321 branch_count = results.len(),
322 "RelativeScore fusion received {} branches but only supports 2 (dense + sparse). \
323 Branches beyond index 1 are ignored.",
324 results.len(),
325 );
326 }
327
328 let dense = results.first().map_or(&[][..], |v| v.as_slice());
329 let sparse = results.get(1).map_or(&[][..], |v| v.as_slice());
330
331 let norm_dense = min_max_normalize(dense);
332 let norm_sparse = min_max_normalize(sparse);
333
334 let mut all_ids: HashMap<u64, f32> = HashMap::new();
336 for (&id, &nd) in &norm_dense {
337 let ns = norm_sparse.get(&id).copied().unwrap_or(0.0);
338 all_ids.insert(id, dense_weight * nd + sparse_weight * ns);
339 }
340 for (&id, &ns) in &norm_sparse {
341 all_ids.entry(id).or_insert_with(|| {
342 let nd = norm_dense.get(&id).copied().unwrap_or(0.0);
343 dense_weight * nd + sparse_weight * ns
344 });
345 }
346
347 let mut fused: Vec<(u64, f32)> = all_ids.into_iter().collect();
348 Self::sort_descending(&mut fused);
349 Ok(fused)
350 }
351}
352
353impl Default for FusionStrategy {
354 fn default() -> Self {
355 Self::RRF { k: 60 }
356 }
357}
358
359fn validate_non_negative(weights: &[f32]) -> Result<(), FusionError> {
365 for &w in weights {
366 if w < 0.0 {
367 return Err(FusionError::NegativeWeight { weight: w });
368 }
369 }
370 Ok(())
371}
372
373fn validate_weight_sum(sum: f32) -> Result<(), FusionError> {
375 if (sum - 1.0).abs() > 0.001 {
376 return Err(FusionError::InvalidWeightSum { sum });
377 }
378 Ok(())
379}
380
381fn min_max_normalize(branch: &[(u64, f32)]) -> HashMap<u64, f32> {
385 if branch.is_empty() {
386 return HashMap::new();
387 }
388 let min = branch.iter().map(|&(_, s)| s).fold(f32::INFINITY, f32::min);
389 let max = branch
390 .iter()
391 .map(|&(_, s)| s)
392 .fold(f32::NEG_INFINITY, f32::max);
393 let range = max - min;
394 branch
395 .iter()
396 .map(|&(id, s)| {
397 let norm = if range < f32::EPSILON {
398 0.5
399 } else {
400 (s - min) / range
401 };
402 (id, norm)
403 })
404 .collect()
405}