1use crate::ChunkId;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct RetrievalMetrics {
10 pub recall: std::collections::HashMap<usize, f32>,
12 pub precision: std::collections::HashMap<usize, f32>,
14 pub mrr: f32,
16 pub ndcg: std::collections::HashMap<usize, f32>,
18 pub map: f32,
20}
21
22impl RetrievalMetrics {
23 pub fn compute(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k_values: &[usize]) -> Self {
25 let mut metrics = Self::default();
26
27 for &k in k_values {
28 metrics.recall.insert(k, Self::recall_at_k(retrieved, relevant, k));
29 metrics.precision.insert(k, Self::precision_at_k(retrieved, relevant, k));
30 metrics.ndcg.insert(k, Self::ndcg_at_k(retrieved, relevant, k));
31 }
32
33 metrics.mrr = Self::mean_reciprocal_rank(retrieved, relevant);
34 metrics.map = Self::average_precision(retrieved, relevant);
35
36 metrics
37 }
38
39 #[must_use]
43 pub fn recall_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
44 if relevant.is_empty() {
45 return 0.0;
46 }
47
48 let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
49 let relevant_retrieved = retrieved_k.intersection(relevant).count();
50
51 relevant_retrieved as f32 / relevant.len() as f32
52 }
53
54 #[must_use]
58 pub fn precision_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
59 if k == 0 {
60 return 0.0;
61 }
62
63 let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
64 let relevant_retrieved = retrieved_k.intersection(relevant).count();
65
66 relevant_retrieved as f32 / k as f32
67 }
68
69 #[must_use]
73 pub fn mean_reciprocal_rank(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
74 for (rank, id) in retrieved.iter().enumerate() {
75 if relevant.contains(id) {
76 return 1.0 / (rank + 1) as f32;
77 }
78 }
79 0.0
80 }
81
82 #[must_use]
86 pub fn ndcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
87 let dcg = Self::dcg_at_k(retrieved, relevant, k);
88 let idcg = Self::ideal_dcg_at_k(relevant.len(), k);
89
90 if idcg == 0.0 {
91 0.0
92 } else {
93 dcg / idcg
94 }
95 }
96
97 fn dcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
102 let mut seen = HashSet::new();
103 retrieved
104 .iter()
105 .take(k)
106 .enumerate()
107 .filter(|(_, id)| relevant.contains(id) && seen.insert(**id))
108 .map(|(rank, _)| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
109 .sum()
110 }
111
112 fn ideal_dcg_at_k(num_relevant: usize, k: usize) -> f32 {
114 (0..num_relevant.min(k))
115 .map(|rank| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
116 .sum()
117 }
118
119 #[must_use]
123 pub fn average_precision(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
124 if relevant.is_empty() {
125 return 0.0;
126 }
127
128 let mut sum_precision = 0.0;
129 let mut relevant_count = 0;
130
131 for (rank, id) in retrieved.iter().enumerate() {
132 if relevant.contains(id) {
133 relevant_count += 1;
134 sum_precision += relevant_count as f32 / (rank + 1) as f32;
135 }
136 }
137
138 sum_precision / relevant.len().max(1) as f32
139 }
140
141 #[must_use]
143 pub fn f1_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
144 let precision = Self::precision_at_k(retrieved, relevant, k);
145 let recall = Self::recall_at_k(retrieved, relevant, k);
146
147 if precision + recall == 0.0 {
148 0.0
149 } else {
150 2.0 * precision * recall / (precision + recall)
151 }
152 }
153
154 #[must_use]
156 pub fn hit_rate_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
157 let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
158 if retrieved_k.intersection(relevant).next().is_some() {
159 1.0
160 } else {
161 0.0
162 }
163 }
164}
165
166#[derive(Debug, Clone, Default, Serialize, Deserialize)]
168pub struct AggregatedMetrics {
169 pub mean_recall: std::collections::HashMap<usize, f32>,
171 pub mean_precision: std::collections::HashMap<usize, f32>,
173 pub mean_mrr: f32,
175 pub mean_ndcg: std::collections::HashMap<usize, f32>,
177 pub map: f32,
179 pub query_count: usize,
181}
182
183impl AggregatedMetrics {
184 pub fn aggregate(metrics: &[RetrievalMetrics]) -> Self {
186 if metrics.is_empty() {
187 return Self::default();
188 }
189
190 let n = metrics.len() as f32;
191 let mut agg = Self { query_count: metrics.len(), ..Default::default() };
192
193 agg.mean_mrr = metrics.iter().map(|m| m.mrr).sum::<f32>() / n;
195 agg.map = metrics.iter().map(|m| m.map).sum::<f32>() / n;
196
197 if let Some(first) = metrics.first() {
199 for &k in first.recall.keys() {
200 let mean_recall = metrics.iter().filter_map(|m| m.recall.get(&k)).sum::<f32>() / n;
201 agg.mean_recall.insert(k, mean_recall);
202
203 let mean_precision =
204 metrics.iter().filter_map(|m| m.precision.get(&k)).sum::<f32>() / n;
205 agg.mean_precision.insert(k, mean_precision);
206
207 let mean_ndcg = metrics.iter().filter_map(|m| m.ndcg.get(&k)).sum::<f32>() / n;
208 agg.mean_ndcg.insert(k, mean_ndcg);
209 }
210 }
211
212 agg
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 fn chunk_id(n: u128) -> ChunkId {
221 ChunkId(uuid::Uuid::from_u128(n))
222 }
223
224 #[test]
227 fn test_recall_at_k_perfect() {
228 let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
229 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
230
231 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
232 assert!((recall - 1.0).abs() < 0.001);
233 }
234
235 #[test]
236 fn test_recall_at_k_partial() {
237 let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(5)];
238 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
239
240 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
241 assert!((recall - 1.0 / 3.0).abs() < 0.001);
242 }
243
244 #[test]
245 fn test_recall_at_k_none() {
246 let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
247 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
248
249 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
250 assert!((recall - 0.0).abs() < 0.001);
251 }
252
253 #[test]
254 fn test_recall_at_k_empty_relevant() {
255 let retrieved = vec![chunk_id(1), chunk_id(2)];
256 let relevant: HashSet<ChunkId> = HashSet::new();
257
258 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
259 assert!((recall - 0.0).abs() < 0.001);
260 }
261
262 #[test]
263 fn test_recall_at_k_smaller_k() {
264 let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(2)];
265 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
266
267 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 1);
269 assert!((recall - 0.0).abs() < 0.001);
270
271 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
273 assert!((recall - 0.5).abs() < 0.001);
274 }
275
276 #[test]
279 fn test_precision_at_k_perfect() {
280 let retrieved = vec![chunk_id(1), chunk_id(2)];
281 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
282
283 let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
284 assert!((precision - 1.0).abs() < 0.001);
285 }
286
287 #[test]
288 fn test_precision_at_k_half() {
289 let retrieved = vec![chunk_id(1), chunk_id(4)];
290 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
291
292 let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
293 assert!((precision - 0.5).abs() < 0.001);
294 }
295
296 #[test]
297 fn test_precision_at_k_zero() {
298 let precision = RetrievalMetrics::precision_at_k(&[], &HashSet::new(), 0);
299 assert!((precision - 0.0).abs() < 0.001);
300 }
301
302 #[test]
305 fn test_mrr_first_position() {
306 let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
307 let relevant: HashSet<_> = [chunk_id(1)].into();
308
309 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
310 assert!((mrr - 1.0).abs() < 0.001);
311 }
312
313 #[test]
314 fn test_mrr_second_position() {
315 let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(3)];
316 let relevant: HashSet<_> = [chunk_id(1)].into();
317
318 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
319 assert!((mrr - 0.5).abs() < 0.001);
320 }
321
322 #[test]
323 fn test_mrr_third_position() {
324 let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(1)];
325 let relevant: HashSet<_> = [chunk_id(1)].into();
326
327 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
328 assert!((mrr - 1.0 / 3.0).abs() < 0.001);
329 }
330
331 #[test]
332 fn test_mrr_not_found() {
333 let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
334 let relevant: HashSet<_> = [chunk_id(1)].into();
335
336 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
337 assert!((mrr - 0.0).abs() < 0.001);
338 }
339
340 #[test]
343 fn test_ndcg_perfect_order() {
344 let retrieved = vec![chunk_id(1), chunk_id(2)];
345 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
346
347 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
348 assert!((ndcg - 1.0).abs() < 0.001);
349 }
350
351 #[test]
352 fn test_ndcg_no_relevant() {
353 let retrieved = vec![chunk_id(3), chunk_id(4)];
354 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
355
356 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
357 assert!((ndcg - 0.0).abs() < 0.001);
358 }
359
360 #[test]
361 fn test_ndcg_empty_relevant() {
362 let retrieved = vec![chunk_id(1), chunk_id(2)];
363 let relevant: HashSet<ChunkId> = HashSet::new();
364
365 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
366 assert!((ndcg - 0.0).abs() < 0.001);
367 }
368
369 #[test]
372 fn test_ap_perfect() {
373 let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
374 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
375
376 let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
377 assert!((ap - 1.0).abs() < 0.001);
379 }
380
381 #[test]
382 fn test_ap_interleaved() {
383 let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2)];
384 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
385
386 let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
387 assert!((ap - 5.0 / 6.0).abs() < 0.001);
389 }
390
391 #[test]
392 fn test_ap_empty_relevant() {
393 let retrieved = vec![chunk_id(1), chunk_id(2)];
394 let relevant: HashSet<ChunkId> = HashSet::new();
395
396 let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
397 assert!((ap - 0.0).abs() < 0.001);
398 }
399
400 #[test]
403 fn test_f1_perfect() {
404 let retrieved = vec![chunk_id(1), chunk_id(2)];
405 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
406
407 let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
408 assert!((f1 - 1.0).abs() < 0.001);
409 }
410
411 #[test]
412 fn test_f1_zero() {
413 let retrieved = vec![chunk_id(3), chunk_id(4)];
414 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
415
416 let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
417 assert!((f1 - 0.0).abs() < 0.001);
418 }
419
420 #[test]
423 fn test_hit_rate_hit() {
424 let retrieved = vec![chunk_id(3), chunk_id(1), chunk_id(4)];
425 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
426
427 let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 3);
428 assert!((hr - 1.0).abs() < 0.001);
429 }
430
431 #[test]
432 fn test_hit_rate_miss() {
433 let retrieved = vec![chunk_id(3), chunk_id(4)];
434 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
435
436 let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 2);
437 assert!((hr - 0.0).abs() < 0.001);
438 }
439
440 #[test]
443 fn test_compute_all_metrics() {
444 let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2), chunk_id(5)];
445 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
446 let k_values = vec![1, 2, 5, 10];
447
448 let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &k_values);
449
450 assert!(!metrics.recall.is_empty());
451 assert!(!metrics.precision.is_empty());
452 assert!(!metrics.ndcg.is_empty());
453 assert!(metrics.mrr > 0.0);
454 }
455
456 #[test]
459 fn test_aggregate_empty() {
460 let agg = AggregatedMetrics::aggregate(&[]);
461 assert_eq!(agg.query_count, 0);
462 }
463
464 #[test]
465 fn test_aggregate_single() {
466 let retrieved = vec![chunk_id(1), chunk_id(2)];
467 let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
468 let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &[1, 2]);
469
470 let agg = AggregatedMetrics::aggregate(&[metrics]);
471 assert_eq!(agg.query_count, 1);
472 assert!((agg.mean_mrr - 1.0).abs() < 0.001);
473 }
474
475 #[test]
476 fn test_aggregate_multiple() {
477 let metrics1 = RetrievalMetrics {
478 mrr: 1.0,
479 map: 1.0,
480 recall: [(1, 1.0), (2, 1.0)].into(),
481 precision: [(1, 1.0), (2, 1.0)].into(),
482 ndcg: [(1, 1.0), (2, 1.0)].into(),
483 };
484 let metrics2 = RetrievalMetrics {
485 mrr: 0.5,
486 map: 0.5,
487 recall: [(1, 0.5), (2, 0.5)].into(),
488 precision: [(1, 0.5), (2, 0.5)].into(),
489 ndcg: [(1, 0.5), (2, 0.5)].into(),
490 };
491
492 let agg = AggregatedMetrics::aggregate(&[metrics1, metrics2]);
493
494 assert_eq!(agg.query_count, 2);
495 assert!((agg.mean_mrr - 0.75).abs() < 0.001);
496 assert!((agg.map - 0.75).abs() < 0.001);
497 }
498
499 use proptest::prelude::*;
502
503 proptest! {
504 #[test]
505 fn prop_recall_bounded(
506 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
507 relevant_ids in prop::collection::vec(0u128..100, 1..10),
508 k in 1usize..20
509 ) {
510 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
511 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
512
513 let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, k);
514 prop_assert!(recall >= 0.0);
515 prop_assert!(recall <= 1.0);
516 }
517
518 #[test]
519 fn prop_precision_bounded(
520 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
521 relevant_ids in prop::collection::vec(0u128..100, 1..10),
522 k in 1usize..20
523 ) {
524 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
525 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
526
527 let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, k);
528 prop_assert!(precision >= 0.0);
529 prop_assert!(precision <= 1.0);
530 }
531
532 #[test]
533 fn prop_mrr_bounded(
534 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
535 relevant_ids in prop::collection::vec(0u128..100, 1..10)
536 ) {
537 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
538 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
539
540 let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
541 prop_assert!(mrr >= 0.0);
542 prop_assert!(mrr <= 1.0);
543 }
544
545 #[test]
546 fn prop_ndcg_bounded(
547 retrieved_ids in prop::collection::vec(0u128..100, 1..20),
548 relevant_ids in prop::collection::vec(0u128..100, 1..10),
549 k in 1usize..20
550 ) {
551 let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
552 let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
553
554 let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, k);
555 prop_assert!(ndcg >= 0.0);
556 prop_assert!(ndcg <= 1.0);
557 }
558 }
559}