Skip to main content

trueno_rag/
metrics.rs

1//! Retrieval evaluation metrics
2
3use crate::ChunkId;
4use serde::{Deserialize, Serialize};
5use std::collections::HashSet;
6
7/// Retrieval metrics for evaluation
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
9pub struct RetrievalMetrics {
10    /// Recall@k for various k values
11    pub recall: std::collections::HashMap<usize, f32>,
12    /// Precision@k for various k values
13    pub precision: std::collections::HashMap<usize, f32>,
14    /// Mean Reciprocal Rank
15    pub mrr: f32,
16    /// Normalized Discounted Cumulative Gain@k
17    pub ndcg: std::collections::HashMap<usize, f32>,
18    /// Mean Average Precision
19    pub map: f32,
20}
21
22impl RetrievalMetrics {
23    /// Compute all metrics for a single query
24    pub fn compute(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k_values: &[usize]) -> Self {
25        let mut metrics = Self::default();
26
27        for &k in k_values {
28            metrics.recall.insert(k, Self::recall_at_k(retrieved, relevant, k));
29            metrics.precision.insert(k, Self::precision_at_k(retrieved, relevant, k));
30            metrics.ndcg.insert(k, Self::ndcg_at_k(retrieved, relevant, k));
31        }
32
33        metrics.mrr = Self::mean_reciprocal_rank(retrieved, relevant);
34        metrics.map = Self::average_precision(retrieved, relevant);
35
36        metrics
37    }
38
39    /// Compute Recall@k
40    ///
41    /// Recall@k = |relevant ∩ retrieved@k| / |relevant|
42    #[must_use]
43    pub fn recall_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
44        if relevant.is_empty() {
45            return 0.0;
46        }
47
48        let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
49        let relevant_retrieved = retrieved_k.intersection(relevant).count();
50
51        relevant_retrieved as f32 / relevant.len() as f32
52    }
53
54    /// Compute Precision@k
55    ///
56    /// Precision@k = |relevant ∩ retrieved@k| / k
57    #[must_use]
58    pub fn precision_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
59        if k == 0 {
60            return 0.0;
61        }
62
63        let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
64        let relevant_retrieved = retrieved_k.intersection(relevant).count();
65
66        relevant_retrieved as f32 / k as f32
67    }
68
69    /// Compute Mean Reciprocal Rank (MRR)
70    ///
71    /// MRR = 1 / rank of first relevant result
72    #[must_use]
73    pub fn mean_reciprocal_rank(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
74        for (rank, id) in retrieved.iter().enumerate() {
75            if relevant.contains(id) {
76                return 1.0 / (rank + 1) as f32;
77            }
78        }
79        0.0
80    }
81
82    /// Compute Normalized Discounted Cumulative Gain@k
83    ///
84    /// NDCG@k = DCG@k / IDCG@k
85    #[must_use]
86    pub fn ndcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
87        let dcg = Self::dcg_at_k(retrieved, relevant, k);
88        let idcg = Self::ideal_dcg_at_k(relevant.len(), k);
89
90        if idcg == 0.0 {
91            0.0
92        } else {
93            dcg / idcg
94        }
95    }
96
97    /// Compute Discounted Cumulative Gain@k
98    ///
99    /// Note: Each relevant item is counted at most once (at its first occurrence)
100    /// to ensure NDCG remains bounded by 1.0.
101    fn dcg_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
102        let mut seen = HashSet::new();
103        retrieved
104            .iter()
105            .take(k)
106            .enumerate()
107            .filter(|(_, id)| relevant.contains(id) && seen.insert(**id))
108            .map(|(rank, _)| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
109            .sum()
110    }
111
112    /// Compute Ideal DCG@k (best possible DCG)
113    fn ideal_dcg_at_k(num_relevant: usize, k: usize) -> f32 {
114        (0..num_relevant.min(k))
115            .map(|rank| 1.0 / (rank as f32 + 2.0).max(f32::EPSILON).log2())
116            .sum()
117    }
118
119    /// Compute Average Precision (AP)
120    ///
121    /// AP = (1/|relevant|) * Σ (Precision@k * rel(k))
122    #[must_use]
123    pub fn average_precision(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>) -> f32 {
124        if relevant.is_empty() {
125            return 0.0;
126        }
127
128        let mut sum_precision = 0.0;
129        let mut relevant_count = 0;
130
131        for (rank, id) in retrieved.iter().enumerate() {
132            if relevant.contains(id) {
133                relevant_count += 1;
134                sum_precision += relevant_count as f32 / (rank + 1) as f32;
135            }
136        }
137
138        sum_precision / relevant.len().max(1) as f32
139    }
140
141    /// Compute F1 score at k
142    #[must_use]
143    pub fn f1_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
144        let precision = Self::precision_at_k(retrieved, relevant, k);
145        let recall = Self::recall_at_k(retrieved, relevant, k);
146
147        if precision + recall == 0.0 {
148            0.0
149        } else {
150            2.0 * precision * recall / (precision + recall)
151        }
152    }
153
154    /// Compute Hit Rate (1 if any relevant in top-k, else 0)
155    #[must_use]
156    pub fn hit_rate_at_k(retrieved: &[ChunkId], relevant: &HashSet<ChunkId>, k: usize) -> f32 {
157        let retrieved_k: HashSet<ChunkId> = retrieved.iter().take(k).copied().collect();
158        if retrieved_k.intersection(relevant).next().is_some() {
159            1.0
160        } else {
161            0.0
162        }
163    }
164}
165
166/// Aggregated metrics across multiple queries
167#[derive(Debug, Clone, Default, Serialize, Deserialize)]
168pub struct AggregatedMetrics {
169    /// Mean Recall@k
170    pub mean_recall: std::collections::HashMap<usize, f32>,
171    /// Mean Precision@k
172    pub mean_precision: std::collections::HashMap<usize, f32>,
173    /// Mean MRR
174    pub mean_mrr: f32,
175    /// Mean NDCG@k
176    pub mean_ndcg: std::collections::HashMap<usize, f32>,
177    /// Mean Average Precision (MAP)
178    pub map: f32,
179    /// Number of queries
180    pub query_count: usize,
181}
182
183impl AggregatedMetrics {
184    /// Aggregate metrics from multiple queries
185    pub fn aggregate(metrics: &[RetrievalMetrics]) -> Self {
186        if metrics.is_empty() {
187            return Self::default();
188        }
189
190        let n = metrics.len() as f32;
191        let mut agg = Self { query_count: metrics.len(), ..Default::default() };
192
193        // Aggregate MRR and MAP
194        agg.mean_mrr = metrics.iter().map(|m| m.mrr).sum::<f32>() / n;
195        agg.map = metrics.iter().map(|m| m.map).sum::<f32>() / n;
196
197        // Aggregate k-based metrics
198        if let Some(first) = metrics.first() {
199            for &k in first.recall.keys() {
200                let mean_recall = metrics.iter().filter_map(|m| m.recall.get(&k)).sum::<f32>() / n;
201                agg.mean_recall.insert(k, mean_recall);
202
203                let mean_precision =
204                    metrics.iter().filter_map(|m| m.precision.get(&k)).sum::<f32>() / n;
205                agg.mean_precision.insert(k, mean_precision);
206
207                let mean_ndcg = metrics.iter().filter_map(|m| m.ndcg.get(&k)).sum::<f32>() / n;
208                agg.mean_ndcg.insert(k, mean_ndcg);
209            }
210        }
211
212        agg
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    fn chunk_id(n: u128) -> ChunkId {
221        ChunkId(uuid::Uuid::from_u128(n))
222    }
223
224    // ============ Recall Tests ============
225
226    #[test]
227    fn test_recall_at_k_perfect() {
228        let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
229        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
230
231        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
232        assert!((recall - 1.0).abs() < 0.001);
233    }
234
235    #[test]
236    fn test_recall_at_k_partial() {
237        let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(5)];
238        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
239
240        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
241        assert!((recall - 1.0 / 3.0).abs() < 0.001);
242    }
243
244    #[test]
245    fn test_recall_at_k_none() {
246        let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
247        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
248
249        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 3);
250        assert!((recall - 0.0).abs() < 0.001);
251    }
252
253    #[test]
254    fn test_recall_at_k_empty_relevant() {
255        let retrieved = vec![chunk_id(1), chunk_id(2)];
256        let relevant: HashSet<ChunkId> = HashSet::new();
257
258        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
259        assert!((recall - 0.0).abs() < 0.001);
260    }
261
262    #[test]
263    fn test_recall_at_k_smaller_k() {
264        let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(2)];
265        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
266
267        // At k=1, only chunk_id(4) which is not relevant
268        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 1);
269        assert!((recall - 0.0).abs() < 0.001);
270
271        // At k=2, chunk_id(1) is relevant
272        let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, 2);
273        assert!((recall - 0.5).abs() < 0.001);
274    }
275
276    // ============ Precision Tests ============
277
278    #[test]
279    fn test_precision_at_k_perfect() {
280        let retrieved = vec![chunk_id(1), chunk_id(2)];
281        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
282
283        let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
284        assert!((precision - 1.0).abs() < 0.001);
285    }
286
287    #[test]
288    fn test_precision_at_k_half() {
289        let retrieved = vec![chunk_id(1), chunk_id(4)];
290        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
291
292        let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, 2);
293        assert!((precision - 0.5).abs() < 0.001);
294    }
295
296    #[test]
297    fn test_precision_at_k_zero() {
298        let precision = RetrievalMetrics::precision_at_k(&[], &HashSet::new(), 0);
299        assert!((precision - 0.0).abs() < 0.001);
300    }
301
302    // ============ MRR Tests ============
303
304    #[test]
305    fn test_mrr_first_position() {
306        let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
307        let relevant: HashSet<_> = [chunk_id(1)].into();
308
309        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
310        assert!((mrr - 1.0).abs() < 0.001);
311    }
312
313    #[test]
314    fn test_mrr_second_position() {
315        let retrieved = vec![chunk_id(4), chunk_id(1), chunk_id(3)];
316        let relevant: HashSet<_> = [chunk_id(1)].into();
317
318        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
319        assert!((mrr - 0.5).abs() < 0.001);
320    }
321
322    #[test]
323    fn test_mrr_third_position() {
324        let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(1)];
325        let relevant: HashSet<_> = [chunk_id(1)].into();
326
327        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
328        assert!((mrr - 1.0 / 3.0).abs() < 0.001);
329    }
330
331    #[test]
332    fn test_mrr_not_found() {
333        let retrieved = vec![chunk_id(4), chunk_id(5), chunk_id(6)];
334        let relevant: HashSet<_> = [chunk_id(1)].into();
335
336        let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
337        assert!((mrr - 0.0).abs() < 0.001);
338    }
339
340    // ============ NDCG Tests ============
341
342    #[test]
343    fn test_ndcg_perfect_order() {
344        let retrieved = vec![chunk_id(1), chunk_id(2)];
345        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
346
347        let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
348        assert!((ndcg - 1.0).abs() < 0.001);
349    }
350
351    #[test]
352    fn test_ndcg_no_relevant() {
353        let retrieved = vec![chunk_id(3), chunk_id(4)];
354        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
355
356        let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
357        assert!((ndcg - 0.0).abs() < 0.001);
358    }
359
360    #[test]
361    fn test_ndcg_empty_relevant() {
362        let retrieved = vec![chunk_id(1), chunk_id(2)];
363        let relevant: HashSet<ChunkId> = HashSet::new();
364
365        let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, 2);
366        assert!((ndcg - 0.0).abs() < 0.001);
367    }
368
369    // ============ Average Precision Tests ============
370
371    #[test]
372    fn test_ap_perfect() {
373        let retrieved = vec![chunk_id(1), chunk_id(2), chunk_id(3)];
374        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
375
376        let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
377        // AP = (1/3) * (1/1 + 2/2 + 3/3) = (1/3) * 3 = 1.0
378        assert!((ap - 1.0).abs() < 0.001);
379    }
380
381    #[test]
382    fn test_ap_interleaved() {
383        let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2)];
384        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
385
386        let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
387        // AP = (1/2) * (1/1 + 2/3) = (1/2) * (1 + 0.667) = 0.833
388        assert!((ap - 5.0 / 6.0).abs() < 0.001);
389    }
390
391    #[test]
392    fn test_ap_empty_relevant() {
393        let retrieved = vec![chunk_id(1), chunk_id(2)];
394        let relevant: HashSet<ChunkId> = HashSet::new();
395
396        let ap = RetrievalMetrics::average_precision(&retrieved, &relevant);
397        assert!((ap - 0.0).abs() < 0.001);
398    }
399
400    // ============ F1 Tests ============
401
402    #[test]
403    fn test_f1_perfect() {
404        let retrieved = vec![chunk_id(1), chunk_id(2)];
405        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
406
407        let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
408        assert!((f1 - 1.0).abs() < 0.001);
409    }
410
411    #[test]
412    fn test_f1_zero() {
413        let retrieved = vec![chunk_id(3), chunk_id(4)];
414        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
415
416        let f1 = RetrievalMetrics::f1_at_k(&retrieved, &relevant, 2);
417        assert!((f1 - 0.0).abs() < 0.001);
418    }
419
420    // ============ Hit Rate Tests ============
421
422    #[test]
423    fn test_hit_rate_hit() {
424        let retrieved = vec![chunk_id(3), chunk_id(1), chunk_id(4)];
425        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
426
427        let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 3);
428        assert!((hr - 1.0).abs() < 0.001);
429    }
430
431    #[test]
432    fn test_hit_rate_miss() {
433        let retrieved = vec![chunk_id(3), chunk_id(4)];
434        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
435
436        let hr = RetrievalMetrics::hit_rate_at_k(&retrieved, &relevant, 2);
437        assert!((hr - 0.0).abs() < 0.001);
438    }
439
440    // ============ Compute Tests ============
441
442    #[test]
443    fn test_compute_all_metrics() {
444        let retrieved = vec![chunk_id(1), chunk_id(4), chunk_id(2), chunk_id(5)];
445        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2), chunk_id(3)].into();
446        let k_values = vec![1, 2, 5, 10];
447
448        let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &k_values);
449
450        assert!(!metrics.recall.is_empty());
451        assert!(!metrics.precision.is_empty());
452        assert!(!metrics.ndcg.is_empty());
453        assert!(metrics.mrr > 0.0);
454    }
455
456    // ============ Aggregation Tests ============
457
458    #[test]
459    fn test_aggregate_empty() {
460        let agg = AggregatedMetrics::aggregate(&[]);
461        assert_eq!(agg.query_count, 0);
462    }
463
464    #[test]
465    fn test_aggregate_single() {
466        let retrieved = vec![chunk_id(1), chunk_id(2)];
467        let relevant: HashSet<_> = [chunk_id(1), chunk_id(2)].into();
468        let metrics = RetrievalMetrics::compute(&retrieved, &relevant, &[1, 2]);
469
470        let agg = AggregatedMetrics::aggregate(&[metrics]);
471        assert_eq!(agg.query_count, 1);
472        assert!((agg.mean_mrr - 1.0).abs() < 0.001);
473    }
474
475    #[test]
476    fn test_aggregate_multiple() {
477        let metrics1 = RetrievalMetrics {
478            mrr: 1.0,
479            map: 1.0,
480            recall: [(1, 1.0), (2, 1.0)].into(),
481            precision: [(1, 1.0), (2, 1.0)].into(),
482            ndcg: [(1, 1.0), (2, 1.0)].into(),
483        };
484        let metrics2 = RetrievalMetrics {
485            mrr: 0.5,
486            map: 0.5,
487            recall: [(1, 0.5), (2, 0.5)].into(),
488            precision: [(1, 0.5), (2, 0.5)].into(),
489            ndcg: [(1, 0.5), (2, 0.5)].into(),
490        };
491
492        let agg = AggregatedMetrics::aggregate(&[metrics1, metrics2]);
493
494        assert_eq!(agg.query_count, 2);
495        assert!((agg.mean_mrr - 0.75).abs() < 0.001);
496        assert!((agg.map - 0.75).abs() < 0.001);
497    }
498
499    // ============ Property-Based Tests ============
500
501    use proptest::prelude::*;
502
503    proptest! {
504        #[test]
505        fn prop_recall_bounded(
506            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
507            relevant_ids in prop::collection::vec(0u128..100, 1..10),
508            k in 1usize..20
509        ) {
510            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
511            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
512
513            let recall = RetrievalMetrics::recall_at_k(&retrieved, &relevant, k);
514            prop_assert!(recall >= 0.0);
515            prop_assert!(recall <= 1.0);
516        }
517
518        #[test]
519        fn prop_precision_bounded(
520            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
521            relevant_ids in prop::collection::vec(0u128..100, 1..10),
522            k in 1usize..20
523        ) {
524            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
525            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
526
527            let precision = RetrievalMetrics::precision_at_k(&retrieved, &relevant, k);
528            prop_assert!(precision >= 0.0);
529            prop_assert!(precision <= 1.0);
530        }
531
532        #[test]
533        fn prop_mrr_bounded(
534            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
535            relevant_ids in prop::collection::vec(0u128..100, 1..10)
536        ) {
537            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
538            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
539
540            let mrr = RetrievalMetrics::mean_reciprocal_rank(&retrieved, &relevant);
541            prop_assert!(mrr >= 0.0);
542            prop_assert!(mrr <= 1.0);
543        }
544
545        #[test]
546        fn prop_ndcg_bounded(
547            retrieved_ids in prop::collection::vec(0u128..100, 1..20),
548            relevant_ids in prop::collection::vec(0u128..100, 1..10),
549            k in 1usize..20
550        ) {
551            let retrieved: Vec<_> = retrieved_ids.into_iter().map(chunk_id).collect();
552            let relevant: HashSet<_> = relevant_ids.into_iter().map(chunk_id).collect();
553
554            let ndcg = RetrievalMetrics::ndcg_at_k(&retrieved, &relevant, k);
555            prop_assert!(ndcg >= 0.0);
556            prop_assert!(ndcg <= 1.0);
557        }
558    }
559}