1#![allow(clippy::cast_precision_loss)]
4#![allow(clippy::unnecessary_wraps)]
5
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, PartialEq)]
10pub enum FusionError {
11 InvalidWeightSum {
13 sum: f32,
15 },
16 NegativeWeight {
18 weight: f32,
20 },
21}
22
23impl std::fmt::Display for FusionError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::InvalidWeightSum { sum } => {
27 write!(f, "Weights must sum to 1.0, got {sum:.4}")
28 }
29 Self::NegativeWeight { weight } => {
30 write!(f, "Weights must be non-negative, got {weight:.4}")
31 }
32 }
33 }
34}
35
36impl std::error::Error for FusionError {}
37
38#[derive(Debug, Clone, PartialEq)]
46pub enum FusionStrategy {
47 Average,
51
52 Maximum,
56
57 RRF {
63 k: u32,
65 },
66
67 Weighted {
72 avg_weight: f32,
74 max_weight: f32,
76 hit_weight: f32,
78 },
79}
80
81impl FusionStrategy {
82 #[must_use]
84 pub fn rrf_default() -> Self {
85 Self::RRF { k: 60 }
86 }
87
88 pub fn weighted(
96 avg_weight: f32,
97 max_weight: f32,
98 hit_weight: f32,
99 ) -> Result<Self, FusionError> {
100 if avg_weight < 0.0 {
102 return Err(FusionError::NegativeWeight { weight: avg_weight });
103 }
104 if max_weight < 0.0 {
105 return Err(FusionError::NegativeWeight { weight: max_weight });
106 }
107 if hit_weight < 0.0 {
108 return Err(FusionError::NegativeWeight { weight: hit_weight });
109 }
110
111 let sum = avg_weight + max_weight + hit_weight;
113 if (sum - 1.0).abs() > 0.001 {
114 return Err(FusionError::InvalidWeightSum { sum });
115 }
116
117 Ok(Self::Weighted {
118 avg_weight,
119 max_weight,
120 hit_weight,
121 })
122 }
123
124 pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
139 if results.is_empty() {
140 return Ok(Vec::new());
141 }
142
143 let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
145 if non_empty_count == 0 {
146 return Ok(Vec::new());
147 }
148
149 let total_queries = results.len();
150
151 match self {
152 Self::Average => Self::fuse_average(results),
153 Self::Maximum => Self::fuse_maximum(results),
154 Self::RRF { k } => Self::fuse_rrf(results, *k),
155 Self::Weighted {
156 avg_weight,
157 max_weight,
158 hit_weight,
159 } => Ok(Self::fuse_weighted(
160 results,
161 *avg_weight,
162 *max_weight,
163 *hit_weight,
164 total_queries,
165 )),
166 }
167 }
168
169 fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
171 let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
172
173 for query_results in results {
174 let mut query_best: HashMap<u64, f32> = HashMap::new();
176 for (id, score) in query_results {
177 query_best
178 .entry(id)
179 .and_modify(|s| *s = s.max(score))
180 .or_insert(score);
181 }
182
183 for (id, score) in query_best {
184 doc_scores.entry(id).or_default().push(score);
185 }
186 }
187
188 let mut fused: Vec<(u64, f32)> = doc_scores
189 .into_iter()
190 .map(|(id, scores)| {
191 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
192 (id, avg)
193 })
194 .collect();
195
196 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198
199 Ok(fused)
200 }
201
202 fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
204 let mut doc_max: HashMap<u64, f32> = HashMap::new();
205
206 for query_results in results {
207 for (id, score) in query_results {
208 doc_max
209 .entry(id)
210 .and_modify(|s| *s = s.max(score))
211 .or_insert(score);
212 }
213 }
214
215 let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
216 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
217
218 Ok(fused)
219 }
220
221 fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
223 let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
224 let k_f32 = k as f32;
225
226 for query_results in results {
227 let mut seen: HashMap<u64, usize> = HashMap::new();
229 for (rank, (id, _score)) in query_results.into_iter().enumerate() {
230 seen.entry(id).or_insert(rank);
232 }
233
234 for (id, rank) in seen {
235 let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
236 *doc_rrf.entry(id).or_insert(0.0) += rrf_score;
237 }
238 }
239
240 let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
241 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
242
243 Ok(fused)
244 }
245
246 #[allow(clippy::cast_precision_loss)]
248 fn fuse_weighted(
249 results: Vec<Vec<(u64, f32)>>,
250 avg_weight: f32,
251 max_weight: f32,
252 hit_weight: f32,
253 total_queries: usize,
254 ) -> Vec<(u64, f32)> {
255 let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
257
258 for query_results in results {
259 let mut query_best: HashMap<u64, f32> = HashMap::new();
260 for (id, score) in query_results {
261 query_best
262 .entry(id)
263 .and_modify(|s| *s = s.max(score))
264 .or_insert(score);
265 }
266
267 for (id, score) in query_best {
268 doc_scores.entry(id).or_default().push(score);
269 }
270 }
271
272 let total_q = total_queries as f32;
273
274 let mut fused: Vec<(u64, f32)> = doc_scores
275 .into_iter()
276 .map(|(id, scores)| {
277 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
278 let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
279 let hit_ratio = scores.len() as f32 / total_q;
280
281 let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
282 (id, combined)
283 })
284 .collect();
285
286 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
287
288 fused
289 }
290}
291
292impl Default for FusionStrategy {
293 fn default() -> Self {
294 Self::RRF { k: 60 }
295 }
296}