skill_runtime/search/
fusion.rs

1//! Fusion algorithms for combining search results
2//!
3//! Implements Reciprocal Rank Fusion (RRF) and weighted sum fusion
4//! for combining results from multiple retrieval systems.
5
6use std::collections::HashMap;
7
8/// Result from fusion containing the ID and combined score
9#[derive(Debug, Clone)]
10pub struct FusedResult {
11    /// Document ID
12    pub id: String,
13    /// Combined score after fusion
14    pub score: f32,
15    /// Source scores for debugging/analysis
16    pub source_scores: HashMap<String, f32>,
17}
18
19/// Fusion method to combine search results
20#[derive(Debug, Clone, Copy, Default, PartialEq)]
21pub enum FusionMethod {
22    /// Reciprocal Rank Fusion (RRF) - default
23    #[default]
24    ReciprocalRank,
25    /// Weighted sum of normalized scores
26    WeightedSum,
27    /// Take maximum score from any source
28    MaxScore,
29}
30
31/// Reciprocal Rank Fusion (RRF)
32///
33/// Combines ranked lists using the formula:
34/// score(d) = Σ(1 / (k + rank_i))
35///
36/// where k is typically 60 (per original paper)
37///
38/// # Arguments
39/// * `ranked_lists` - List of (source_name, rankings) where rankings are (id, original_score)
40/// * `k` - RRF constant, default 60
41/// * `top_k` - Number of results to return
42///
43/// # Returns
44/// Fused results sorted by combined score descending
45pub fn reciprocal_rank_fusion(
46    ranked_lists: Vec<(&str, Vec<(String, f32)>)>,
47    k: f32,
48    top_k: usize,
49) -> Vec<FusedResult> {
50    let mut scores: HashMap<String, (f32, HashMap<String, f32>)> = HashMap::new();
51
52    for (source_name, rankings) in ranked_lists {
53        for (rank, (id, original_score)) in rankings.into_iter().enumerate() {
54            let rrf_score = 1.0 / (k + (rank + 1) as f32);
55
56            let entry = scores.entry(id).or_insert_with(|| (0.0, HashMap::new()));
57            entry.0 += rrf_score;
58            entry.1.insert(source_name.to_string(), original_score);
59        }
60    }
61
62    let mut results: Vec<FusedResult> = scores
63        .into_iter()
64        .map(|(id, (score, source_scores))| FusedResult {
65            id,
66            score,
67            source_scores,
68        })
69        .collect();
70
71    // Sort by score descending
72    results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
73
74    results.truncate(top_k);
75    results
76}
77
78/// Weighted sum fusion
79///
80/// Normalizes scores from each source to [0,1] and combines with weights
81///
82/// # Arguments
83/// * `weighted_lists` - List of (source_name, weight, rankings) where rankings are (id, original_score)
84/// * `top_k` - Number of results to return
85///
86/// # Returns
87/// Fused results sorted by combined score descending
88pub fn weighted_sum_fusion(
89    weighted_lists: Vec<(&str, f32, Vec<(String, f32)>)>,
90    top_k: usize,
91) -> Vec<FusedResult> {
92    let mut scores: HashMap<String, (f32, HashMap<String, f32>)> = HashMap::new();
93
94    for (source_name, weight, rankings) in weighted_lists {
95        // Find min/max for normalization
96        let (min_score, max_score) = rankings.iter().fold((f32::MAX, f32::MIN), |(min, max), (_, s)| {
97            (min.min(*s), max.max(*s))
98        });
99
100        let range = max_score - min_score;
101
102        for (id, original_score) in rankings {
103            // Normalize to [0, 1]
104            let normalized = if range > 0.0 {
105                (original_score - min_score) / range
106            } else {
107                1.0 // All scores are the same
108            };
109
110            let weighted_score = normalized * weight;
111
112            let entry = scores.entry(id).or_insert_with(|| (0.0, HashMap::new()));
113            entry.0 += weighted_score;
114            entry.1.insert(source_name.to_string(), original_score);
115        }
116    }
117
118    let mut results: Vec<FusedResult> = scores
119        .into_iter()
120        .map(|(id, (score, source_scores))| FusedResult {
121            id,
122            score,
123            source_scores,
124        })
125        .collect();
126
127    results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
128
129    results.truncate(top_k);
130    results
131}
132
133/// Max score fusion
134///
135/// Takes the maximum score from any source for each document
136pub fn max_score_fusion(
137    ranked_lists: Vec<(&str, Vec<(String, f32)>)>,
138    top_k: usize,
139) -> Vec<FusedResult> {
140    let mut scores: HashMap<String, (f32, HashMap<String, f32>)> = HashMap::new();
141
142    for (source_name, rankings) in ranked_lists {
143        for (id, original_score) in rankings {
144            let entry = scores.entry(id).or_insert_with(|| (f32::MIN, HashMap::new()));
145            if original_score > entry.0 {
146                entry.0 = original_score;
147            }
148            entry.1.insert(source_name.to_string(), original_score);
149        }
150    }
151
152    let mut results: Vec<FusedResult> = scores
153        .into_iter()
154        .map(|(id, (score, source_scores))| FusedResult {
155            id,
156            score,
157            source_scores,
158        })
159        .collect();
160
161    results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
162
163    results.truncate(top_k);
164    results
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_rrf_fusion_basic() {
173        let dense_results = vec![
174            ("doc1".to_string(), 0.95),
175            ("doc2".to_string(), 0.85),
176            ("doc3".to_string(), 0.75),
177        ];
178
179        let sparse_results = vec![
180            ("doc2".to_string(), 10.5), // doc2 ranks higher in BM25
181            ("doc1".to_string(), 8.3),
182            ("doc4".to_string(), 7.1),
183        ];
184
185        let results = reciprocal_rank_fusion(
186            vec![("dense", dense_results), ("sparse", sparse_results)],
187            60.0,
188            5,
189        );
190
191        // doc1 and doc2 should have the highest scores (appear in both lists)
192        assert!(results.len() <= 5);
193
194        // Both doc1 and doc2 should be in results
195        let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
196        assert!(ids.contains(&"doc1"));
197        assert!(ids.contains(&"doc2"));
198    }
199
200    #[test]
201    fn test_rrf_k_parameter() {
202        let list = vec![
203            ("doc1".to_string(), 1.0),
204            ("doc2".to_string(), 0.9),
205        ];
206
207        // With k=60, rank 1 gets score 1/(60+1) = 0.0164
208        let results = reciprocal_rank_fusion(vec![("test", list.clone())], 60.0, 5);
209        assert!((results[0].score - 1.0 / 61.0).abs() < 0.001);
210        assert!((results[1].score - 1.0 / 62.0).abs() < 0.001);
211    }
212
213    #[test]
214    fn test_weighted_sum_fusion() {
215        let dense_results = vec![
216            ("doc1".to_string(), 0.9),
217            ("doc2".to_string(), 0.7),
218        ];
219
220        let sparse_results = vec![
221            ("doc1".to_string(), 5.0),
222            ("doc2".to_string(), 10.0), // Higher sparse score
223        ];
224
225        let results = weighted_sum_fusion(
226            vec![("dense", 0.7, dense_results), ("sparse", 0.3, sparse_results)],
227            5,
228        );
229
230        assert!(!results.is_empty());
231        // Both should appear
232        let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
233        assert!(ids.contains(&"doc1"));
234        assert!(ids.contains(&"doc2"));
235    }
236
237    #[test]
238    fn test_fusion_with_empty_list() {
239        let results = reciprocal_rank_fusion(vec![], 60.0, 5);
240        assert!(results.is_empty());
241    }
242
243    #[test]
244    fn test_fusion_source_scores_preserved() {
245        let dense_results = vec![("doc1".to_string(), 0.95)];
246        let sparse_results = vec![("doc1".to_string(), 8.5)];
247
248        let results = reciprocal_rank_fusion(
249            vec![("dense", dense_results), ("sparse", sparse_results)],
250            60.0,
251            5,
252        );
253
254        assert_eq!(results[0].id, "doc1");
255        assert_eq!(results[0].source_scores.get("dense"), Some(&0.95));
256        assert_eq!(results[0].source_scores.get("sparse"), Some(&8.5));
257    }
258
259    #[test]
260    fn test_max_score_fusion() {
261        let list1 = vec![
262            ("doc1".to_string(), 0.5),
263            ("doc2".to_string(), 0.8),
264        ];
265
266        let list2 = vec![
267            ("doc1".to_string(), 0.9), // Higher
268            ("doc2".to_string(), 0.3),
269        ];
270
271        let results = max_score_fusion(vec![("a", list1), ("b", list2)], 5);
272
273        // doc1 should have score 0.9 (max), doc2 should have 0.8
274        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
275        let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
276        assert!((doc1.score - 0.9).abs() < 0.001);
277        assert!((doc2.score - 0.8).abs() < 0.001);
278    }
279}