Skip to main content

seekr_code/search/
fusion.rs

1//! Multi-source result fusion.
2//!
3//! Implements Reciprocal Rank Fusion (RRF) to combine results from
4//! text search and semantic search into a unified ranked list.
5//!
6//! Formula: `score = sum(1 / (k + rank_i))` where k defaults to 60.
7
8use std::collections::HashMap;
9
10use crate::index::SearchHit;
11use crate::search::ast_pattern::AstMatch;
12use crate::search::text::TextMatch;
13
14/// Fused search result combining scores from multiple sources.
15#[derive(Debug, Clone)]
16pub struct FusedResult {
17    /// The chunk ID.
18    pub chunk_id: u64,
19
20    /// The RRF fusion score.
21    pub fused_score: f32,
22
23    /// Original score from text search (if any).
24    pub text_score: Option<f32>,
25
26    /// Original score from semantic search (if any).
27    pub semantic_score: Option<f32>,
28
29    /// Original score from AST pattern search (if any).
30    pub ast_score: Option<f32>,
31
32    /// Matched line numbers from text search (propagated for display).
33    pub matched_lines: Vec<usize>,
34}
35
36/// Perform Reciprocal Rank Fusion (RRF) on multiple result lists.
37///
38/// Combines text search and semantic search results into a single
39/// ranked list using the RRF formula: `score = sum(1 / (k + rank))`.
40///
41/// # Arguments
42/// * `text_results` - Results from text regex search (in rank order).
43/// * `semantic_results` - Results from semantic vector search (in rank order).
44/// * `k` - RRF parameter controlling rank discount (default: 60).
45/// * `top_k` - Maximum number of fused results to return.
46pub fn rrf_fuse(
47    text_results: &[TextMatch],
48    semantic_results: &[SearchHit],
49    k: u32,
50    top_k: usize,
51) -> Vec<FusedResult> {
52    let mut scores: HashMap<u64, FusedResult> = HashMap::new();
53
54    // Process text search results
55    for (rank, result) in text_results.iter().enumerate() {
56        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
57
58        scores
59            .entry(result.chunk_id)
60            .and_modify(|e| {
61                e.fused_score += rrf_score;
62                e.text_score = Some(result.score);
63                e.matched_lines = result.matched_lines.clone();
64            })
65            .or_insert(FusedResult {
66                chunk_id: result.chunk_id,
67                fused_score: rrf_score,
68                text_score: Some(result.score),
69                semantic_score: None,
70                ast_score: None,
71                matched_lines: result.matched_lines.clone(),
72            });
73    }
74
75    // Process semantic search results
76    for (rank, result) in semantic_results.iter().enumerate() {
77        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
78
79        scores
80            .entry(result.chunk_id)
81            .and_modify(|e| {
82                e.fused_score += rrf_score;
83                e.semantic_score = Some(result.score);
84            })
85            .or_insert(FusedResult {
86                chunk_id: result.chunk_id,
87                fused_score: rrf_score,
88                text_score: None,
89                semantic_score: Some(result.score),
90                ast_score: None,
91                matched_lines: Vec::new(),
92            });
93    }
94
95    // Sort by fused score descending
96    let mut fused: Vec<FusedResult> = scores.into_values().collect();
97    fused.sort_by(|a, b| {
98        b.fused_score
99            .partial_cmp(&a.fused_score)
100            .unwrap_or(std::cmp::Ordering::Equal)
101    });
102
103    // Truncate to top-k
104    fused.truncate(top_k);
105
106    fused
107}
108
109/// Perform three-way RRF fusion across text, semantic, and AST results.
110///
111/// Extends the basic two-way RRF to include AST pattern search results.
112pub fn rrf_fuse_three(
113    text_results: &[TextMatch],
114    semantic_results: &[SearchHit],
115    ast_results: &[AstMatch],
116    k: u32,
117    top_k: usize,
118) -> Vec<FusedResult> {
119    let mut scores: HashMap<u64, FusedResult> = HashMap::new();
120
121    // Process text search results
122    for (rank, result) in text_results.iter().enumerate() {
123        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
124
125        scores
126            .entry(result.chunk_id)
127            .and_modify(|e| {
128                e.fused_score += rrf_score;
129                e.text_score = Some(result.score);
130                e.matched_lines = result.matched_lines.clone();
131            })
132            .or_insert(FusedResult {
133                chunk_id: result.chunk_id,
134                fused_score: rrf_score,
135                text_score: Some(result.score),
136                semantic_score: None,
137                ast_score: None,
138                matched_lines: result.matched_lines.clone(),
139            });
140    }
141
142    // Process semantic search results
143    for (rank, result) in semantic_results.iter().enumerate() {
144        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
145
146        scores
147            .entry(result.chunk_id)
148            .and_modify(|e| {
149                e.fused_score += rrf_score;
150                e.semantic_score = Some(result.score);
151            })
152            .or_insert(FusedResult {
153                chunk_id: result.chunk_id,
154                fused_score: rrf_score,
155                text_score: None,
156                semantic_score: Some(result.score),
157                ast_score: None,
158                matched_lines: Vec::new(),
159            });
160    }
161
162    // Process AST pattern search results
163    for (rank, result) in ast_results.iter().enumerate() {
164        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
165
166        scores
167            .entry(result.chunk_id)
168            .and_modify(|e| {
169                e.fused_score += rrf_score;
170                e.ast_score = Some(result.score);
171            })
172            .or_insert(FusedResult {
173                chunk_id: result.chunk_id,
174                fused_score: rrf_score,
175                text_score: None,
176                semantic_score: None,
177                ast_score: Some(result.score),
178                matched_lines: Vec::new(),
179            });
180    }
181
182    // Sort by fused score descending
183    let mut fused: Vec<FusedResult> = scores.into_values().collect();
184    fused.sort_by(|a, b| {
185        b.fused_score
186            .partial_cmp(&a.fused_score)
187            .unwrap_or(std::cmp::Ordering::Equal)
188    });
189
190    fused.truncate(top_k);
191    fused
192}
193
194/// Fuse only semantic results (used when text search is not applicable).
195pub fn fuse_semantic_only(semantic_results: &[SearchHit], top_k: usize) -> Vec<FusedResult> {
196    semantic_results
197        .iter()
198        .take(top_k)
199        .map(|r| FusedResult {
200            chunk_id: r.chunk_id,
201            fused_score: r.score,
202            text_score: None,
203            semantic_score: Some(r.score),
204            ast_score: None,
205            matched_lines: Vec::new(),
206        })
207        .collect()
208}
209
210/// Fuse only text results (used when semantic search is not applicable).
211pub fn fuse_text_only(text_results: &[TextMatch], top_k: usize) -> Vec<FusedResult> {
212    text_results
213        .iter()
214        .take(top_k)
215        .map(|r| FusedResult {
216            chunk_id: r.chunk_id,
217            fused_score: r.score,
218            text_score: Some(r.score),
219            semantic_score: None,
220            ast_score: None,
221            matched_lines: r.matched_lines.clone(),
222        })
223        .collect()
224}
225
226/// Fuse only AST pattern search results.
227pub fn fuse_ast_only(ast_results: &[AstMatch], top_k: usize) -> Vec<FusedResult> {
228    ast_results
229        .iter()
230        .take(top_k)
231        .map(|r| FusedResult {
232            chunk_id: r.chunk_id,
233            fused_score: r.score,
234            text_score: None,
235            semantic_score: None,
236            ast_score: Some(r.score),
237            matched_lines: Vec::new(),
238        })
239        .collect()
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    fn make_text_matches(chunk_ids: &[u64]) -> Vec<TextMatch> {
247        chunk_ids
248            .iter()
249            .enumerate()
250            .map(|(i, &id)| TextMatch {
251                chunk_id: id,
252                matched_lines: vec![0],
253                score: (chunk_ids.len() - i) as f32,
254            })
255            .collect()
256    }
257
258    fn make_semantic_hits(chunk_ids: &[u64]) -> Vec<SearchHit> {
259        chunk_ids
260            .iter()
261            .enumerate()
262            .map(|(i, &id)| SearchHit {
263                chunk_id: id,
264                score: 1.0 - (i as f32 * 0.1),
265            })
266            .collect()
267    }
268
269    #[test]
270    fn test_rrf_basic_fusion() {
271        // Text: [1, 2, 3]
272        // Semantic: [2, 3, 4]
273        let text = make_text_matches(&[1, 2, 3]);
274        let semantic = make_semantic_hits(&[2, 3, 4]);
275
276        let fused = rrf_fuse(&text, &semantic, 60, 10);
277
278        // Chunk 2 appears in both lists, should have highest fused score
279        assert!(!fused.is_empty());
280
281        let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
282        let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
283
284        // Chunk 2 is in both, chunk 1 is in text only
285        assert!(
286            chunk_2.fused_score > chunk_1.fused_score,
287            "Chunk appearing in both lists should rank higher"
288        );
289    }
290
291    #[test]
292    fn test_rrf_preserves_all_unique_results() {
293        let text = make_text_matches(&[1, 2]);
294        let semantic = make_semantic_hits(&[3, 4]);
295
296        let fused = rrf_fuse(&text, &semantic, 60, 10);
297        assert_eq!(fused.len(), 4, "All unique chunks should be in results");
298    }
299
300    #[test]
301    fn test_rrf_top_k_truncation() {
302        let text = make_text_matches(&[1, 2, 3, 4, 5]);
303        let semantic = make_semantic_hits(&[6, 7, 8, 9, 10]);
304
305        let fused = rrf_fuse(&text, &semantic, 60, 3);
306        assert_eq!(fused.len(), 3, "Should respect top-k");
307    }
308
309    #[test]
310    fn test_rrf_empty_inputs() {
311        let fused = rrf_fuse(&[], &[], 60, 10);
312        assert!(fused.is_empty());
313    }
314
315    #[test]
316    fn test_fuse_semantic_only() {
317        let semantic = make_semantic_hits(&[1, 2, 3]);
318        let fused = fuse_semantic_only(&semantic, 2);
319        assert_eq!(fused.len(), 2);
320        assert!(fused[0].text_score.is_none());
321        assert!(fused[0].semantic_score.is_some());
322    }
323
324    #[test]
325    fn test_fuse_text_only() {
326        let text = make_text_matches(&[1, 2, 3]);
327        let fused = fuse_text_only(&text, 2);
328        assert_eq!(fused.len(), 2);
329        assert!(fused[0].text_score.is_some());
330        assert!(fused[0].semantic_score.is_none());
331        assert!(fused[0].ast_score.is_none());
332    }
333
334    fn make_ast_matches(chunk_ids: &[u64]) -> Vec<AstMatch> {
335        chunk_ids
336            .iter()
337            .enumerate()
338            .map(|(i, &id)| AstMatch {
339                chunk_id: id,
340                score: 1.0 - (i as f32 * 0.1),
341            })
342            .collect()
343    }
344
345    #[test]
346    fn test_fuse_ast_only() {
347        let ast = make_ast_matches(&[1, 2, 3]);
348        let fused = fuse_ast_only(&ast, 2);
349        assert_eq!(fused.len(), 2);
350        assert!(fused[0].text_score.is_none());
351        assert!(fused[0].semantic_score.is_none());
352        assert!(fused[0].ast_score.is_some());
353    }
354
355    #[test]
356    fn test_rrf_three_way_fusion() {
357        let text = make_text_matches(&[1, 2]);
358        let semantic = make_semantic_hits(&[2, 3]);
359        let ast = make_ast_matches(&[3, 4]);
360
361        let fused = rrf_fuse_three(&text, &semantic, &ast, 60, 10);
362
363        // All 4 unique chunks should appear
364        assert_eq!(fused.len(), 4);
365
366        // Chunk 2 appears in text + semantic, chunk 3 in semantic + ast
367        let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
368        let chunk_3 = fused.iter().find(|r| r.chunk_id == 3).unwrap();
369        let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
370        let chunk_4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
371
372        // Chunks appearing in 2 lists should rank higher than those in 1
373        assert!(chunk_2.fused_score > chunk_1.fused_score);
374        assert!(chunk_3.fused_score > chunk_4.fused_score);
375    }
376}