1use crate::{retrieve::RetrievalResult, Result};
4use serde::{Deserialize, Serialize};
5
6pub trait Reranker: Send + Sync {
8 fn rerank(
10 &self,
11 query: &str,
12 candidates: &[RetrievalResult],
13 top_k: usize,
14 ) -> Result<Vec<RetrievalResult>>;
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct LexicalReranker {
20 pub exact_match_weight: f32,
22 pub coverage_weight: f32,
24 pub position_weight: f32,
26 pub case_insensitive: bool,
28}
29
30impl Default for LexicalReranker {
31 fn default() -> Self {
32 Self {
33 exact_match_weight: 0.3,
34 coverage_weight: 0.5,
35 position_weight: 0.2,
36 case_insensitive: true,
37 }
38 }
39}
40
41impl LexicalReranker {
42 #[must_use]
44 pub fn new() -> Self {
45 Self::default()
46 }
47
48 #[must_use]
50 pub fn with_weights(mut self, exact_match: f32, coverage: f32, position: f32) -> Self {
51 self.exact_match_weight = exact_match;
52 self.coverage_weight = coverage;
53 self.position_weight = position;
54 self
55 }
56
57 fn score(&self, query: &str, content: &str) -> f32 {
59 let (query, content) = if self.case_insensitive {
60 (query.to_lowercase(), content.to_lowercase())
61 } else {
62 (query.to_string(), content.to_string())
63 };
64
65 let query_terms: Vec<&str> = query.split_whitespace().collect();
66 if query_terms.is_empty() {
67 return 0.0;
68 }
69
70 let exact_match = if content.contains(&query) { 1.0 } else { 0.0 };
72
73 let matches = query_terms.iter().filter(|term| content.contains(*term)).count() as f32;
75 let coverage = matches / query_terms.len().max(1) as f32;
76
77 let position_score = query_terms
79 .iter()
80 .filter_map(|term| content.find(term))
81 .map(|pos| 1.0 / (1.0 + pos as f32 / 100.0))
82 .sum::<f32>()
83 / query_terms.len().max(1) as f32;
84
85 self.exact_match_weight * exact_match
86 + self.coverage_weight * coverage
87 + self.position_weight * position_score
88 }
89}
90
91impl Reranker for LexicalReranker {
92 fn rerank(
93 &self,
94 query: &str,
95 candidates: &[RetrievalResult],
96 top_k: usize,
97 ) -> Result<Vec<RetrievalResult>> {
98 let mut scored: Vec<(RetrievalResult, f32)> = candidates
99 .iter()
100 .map(|c| {
101 let score = self.score(query, &c.chunk.content);
102 (c.clone(), score)
103 })
104 .collect();
105
106 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
108
109 Ok(scored
111 .into_iter()
112 .take(top_k)
113 .map(|(mut result, score)| {
114 result.rerank_score = Some(score);
115 result
116 })
117 .collect())
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct MockCrossEncoderReranker {
124 model_id: String,
126}
127
128impl MockCrossEncoderReranker {
129 #[must_use]
131 pub fn new(model_id: impl Into<String>) -> Self {
132 Self { model_id: model_id.into() }
133 }
134
135 #[must_use]
137 pub fn model_id(&self) -> &str {
138 &self.model_id
139 }
140
141 #[allow(clippy::unused_self)]
143 fn score_pair(&self, query: &str, document: &str) -> f32 {
144 let query_lower = query.to_lowercase();
146 let doc_lower = document.to_lowercase();
147
148 let query_terms: std::collections::HashSet<&str> = query_lower.split_whitespace().collect();
149 let doc_terms: std::collections::HashSet<&str> = doc_lower.split_whitespace().collect();
150
151 if query_terms.is_empty() || doc_terms.is_empty() {
152 return 0.0;
153 }
154
155 let overlap = query_terms.intersection(&doc_terms).count();
156 overlap as f32 / query_terms.len() as f32
157 }
158}
159
160impl Reranker for MockCrossEncoderReranker {
161 fn rerank(
162 &self,
163 query: &str,
164 candidates: &[RetrievalResult],
165 top_k: usize,
166 ) -> Result<Vec<RetrievalResult>> {
167 let mut scored: Vec<(RetrievalResult, f32)> = candidates
168 .iter()
169 .map(|c| {
170 let score = self.score_pair(query, &c.chunk.content);
171 (c.clone(), score)
172 })
173 .collect();
174
175 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
176
177 Ok(scored
178 .into_iter()
179 .take(top_k)
180 .map(|(mut result, score)| {
181 result.rerank_score = Some(score);
182 result
183 })
184 .collect())
185 }
186}
187
188pub struct CompositeReranker {
190 rerankers: Vec<(Box<dyn Reranker>, f32)>,
191}
192
193impl CompositeReranker {
194 #[must_use]
196 pub fn new() -> Self {
197 Self { rerankers: Vec::new() }
198 }
199
200 #[must_use]
202 pub fn with_reranker(mut self, reranker: Box<dyn Reranker>, weight: f32) -> Self {
203 self.rerankers.push((reranker, weight));
204 self
205 }
206}
207
208impl Default for CompositeReranker {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214impl Reranker for CompositeReranker {
215 fn rerank(
216 &self,
217 query: &str,
218 candidates: &[RetrievalResult],
219 top_k: usize,
220 ) -> Result<Vec<RetrievalResult>> {
221 if self.rerankers.is_empty() {
222 return Ok(candidates.iter().take(top_k).cloned().collect());
223 }
224
225 let mut combined_scores: std::collections::HashMap<usize, f32> =
227 std::collections::HashMap::new();
228
229 for (reranker, weight) in &self.rerankers {
230 let reranked = reranker.rerank(query, candidates, candidates.len())?;
231 for result in &reranked {
232 for (orig_idx, orig) in candidates.iter().enumerate() {
234 if orig.chunk.id == result.chunk.id {
235 let score = result.rerank_score.unwrap_or(0.0);
236 *combined_scores.entry(orig_idx).or_insert(0.0) += weight * score;
237 break;
238 }
239 }
240 }
241 }
242
243 let mut indexed: Vec<_> = combined_scores.into_iter().collect();
245 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246
247 Ok(indexed
248 .into_iter()
249 .take(top_k)
250 .map(|(idx, score)| {
251 let mut result = candidates[idx].clone();
252 result.rerank_score = Some(score);
253 result
254 })
255 .collect())
256 }
257}
258
259#[derive(Debug, Clone, Default)]
261pub struct NoOpReranker;
262
263impl NoOpReranker {
264 #[must_use]
266 pub fn new() -> Self {
267 Self
268 }
269}
270
271impl Reranker for NoOpReranker {
272 fn rerank(
273 &self,
274 _query: &str,
275 candidates: &[RetrievalResult],
276 top_k: usize,
277 ) -> Result<Vec<RetrievalResult>> {
278 Ok(candidates.iter().take(top_k).cloned().collect())
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::{Chunk, DocumentId};
286
287 fn create_result(content: &str) -> RetrievalResult {
288 let chunk = Chunk::new(DocumentId::new(), content.to_string(), 0, content.len());
289 RetrievalResult::new(chunk)
290 }
291
292 fn create_result_with_score(content: &str, dense: f32) -> RetrievalResult {
293 create_result(content).with_dense_score(dense)
294 }
295
296 #[test]
299 fn test_lexical_reranker_default() {
300 let reranker = LexicalReranker::default();
301 assert!((reranker.exact_match_weight - 0.3).abs() < 0.01);
302 assert!((reranker.coverage_weight - 0.5).abs() < 0.01);
303 assert!((reranker.position_weight - 0.2).abs() < 0.01);
304 assert!(reranker.case_insensitive);
305 }
306
307 #[test]
308 fn test_lexical_reranker_with_weights() {
309 let reranker = LexicalReranker::new().with_weights(0.5, 0.3, 0.2);
310 assert!((reranker.exact_match_weight - 0.5).abs() < 0.01);
311 }
312
313 #[test]
314 fn test_lexical_reranker_exact_match() {
315 let reranker = LexicalReranker::new();
316 let candidates = vec![
317 create_result("This contains the exact query machine learning"),
318 create_result("This mentions machine and learning separately"),
319 ];
320
321 let results = reranker.rerank("machine learning", &candidates, 2).unwrap();
322
323 assert!(results[0].rerank_score.unwrap() > results[1].rerank_score.unwrap());
325 }
326
327 #[test]
328 fn test_lexical_reranker_coverage() {
329 let reranker = LexicalReranker::new();
330 let candidates =
331 vec![create_result("machine learning algorithms"), create_result("machine only here")];
332
333 let results = reranker.rerank("machine learning neural networks", &candidates, 2).unwrap();
334
335 assert!(results[0].rerank_score.unwrap() >= results[1].rerank_score.unwrap());
337 }
338
339 #[test]
340 fn test_lexical_reranker_top_k() {
341 let reranker = LexicalReranker::new();
342 let candidates: Vec<_> = (0..10).map(|i| create_result(&format!("doc {i}"))).collect();
343
344 let results = reranker.rerank("doc", &candidates, 3).unwrap();
345 assert_eq!(results.len(), 3);
346 }
347
348 #[test]
349 fn test_lexical_reranker_empty_query() {
350 let reranker = LexicalReranker::new();
351 let candidates = vec![create_result("test content")];
352
353 let results = reranker.rerank("", &candidates, 10).unwrap();
354 assert_eq!(results.len(), 1);
355 assert!((results[0].rerank_score.unwrap() - 0.0).abs() < 0.001);
356 }
357
358 #[test]
359 fn test_lexical_reranker_case_insensitive() {
360 let reranker = LexicalReranker::new();
361 let candidates = vec![create_result("MACHINE LEARNING"), create_result("machine learning")];
362
363 let results = reranker.rerank("Machine Learning", &candidates, 2).unwrap();
364
365 let diff = (results[0].rerank_score.unwrap() - results[1].rerank_score.unwrap()).abs();
367 assert!(diff < 0.01);
368 }
369
370 #[test]
373 fn test_mock_cross_encoder_new() {
374 let reranker = MockCrossEncoderReranker::new("ms-marco-MiniLM");
375 assert_eq!(reranker.model_id(), "ms-marco-MiniLM");
376 }
377
378 #[test]
379 fn test_mock_cross_encoder_rerank() {
380 let reranker = MockCrossEncoderReranker::new("test-model");
381 let candidates =
382 vec![create_result("machine learning algorithms"), create_result("cooking recipes")];
383
384 let results = reranker.rerank("machine learning", &candidates, 2).unwrap();
385
386 assert!(results[0].rerank_score.unwrap() > results[1].rerank_score.unwrap());
388 }
389
390 #[test]
391 fn test_mock_cross_encoder_empty() {
392 let reranker = MockCrossEncoderReranker::new("test-model");
393 let results = reranker.rerank("test", &[], 10).unwrap();
394 assert!(results.is_empty());
395 }
396
397 #[test]
400 fn test_composite_reranker_empty() {
401 let reranker = CompositeReranker::new();
402 let candidates = vec![create_result("test")];
403
404 let results = reranker.rerank("test", &candidates, 10).unwrap();
405 assert_eq!(results.len(), 1);
406 }
407
408 #[test]
409 fn test_composite_reranker_single() {
410 let lexical = Box::new(LexicalReranker::new());
411 let reranker = CompositeReranker::new().with_reranker(lexical, 1.0);
412
413 let candidates =
414 vec![create_result("exact match query here"), create_result("no match at all")];
415
416 let results = reranker.rerank("query", &candidates, 2).unwrap();
417 assert_eq!(results.len(), 2);
418 assert!(results[0].rerank_score.is_some());
419 }
420
421 #[test]
422 fn test_composite_reranker_multiple() {
423 let lexical = Box::new(LexicalReranker::new());
424 let cross = Box::new(MockCrossEncoderReranker::new("test"));
425
426 let reranker =
427 CompositeReranker::new().with_reranker(lexical, 0.5).with_reranker(cross, 0.5);
428
429 let candidates =
430 vec![create_result("machine learning test"), create_result("unrelated content")];
431
432 let results = reranker.rerank("machine learning", &candidates, 2).unwrap();
433 assert_eq!(results.len(), 2);
434 }
435
436 #[test]
439 fn test_noop_reranker() {
440 let reranker = NoOpReranker::new();
441 let candidates =
442 vec![create_result_with_score("first", 0.9), create_result_with_score("second", 0.8)];
443
444 let results = reranker.rerank("query", &candidates, 10).unwrap();
445
446 assert_eq!(results.len(), 2);
447 assert!(results[0].chunk.content.contains("first"));
449 }
450
451 #[test]
452 fn test_noop_reranker_top_k() {
453 let reranker = NoOpReranker::new();
454 let candidates: Vec<_> = (0..10).map(|i| create_result(&format!("doc {i}"))).collect();
455
456 let results = reranker.rerank("query", &candidates, 3).unwrap();
457 assert_eq!(results.len(), 3);
458 }
459
460 use proptest::prelude::*;
463
464 proptest! {
465 #[test]
466 fn prop_lexical_rerank_scores_bounded(
467 query in "[a-zA-Z ]{1,20}",
468 content in "[a-zA-Z ]{1,100}"
469 ) {
470 let reranker = LexicalReranker::new();
471 let candidates = vec![create_result(&content)];
472
473 let results = reranker.rerank(&query, &candidates, 1).unwrap();
474 let score = results[0].rerank_score.unwrap();
475
476 prop_assert!(score >= 0.0);
477 prop_assert!(score <= 1.0);
478 }
479
480 #[test]
481 fn prop_rerank_respects_top_k(k in 1usize..10, n in 1usize..20) {
482 let reranker = LexicalReranker::new();
483 let candidates: Vec<_> = (0..n)
484 .map(|i| create_result(&format!("document {i}")))
485 .collect();
486
487 let results = reranker.rerank("document", &candidates, k).unwrap();
488 prop_assert!(results.len() <= k);
489 prop_assert!(results.len() <= n);
490 }
491
492 #[test]
493 fn prop_noop_preserves_order(n in 1usize..10) {
494 let reranker = NoOpReranker::new();
495 let candidates: Vec<_> = (0..n)
496 .map(|i| create_result(&format!("doc {i}")))
497 .collect();
498
499 let results = reranker.rerank("query", &candidates, n).unwrap();
500
501 for (i, result) in results.iter().enumerate() {
502 let expected = format!("doc {i}");
503 prop_assert!(result.chunk.content.contains(&expected));
504 }
505 }
506 }
507}