second_brain_api/eval/
metrics.rs1use std::collections::HashSet;
2
3use uuid::Uuid;
4
5pub fn recall_at_k(ranked: &[Uuid], relevant: &HashSet<Uuid>, k: usize) -> f32 {
6 let hit = ranked.iter().take(k).any(|id| relevant.contains(id));
7 if hit { 1.0 } else { 0.0 }
8}
9
10pub fn mrr(ranked: &[Uuid], relevant: &HashSet<Uuid>) -> f32 {
11 ranked
12 .iter()
13 .position(|id| relevant.contains(id))
14 .map(|idx| 1.0 / (idx as f32 + 1.0))
15 .unwrap_or(0.0)
16}
17
18pub fn precision_at_k(ranked: &[Uuid], relevant: &HashSet<Uuid>, k: usize) -> f32 {
19 if k == 0 {
20 return 0.0;
21 }
22 let hits = ranked
23 .iter()
24 .take(k)
25 .filter(|id| relevant.contains(id))
26 .count();
27 hits as f32 / k as f32
28}
29
30pub fn gated_out_rate(flags: &[bool]) -> f32 {
31 if flags.is_empty() {
32 return 0.0;
33 }
34 let gated = flags.iter().filter(|f| **f).count();
35 gated as f32 / flags.len() as f32
36}
37
38#[cfg(test)]
39mod tests {
40 use super::*;
41
42 fn ids(n: usize) -> Vec<Uuid> {
43 (0..n).map(|_| Uuid::new_v4()).collect()
44 }
45
46 #[test]
47 fn gold_at_rank_three_scores_as_expected() {
48 let ranked = ids(5);
49 let relevant: HashSet<Uuid> = [ranked[2]].into_iter().collect();
50
51 assert_eq!(recall_at_k(&ranked, &relevant, 1), 0.0);
52 assert_eq!(recall_at_k(&ranked, &relevant, 3), 1.0);
53 assert_eq!(recall_at_k(&ranked, &relevant, 5), 1.0);
54 assert!((mrr(&ranked, &relevant) - 1.0 / 3.0).abs() < 1e-6);
55 assert!((precision_at_k(&ranked, &relevant, 5) - 1.0 / 5.0).abs() < 1e-6);
56 }
57
58 #[test]
59 fn no_relevant_in_ranking_scores_zero() {
60 let ranked = ids(5);
61 let relevant: HashSet<Uuid> = [Uuid::new_v4()].into_iter().collect();
62
63 assert_eq!(recall_at_k(&ranked, &relevant, 5), 0.0);
64 assert_eq!(mrr(&ranked, &relevant), 0.0);
65 assert_eq!(precision_at_k(&ranked, &relevant, 5), 0.0);
66 }
67
68 #[test]
69 fn empty_ranking_scores_zero() {
70 let ranked: Vec<Uuid> = Vec::new();
71 let relevant: HashSet<Uuid> = [Uuid::new_v4()].into_iter().collect();
72
73 assert_eq!(recall_at_k(&ranked, &relevant, 5), 0.0);
74 assert_eq!(mrr(&ranked, &relevant), 0.0);
75 assert_eq!(precision_at_k(&ranked, &relevant, 5), 0.0);
76 }
77
78 #[test]
79 fn precision_counts_intersection_in_window() {
80 let ranked = ids(5);
81 let relevant: HashSet<Uuid> = [ranked[0], ranked[2]].into_iter().collect();
82
83 assert!((precision_at_k(&ranked, &relevant, 5) - 2.0 / 5.0).abs() < 1e-6);
84 assert!((precision_at_k(&ranked, &relevant, 1) - 1.0).abs() < 1e-6);
85 }
86
87 #[test]
88 fn mrr_uses_first_relevant_rank() {
89 let ranked = ids(4);
90 let relevant: HashSet<Uuid> = [ranked[1], ranked[3]].into_iter().collect();
91
92 assert!((mrr(&ranked, &relevant) - 1.0 / 2.0).abs() < 1e-6);
93 }
94
95 #[test]
96 fn gated_out_rate_is_fraction_true() {
97 assert!((gated_out_rate(&[true, false, false, false]) - 0.25).abs() < 1e-6);
98 assert_eq!(gated_out_rate(&[false, false]), 0.0);
99 assert_eq!(gated_out_rate(&[true, true]), 1.0);
100 assert_eq!(gated_out_rate(&[]), 0.0);
101 }
102}