1use std::collections::HashMap;
9
10use crate::index::SearchHit;
11use crate::search::ast_pattern::AstMatch;
12use crate::search::text::TextMatch;
13
14#[derive(Debug, Clone)]
16pub struct FusedResult {
17 pub chunk_id: u64,
19
20 pub fused_score: f32,
22
23 pub text_score: Option<f32>,
25
26 pub semantic_score: Option<f32>,
28
29 pub ast_score: Option<f32>,
31
32 pub matched_lines: Vec<usize>,
34}
35
36pub fn rrf_fuse(
47 text_results: &[TextMatch],
48 semantic_results: &[SearchHit],
49 k: u32,
50 top_k: usize,
51) -> Vec<FusedResult> {
52 let mut scores: HashMap<u64, FusedResult> = HashMap::new();
53
54 for (rank, result) in text_results.iter().enumerate() {
56 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
57
58 scores
59 .entry(result.chunk_id)
60 .and_modify(|e| {
61 e.fused_score += rrf_score;
62 e.text_score = Some(result.score);
63 e.matched_lines = result.matched_lines.clone();
64 })
65 .or_insert(FusedResult {
66 chunk_id: result.chunk_id,
67 fused_score: rrf_score,
68 text_score: Some(result.score),
69 semantic_score: None,
70 ast_score: None,
71 matched_lines: result.matched_lines.clone(),
72 });
73 }
74
75 for (rank, result) in semantic_results.iter().enumerate() {
77 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
78
79 scores
80 .entry(result.chunk_id)
81 .and_modify(|e| {
82 e.fused_score += rrf_score;
83 e.semantic_score = Some(result.score);
84 })
85 .or_insert(FusedResult {
86 chunk_id: result.chunk_id,
87 fused_score: rrf_score,
88 text_score: None,
89 semantic_score: Some(result.score),
90 ast_score: None,
91 matched_lines: Vec::new(),
92 });
93 }
94
95 let mut fused: Vec<FusedResult> = scores.into_values().collect();
97 fused.sort_by(|a, b| {
98 b.fused_score
99 .partial_cmp(&a.fused_score)
100 .unwrap_or(std::cmp::Ordering::Equal)
101 });
102
103 fused.truncate(top_k);
105
106 fused
107}
108
109pub fn rrf_fuse_three(
113 text_results: &[TextMatch],
114 semantic_results: &[SearchHit],
115 ast_results: &[AstMatch],
116 k: u32,
117 top_k: usize,
118) -> Vec<FusedResult> {
119 let mut scores: HashMap<u64, FusedResult> = HashMap::new();
120
121 for (rank, result) in text_results.iter().enumerate() {
123 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
124
125 scores
126 .entry(result.chunk_id)
127 .and_modify(|e| {
128 e.fused_score += rrf_score;
129 e.text_score = Some(result.score);
130 e.matched_lines = result.matched_lines.clone();
131 })
132 .or_insert(FusedResult {
133 chunk_id: result.chunk_id,
134 fused_score: rrf_score,
135 text_score: Some(result.score),
136 semantic_score: None,
137 ast_score: None,
138 matched_lines: result.matched_lines.clone(),
139 });
140 }
141
142 for (rank, result) in semantic_results.iter().enumerate() {
144 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
145
146 scores
147 .entry(result.chunk_id)
148 .and_modify(|e| {
149 e.fused_score += rrf_score;
150 e.semantic_score = Some(result.score);
151 })
152 .or_insert(FusedResult {
153 chunk_id: result.chunk_id,
154 fused_score: rrf_score,
155 text_score: None,
156 semantic_score: Some(result.score),
157 ast_score: None,
158 matched_lines: Vec::new(),
159 });
160 }
161
162 for (rank, result) in ast_results.iter().enumerate() {
164 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
165
166 scores
167 .entry(result.chunk_id)
168 .and_modify(|e| {
169 e.fused_score += rrf_score;
170 e.ast_score = Some(result.score);
171 })
172 .or_insert(FusedResult {
173 chunk_id: result.chunk_id,
174 fused_score: rrf_score,
175 text_score: None,
176 semantic_score: None,
177 ast_score: Some(result.score),
178 matched_lines: Vec::new(),
179 });
180 }
181
182 let mut fused: Vec<FusedResult> = scores.into_values().collect();
184 fused.sort_by(|a, b| {
185 b.fused_score
186 .partial_cmp(&a.fused_score)
187 .unwrap_or(std::cmp::Ordering::Equal)
188 });
189
190 fused.truncate(top_k);
191 fused
192}
193
194pub fn fuse_semantic_only(semantic_results: &[SearchHit], top_k: usize) -> Vec<FusedResult> {
196 semantic_results
197 .iter()
198 .take(top_k)
199 .map(|r| FusedResult {
200 chunk_id: r.chunk_id,
201 fused_score: r.score,
202 text_score: None,
203 semantic_score: Some(r.score),
204 ast_score: None,
205 matched_lines: Vec::new(),
206 })
207 .collect()
208}
209
210pub fn fuse_text_only(text_results: &[TextMatch], top_k: usize) -> Vec<FusedResult> {
212 text_results
213 .iter()
214 .take(top_k)
215 .map(|r| FusedResult {
216 chunk_id: r.chunk_id,
217 fused_score: r.score,
218 text_score: Some(r.score),
219 semantic_score: None,
220 ast_score: None,
221 matched_lines: r.matched_lines.clone(),
222 })
223 .collect()
224}
225
226pub fn fuse_ast_only(ast_results: &[AstMatch], top_k: usize) -> Vec<FusedResult> {
228 ast_results
229 .iter()
230 .take(top_k)
231 .map(|r| FusedResult {
232 chunk_id: r.chunk_id,
233 fused_score: r.score,
234 text_score: None,
235 semantic_score: None,
236 ast_score: Some(r.score),
237 matched_lines: Vec::new(),
238 })
239 .collect()
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 fn make_text_matches(chunk_ids: &[u64]) -> Vec<TextMatch> {
247 chunk_ids
248 .iter()
249 .enumerate()
250 .map(|(i, &id)| TextMatch {
251 chunk_id: id,
252 matched_lines: vec![0],
253 score: (chunk_ids.len() - i) as f32,
254 })
255 .collect()
256 }
257
258 fn make_semantic_hits(chunk_ids: &[u64]) -> Vec<SearchHit> {
259 chunk_ids
260 .iter()
261 .enumerate()
262 .map(|(i, &id)| SearchHit {
263 chunk_id: id,
264 score: 1.0 - (i as f32 * 0.1),
265 })
266 .collect()
267 }
268
269 #[test]
270 fn test_rrf_basic_fusion() {
271 let text = make_text_matches(&[1, 2, 3]);
274 let semantic = make_semantic_hits(&[2, 3, 4]);
275
276 let fused = rrf_fuse(&text, &semantic, 60, 10);
277
278 assert!(!fused.is_empty());
280
281 let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
282 let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
283
284 assert!(
286 chunk_2.fused_score > chunk_1.fused_score,
287 "Chunk appearing in both lists should rank higher"
288 );
289 }
290
291 #[test]
292 fn test_rrf_preserves_all_unique_results() {
293 let text = make_text_matches(&[1, 2]);
294 let semantic = make_semantic_hits(&[3, 4]);
295
296 let fused = rrf_fuse(&text, &semantic, 60, 10);
297 assert_eq!(fused.len(), 4, "All unique chunks should be in results");
298 }
299
300 #[test]
301 fn test_rrf_top_k_truncation() {
302 let text = make_text_matches(&[1, 2, 3, 4, 5]);
303 let semantic = make_semantic_hits(&[6, 7, 8, 9, 10]);
304
305 let fused = rrf_fuse(&text, &semantic, 60, 3);
306 assert_eq!(fused.len(), 3, "Should respect top-k");
307 }
308
309 #[test]
310 fn test_rrf_empty_inputs() {
311 let fused = rrf_fuse(&[], &[], 60, 10);
312 assert!(fused.is_empty());
313 }
314
315 #[test]
316 fn test_fuse_semantic_only() {
317 let semantic = make_semantic_hits(&[1, 2, 3]);
318 let fused = fuse_semantic_only(&semantic, 2);
319 assert_eq!(fused.len(), 2);
320 assert!(fused[0].text_score.is_none());
321 assert!(fused[0].semantic_score.is_some());
322 }
323
324 #[test]
325 fn test_fuse_text_only() {
326 let text = make_text_matches(&[1, 2, 3]);
327 let fused = fuse_text_only(&text, 2);
328 assert_eq!(fused.len(), 2);
329 assert!(fused[0].text_score.is_some());
330 assert!(fused[0].semantic_score.is_none());
331 assert!(fused[0].ast_score.is_none());
332 }
333
334 fn make_ast_matches(chunk_ids: &[u64]) -> Vec<AstMatch> {
335 chunk_ids
336 .iter()
337 .enumerate()
338 .map(|(i, &id)| AstMatch {
339 chunk_id: id,
340 score: 1.0 - (i as f32 * 0.1),
341 })
342 .collect()
343 }
344
345 #[test]
346 fn test_fuse_ast_only() {
347 let ast = make_ast_matches(&[1, 2, 3]);
348 let fused = fuse_ast_only(&ast, 2);
349 assert_eq!(fused.len(), 2);
350 assert!(fused[0].text_score.is_none());
351 assert!(fused[0].semantic_score.is_none());
352 assert!(fused[0].ast_score.is_some());
353 }
354
355 #[test]
356 fn test_rrf_three_way_fusion() {
357 let text = make_text_matches(&[1, 2]);
358 let semantic = make_semantic_hits(&[2, 3]);
359 let ast = make_ast_matches(&[3, 4]);
360
361 let fused = rrf_fuse_three(&text, &semantic, &ast, 60, 10);
362
363 assert_eq!(fused.len(), 4);
365
366 let chunk_2 = fused.iter().find(|r| r.chunk_id == 2).unwrap();
368 let chunk_3 = fused.iter().find(|r| r.chunk_id == 3).unwrap();
369 let chunk_1 = fused.iter().find(|r| r.chunk_id == 1).unwrap();
370 let chunk_4 = fused.iter().find(|r| r.chunk_id == 4).unwrap();
371
372 assert!(chunk_2.fused_score > chunk_1.fused_score);
374 assert!(chunk_3.fused_score > chunk_4.fused_score);
375 }
376}