Skip to main content

uni_query/query/
fusion.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Score fusion algorithms for combining results from multiple search sources.
5//!
6//! Extracted from procedure_call.rs for reuse by `similar_to` and hybrid search procedures.
7
8use std::collections::{HashMap, HashSet};
9use uni_common::Vid;
10
11/// Reciprocal Rank Fusion (RRF) for combining ranked result lists.
12///
13/// RRF score = sum(1 / (k + rank + 1)) for each result list.
14/// Results are sorted by fused score descending.
15pub fn fuse_rrf(
16    vec_results: &[(Vid, f32)],
17    fts_results: &[(Vid, f32)],
18    k: usize,
19) -> Vec<(Vid, f32)> {
20    let mut scores: HashMap<Vid, f32> = HashMap::new();
21
22    for ranked_list in [vec_results, fts_results] {
23        for (rank, (vid, _)) in ranked_list.iter().enumerate() {
24            let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
25            *scores.entry(*vid).or_default() += rrf_score;
26        }
27    }
28
29    let mut results: Vec<_> = scores.into_iter().collect();
30    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
31    results
32}
33
34/// Weighted fusion: alpha * vec_score + (1 - alpha) * fts_score.
35///
36/// Both score sets are normalized to [0, 1] range before fusion.
37/// Vector scores are assumed to be distances (lower = more similar)
38/// and are inverted. FTS scores are normalized by max.
39pub fn fuse_weighted(
40    vec_results: &[(Vid, f32)],
41    fts_results: &[(Vid, f32)],
42    alpha: f32,
43) -> Vec<(Vid, f32)> {
44    // Normalize vector scores (distance -> similarity)
45    let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
46    let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
47    let vec_range = if vec_max > vec_min {
48        vec_max - vec_min
49    } else {
50        1.0
51    };
52
53    let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
54
55    let vec_scores: HashMap<Vid, f32> = vec_results
56        .iter()
57        .map(|(vid, dist)| {
58            let norm = 1.0 - (dist - vec_min) / vec_range;
59            (*vid, norm)
60        })
61        .collect();
62
63    let fts_scores: HashMap<Vid, f32> = fts_results
64        .iter()
65        .map(|(vid, score)| {
66            let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
67            (*vid, norm)
68        })
69        .collect();
70
71    let all_vids: HashSet<Vid> = vec_scores
72        .keys()
73        .chain(fts_scores.keys())
74        .cloned()
75        .collect();
76
77    let mut results: Vec<(Vid, f32)> = all_vids
78        .into_iter()
79        .map(|vid| {
80            let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
81            let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
82            let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
83            (vid, fused)
84        })
85        .collect();
86
87    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
88    results
89}
90
91/// Multi-source weighted fusion for `similar_to`.
92///
93/// Unlike the two-source `fuse_weighted`, this operates on pre-normalized
94/// `[0, 1]` scores and supports an arbitrary number of sources.
95pub fn fuse_weighted_multi(scores: &[f32], weights: &[f32]) -> f32 {
96    debug_assert_eq!(scores.len(), weights.len());
97    scores.iter().zip(weights.iter()).map(|(s, w)| s * w).sum()
98}
99
100/// Multi-source RRF fusion for point computation context.
101///
102/// In point computation (single node evaluation), there is no global ranking
103/// context. We fall back to weighted fusion with equal weights and emit a warning.
104///
105/// Returns `(fused_score, used_fallback)` where `used_fallback` is true
106/// when RRF was requested but we fell back to equal-weight fusion.
107pub fn fuse_rrf_point(scores: &[f32]) -> (f32, bool) {
108    if scores.is_empty() {
109        return (0.0, false);
110    }
111    let weight = 1.0 / scores.len() as f32;
112    let fused: f32 = scores.iter().map(|s| s * weight).sum();
113    (fused, true)
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_fuse_weighted_multi() {
122        let scores = vec![0.8, 0.6];
123        let weights = vec![0.7, 0.3];
124        let result = fuse_weighted_multi(&scores, &weights);
125        assert!((result - 0.74).abs() < 1e-6);
126    }
127
128    #[test]
129    fn test_fuse_weighted_multi_equal() {
130        let scores = vec![0.5, 0.5, 0.5];
131        let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0];
132        let result = fuse_weighted_multi(&scores, &weights);
133        assert!((result - 0.5).abs() < 1e-6);
134    }
135
136    #[test]
137    fn test_fuse_rrf_point_fallback() {
138        let scores = vec![0.8, 0.6];
139        let (result, used_fallback) = fuse_rrf_point(&scores);
140        assert!(used_fallback);
141        assert!((result - 0.7).abs() < 1e-6);
142    }
143
144    #[test]
145    fn test_fuse_rrf_point_empty() {
146        let (result, used_fallback) = fuse_rrf_point(&[]);
147        assert!(!used_fallback);
148        assert!((result - 0.0).abs() < 1e-6);
149    }
150}