Skip to main content

uni_query_functions/
fusion.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Score fusion algorithms for combining results from multiple search sources.
5//!
6//! Extracted from procedure_call.rs for reuse by `similar_to` and hybrid search procedures.
7
8use std::collections::{HashMap, HashSet};
9use uni_common::Vid;
10
11/// Reciprocal Rank Fusion (RRF) over an arbitrary number of ranked lists.
12///
13/// RRF score = sum over every list of `1 / (k + rank + 1)`; results are sorted
14/// by fused score descending. An empty list iterates zero times and therefore
15/// contributes nothing, so passing a source with no hits is a no-op — a two-way
16/// fusion stays identical when a third (e.g. sparse) source is absent.
17pub fn fuse_rrf_multi(ranked_lists: &[&[(Vid, f32)]], k: usize) -> Vec<(Vid, f32)> {
18    let mut scores: HashMap<Vid, f32> = HashMap::new();
19
20    for ranked_list in ranked_lists {
21        for (rank, (vid, _)) in ranked_list.iter().enumerate() {
22            let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
23            *scores.entry(*vid).or_default() += rrf_score;
24        }
25    }
26
27    let mut results: Vec<_> = scores.into_iter().collect();
28    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
29    results
30}
31
32/// Reciprocal Rank Fusion (RRF) for combining two ranked result lists.
33///
34/// Thin two-source shim over [`fuse_rrf_multi`]; preserved so existing callers
35/// remain unchanged.
36pub fn fuse_rrf(
37    vec_results: &[(Vid, f32)],
38    fts_results: &[(Vid, f32)],
39    k: usize,
40) -> Vec<(Vid, f32)> {
41    fuse_rrf_multi(&[vec_results, fts_results], k)
42}
43
44/// Weighted fusion: alpha * vec_score + (1 - alpha) * fts_score.
45///
46/// Both score sets are normalized to [0, 1] range before fusion.
47/// Vector scores are assumed to be distances (lower = more similar)
48/// and are inverted. FTS scores are normalized by max.
49pub fn fuse_weighted(
50    vec_results: &[(Vid, f32)],
51    fts_results: &[(Vid, f32)],
52    alpha: f32,
53) -> Vec<(Vid, f32)> {
54    // Normalize vector scores (distance -> similarity)
55    let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
56    let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
57    let vec_range = if vec_max > vec_min {
58        vec_max - vec_min
59    } else {
60        1.0
61    };
62
63    let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
64
65    let vec_scores: HashMap<Vid, f32> = vec_results
66        .iter()
67        .map(|(vid, dist)| {
68            let norm = 1.0 - (dist - vec_min) / vec_range;
69            (*vid, norm)
70        })
71        .collect();
72
73    let fts_scores: HashMap<Vid, f32> = fts_results
74        .iter()
75        .map(|(vid, score)| {
76            let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
77            (*vid, norm)
78        })
79        .collect();
80
81    let all_vids: HashSet<Vid> = vec_scores
82        .keys()
83        .chain(fts_scores.keys())
84        .cloned()
85        .collect();
86
87    let mut results: Vec<(Vid, f32)> = all_vids
88        .into_iter()
89        .map(|vid| {
90            let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
91            let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
92            let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
93            (vid, fused)
94        })
95        .collect();
96
97    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
98    results
99}
100
101/// How a source's raw scores map onto the `[0, 1]` fusion range.
102///
103/// Distances (lower is more similar) and scores (higher is more similar) need
104/// opposite normalization; this tags which a source uses.
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum NormKind {
107    /// Lower is better (e.g. vector distance); min-max inverted to a similarity.
108    DistanceToSim,
109    /// Higher is better (e.g. FTS relevance or sparse dot); divided by the max.
110    ScoreByMax,
111}
112
113/// One input to [`fuse_weighted_sources`]: a ranked `(vid, raw_score)` list, the
114/// source's fusion weight, and how its raw scores normalize onto `[0, 1]`.
115pub type WeightedSource<'a> = (&'a [(Vid, f32)], f32, NormKind);
116
117/// Weighted fusion across an arbitrary number of per-source-normalized lists.
118///
119/// Each source carries its ranked `(vid, raw_score)` list, a fusion weight, and
120/// a [`NormKind`] describing how its raw scores normalize onto `[0, 1]` before
121/// the weighted sum. Results are sorted by fused score descending. A vid present
122/// in only some sources contributes only from those sources (others count zero).
123///
124/// This generalizes [`fuse_weighted`] to three or more sources (e.g.
125/// dense + text + sparse), reproducing the per-source normalization the
126/// two-source path applies (`DistanceToSim` for vectors, `ScoreByMax` for FTS).
127pub fn fuse_weighted_sources(sources: &[WeightedSource<'_>]) -> Vec<(Vid, f32)> {
128    let mut fused: HashMap<Vid, f32> = HashMap::new();
129
130    for (results, weight, norm) in sources {
131        let normalized: HashMap<Vid, f32> = match norm {
132            NormKind::DistanceToSim => {
133                let max = results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
134                let min = results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
135                let range = if max > min { max - min } else { 1.0 };
136                results
137                    .iter()
138                    .map(|(vid, dist)| (*vid, 1.0 - (dist - min) / range))
139                    .collect()
140            }
141            NormKind::ScoreByMax => {
142                let max = results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
143                results
144                    .iter()
145                    .map(|(vid, score)| {
146                        let norm = if max > 0.0 { score / max } else { 0.0 };
147                        (*vid, norm)
148                    })
149                    .collect()
150            }
151        };
152
153        for (vid, norm_score) in normalized {
154            *fused.entry(vid).or_default() += weight * norm_score;
155        }
156    }
157
158    let mut results: Vec<(Vid, f32)> = fused.into_iter().collect();
159    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
160    results
161}
162
163/// Multi-source weighted fusion for `similar_to`.
164///
165/// Unlike the two-source `fuse_weighted`, this operates on pre-normalized
166/// `[0, 1]` scores and supports an arbitrary number of sources.
167pub fn fuse_weighted_multi(scores: &[f32], weights: &[f32]) -> f32 {
168    debug_assert_eq!(scores.len(), weights.len());
169    scores.iter().zip(weights.iter()).map(|(s, w)| s * w).sum()
170}
171
172/// Multi-source RRF fusion for point computation context.
173///
174/// In point computation (single node evaluation), there is no global ranking
175/// context. We fall back to weighted fusion with equal weights and emit a warning.
176///
177/// Returns `(fused_score, used_fallback)` where `used_fallback` is true
178/// when RRF was requested but we fell back to equal-weight fusion.
179pub fn fuse_rrf_point(scores: &[f32]) -> (f32, bool) {
180    if scores.is_empty() {
181        return (0.0, false);
182    }
183    let weight = 1.0 / scores.len() as f32;
184    let fused: f32 = scores.iter().map(|s| s * weight).sum();
185    (fused, true)
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_fuse_weighted_multi() {
194        let scores = vec![0.8, 0.6];
195        let weights = vec![0.7, 0.3];
196        let result = fuse_weighted_multi(&scores, &weights);
197        assert!((result - 0.74).abs() < 1e-6);
198    }
199
200    #[test]
201    fn test_fuse_weighted_multi_equal() {
202        let scores = vec![0.5, 0.5, 0.5];
203        let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
204        let result = fuse_weighted_multi(&scores, &weights);
205        assert!((result - 0.5).abs() < 1e-6);
206    }
207
208    #[test]
209    fn test_fuse_rrf_point_fallback() {
210        let scores = vec![0.8, 0.6];
211        let (result, used_fallback) = fuse_rrf_point(&scores);
212        assert!(used_fallback);
213        assert!((result - 0.7).abs() < 1e-6);
214    }
215
216    #[test]
217    fn test_fuse_rrf_point_empty() {
218        let (result, used_fallback) = fuse_rrf_point(&[]);
219        assert!(!used_fallback);
220        assert!((result - 0.0).abs() < 1e-6);
221    }
222
223    #[test]
224    fn test_fuse_rrf_disjoint_lists() {
225        let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
226        let fts_results = vec![(Vid::from(3u64), 0.8), (Vid::from(4u64), 0.6)];
227        let fused = fuse_rrf(&vec_results, &fts_results, 60);
228
229        // All 4 VIDs should appear (disjoint union)
230        assert_eq!(fused.len(), 4);
231        let vids: HashSet<Vid> = fused.iter().map(|(v, _)| *v).collect();
232        assert!(vids.contains(&Vid::from(1u64)));
233        assert!(vids.contains(&Vid::from(2u64)));
234        assert!(vids.contains(&Vid::from(3u64)));
235        assert!(vids.contains(&Vid::from(4u64)));
236    }
237
238    #[test]
239    fn test_fuse_rrf_overlapping_lists() {
240        let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
241        let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
242        let fused = fuse_rrf(&vec_results, &fts_results, 60);
243
244        // VID 1 appears in both lists → should have highest fused score
245        assert_eq!(fused.len(), 3);
246        assert_eq!(
247            fused[0].0,
248            Vid::from(1u64),
249            "Overlapping VID should rank first"
250        );
251    }
252
253    #[test]
254    fn test_fuse_rrf_empty_lists() {
255        let fused = fuse_rrf(&[], &[], 60);
256        assert!(fused.is_empty());
257    }
258
259    #[test]
260    fn test_fuse_rrf_multi_three_sources_overlap_wins() {
261        let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
262        let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
263        let sparse_results = vec![(Vid::from(1u64), 5.0), (Vid::from(4u64), 1.0)];
264        let fused = fuse_rrf_multi(&[&vec_results, &fts_results, &sparse_results], 60);
265
266        // VID 1 is the only id in all three lists → must rank first.
267        assert_eq!(fused.len(), 4);
268        assert_eq!(fused[0].0, Vid::from(1u64));
269    }
270
271    #[test]
272    fn test_fuse_rrf_multi_empty_third_source_is_noop() {
273        let vec_results = vec![(Vid::from(1u64), 0.9), (Vid::from(2u64), 0.7)];
274        let fts_results = vec![(Vid::from(1u64), 0.8), (Vid::from(3u64), 0.6)];
275
276        // Compare as score maps: an empty source adds nothing, so the fused
277        // scores are identical. (Tie ordering among equal scores follows HashMap
278        // iteration order and is not stable — the original `fuse_rrf` is the same.)
279        let two_way: HashMap<Vid, f32> = fuse_rrf(&vec_results, &fts_results, 60)
280            .into_iter()
281            .collect();
282        let three_way: HashMap<Vid, f32> = fuse_rrf_multi(&[&vec_results, &fts_results, &[]], 60)
283            .into_iter()
284            .collect();
285
286        assert_eq!(two_way, three_way, "absent sparse source must be a no-op");
287    }
288
289    #[test]
290    fn test_fuse_weighted_sources_normalizes_per_source() {
291        // Vector scores are distances (lower better); sparse are dot (higher better).
292        let vec_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 1.0)];
293        let sparse_results = vec![(Vid::from(1u64), 2.0), (Vid::from(2u64), 4.0)];
294        let fused = fuse_weighted_sources(&[
295            (&vec_results, 0.5, NormKind::DistanceToSim),
296            (&sparse_results, 0.5, NormKind::ScoreByMax),
297        ]);
298
299        // VID1: 0.5*1.0 (closest) + 0.5*(2/4)=0.25 → 0.75.
300        // VID2: 0.5*0.0 (farthest) + 0.5*(4/4)=0.5  → 0.50.
301        let v1 = fused.iter().find(|(v, _)| *v == Vid::from(1u64)).unwrap().1;
302        let v2 = fused.iter().find(|(v, _)| *v == Vid::from(2u64)).unwrap().1;
303        assert!((v1 - 0.75).abs() < 1e-6);
304        assert!((v2 - 0.50).abs() < 1e-6);
305        assert_eq!(fused[0].0, Vid::from(1u64));
306    }
307
308    #[test]
309    fn test_fuse_weighted_sources_zero_max_sparse() {
310        // All-zero sparse scores must normalize to 0, not divide by zero.
311        let vec_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 1.0)];
312        let sparse_results = vec![(Vid::from(1u64), 0.0), (Vid::from(2u64), 0.0)];
313        let fused = fuse_weighted_sources(&[
314            (&vec_results, 0.5, NormKind::DistanceToSim),
315            (&sparse_results, 0.5, NormKind::ScoreByMax),
316        ]);
317
318        let v1 = fused.iter().find(|(v, _)| *v == Vid::from(1u64)).unwrap().1;
319        assert!(
320            (v1 - 0.5).abs() < 1e-6,
321            "sparse contributes 0 when all zero"
322        );
323    }
324}