Skip to main content

seekr_code/search/
fusion.rs

1//! Multi-source result fusion.
2//!
3//! Implements Reciprocal Rank Fusion (RRF) to combine results from
4//! text search and semantic search into a unified ranked list.
5//!
6//! Formula: `score = sum(1 / (k + rank_i))` where k defaults to 60.
7
8use std::collections::HashMap;
9
10use crate::index::SearchHit;
11use crate::search::ast_pattern::AstMatch;
12use crate::search::text::TextMatch;
13
14/// Fused search result combining scores from multiple sources.
15#[derive(Debug, Clone)]
16pub struct FusedResult {
17    /// The chunk ID.
18    pub chunk_id: u64,
19
20    /// The RRF fusion score.
21    pub fused_score: f32,
22
23    /// Original score from text search (if any).
24    pub text_score: Option<f32>,
25
26    /// Original score from semantic search (if any).
27    pub semantic_score: Option<f32>,
28
29    /// Original score from AST pattern search (if any).
30    pub ast_score: Option<f32>,
31
32    /// Matched line numbers from text search (propagated for display).
33    pub matched_lines: Vec<usize>,
34}
35
36/// Perform Reciprocal Rank Fusion (RRF) on multiple result lists.
37///
38/// Combines text search and semantic search results into a single
39/// ranked list using the RRF formula: `score = sum(1 / (k + rank))`.
40///
41/// # Arguments
42/// * `text_results` - Results from text regex search (in rank order).
43/// * `semantic_results` - Results from semantic vector search (in rank order).
44/// * `k` - RRF parameter controlling rank discount (default: 60).
45/// * `top_k` - Maximum number of fused results to return.
46pub fn rrf_fuse(
47    text_results: &[TextMatch],
48    semantic_results: &[SearchHit],
49    k: u32,
50    top_k: usize,
51) -> Vec<FusedResult> {
52    let mut scores: HashMap<u64, FusedResult> = HashMap::new();
53
54    // Process text search results
55    for (rank, result) in text_results.iter().enumerate() {
56        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
57
58        scores
59            .entry(result.chunk_id)
60            .and_modify(|e| {
61                e.fused_score += rrf_score;
62                e.text_score = Some(result.score);
63                e.matched_lines = result.matched_lines.clone();
64            })
65            .or_insert(FusedResult {
66                chunk_id: result.chunk_id,
67                fused_score: rrf_score,
68                text_score: Some(result.score),
69                semantic_score: None,
70                ast_score: None,
71                matched_lines: result.matched_lines.clone(),
72            });
73    }
74
75    // Process semantic search results
76    for (rank, result) in semantic_results.iter().enumerate() {
77        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
78
79        scores
80            .entry(result.chunk_id)
81            .and_modify(|e| {
82                e.fused_score += rrf_score;
83                e.semantic_score = Some(result.score);
84            })
85            .or_insert(FusedResult {
86                chunk_id: result.chunk_id,
87                fused_score: rrf_score,
88                text_score: None,
89                semantic_score: Some(result.score),
90                ast_score: None,
91                matched_lines: Vec::new(),
92            });
93    }
94
95    // Sort by fused score descending
96    let mut fused: Vec<FusedResult> = scores.into_values().collect();
97    fused.sort_by(|a, b| {
98        b.fused_score
99            .partial_cmp(&a.fused_score)
100            .unwrap_or(std::cmp::Ordering::Equal)
101    });
102
103    // Truncate to top-k
104    fused.truncate(top_k);
105
106    fused
107}
108
109/// Perform three-way RRF fusion across text, semantic, and AST results.
110///
111/// Extends the basic two-way RRF to include AST pattern search results.
112pub fn rrf_fuse_three(
113    text_results: &[TextMatch],
114    semantic_results: &[SearchHit],
115    ast_results: &[AstMatch],
116    k: u32,
117    top_k: usize,
118) -> Vec<FusedResult> {
119    let mut scores: HashMap<u64, FusedResult> = HashMap::new();
120
121    // Process text search results
122    for (rank, result) in text_results.iter().enumerate() {
123        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
124
125        scores
126            .entry(result.chunk_id)
127            .and_modify(|e| {
128                e.fused_score += rrf_score;
129                e.text_score = Some(result.score);
130                e.matched_lines = result.matched_lines.clone();
131            })
132            .or_insert(FusedResult {
133                chunk_id: result.chunk_id,
134                fused_score: rrf_score,
135                text_score: Some(result.score),
136                semantic_score: None,
137                ast_score: None,
138                matched_lines: result.matched_lines.clone(),
139            });
140    }
141
142    // Process semantic search results
143    for (rank, result) in semantic_results.iter().enumerate() {
144        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
145
146        scores
147            .entry(result.chunk_id)
148            .and_modify(|e| {
149                e.fused_score += rrf_score;
150                e.semantic_score = Some(result.score);
151            })
152            .or_insert(FusedResult {
153                chunk_id: result.chunk_id,
154                fused_score: rrf_score,
155                text_score: None,
156                semantic_score: Some(result.score),
157                ast_score: None,
158                matched_lines: Vec::new(),
159            });
160    }
161
162    // Process AST pattern search results
163    for (rank, result) in ast_results.iter().enumerate() {
164        let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
165
166        scores
167            .entry(result.chunk_id)
168            .and_modify(|e| {
169                e.fused_score += rrf_score;
170                e.ast_score = Some(result.score);
171            })
172            .or_insert(FusedResult {
173                chunk_id: result.chunk_id,
174                fused_score: rrf_score,
175                text_score: None,
176                semantic_score: None,
177                ast_score: Some(result.score),
178                matched_lines: Vec::new(),
179            });
180    }
181
182    // Sort by fused score descending
183    let mut fused: Vec<FusedResult> = scores.into_values().collect();
184    fused.sort_by(|a, b| {
185        b.fused_score
186            .partial_cmp(&a.fused_score)
187            .unwrap_or(std::cmp::Ordering::Equal)
188    });
189
190    fused.truncate(top_k);
191    fused
192}
193
194/// Fuse only semantic results (used when text search is not applicable).
195pub fn fuse_semantic_only(
196    semantic_results: &[SearchHit],
197    top_k: usize,
198) -> Vec<FusedResult> {
199    semantic_results
200        .iter()
201        .take(top_k)
202        .map(|r| FusedResult {
203            chunk_id: r.chunk_id,
204            fused_score: r.score,
205            text_score: None,
206            semantic_score: Some(r.score),
207            ast_score: None,
208            matched_lines: Vec::new(),
209        })
210        .collect()
211}
212
213/// Fuse only text results (used when semantic search is not applicable).
214pub fn fuse_text_only(
215    text_results: &[TextMatch],
216    top_k: usize,
217) -> Vec<FusedResult> {
218    text_results
219        .iter()
220        .take(top_k)
221        .map(|r| FusedResult {
222            chunk_id: r.chunk_id,
223            fused_score: r.score,
224            text_score: Some(r.score),
225            semantic_score: None,
226            ast_score: None,
227            matched_lines: r.matched_lines.clone(),
228        })
229        .collect()
230}
231
232/// Fuse only AST pattern search results.
233pub fn fuse_ast_only(
234    ast_results: &[AstMatch],
235    top_k: usize,
236) -> Vec<FusedResult> {
237    ast_results
238        .iter()
239        .take(top_k)
240        .map(|r| FusedResult {
241            chunk_id: r.chunk_id,
242            fused_score: r.score,
243            text_score: None,
244            semantic_score: None,
245            ast_score: Some(r.score),
246            matched_lines: Vec::new(),
247        })
248        .collect()
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    fn make_text_matches(chunk_ids: &[u64]) -> Vec<TextMatch> {
256        chunk_ids
257            .iter()
258            .enumerate()
259            .map(|(i, &id)| TextMatch {
260                chunk_id: id,
261                matched_lines: vec![0],
262                score: (chunk_ids.len() - i) as f32,
263            })
264            .collect()
265    }
266
267    fn make_semantic_hits(chunk_ids: &[u64]) -> Vec<SearchHit> {
268        chunk_ids
269            .iter()
270            .enumerate()
271            .map(|(i, &id)| SearchHit {
272                chunk_id: id,
273                score: 1.0 - (i as f32 * 0.1),
274            })
275            .collect()
276    }
277
278    #[test]
279    fn test_rrf_basic_fusion() {
280        // Text: [1, 2, 3]
281        // Semantic: [2, 3, 4]
282        let text = make_text_matches(&[1, 2, 3]);
283        let semantic = make_semantic_hits(&[2, 3, 4]);
284
285        let fused = rrf_fuse(&text, &semantic, 60, 10);
286
287        // Chunk 2 appears in both lists, should have highest fused score
288        assert!(!fused.is_empty());
289
290        let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
291        let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
292
293        // Chunk 2 is in both, chunk 1 is in text only
294        assert!(
295            chunk_2.fused_score > chunk_1.fused_score,
296            "Chunk appearing in both lists should rank higher"
297        );
298    }
299
300    #[test]
301    fn test_rrf_preserves_all_unique_results() {
302        let text = make_text_matches(&[1, 2]);
303        let semantic = make_semantic_hits(&[3, 4]);
304
305        let fused = rrf_fuse(&text, &semantic, 60, 10);
306        assert_eq!(fused.len(), 4, "All unique chunks should be in results");
307    }
308
309    #[test]
310    fn test_rrf_top_k_truncation() {
311        let text = make_text_matches(&[1, 2, 3, 4, 5]);
312        let semantic = make_semantic_hits(&[6, 7, 8, 9, 10]);
313
314        let fused = rrf_fuse(&text, &semantic, 60, 3);
315        assert_eq!(fused.len(), 3, "Should respect top-k");
316    }
317
318    #[test]
319    fn test_rrf_empty_inputs() {
320        let fused = rrf_fuse(&[], &[], 60, 10);
321        assert!(fused.is_empty());
322    }
323
324    #[test]
325    fn test_fuse_semantic_only() {
326        let semantic = make_semantic_hits(&[1, 2, 3]);
327        let fused = fuse_semantic_only(&semantic, 2);
328        assert_eq!(fused.len(), 2);
329        assert!(fused[0].text_score.is_none());
330        assert!(fused[0].semantic_score.is_some());
331    }
332
333    #[test]
334    fn test_fuse_text_only() {
335        let text = make_text_matches(&[1, 2, 3]);
336        let fused = fuse_text_only(&text, 2);
337        assert_eq!(fused.len(), 2);
338        assert!(fused[0].text_score.is_some());
339        assert!(fused[0].semantic_score.is_none());
340        assert!(fused[0].ast_score.is_none());
341    }
342
343    fn make_ast_matches(chunk_ids: &[u64]) -> Vec<AstMatch> {
344        chunk_ids
345            .iter()
346            .enumerate()
347            .map(|(i, &id)| AstMatch {
348                chunk_id: id,
349                score: 1.0 - (i as f32 * 0.1),
350            })
351            .collect()
352    }
353
354    #[test]
355    fn test_fuse_ast_only() {
356        let ast = make_ast_matches(&[1, 2, 3]);
357        let fused = fuse_ast_only(&ast, 2);
358        assert_eq!(fused.len(), 2);
359        assert!(fused[0].text_score.is_none());
360        assert!(fused[0].semantic_score.is_none());
361        assert!(fused[0].ast_score.is_some());
362    }
363
364    #[test]
365    fn test_rrf_three_way_fusion() {
366        let text = make_text_matches(&[1, 2]);
367        let semantic = make_semantic_hits(&[2, 3]);
368        let ast = make_ast_matches(&[3, 4]);
369
370        let fused = rrf_fuse_three(&text, &semantic, &ast, 60, 10);
371
372        // All 4 unique chunks should appear
373        assert_eq!(fused.len(), 4);
374
375        // Chunk 2 appears in text + semantic, chunk 3 in semantic + ast
376        let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
377        let chunk_3 = fused.iter().find(|r| r.chunk_id == 3).unwrap();
378        let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
379        let chunk_4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
380
381        // Chunks appearing in 2 lists should rank higher than those in 1
382        assert!(chunk_2.fused_score > chunk_1.fused_score);
383        assert!(chunk_3.fused_score > chunk_4.fused_score);
384    }
385}