Skip to main content

touchstone_rs/metrics/
classification.rs

1use super::{
2    Metric,
3    thresholding::{Threshold, apply_threshold},
4};
5
6/// Precision metric: ratio of true positives to predicted positives.
7pub struct Precision {
8    /// Threshold strategy used to binarize anomaly scores.
9    threshold: Box<dyn Threshold>,
10}
11
12/// Recall metric: ratio of true positives to actual positives.
13pub struct Recall {
14    /// Threshold strategy used to binarize anomaly scores.
15    threshold: Box<dyn Threshold>,
16}
17
18/// F1 Score metric: harmonic mean of precision and recall.
19pub struct F1Score {
20    /// Threshold strategy used to binarize anomaly scores.
21    threshold: Box<dyn Threshold>,
22}
23
24impl Precision {
25    /// Creates a new Precision metric with the given threshold strategy.
26    pub fn new(t: impl Threshold + 'static) -> Self {
27        Self {
28            threshold: Box::new(t),
29        }
30    }
31}
32
33impl Recall {
34    /// Creates a new Recall metric with the given threshold strategy.
35    pub fn new(t: impl Threshold + 'static) -> Self {
36        Self {
37            threshold: Box::new(t),
38        }
39    }
40}
41
42impl F1Score {
43    /// Creates a new F1 Score metric with the given threshold strategy.
44    pub fn new(t: impl Threshold + 'static) -> Self {
45        Self {
46            threshold: Box::new(t),
47        }
48    }
49}
50
51/// Computes confusion matrix counts: (true_positives, false_positives, false_negatives).
52fn confusion(labels: &[u8], preds: &[u8]) -> (usize, usize, usize) {
53    let mut tp = 0;
54    let mut fp = 0;
55    let mut fn_ = 0;
56    for (&l, &p) in labels.iter().zip(preds.iter()) {
57        match (l, p) {
58            (1, 1) => tp += 1,
59            (0, 1) => fp += 1,
60            (1, 0) => fn_ += 1,
61            _ => {}
62        }
63    }
64    (tp, fp, fn_)
65}
66
67impl Metric for Precision {
68    fn name(&self) -> &str {
69        "Precision"
70    }
71    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
72        let thresh = self.threshold.threshold(scores);
73        let preds = apply_threshold(scores, thresh);
74        let (tp, fp, _) = confusion(labels, &preds);
75        if tp + fp == 0 {
76            return 0.0;
77        }
78        tp as f64 / (tp + fp) as f64
79    }
80}
81
82impl Metric for Recall {
83    fn name(&self) -> &str {
84        "Recall"
85    }
86    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
87        let thresh = self.threshold.threshold(scores);
88        let preds = apply_threshold(scores, thresh);
89        let (tp, _, fn_) = confusion(labels, &preds);
90        if tp + fn_ == 0 {
91            return f64::NAN;
92        }
93        tp as f64 / (tp + fn_) as f64
94    }
95}
96
97impl Metric for F1Score {
98    fn name(&self) -> &str {
99        "F1"
100    }
101    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
102        let thresh = self.threshold.threshold(scores);
103        let preds = apply_threshold(scores, thresh);
104        let (tp, fp, fn_) = confusion(labels, &preds);
105        let denom = 2 * tp + fp + fn_;
106        if denom == 0 {
107            return 0.0;
108        }
109        2.0 * tp as f64 / denom as f64
110    }
111}