Skip to main content

velesdb_core/fusion/
strategy.rs

1//! Fusion strategies for combining multi-query search results.
2
3#![allow(clippy::unnecessary_wraps)]
4
5use std::collections::HashMap;
6
7/// Error type for fusion operations.
8#[derive(Debug, Clone, PartialEq)]
9pub enum FusionError {
10    /// Weights do not sum to 1.0 (within tolerance).
11    InvalidWeightSum {
12        /// The actual sum of weights.
13        sum: f32,
14    },
15    /// Negative weight provided.
16    NegativeWeight {
17        /// The negative weight value.
18        weight: f32,
19    },
20}
21
22impl std::fmt::Display for FusionError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::InvalidWeightSum { sum } => {
26                write!(f, "Weights must sum to 1.0, got {sum:.4}")
27            }
28            Self::NegativeWeight { weight } => {
29                write!(f, "Weights must be non-negative, got {weight:.4}")
30            }
31        }
32    }
33}
34
35impl std::error::Error for FusionError {}
36
37/// Strategy for fusing results from multiple vector searches.
38///
39/// Each strategy combines results differently, optimizing for various use cases:
40/// - `Average`: Good for general-purpose fusion
41/// - `Maximum`: Emphasizes documents that score very high in any query
42/// - `RRF`: Position-based fusion, robust to score scale differences
43/// - `Weighted`: Custom combination with explicit control over factors
44#[derive(Debug, Clone, PartialEq)]
45pub enum FusionStrategy {
46    /// Average score across all queries where the document appears.
47    ///
48    /// Score = mean(scores for this document across queries)
49    Average,
50
51    /// Maximum score across all queries.
52    ///
53    /// Score = max(scores for this document across queries)
54    Maximum,
55
56    /// Reciprocal Rank Fusion.
57    ///
58    /// Score = Σ 1/(k + `rank_i`) for each query where document appears.
59    /// Standard k=60 provides good balance between emphasizing top ranks
60    /// while still considering lower-ranked results.
61    RRF {
62        /// Ranking constant (default: 60).
63        k: u32,
64    },
65
66    /// Weighted combination of average, maximum, and hit ratio.
67    ///
68    /// Score = `avg_weight` × `avg_score` + `max_weight` × `max_score` + `hit_weight` × `hit_ratio`
69    /// where `hit_ratio` = (number of queries containing doc) / (total queries)
70    Weighted {
71        /// Weight for average score component.
72        avg_weight: f32,
73        /// Weight for maximum score component.
74        max_weight: f32,
75        /// Weight for hit ratio component.
76        hit_weight: f32,
77    },
78
79    /// Relative Score Fusion for dense + sparse hybrid search.
80    ///
81    /// Each branch is min-max normalized independently, then combined via
82    /// weighted sum: `final = dense_weight * norm_dense + sparse_weight * norm_sparse`.
83    /// Docs appearing in only one branch get 0 for the missing branch.
84    RelativeScore {
85        /// Weight for the dense (vector) branch.
86        dense_weight: f32,
87        /// Weight for the sparse branch.
88        sparse_weight: f32,
89    },
90}
91
92impl FusionStrategy {
93    /// Creates an RRF strategy with the standard k=60 parameter.
94    #[must_use]
95    pub fn rrf_default() -> Self {
96        Self::RRF { k: 60 }
97    }
98
99    /// Creates a `RelativeScore` strategy with validation.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if:
104    /// - Weights do not sum to 1.0 (within 0.001 tolerance)
105    /// - Any weight is negative
106    pub fn relative_score(dense_weight: f32, sparse_weight: f32) -> Result<Self, FusionError> {
107        validate_non_negative(&[dense_weight, sparse_weight])?;
108        validate_weight_sum(dense_weight + sparse_weight)?;
109        Ok(Self::RelativeScore {
110            dense_weight,
111            sparse_weight,
112        })
113    }
114
115    /// Creates a Weighted strategy with validation.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if:
120    /// - Weights do not sum to 1.0 (within 0.001 tolerance)
121    /// - Any weight is negative
122    pub fn weighted(
123        avg_weight: f32,
124        max_weight: f32,
125        hit_weight: f32,
126    ) -> Result<Self, FusionError> {
127        validate_non_negative(&[avg_weight, max_weight, hit_weight])?;
128        validate_weight_sum(avg_weight + max_weight + hit_weight)?;
129
130        Ok(Self::Weighted {
131            avg_weight,
132            max_weight,
133            hit_weight,
134        })
135    }
136
137    /// Fuses results from multiple queries into a single ranked list.
138    ///
139    /// # Arguments
140    ///
141    /// * `results` - Vec of search results, one per query. Each inner Vec
142    ///   contains `(document_id, score)` tuples, assumed sorted by score descending.
143    ///
144    /// # Returns
145    ///
146    /// A single Vec of `(document_id, fused_score)` sorted by score descending.
147    ///
148    /// # Errors
149    ///
150    /// Currently infallible, but returns Result for future extensibility.
151    pub fn fuse(&self, results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
152        if results.is_empty() {
153            return Ok(Vec::new());
154        }
155
156        // Filter out empty query results for counting
157        let non_empty_count = results.iter().filter(|r| !r.is_empty()).count();
158        if non_empty_count == 0 {
159            return Ok(Vec::new());
160        }
161
162        let total_queries = results.len();
163
164        match self {
165            Self::Average => Self::fuse_average(results),
166            Self::Maximum => Self::fuse_maximum(results),
167            Self::RRF { k } => Self::fuse_rrf(results, *k),
168            Self::Weighted {
169                avg_weight,
170                max_weight,
171                hit_weight,
172            } => Ok(Self::fuse_weighted(
173                results,
174                *avg_weight,
175                *max_weight,
176                *hit_weight,
177                total_queries,
178            )),
179            Self::RelativeScore {
180                dense_weight,
181                sparse_weight,
182            } => Self::fuse_relative_score(&results, *dense_weight, *sparse_weight),
183        }
184    }
185
186    /// Collects per-document best scores across queries (deduplicates within each query).
187    ///
188    /// Returns a map from document ID to the list of its best scores (one per query
189    /// where it appeared).
190    fn collect_doc_scores(results: Vec<Vec<(u64, f32)>>) -> HashMap<u64, Vec<f32>> {
191        let mut doc_scores: HashMap<u64, Vec<f32>> = HashMap::new();
192
193        for query_results in results {
194            let mut query_best: HashMap<u64, f32> = HashMap::new();
195            for (id, score) in query_results {
196                query_best
197                    .entry(id)
198                    .and_modify(|s| *s = s.max(score))
199                    .or_insert(score);
200            }
201
202            for (id, score) in query_best {
203                doc_scores.entry(id).or_default().push(score);
204            }
205        }
206
207        doc_scores
208    }
209
210    /// Sorts a fused result set by score descending.
211    fn sort_descending(fused: &mut [(u64, f32)]) {
212        fused.sort_by(|a, b| b.1.total_cmp(&a.1));
213    }
214
215    /// Average fusion: mean of scores for each document.
216    #[allow(clippy::cast_precision_loss)]
217    // Reason: scores.len() is the number of queries a document appeared in;
218    // this is a small count that fits exactly in f32.
219    fn fuse_average(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
220        let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
221            .into_iter()
222            .map(|(id, scores)| {
223                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
224                (id, avg)
225            })
226            .collect();
227
228        Self::sort_descending(&mut fused);
229        Ok(fused)
230    }
231
232    /// Maximum fusion: best score for each document.
233    fn fuse_maximum(results: Vec<Vec<(u64, f32)>>) -> Result<Vec<(u64, f32)>, FusionError> {
234        let mut doc_max: HashMap<u64, f32> = HashMap::new();
235
236        for query_results in results {
237            for (id, score) in query_results {
238                doc_max
239                    .entry(id)
240                    .and_modify(|s| *s = s.max(score))
241                    .or_insert(score);
242            }
243        }
244
245        let mut fused: Vec<(u64, f32)> = doc_max.into_iter().collect();
246        Self::sort_descending(&mut fused);
247        Ok(fused)
248    }
249
250    /// RRF fusion: reciprocal rank fusion.
251    #[allow(clippy::cast_precision_loss)]
252    // Reason: k (u32, typically 60) and rank+1 (small loop index) both fit
253    // exactly in f32 (exact up to 2^24).
254    fn fuse_rrf(results: Vec<Vec<(u64, f32)>>, k: u32) -> Result<Vec<(u64, f32)>, FusionError> {
255        let mut doc_rrf: HashMap<u64, f32> = HashMap::new();
256        // Reason: k is the RRF constant (default 60, max u32); u32 → f32 is
257        // exact for values <= 16_777_216, so no precision loss in practice.
258        let k_f32 = k as f32;
259
260        for query_results in results {
261            // Deduplicate and get rank order
262            let mut seen: HashMap<u64, usize> = HashMap::new();
263            for (rank, (id, _score)) in query_results.into_iter().enumerate() {
264                // Only count first occurrence (best rank) for each doc in this query
265                seen.entry(id).or_insert(rank);
266            }
267
268            for (id, rank) in seen {
269                let rrf_score = 1.0 / (k_f32 + (rank + 1) as f32);
270                *doc_rrf.entry(id).or_insert(0.0) += rrf_score;
271            }
272        }
273
274        let mut fused: Vec<(u64, f32)> = doc_rrf.into_iter().collect();
275        Self::sort_descending(&mut fused);
276        Ok(fused)
277    }
278
279    /// Weighted fusion: combination of avg, max, and hit ratio.
280    #[allow(clippy::cast_precision_loss)]
281    // Reason: total_queries and scores.len() are small counts (number of
282    // queries/hits per document); both fit exactly in f32.
283    fn fuse_weighted(
284        results: Vec<Vec<(u64, f32)>>,
285        avg_weight: f32,
286        max_weight: f32,
287        hit_weight: f32,
288        total_queries: usize,
289    ) -> Vec<(u64, f32)> {
290        let total_q = total_queries as f32;
291
292        let mut fused: Vec<(u64, f32)> = Self::collect_doc_scores(results)
293            .into_iter()
294            .map(|(id, scores)| {
295                let avg = scores.iter().sum::<f32>() / scores.len() as f32;
296                let max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
297                let hit_ratio = scores.len() as f32 / total_q;
298
299                let combined = avg_weight * avg + max_weight * max + hit_weight * hit_ratio;
300                (id, combined)
301            })
302            .collect();
303
304        Self::sort_descending(&mut fused);
305        fused
306    }
307
308    /// Relative Score Fusion: per-branch min-max normalization + weighted sum.
309    ///
310    /// Expects exactly two branches in `results`: index 0 = dense, index 1 = sparse.
311    /// If more branches are provided, only the first two are used; the extras
312    /// are silently discarded. A warning is emitted so callers can detect the
313    /// accidental multi-branch case during development.
314    fn fuse_relative_score(
315        results: &[Vec<(u64, f32)>],
316        dense_weight: f32,
317        sparse_weight: f32,
318    ) -> Result<Vec<(u64, f32)>, FusionError> {
319        if results.len() > 2 {
320            tracing::warn!(
321                branch_count = results.len(),
322                "RelativeScore fusion received {} branches but only supports 2 (dense + sparse). \
323                 Branches beyond index 1 are ignored.",
324                results.len(),
325            );
326        }
327
328        let dense = results.first().map_or(&[][..], |v| v.as_slice());
329        let sparse = results.get(1).map_or(&[][..], |v| v.as_slice());
330
331        let norm_dense = min_max_normalize(dense);
332        let norm_sparse = min_max_normalize(sparse);
333
334        // Collect all doc IDs
335        let mut all_ids: HashMap<u64, f32> = HashMap::new();
336        for (&id, &nd) in &norm_dense {
337            let ns = norm_sparse.get(&id).copied().unwrap_or(0.0);
338            all_ids.insert(id, dense_weight * nd + sparse_weight * ns);
339        }
340        for (&id, &ns) in &norm_sparse {
341            all_ids.entry(id).or_insert_with(|| {
342                let nd = norm_dense.get(&id).copied().unwrap_or(0.0);
343                dense_weight * nd + sparse_weight * ns
344            });
345        }
346
347        let mut fused: Vec<(u64, f32)> = all_ids.into_iter().collect();
348        Self::sort_descending(&mut fused);
349        Ok(fused)
350    }
351}
352
353impl Default for FusionStrategy {
354    fn default() -> Self {
355        Self::RRF { k: 60 }
356    }
357}
358
359// ---------------------------------------------------------------------------
360// Shared validation helpers (extracted from `relative_score` / `weighted`)
361// ---------------------------------------------------------------------------
362
363/// Validates that no weight in the slice is negative.
364fn validate_non_negative(weights: &[f32]) -> Result<(), FusionError> {
365    for &w in weights {
366        if w < 0.0 {
367            return Err(FusionError::NegativeWeight { weight: w });
368        }
369    }
370    Ok(())
371}
372
373/// Validates that a weight sum is 1.0 (within 0.001 tolerance).
374fn validate_weight_sum(sum: f32) -> Result<(), FusionError> {
375    if (sum - 1.0).abs() > 0.001 {
376        return Err(FusionError::InvalidWeightSum { sum });
377    }
378    Ok(())
379}
380
381/// Min-max normalize a branch of `(id, score)` pairs.
382///
383/// If the score range is smaller than `f32::EPSILON`, all items receive 0.5.
384fn min_max_normalize(branch: &[(u64, f32)]) -> HashMap<u64, f32> {
385    if branch.is_empty() {
386        return HashMap::new();
387    }
388    let min = branch.iter().map(|&(_, s)| s).fold(f32::INFINITY, f32::min);
389    let max = branch
390        .iter()
391        .map(|&(_, s)| s)
392        .fold(f32::NEG_INFINITY, f32::max);
393    let range = max - min;
394    branch
395        .iter()
396        .map(|&(id, s)| {
397            let norm = if range < f32::EPSILON {
398                0.5
399            } else {
400                (s - min) / range
401            };
402            (id, norm)
403        })
404        .collect()
405}