1#![allow(clippy::cast_precision_loss)]
4#![allow(clippy::unnecessary_wraps)]
5
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum FusionError {
11 InvalidWeightSum {
13 sum: f32,
15 },
16 NegativeWeight {
18 weight: f32,
20 },
21}
22
23impl std::fmt::Display for FusionError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::InvalidWeightSum { sum } => {
27 write!(f, "Weights must sum to 1.0, got {sum:.4}")
28 }
29 Self::NegativeWeight { weight } => {
30 write!(f, "Weights must be non-negative, got {weight:.4}")
31 }
32 }
33 }
34}
35
36impl std::error::Error for FusionError {}
37
38#[derive(Debug, Clone, PartialEq)]
46pub enum FusionStrategy {
47 Average,
51
52 Maximum,
56
57 RRF {
63 k: u32,
65 },
66
67 Weighted {
72 avg_weight: f32,
74 max_weight: f32,
76 hit_weight: f32,
78 },
79}
80
81impl FusionStrategy {
82 #[must_use]
84 pub fn rrf_default() -> Self {
85 Self::RRF { k: 60 }
86 }
87
88 pub fn weighted(
96 avg_weight: f32,
97 max_weight: f32,
98 hit_weight: f32,
99 ) -> Result<Self, FusionError> {
100 if avg_weight < 0.0 {
102 return Err(FusionError::NegativeWeight { weight: avg_weight });
103 }
104 if max_weight < 0.0 {
105 return Err(FusionError::NegativeWeight { weight: max_weight });
106 }
107 if hit_weight < 0.0 {
108 return Err(FusionError::NegativeWeight { weight: hit_weight });
109 }
110
111 let sum = avg_weight + max_weight + hit_weight;
113 if (sum - 1.0).abs() > 0.001 {
114 return Err(FusionError::InvalidWeightSum { sum });
115 }
116
117 Ok(Self::Weighted {
118 avg_weight,
119 max_weight,
120 hit_weight,
121 })
122 }
123
124 pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
139 if results.is_empty() {
140 return Ok(Vec::new());
141 }
142
143 let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
145 if non_empty_count == 0 {
146 return Ok(Vec::new());
147 }
148
149 let total_queries = results.len();
150
151 match self {
152 Self::Average => Self::fuse_average(results),
153 Self::Maximum => Self::fuse_maximum(results),
154 Self::RRF { k } => Self::fuse_rrf(results, *k),
155 Self::Weighted {
156 avg_weight,
157 max_weight,
158 hit_weight,
159 } => Ok(Self::fuse_weighted(
160 results,
161 *avg_weight,
162 *max_weight,
163 *hit_weight,
164 total_queries,
165 )),
166 }
167 }
168
169 fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
171 let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
172
173 for query_results in results {
174 let mut query_best: HashMap<u64, f32> = HashMap::new();
176 for (id, score) in query_results {
177 query_best
178 .entry(id)
179 .and_modify(|s| *s = s.max(score))
180 .or_insert(score);
181 }
182
183 for (id, score) in query_best {
184 doc_scores.entry(id).or_default().push(score);
185 }
186 }
187
188 let mut fused: Vec<(u64, f32)> = doc_scores
189 .into_iter()
190 .map(|(id, scores)| {
191 #[allow(clippy::cast_precision_loss)]
193 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
194 (id, avg)
195 })
196 .collect();
197
198 fused.sort_by(|a, b| b.1.total_cmp(&a.1));
200
201 Ok(fused)
202 }
203
204 fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
206 let mut doc_max: HashMap<u64, f32> = HashMap::new();
207
208 for query_results in results {
209 for (id, score) in query_results {
210 doc_max
211 .entry(id)
212 .and_modify(|s| *s = s.max(score))
213 .or_insert(score);
214 }
215 }
216
217 let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
218 fused.sort_by(|a, b| b.1.total_cmp(&a.1));
219
220 Ok(fused)
221 }
222
223 fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
225 let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
226 #[allow(clippy::cast_precision_loss)]
228 let k_f32 = k as f32;
229
230 for query_results in results {
231 let mut seen: HashMap<u64, usize> = HashMap::new();
233 for (rank, (id, _score)) in query_results.into_iter().enumerate() {
234 seen.entry(id).or_insert(rank);
236 }
237
238 for (id, rank) in seen {
239 #[allow(clippy::cast_precision_loss)]
241 let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
242 *doc_rrf.entry(id).or_insert(0.0) += rrf_score;
243 }
244 }
245
246 let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
247 fused.sort_by(|a, b| b.1.total_cmp(&a.1));
248
249 Ok(fused)
250 }
251
252 #[allow(clippy::cast_precision_loss)]
254 fn fuse_weighted(
255 results: Vec<Vec<(u64, f32)>>,
256 avg_weight: f32,
257 max_weight: f32,
258 hit_weight: f32,
259 total_queries: usize,
260 ) -> Vec<(u64, f32)> {
261 let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
263
264 for query_results in results {
265 let mut query_best: HashMap<u64, f32> = HashMap::new();
266 for (id, score) in query_results {
267 query_best
268 .entry(id)
269 .and_modify(|s| *s = s.max(score))
270 .or_insert(score);
271 }
272
273 for (id, score) in query_best {
274 doc_scores.entry(id).or_default().push(score);
275 }
276 }
277
278 #[allow(clippy::cast_precision_loss)]
280 let total_q = total_queries as f32;
281
282 let mut fused: Vec<(u64, f32)> = doc_scores
283 .into_iter()
284 .map(|(id, scores)| {
285 #[allow(clippy::cast_precision_loss)]
287 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
288 let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
289 #[allow(clippy::cast_precision_loss)]
290 let hit_ratio = scores.len() as f32 / total_q;
291
292 let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
293 (id, combined)
294 })
295 .collect();
296
297 fused.sort_by(|a, b| b.1.total_cmp(&a.1));
298
299 fused
300 }
301}
302
303impl Default for FusionStrategy {
304 fn default() -> Self {
305 Self::RRF { k: 60 }
306 }
307}