velesdb_core/fusion/
strategy.rs

1//! Fusion strategies for combining multi-query search results.
2
3#![allow(clippy::cast_precision_loss)]
4#![allow(clippy::unnecessary_wraps)]
5
6use std::collections::HashMap;
7
8/// Error type for fusion operations.
9#[derive(Debug, Clone, PartialEq)]
10pub enum FusionError {
11    /// Weights do not sum to 1.0 (within tolerance).
12    InvalidWeightSum {
13        /// The actual sum of weights.
14        sum: f32,
15    },
16    /// Negative weight provided.
17    NegativeWeight {
18        /// The negative weight value.
19        weight: f32,
20    },
21}
22
23impl std::fmt::Display for FusionError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::InvalidWeightSum { sum } => {
27                write!(f, "Weights must sum to 1.0, got {sum:.4}")
28            }
29            Self::NegativeWeight { weight } => {
30                write!(f, "Weights must be non-negative, got {weight:.4}")
31            }
32        }
33    }
34}
35
36impl std::error::Error for FusionError {}
37
38/// Strategy for fusing results from multiple vector searches.
39///
40/// Each strategy combines results differently, optimizing for various use cases:
41/// - `Average`: Good for general-purpose fusion
42/// - `Maximum`: Emphasizes documents that score very high in any query
43/// - `RRF`: Position-based fusion, robust to score scale differences
44/// - `Weighted`: Custom combination with explicit control over factors
45#[derive(Debug, Clone, PartialEq)]
46pub enum FusionStrategy {
47    /// Average score across all queries where the document appears.
48    ///
49    /// Score = mean(scores for this document across queries)
50    Average,
51
52    /// Maximum score across all queries.
53    ///
54    /// Score = max(scores for this document across queries)
55    Maximum,
56
57    /// Reciprocal Rank Fusion.
58    ///
59    /// Score = Σ 1/(k + `rank_i`) for each query where document appears.
60    /// Standard k=60 provides good balance between emphasizing top ranks
61    /// while still considering lower-ranked results.
62    RRF {
63        /// Ranking constant (default: 60).
64        k: u32,
65    },
66
67    /// Weighted combination of average, maximum, and hit ratio.
68    ///
69    /// Score = `avg_weight` × `avg_score` + `max_weight` × `max_score` + `hit_weight` × `hit_ratio`
70    /// where `hit_ratio` = (number of queries containing doc) / (total queries)
71    Weighted {
72        /// Weight for average score component.
73        avg_weight: f32,
74        /// Weight for maximum score component.
75        max_weight: f32,
76        /// Weight for hit ratio component.
77        hit_weight: f32,
78    },
79}
80
81impl FusionStrategy {
82    /// Creates an RRF strategy with the standard k=60 parameter.
83    #[must_use]
84    pub fn rrf_default() -> Self {
85        Self::RRF { k: 60 }
86    }
87
88    /// Creates a Weighted strategy with validation.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if:
93    /// - Weights do not sum to 1.0 (within 0.001 tolerance)
94    /// - Any weight is negative
95    pub fn weighted(
96        avg_weight: f32,
97        max_weight: f32,
98        hit_weight: f32,
99    ) -> Result<Self, FusionError> {
100        // Validate non-negative
101        if avg_weight < 0.0 {
102            return Err(FusionError::NegativeWeight { weight: avg_weight });
103        }
104        if max_weight < 0.0 {
105            return Err(FusionError::NegativeWeight { weight: max_weight });
106        }
107        if hit_weight < 0.0 {
108            return Err(FusionError::NegativeWeight { weight: hit_weight });
109        }
110
111        // Validate sum to 1.0
112        let sum = avg_weight + max_weight + hit_weight;
113        if (sum - 1.0).abs() > 0.001 {
114            return Err(FusionError::InvalidWeightSum { sum });
115        }
116
117        Ok(Self::Weighted {
118            avg_weight,
119            max_weight,
120            hit_weight,
121        })
122    }
123
124    /// Fuses results from multiple queries into a single ranked list.
125    ///
126    /// # Arguments
127    ///
128    /// * `results` - Vec of search results, one per query. Each inner Vec
129    ///   contains `(document_id, score)` tuples, assumed sorted by score descending.
130    ///
131    /// # Returns
132    ///
133    /// A single Vec of `(document_id, fused_score)` sorted by score descending.
134    ///
135    /// # Errors
136    ///
137    /// Currently infallible, but returns Result for future extensibility.
138    pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
139        if results.is_empty() {
140            return Ok(Vec::new());
141        }
142
143        // Filter out empty query results for counting
144        let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
145        if non_empty_count == 0 {
146            return Ok(Vec::new());
147        }
148
149        let total_queries = results.len();
150
151        match self {
152            Self::Average => Self::fuse_average(results),
153            Self::Maximum => Self::fuse_maximum(results),
154            Self::RRF { k } => Self::fuse_rrf(results, *k),
155            Self::Weighted {
156                avg_weight,
157                max_weight,
158                hit_weight,
159            } => Ok(Self::fuse_weighted(
160                results,
161                *avg_weight,
162                *max_weight,
163                *hit_weight,
164                total_queries,
165            )),
166        }
167    }
168
169    /// Average fusion: mean of scores for each document.
170    fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
171        let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
172
173        for query_results in results {
174            // Deduplicate within query (take best score for each doc)
175            let mut query_best: HashMap<u64, f32> = HashMap::new();
176            for (id, score) in query_results {
177                query_best
178                    .entry(id)
179                    .and_modify(|s| *s = s.max(score))
180                    .or_insert(score);
181            }
182
183            for (id, score) in query_best {
184                doc_scores.entry(id).or_default().push(score);
185            }
186        }
187
188        let mut fused: Vec<(u64, f32)> = doc_scores
189            .into_iter()
190            .map(|(id, scores)| {
191                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
192                (id, avg)
193            })
194            .collect();
195
196        // Sort by score descending
197        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
198
199        Ok(fused)
200    }
201
202    /// Maximum fusion: best score for each document.
203    fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
204        let mut doc_max: HashMap<u64, f32> = HashMap::new();
205
206        for query_results in results {
207            for (id, score) in query_results {
208                doc_max
209                    .entry(id)
210                    .and_modify(|s| *s = s.max(score))
211                    .or_insert(score);
212            }
213        }
214
215        let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
216        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
217
218        Ok(fused)
219    }
220
221    /// RRF fusion: reciprocal rank fusion.
222    fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
223        let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
224        let k_f32 = k as f32;
225
226        for query_results in results {
227            // Deduplicate and get rank order
228            let mut seen: HashMap<u64, usize> = HashMap::new();
229            for (rank, (id, _score)) in query_results.into_iter().enumerate() {
230                // Only count first occurrence (best rank) for each doc in this query
231                seen.entry(id).or_insert(rank);
232            }
233
234            for (id, rank) in seen {
235                let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
236                *doc_rrf.entry(id).or_insert(0.0) += rrf_score;
237            }
238        }
239
240        let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
241        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
242
243        Ok(fused)
244    }
245
246    /// Weighted fusion: combination of avg, max, and hit ratio.
247    #[allow(clippy::cast_precision_loss)]
248    fn fuse_weighted(
249        results: Vec<Vec<(u64, f32)>>,
250        avg_weight: f32,
251        max_weight: f32,
252        hit_weight: f32,
253        total_queries: usize,
254    ) -> Vec<(u64, f32)> {
255        // Collect all scores per document
256        let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
257
258        for query_results in results {
259            let mut query_best: HashMap<u64, f32> = HashMap::new();
260            for (id, score) in query_results {
261                query_best
262                    .entry(id)
263                    .and_modify(|s| *s = s.max(score))
264                    .or_insert(score);
265            }
266
267            for (id, score) in query_best {
268                doc_scores.entry(id).or_default().push(score);
269            }
270        }
271
272        let total_q = total_queries as f32;
273
274        let mut fused: Vec<(u64, f32)> = doc_scores
275            .into_iter()
276            .map(|(id, scores)| {
277                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
278                let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
279                let hit_ratio = scores.len() as f32 / total_q;
280
281                let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
282                (id, combined)
283            })
284            .collect();
285
286        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
287
288        fused
289    }
290}
291
292impl Default for FusionStrategy {
293    fn default() -> Self {
294        Self::RRF { k: 60 }
295    }
296}