touchstone_rs/metrics/
classification.rs1use super::{
2 Metric,
3 thresholding::{Threshold, apply_threshold},
4};
5
6pub struct Precision {
8 threshold: Box<dyn Threshold>,
10}
11
12pub struct Recall {
14 threshold: Box<dyn Threshold>,
16}
17
18pub struct F1Score {
20 threshold: Box<dyn Threshold>,
22}
23
24impl Precision {
25 pub fn new(t: impl Threshold + 'static) -> Self {
27 Self {
28 threshold: Box::new(t),
29 }
30 }
31}
32
33impl Recall {
34 pub fn new(t: impl Threshold + 'static) -> Self {
36 Self {
37 threshold: Box::new(t),
38 }
39 }
40}
41
42impl F1Score {
43 pub fn new(t: impl Threshold + 'static) -> Self {
45 Self {
46 threshold: Box::new(t),
47 }
48 }
49}
50
51fn confusion(labels: &[u8], preds: &[u8]) -> (usize, usize, usize) {
53 let mut tp = 0;
54 let mut fp = 0;
55 let mut fn_ = 0;
56 for (&l, &p) in labels.iter().zip(preds.iter()) {
57 match (l, p) {
58 (1, 1) => tp += 1,
59 (0, 1) => fp += 1,
60 (1, 0) => fn_ += 1,
61 _ => {}
62 }
63 }
64 (tp, fp, fn_)
65}
66
67impl Metric for Precision {
68 fn name(&self) -> &str {
69 "Precision"
70 }
71 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
72 let thresh = self.threshold.threshold(scores);
73 let preds = apply_threshold(scores, thresh);
74 let (tp, fp, _) = confusion(labels, &preds);
75 if tp + fp == 0 {
76 return 0.0;
77 }
78 tp as f64 / (tp + fp) as f64
79 }
80}
81
82impl Metric for Recall {
83 fn name(&self) -> &str {
84 "Recall"
85 }
86 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
87 let thresh = self.threshold.threshold(scores);
88 let preds = apply_threshold(scores, thresh);
89 let (tp, _, fn_) = confusion(labels, &preds);
90 if tp + fn_ == 0 {
91 return f64::NAN;
92 }
93 tp as f64 / (tp + fn_) as f64
94 }
95}
96
97impl Metric for F1Score {
98 fn name(&self) -> &str {
99 "F1"
100 }
101 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
102 let thresh = self.threshold.threshold(scores);
103 let preds = apply_threshold(scores, thresh);
104 let (tp, fp, fn_) = confusion(labels, &preds);
105 let denom = 2 * tp + fp + fn_;
106 if denom == 0 {
107 return 0.0;
108 }
109 2.0 * tp as f64 / denom as f64
110 }
111}