Skip to main content

sqlite_graphrag/storage/
fusion.rs

1//! RRF (Reciprocal Rank Fusion) utilities shared between `hybrid-search` and
2//! `deep-research`.
3//!
4//! The formula used is the canonical RRF score:
5//!
6//! ```text
7//! score(d) = sum_over_lists { weight * 1 / (rrf_k + rank(d)) }
8//! ```
9//!
10//! where `rank` is 1-indexed position in each ordered list.  The map returned
11//! by [`rrf_fuse`] contains un-normalised scores; callers that need a `[0,1]`
12//! range should divide by the theoretical maximum:
13//!
14//! ```text
15//! max_possible = sum_over_lists { weight * 1 / (rrf_k + 1) }
16//! ```
17
18use std::collections::HashMap;
19
20/// Fuse multiple ranked lists of integer IDs via Reciprocal Rank Fusion.
21///
22/// Each element of `lists` is `(weight, ranked_ids)` where `ranked_ids` is
23/// ordered best-first (index 0 = rank 1).
24///
25/// Returns a `HashMap<id, combined_score>` using un-normalised RRF scores.
26/// Higher score means higher relevance.
27///
28/// # Examples
29///
30/// ```
31/// use sqlite_graphrag::storage::fusion::rrf_fuse;
32///
33/// // Two lists with equal weight — item 1 appears in both at rank 1 and 2
34/// // so it accumulates more score than item 2 (rank 2) or item 3 (rank 1 only).
35/// let knn: Vec<i64> = vec![1, 2];
36/// let fts: Vec<i64> = vec![1, 3];
37/// let scores = rrf_fuse(&[(1.0, &knn), (1.0, &fts)], 60.0);
38/// assert!(scores[&1] > scores[&2]);
39/// assert!(scores[&1] > scores[&3]);
40/// ```
41pub fn rrf_fuse(lists: &[(f64, &Vec<i64>)], rrf_k: f64) -> HashMap<i64, f64> {
42    let mut combined: HashMap<i64, f64> = HashMap::new();
43    for (weight, ids) in lists {
44        for (rank, &id) in ids.iter().enumerate() {
45            // rank is 0-indexed here; formula uses 1-indexed, so we add 1.
46            let contribution = weight * (1.0 / (rrf_k + rank as f64 + 1.0));
47            *combined.entry(id).or_insert(0.0) += contribution;
48        }
49    }
50    combined
51}
52
53/// Compute the theoretical maximum RRF score for a given set of weights and
54/// `rrf_k`.
55///
56/// Useful for normalising `rrf_fuse` scores to `[0, 1]`:
57///
58/// ```
59/// use sqlite_graphrag::storage::fusion::{rrf_fuse, rrf_max_possible};
60///
61/// let weights = vec![1.0_f64, 1.0_f64];
62/// let max = rrf_max_possible(&weights, 60.0);
63/// assert!(max > 0.0);
64/// ```
65pub fn rrf_max_possible(weights: &[f64], rrf_k: f64) -> f64 {
66    weights.iter().map(|w| w * (1.0 / (rrf_k + 1.0))).sum()
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    #[test]
74    fn rrf_fuse_single_list_rank_order_preserved() {
75        // Items at lower rank index get higher scores.
76        let list = vec![10i64, 20, 30];
77        let scores = rrf_fuse(&[(1.0, &list)], 60.0);
78        assert!(scores[&10] > scores[&20]);
79        assert!(scores[&20] > scores[&30]);
80    }
81
82    #[test]
83    fn rrf_fuse_two_lists_overlap_accumulates() {
84        // Item 1 appears first in both lists — must beat item 2 (rank 1 in one list only).
85        let knn = vec![1i64, 2];
86        let fts = vec![1i64, 3];
87        let scores = rrf_fuse(&[(1.0, &knn), (1.0, &fts)], 60.0);
88        assert!(scores[&1] > scores[&2], "overlap item must score higher");
89        assert!(scores[&1] > scores[&3], "overlap item must score higher");
90    }
91
92    #[test]
93    fn rrf_fuse_empty_lists_returns_empty() {
94        let empty: Vec<i64> = vec![];
95        let scores = rrf_fuse(&[(1.0, &empty)], 60.0);
96        assert!(scores.is_empty());
97    }
98
99    #[test]
100    fn rrf_fuse_zero_weight_list_has_no_effect() {
101        let list_a = vec![1i64, 2];
102        let list_b = vec![3i64, 4];
103        let scores_with = rrf_fuse(&[(1.0, &list_a), (0.0, &list_b)], 60.0);
104        // Items 3 and 4 should have score 0.0 (or not present).
105        assert_eq!(scores_with.get(&3).copied().unwrap_or(0.0), 0.0);
106        assert_eq!(scores_with.get(&4).copied().unwrap_or(0.0), 0.0);
107    }
108
109    #[test]
110    fn rrf_fuse_weights_scale_contribution() {
111        // Higher weight means higher score for same rank.
112        let list = vec![1i64];
113        let low = rrf_fuse(&[(0.5, &list)], 60.0);
114        let high = rrf_fuse(&[(2.0, &list)], 60.0);
115        assert!(high[&1] > low[&1]);
116    }
117
118    #[test]
119    fn rrf_max_possible_sums_weights() {
120        // With rrf_k=60, max for one list of weight 1.0 is 1/(60+1) ≈ 0.01639.
121        let max = rrf_max_possible(&[1.0], 60.0);
122        let expected = 1.0 / 61.0;
123        assert!((max - expected).abs() < 1e-9);
124
125        // Two equal-weight lists: sum of both.
126        let max2 = rrf_max_possible(&[1.0, 1.0], 60.0);
127        assert!((max2 - 2.0 / 61.0).abs() < 1e-9);
128    }
129
130    #[test]
131    fn rrf_fuse_deterministic_for_same_input() {
132        let list_a = vec![1i64, 2, 3];
133        let list_b = vec![2i64, 1, 4];
134        let scores_1 = rrf_fuse(&[(1.0, &list_a), (1.0, &list_b)], 60.0);
135        let scores_2 = rrf_fuse(&[(1.0, &list_a), (1.0, &list_b)], 60.0);
136        for id in [1i64, 2, 3, 4] {
137            assert_eq!(
138                scores_1.get(&id).copied().unwrap_or(0.0),
139                scores_2.get(&id).copied().unwrap_or(0.0)
140            );
141        }
142    }
143}