Skip to main content

velesdb_core/fusion/
strategy.rs

1//! Fusion strategies for combining multi-query search results.
2
3#![allow(clippy::cast_precision_loss)]
4#![allow(clippy::unnecessary_wraps)]
5
6use std::collections::HashMap;
7
8/// Error type for fusion operations.
9#[derive(Debug, Clone, PartialEq)]
10pub enum FusionError {
11    /// Weights do not sum to 1.0 (within tolerance).
12    InvalidWeightSum {
13        /// The actual sum of weights.
14        sum: f32,
15    },
16    /// Negative weight provided.
17    NegativeWeight {
18        /// The negative weight value.
19        weight: f32,
20    },
21}
22
23impl std::fmt::Display for FusionError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::InvalidWeightSum { sum } => {
27                write!(f, "Weights must sum to 1.0, got {sum:.4}")
28            }
29            Self::NegativeWeight { weight } => {
30                write!(f, "Weights must be non-negative, got {weight:.4}")
31            }
32        }
33    }
34}
35
36impl std::error::Error for FusionError {}
37
38/// Strategy for fusing results from multiple vector searches.
39///
40/// Each strategy combines results differently, optimizing for various use cases:
41/// - `Average`: Good for general-purpose fusion
42/// - `Maximum`: Emphasizes documents that score very high in any query
43/// - `RRF`: Position-based fusion, robust to score scale differences
44/// - `Weighted`: Custom combination with explicit control over factors
45#[derive(Debug, Clone, PartialEq)]
46pub enum FusionStrategy {
47    /// Average score across all queries where the document appears.
48    ///
49    /// Score = mean(scores for this document across queries)
50    Average,
51
52    /// Maximum score across all queries.
53    ///
54    /// Score = max(scores for this document across queries)
55    Maximum,
56
57    /// Reciprocal Rank Fusion.
58    ///
59    /// Score = Σ 1/(k + `rank_i`) for each query where document appears.
60    /// Standard k=60 provides good balance between emphasizing top ranks
61    /// while still considering lower-ranked results.
62    RRF {
63        /// Ranking constant (default: 60).
64        k: u32,
65    },
66
67    /// Weighted combination of average, maximum, and hit ratio.
68    ///
69    /// Score = `avg_weight` × `avg_score` + `max_weight` × `max_score` + `hit_weight` × `hit_ratio`
70    /// where `hit_ratio` = (number of queries containing doc) / (total queries)
71    Weighted {
72        /// Weight for average score component.
73        avg_weight: f32,
74        /// Weight for maximum score component.
75        max_weight: f32,
76        /// Weight for hit ratio component.
77        hit_weight: f32,
78    },
79}
80
81impl FusionStrategy {
82    /// Creates an RRF strategy with the standard k=60 parameter.
83    #[must_use]
84    pub fn rrf_default() -> Self {
85        Self::RRF { k: 60 }
86    }
87
88    /// Creates a Weighted strategy with validation.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if:
93    /// - Weights do not sum to 1.0 (within 0.001 tolerance)
94    /// - Any weight is negative
95    pub fn weighted(
96        avg_weight: f32,
97        max_weight: f32,
98        hit_weight: f32,
99    ) -> Result<Self, FusionError> {
100        // Validate non-negative
101        if avg_weight < 0.0 {
102            return Err(FusionError::NegativeWeight { weight: avg_weight });
103        }
104        if max_weight < 0.0 {
105            return Err(FusionError::NegativeWeight { weight: max_weight });
106        }
107        if hit_weight < 0.0 {
108            return Err(FusionError::NegativeWeight { weight: hit_weight });
109        }
110
111        // Validate sum to 1.0
112        let sum = avg_weight + max_weight + hit_weight;
113        if (sum - 1.0).abs() > 0.001 {
114            return Err(FusionError::InvalidWeightSum { sum });
115        }
116
117        Ok(Self::Weighted {
118            avg_weight,
119            max_weight,
120            hit_weight,
121        })
122    }
123
124    /// Fuses results from multiple queries into a single ranked list.
125    ///
126    /// # Arguments
127    ///
128    /// * `results` - Vec of search results, one per query. Each inner Vec
129    ///   contains `(document_id, score)` tuples, assumed sorted by score descending.
130    ///
131    /// # Returns
132    ///
133    /// A single Vec of `(document_id, fused_score)` sorted by score descending.
134    ///
135    /// # Errors
136    ///
137    /// Currently infallible, but returns Result for future extensibility.
138    pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
139        if results.is_empty() {
140            return Ok(Vec::new());
141        }
142
143        // Filter out empty query results for counting
144        let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
145        if non_empty_count == 0 {
146            return Ok(Vec::new());
147        }
148
149        let total_queries = results.len();
150
151        match self {
152            Self::Average => Self::fuse_average(results),
153            Self::Maximum => Self::fuse_maximum(results),
154            Self::RRF { k } => Self::fuse_rrf(results, *k),
155            Self::Weighted {
156                avg_weight,
157                max_weight,
158                hit_weight,
159            } => Ok(Self::fuse_weighted(
160                results,
161                *avg_weight,
162                *max_weight,
163                *hit_weight,
164                total_queries,
165            )),
166        }
167    }
168
169    /// Average fusion: mean of scores for each document.
170    fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
171        let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
172
173        for query_results in results {
174            // Deduplicate within query (take best score for each doc)
175            let mut query_best: HashMap<u64, f32> = HashMap::new();
176            for (id, score) in query_results {
177                query_best
178                    .entry(id)
179                    .and_modify(|s| *s = s.max(score))
180                    .or_insert(score);
181            }
182
183            for (id, score) in query_best {
184                doc_scores.entry(id).or_default().push(score);
185            }
186        }
187
188        let mut fused: Vec<(u64, f32)> = doc_scores
189            .into_iter()
190            .map(|(id, scores)| {
191                // SAFETY: scores.len() is typically < 1000, fits in f32 with full precision
192                #[allow(clippy::cast_precision_loss)]
193                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
194                (id, avg)
195            })
196            .collect();
197
198        // Sort by score descending
199        fused.sort_by(|a, b| b.1.total_cmp(&a.1));
200
201        Ok(fused)
202    }
203
204    /// Maximum fusion: best score for each document.
205    fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
206        let mut doc_max: HashMap<u64, f32> = HashMap::new();
207
208        for query_results in results {
209            for (id, score) in query_results {
210                doc_max
211                    .entry(id)
212                    .and_modify(|s| *s = s.max(score))
213                    .or_insert(score);
214            }
215        }
216
217        let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
218        fused.sort_by(|a, b| b.1.total_cmp(&a.1));
219
220        Ok(fused)
221    }
222
223    /// RRF fusion: reciprocal rank fusion.
224    fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
225        let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
226        // SAFETY: k is typically 60, fits in f32 with full precision
227        #[allow(clippy::cast_precision_loss)]
228        let k_f32 = k as f32;
229
230        for query_results in results {
231            // Deduplicate and get rank order
232            let mut seen: HashMap<u64, usize> = HashMap::new();
233            for (rank, (id, _score)) in query_results.into_iter().enumerate() {
234                // Only count first occurrence (best rank) for each doc in this query
235                seen.entry(id).or_insert(rank);
236            }
237
238            for (id, rank) in seen {
239                // SAFETY: rank is typically < 1000, fits in f32 with full precision
240                #[allow(clippy::cast_precision_loss)]
241                let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
242                *doc_rrf.entry(id).or_insert(0.0) += rrf_score;
243            }
244        }
245
246        let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
247        fused.sort_by(|a, b| b.1.total_cmp(&a.1));
248
249        Ok(fused)
250    }
251
252    /// Weighted fusion: combination of avg, max, and hit ratio.
253    #[allow(clippy::cast_precision_loss)]
254    fn fuse_weighted(
255        results: Vec<Vec<(u64, f32)>>,
256        avg_weight: f32,
257        max_weight: f32,
258        hit_weight: f32,
259        total_queries: usize,
260    ) -> Vec<(u64, f32)> {
261        // Collect all scores per document
262        let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
263
264        for query_results in results {
265            let mut query_best: HashMap<u64, f32> = HashMap::new();
266            for (id, score) in query_results {
267                query_best
268                    .entry(id)
269                    .and_modify(|s| *s = s.max(score))
270                    .or_insert(score);
271            }
272
273            for (id, score) in query_best {
274                doc_scores.entry(id).or_default().push(score);
275            }
276        }
277
278        // SAFETY: total_queries is typically < 100, fits in f32 with full precision
279        #[allow(clippy::cast_precision_loss)]
280        let total_q = total_queries as f32;
281
282        let mut fused: Vec<(u64, f32)> = doc_scores
283            .into_iter()
284            .map(|(id, scores)| {
285                // SAFETY: scores.len() is typically < 1000, fits in f32 with full precision
286                #[allow(clippy::cast_precision_loss)]
287                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
288                let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
289                #[allow(clippy::cast_precision_loss)]
290                let hit_ratio = scores.len() as f32 / total_q;
291
292                let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
293                (id, combined)
294            })
295            .collect();
296
297        fused.sort_by(|a, b| b.1.total_cmp(&a.1));
298
299        fused
300    }
301}
302
303impl Default for FusionStrategy {
304    fn default() -> Self {
305        Self::RRF { k: 60 }
306    }
307}