sqlite_graphrag/storage/
fusion.rs1use std::collections::HashMap;
19
20pub fn rrf_fuse(lists: &[(f64, &Vec<i64>)], rrf_k: f64) -> HashMap<i64, f64> {
42 let total_ids: usize = lists.iter().map(|(_, ids)| ids.len()).sum();
43 let mut combined: HashMap<i64, f64> = HashMap::with_capacity(total_ids);
44 for (weight, ids) in lists {
45 for (rank, &id) in ids.iter().enumerate() {
46 let contribution = weight * (1.0 / (rrf_k + rank as f64 + 1.0));
48 *combined.entry(id).or_insert(0.0) += contribution;
49 }
50 }
51 combined
52}
53
54pub fn rrf_max_possible(weights: &[f64], rrf_k: f64) -> f64 {
67 weights.iter().map(|w| w * (1.0 / (rrf_k + 1.0))).sum()
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn rrf_fuse_single_list_rank_order_preserved() {
76 let list = vec![10i64, 20, 30];
78 let scores = rrf_fuse(&[(1.0, &list)], 60.0);
79 assert!(scores[&10] > scores[&20]);
80 assert!(scores[&20] > scores[&30]);
81 }
82
83 #[test]
84 fn rrf_fuse_two_lists_overlap_accumulates() {
85 let knn = vec![1i64, 2];
87 let fts = vec![1i64, 3];
88 let scores = rrf_fuse(&[(1.0, &knn), (1.0, &fts)], 60.0);
89 assert!(scores[&1] > scores[&2], "overlap item must score higher");
90 assert!(scores[&1] > scores[&3], "overlap item must score higher");
91 }
92
93 #[test]
94 fn rrf_fuse_empty_lists_returns_empty() {
95 let empty: Vec<i64> = vec![];
96 let scores = rrf_fuse(&[(1.0, &empty)], 60.0);
97 assert!(scores.is_empty());
98 }
99
100 #[test]
101 fn rrf_fuse_zero_weight_list_has_no_effect() {
102 let list_a = vec![1i64, 2];
103 let list_b = vec![3i64, 4];
104 let scores_with = rrf_fuse(&[(1.0, &list_a), (0.0, &list_b)], 60.0);
105 assert_eq!(scores_with.get(&3).copied().unwrap_or(0.0), 0.0);
107 assert_eq!(scores_with.get(&4).copied().unwrap_or(0.0), 0.0);
108 }
109
110 #[test]
111 fn rrf_fuse_weights_scale_contribution() {
112 let list = vec![1i64];
114 let low = rrf_fuse(&[(0.5, &list)], 60.0);
115 let high = rrf_fuse(&[(2.0, &list)], 60.0);
116 assert!(high[&1] > low[&1]);
117 }
118
119 #[test]
120 fn rrf_max_possible_sums_weights() {
121 let max = rrf_max_possible(&[1.0], 60.0);
123 let expected = 1.0 / 61.0;
124 assert!((max - expected).abs() < 1e-9);
125
126 let max2 = rrf_max_possible(&[1.0, 1.0], 60.0);
128 assert!((max2 - 2.0 / 61.0).abs() < 1e-9);
129 }
130
131 #[test]
132 fn rrf_fuse_deterministic_for_same_input() {
133 let list_a = vec![1i64, 2, 3];
134 let list_b = vec![2i64, 1, 4];
135 let scores_1 = rrf_fuse(&[(1.0, &list_a), (1.0, &list_b)], 60.0);
136 let scores_2 = rrf_fuse(&[(1.0, &list_a), (1.0, &list_b)], 60.0);
137 for id in [1i64, 2, 3, 4] {
138 assert_eq!(
139 scores_1.get(&id).copied().unwrap_or(0.0),
140 scores_2.get(&id).copied().unwrap_or(0.0)
141 );
142 }
143 }
144}