Skip to main content

second_brain_api/eval/
metrics.rs

1use std::collections::HashSet;
2
3use uuid::Uuid;
4
5pub fn recall_at_k(ranked: &[Uuid], relevant: &HashSet<Uuid>, k: usize) -> f32 {
6    let hit = ranked.iter().take(k).any(|id| relevant.contains(id));
7    if hit { 1.0 } else { 0.0 }
8}
9
10pub fn mrr(ranked: &[Uuid], relevant: &HashSet<Uuid>) -> f32 {
11    ranked
12        .iter()
13        .position(|id| relevant.contains(id))
14        .map(|idx| 1.0 / (idx as f32 + 1.0))
15        .unwrap_or(0.0)
16}
17
18pub fn precision_at_k(ranked: &[Uuid], relevant: &HashSet<Uuid>, k: usize) -> f32 {
19    if k == 0 {
20        return 0.0;
21    }
22    let hits = ranked
23        .iter()
24        .take(k)
25        .filter(|id| relevant.contains(id))
26        .count();
27    hits as f32 / k as f32
28}
29
30pub fn gated_out_rate(flags: &[bool]) -> f32 {
31    if flags.is_empty() {
32        return 0.0;
33    }
34    let gated = flags.iter().filter(|f| **f).count();
35    gated as f32 / flags.len() as f32
36}
37
38#[cfg(test)]
39mod tests {
40    use super::*;
41
42    fn ids(n: usize) -> Vec<Uuid> {
43        (0..n).map(|_| Uuid::new_v4()).collect()
44    }
45
46    #[test]
47    fn gold_at_rank_three_scores_as_expected() {
48        let ranked = ids(5);
49        let relevant: HashSet<Uuid> = [ranked[2]].into_iter().collect();
50
51        assert_eq!(recall_at_k(&ranked, &relevant, 1), 0.0);
52        assert_eq!(recall_at_k(&ranked, &relevant, 3), 1.0);
53        assert_eq!(recall_at_k(&ranked, &relevant, 5), 1.0);
54        assert!((mrr(&ranked, &relevant) - 1.0 / 3.0).abs() < 1e-6);
55        assert!((precision_at_k(&ranked, &relevant, 5) - 1.0 / 5.0).abs() < 1e-6);
56    }
57
58    #[test]
59    fn no_relevant_in_ranking_scores_zero() {
60        let ranked = ids(5);
61        let relevant: HashSet<Uuid> = [Uuid::new_v4()].into_iter().collect();
62
63        assert_eq!(recall_at_k(&ranked, &relevant, 5), 0.0);
64        assert_eq!(mrr(&ranked, &relevant), 0.0);
65        assert_eq!(precision_at_k(&ranked, &relevant, 5), 0.0);
66    }
67
68    #[test]
69    fn empty_ranking_scores_zero() {
70        let ranked: Vec<Uuid> = Vec::new();
71        let relevant: HashSet<Uuid> = [Uuid::new_v4()].into_iter().collect();
72
73        assert_eq!(recall_at_k(&ranked, &relevant, 5), 0.0);
74        assert_eq!(mrr(&ranked, &relevant), 0.0);
75        assert_eq!(precision_at_k(&ranked, &relevant, 5), 0.0);
76    }
77
78    #[test]
79    fn precision_counts_intersection_in_window() {
80        let ranked = ids(5);
81        let relevant: HashSet<Uuid> = [ranked[0], ranked[2]].into_iter().collect();
82
83        assert!((precision_at_k(&ranked, &relevant, 5) - 2.0 / 5.0).abs() < 1e-6);
84        assert!((precision_at_k(&ranked, &relevant, 1) - 1.0).abs() < 1e-6);
85    }
86
87    #[test]
88    fn mrr_uses_first_relevant_rank() {
89        let ranked = ids(4);
90        let relevant: HashSet<Uuid> = [ranked[1], ranked[3]].into_iter().collect();
91
92        assert!((mrr(&ranked, &relevant) - 1.0 / 2.0).abs() < 1e-6);
93    }
94
95    #[test]
96    fn gated_out_rate_is_fraction_true() {
97        assert!((gated_out_rate(&[true, false, false, false]) - 0.25).abs() < 1e-6);
98        assert_eq!(gated_out_rate(&[false, false]), 0.0);
99        assert_eq!(gated_out_rate(&[true, true]), 1.0);
100        assert_eq!(gated_out_rate(&[]), 0.0);
101    }
102}