Skip to main content

scry_learn/metrics/
classification.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Classification metrics: accuracy, precision, recall, F1, confusion matrix.
3
4use std::collections::HashMap;
5use std::fmt;
6
7/// Averaging strategy for multi-class metrics.
8#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9#[non_exhaustive]
10pub enum Average {
11    /// Binary classification (positive class = 1.0).
12    Binary,
13    /// Unweighted mean across all classes.
14    Macro,
15    /// Weighted mean by class support (number of true instances).
16    Weighted,
17}
18
19/// A confusion matrix.
20#[derive(Clone, Debug)]
21#[non_exhaustive]
22pub struct ConfusionMatrix {
23    /// The matrix: `matrix[true_class][predicted_class]`.
24    pub matrix: Vec<Vec<usize>>,
25    /// Class labels.
26    pub labels: Vec<String>,
27}
28
29/// Per-class metrics.
30#[derive(Clone, Debug)]
31#[non_exhaustive]
32pub struct ClassMetrics {
33    /// Precision for this class.
34    pub precision: f64,
35    /// Recall for this class.
36    pub recall: f64,
37    /// F1-score for this class.
38    pub f1: f64,
39    /// Number of true instances (support).
40    pub support: usize,
41}
42
43/// A full classification report.
44#[derive(Clone, Debug)]
45#[non_exhaustive]
46pub struct ClassificationReport {
47    /// Overall accuracy.
48    pub accuracy: f64,
49    /// Per-class metrics.
50    pub per_class: Vec<(String, ClassMetrics)>,
51    /// Macro-averaged metrics.
52    pub macro_avg: ClassMetrics,
53    /// Weighted-averaged metrics.
54    pub weighted_avg: ClassMetrics,
55    /// Total number of samples.
56    pub total_support: usize,
57}
58
59impl fmt::Display for ClassificationReport {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        writeln!(
62            f,
63            "{:>15} {:>10} {:>10} {:>10} {:>10}",
64            "", "precision", "recall", "f1-score", "support"
65        )?;
66        writeln!(f)?;
67        for (label, m) in &self.per_class {
68            writeln!(
69                f,
70                "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
71                label, m.precision, m.recall, m.f1, m.support
72            )?;
73        }
74        writeln!(f)?;
75        writeln!(
76            f,
77            "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
78            "accuracy", "", "", self.accuracy, self.total_support
79        )?;
80        writeln!(
81            f,
82            "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
83            "macro avg",
84            self.macro_avg.precision,
85            self.macro_avg.recall,
86            self.macro_avg.f1,
87            self.total_support
88        )?;
89        writeln!(
90            f,
91            "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
92            "weighted avg",
93            self.weighted_avg.precision,
94            self.weighted_avg.recall,
95            self.weighted_avg.f1,
96            self.total_support
97        )?;
98        Ok(())
99    }
100}
101
102/// Compute accuracy: fraction of correct predictions.
103pub fn accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
104    if y_true.is_empty() {
105        return 0.0;
106    }
107    let correct = y_true
108        .iter()
109        .zip(y_pred.iter())
110        .filter(|(t, p)| (*t - *p).abs() < 1e-6)
111        .count();
112    correct as f64 / y_true.len() as f64
113}
114
115/// Compute precision from a pre-built confusion matrix.
116fn precision_from_cm(cm: &ConfusionMatrix, avg: Average) -> f64 {
117    let n = cm.matrix.len();
118    match avg {
119        Average::Binary => {
120            let tp = if n >= 2 { cm.matrix[1][1] } else { 0 };
121            let fp = (0..n)
122                .map(|i| if i == 1 { 0 } else { cm.matrix[i][1] })
123                .sum::<usize>();
124            if tp + fp == 0 {
125                0.0
126            } else {
127                tp as f64 / (tp + fp) as f64
128            }
129        }
130        Average::Macro => {
131            let mut total = 0.0;
132            for c in 0..n {
133                let tp = cm.matrix[c][c];
134                let fp: usize = (0..n)
135                    .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
136                    .sum();
137                total += if tp + fp == 0 {
138                    0.0
139                } else {
140                    tp as f64 / (tp + fp) as f64
141                };
142            }
143            total / n as f64
144        }
145        Average::Weighted => {
146            let mut total = 0.0;
147            let mut total_support = 0;
148            for c in 0..n {
149                let support: usize = cm.matrix[c].iter().sum();
150                let tp = cm.matrix[c][c];
151                let fp: usize = (0..n)
152                    .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
153                    .sum();
154                let p = if tp + fp == 0 {
155                    0.0
156                } else {
157                    tp as f64 / (tp + fp) as f64
158                };
159                total += p * support as f64;
160                total_support += support;
161            }
162            if total_support == 0 {
163                0.0
164            } else {
165                total / total_support as f64
166            }
167        }
168    }
169}
170
171/// Compute recall from a pre-built confusion matrix.
172fn recall_from_cm(cm: &ConfusionMatrix, avg: Average) -> f64 {
173    let n = cm.matrix.len();
174    match avg {
175        Average::Binary => {
176            let tp = if n >= 2 { cm.matrix[1][1] } else { 0 };
177            let fn_ = if n >= 2 {
178                (0..n)
179                    .map(|j| if j == 1 { 0 } else { cm.matrix[1][j] })
180                    .sum::<usize>()
181            } else {
182                0
183            };
184            if tp + fn_ == 0 {
185                0.0
186            } else {
187                tp as f64 / (tp + fn_) as f64
188            }
189        }
190        Average::Macro => {
191            let mut total = 0.0;
192            for c in 0..n {
193                let tp = cm.matrix[c][c];
194                let support: usize = cm.matrix[c].iter().sum();
195                total += if support == 0 {
196                    0.0
197                } else {
198                    tp as f64 / support as f64
199                };
200            }
201            total / n as f64
202        }
203        Average::Weighted => {
204            let mut total = 0.0;
205            let mut total_support = 0;
206            for c in 0..n {
207                let support: usize = cm.matrix[c].iter().sum();
208                let tp = cm.matrix[c][c];
209                let r = if support == 0 {
210                    0.0
211                } else {
212                    tp as f64 / support as f64
213                };
214                total += r * support as f64;
215                total_support += support;
216            }
217            if total_support == 0 {
218                0.0
219            } else {
220                total / total_support as f64
221            }
222        }
223    }
224}
225
226/// Compute precision (builds confusion matrix internally).
227pub fn precision(y_true: &[f64], y_pred: &[f64], avg: Average) -> f64 {
228    let cm = confusion_matrix(y_true, y_pred);
229    precision_from_cm(&cm, avg)
230}
231
232/// Compute recall (builds confusion matrix internally).
233pub fn recall(y_true: &[f64], y_pred: &[f64], avg: Average) -> f64 {
234    let cm = confusion_matrix(y_true, y_pred);
235    recall_from_cm(&cm, avg)
236}
237
238/// Compute F1 score.
239///
240/// For `Binary`, computes `2 * precision * recall / (precision + recall)`.
241/// For `Macro`, computes per-class F1 scores then averages (matching sklearn).
242/// For `Weighted`, computes per-class F1 scores then takes a support-weighted average.
243pub fn f1_score(y_true: &[f64], y_pred: &[f64], avg: Average) -> f64 {
244    let cm = confusion_matrix(y_true, y_pred);
245    let n = cm.matrix.len();
246
247    match avg {
248        Average::Binary => {
249            let p = precision_from_cm(&cm, Average::Binary);
250            let r = recall_from_cm(&cm, Average::Binary);
251            if p + r == 0.0 {
252                0.0
253            } else {
254                2.0 * p * r / (p + r)
255            }
256        }
257        Average::Macro => {
258            let mut total_f1 = 0.0;
259            for c in 0..n {
260                let tp = cm.matrix[c][c];
261                let fp: usize = (0..n)
262                    .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
263                    .sum();
264                let support: usize = cm.matrix[c].iter().sum();
265                let p = if tp + fp == 0 {
266                    0.0
267                } else {
268                    tp as f64 / (tp + fp) as f64
269                };
270                let r = if support == 0 {
271                    0.0
272                } else {
273                    tp as f64 / support as f64
274                };
275                total_f1 += if p + r == 0.0 {
276                    0.0
277                } else {
278                    2.0 * p * r / (p + r)
279                };
280            }
281            total_f1 / n as f64
282        }
283        Average::Weighted => {
284            let mut total_f1 = 0.0;
285            let mut total_support = 0;
286            for c in 0..n {
287                let tp = cm.matrix[c][c];
288                let fp: usize = (0..n)
289                    .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
290                    .sum();
291                let support: usize = cm.matrix[c].iter().sum();
292                let p = if tp + fp == 0 {
293                    0.0
294                } else {
295                    tp as f64 / (tp + fp) as f64
296                };
297                let r = if support == 0 {
298                    0.0
299                } else {
300                    tp as f64 / support as f64
301                };
302                let f = if p + r == 0.0 {
303                    0.0
304                } else {
305                    2.0 * p * r / (p + r)
306                };
307                total_f1 += f * support as f64;
308                total_support += support;
309            }
310            if total_support == 0 {
311                0.0
312            } else {
313                total_f1 / total_support as f64
314            }
315        }
316    }
317}
318
319/// Build a confusion matrix from true and predicted labels.
320pub fn confusion_matrix(y_true: &[f64], y_pred: &[f64]) -> ConfusionMatrix {
321    let mut classes: Vec<i64> = y_true
322        .iter()
323        .chain(y_pred.iter())
324        .map(|&v| v as i64)
325        .collect();
326    classes.sort_unstable();
327    classes.dedup();
328
329    let n = classes.len();
330    let mut matrix = vec![vec![0usize; n]; n];
331    let labels: Vec<String> = classes
332        .iter()
333        .map(std::string::ToString::to_string)
334        .collect();
335
336    // O(1) lookup per sample instead of O(k) linear scan.
337    let class_map: HashMap<i64, usize> = classes.iter().enumerate().map(|(i, &c)| (c, i)).collect();
338
339    for (&t, &p) in y_true.iter().zip(y_pred.iter()) {
340        let ti = class_map.get(&(t as i64)).copied().unwrap_or(0);
341        let pi = class_map.get(&(p as i64)).copied().unwrap_or(0);
342        matrix[ti][pi] += 1;
343    }
344
345    ConfusionMatrix { matrix, labels }
346}
347
348/// Generate a full classification report (like sklearn's `classification_report`).
349pub fn classification_report(y_true: &[f64], y_pred: &[f64]) -> ClassificationReport {
350    let cm = confusion_matrix(y_true, y_pred);
351    let n = cm.matrix.len();
352    let total: usize = cm.matrix.iter().flat_map(|r| r.iter()).sum();
353
354    let mut per_class = Vec::with_capacity(n);
355    let mut macro_p = 0.0;
356    let mut macro_r = 0.0;
357    let mut macro_f = 0.0;
358    let mut weighted_p = 0.0;
359    let mut weighted_r = 0.0;
360    let mut weighted_f = 0.0;
361
362    for c in 0..n {
363        let tp = cm.matrix[c][c];
364        let support: usize = cm.matrix[c].iter().sum();
365        let fp: usize = (0..n)
366            .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
367            .sum();
368
369        let p = if tp + fp == 0 {
370            0.0
371        } else {
372            tp as f64 / (tp + fp) as f64
373        };
374        let r = if support == 0 {
375            0.0
376        } else {
377            tp as f64 / support as f64
378        };
379        let f = if p + r == 0.0 {
380            0.0
381        } else {
382            2.0 * p * r / (p + r)
383        };
384
385        per_class.push((
386            cm.labels[c].clone(),
387            ClassMetrics {
388                precision: p,
389                recall: r,
390                f1: f,
391                support,
392            },
393        ));
394
395        macro_p += p;
396        macro_r += r;
397        macro_f += f;
398        weighted_p += p * support as f64;
399        weighted_r += r * support as f64;
400        weighted_f += f * support as f64;
401    }
402
403    let n_f = n as f64;
404    let total_f = total as f64;
405
406    ClassificationReport {
407        accuracy: accuracy(y_true, y_pred),
408        per_class,
409        macro_avg: ClassMetrics {
410            precision: macro_p / n_f,
411            recall: macro_r / n_f,
412            f1: macro_f / n_f,
413            support: total,
414        },
415        weighted_avg: ClassMetrics {
416            precision: if total > 0 { weighted_p / total_f } else { 0.0 },
417            recall: if total > 0 { weighted_r / total_f } else { 0.0 },
418            f1: if total > 0 { weighted_f / total_f } else { 0.0 },
419            support: total,
420        },
421        total_support: total,
422    }
423}
424
425/// Compute log-loss (cross-entropy loss) for probabilistic predictions.
426///
427/// # Arguments
428/// - `y_true` — true class labels (0-indexed integers as f64)
429/// - `y_prob` — predicted probability vectors, one per sample
430///
431/// Probabilities are clipped to `[1e-15, 1 - 1e-15]` to avoid `log(0)`.
432pub fn log_loss(y_true: &[f64], y_prob: &[Vec<f64>]) -> f64 {
433    if y_true.is_empty() || y_prob.is_empty() {
434        return 0.0;
435    }
436    let eps = 1e-15;
437    let n = y_true.len();
438    let mut total = 0.0;
439    for (i, &label) in y_true.iter().enumerate() {
440        let class_idx = label as usize;
441        if class_idx < y_prob[i].len() {
442            let p = y_prob[i][class_idx].clamp(eps, 1.0 - eps);
443            total -= p.ln();
444        }
445    }
446    total / n as f64
447}
448
449/// Balanced accuracy: mean per-class recall (macro recall).
450///
451/// Particularly useful when classes are imbalanced, since it weights
452/// each class equally regardless of support.
453pub fn balanced_accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
454    if y_true.is_empty() {
455        return 0.0;
456    }
457    let cm = confusion_matrix(y_true, y_pred);
458    let n = cm.matrix.len();
459    let mut total_recall = 0.0;
460    for c in 0..n {
461        let support: usize = cm.matrix[c].iter().sum();
462        let tp = cm.matrix[c][c];
463        total_recall += if support == 0 {
464            0.0
465        } else {
466            tp as f64 / support as f64
467        };
468    }
469    total_recall / n as f64
470}
471
472/// Cohen's kappa coefficient — inter-rater agreement adjusted for chance.
473///
474/// Returns a value in `[-1, 1]` where 1 means perfect agreement,
475/// 0 means agreement no better than chance, and negative values
476/// mean worse than chance.
477pub fn cohen_kappa_score(y_true: &[f64], y_pred: &[f64]) -> f64 {
478    if y_true.is_empty() {
479        return 0.0;
480    }
481    let cm = confusion_matrix(y_true, y_pred);
482    let n_classes = cm.matrix.len();
483    let total: f64 = cm.matrix.iter().flat_map(|r| r.iter()).sum::<usize>() as f64;
484    if total == 0.0 {
485        return 0.0;
486    }
487
488    // Observed agreement
489    let p_o: f64 = (0..n_classes).map(|c| cm.matrix[c][c] as f64).sum::<f64>() / total;
490
491    // Expected agreement by chance
492    let mut p_e = 0.0;
493    for c in 0..n_classes {
494        let row_sum: f64 = cm.matrix[c].iter().sum::<usize>() as f64;
495        let col_sum: f64 = (0..n_classes).map(|r| cm.matrix[r][c] as f64).sum::<f64>();
496        p_e += (row_sum * col_sum) / (total * total);
497    }
498
499    if (1.0 - p_e).abs() < 1e-15 {
500        return if (p_o - 1.0).abs() < 1e-15 { 1.0 } else { 0.0 };
501    }
502
503    (p_o - p_e) / (1.0 - p_e)
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_accuracy_perfect() {
512        assert!((accuracy(&[0.0, 1.0, 2.0], &[0.0, 1.0, 2.0]) - 1.0).abs() < 1e-10);
513    }
514
515    #[test]
516    fn test_accuracy_half() {
517        assert!((accuracy(&[0.0, 1.0, 0.0, 1.0], &[0.0, 0.0, 0.0, 1.0]) - 0.75).abs() < 1e-10);
518    }
519
520    #[test]
521    fn test_confusion_matrix_binary() {
522        let y_true = vec![0.0, 0.0, 1.0, 1.0];
523        let y_pred = vec![0.0, 1.0, 0.0, 1.0];
524        let cm = confusion_matrix(&y_true, &y_pred);
525        assert_eq!(cm.matrix, vec![vec![1, 1], vec![1, 1]]);
526    }
527
528    #[test]
529    fn test_classification_report_display() {
530        let y_true = vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0];
531        let y_pred = vec![0.0, 0.0, 1.0, 2.0, 1.0, 2.0];
532        let report = classification_report(&y_true, &y_pred);
533        let output = format!("{report}");
534        assert!(output.contains("accuracy"));
535        assert!(output.contains("macro avg"));
536    }
537
538    #[test]
539    fn test_f1_binary() {
540        // TP=1, FP=1, FN=1 → P=0.5, R=0.5, F1=0.5
541        let y_true = vec![0.0, 1.0, 1.0];
542        let y_pred = vec![1.0, 1.0, 0.0];
543        let f = f1_score(&y_true, &y_pred, Average::Binary);
544        assert!((f - 0.5).abs() < 1e-6, "expected F1=0.5, got {f}");
545    }
546
547    // -----------------------------------------------------------------------
548    // log_loss tests
549    // -----------------------------------------------------------------------
550
551    #[test]
552    fn test_log_loss_perfect() {
553        // Perfect predictions → each true class has probability 1.0
554        let y_true = vec![0.0, 1.0, 2.0];
555        let y_prob = vec![
556            vec![1.0, 0.0, 0.0],
557            vec![0.0, 1.0, 0.0],
558            vec![0.0, 0.0, 1.0],
559        ];
560        let ll = log_loss(&y_true, &y_prob);
561        assert!(ll < 1e-10, "perfect log_loss should be ~0, got {ll}");
562    }
563
564    #[test]
565    fn test_log_loss_random() {
566        // Uniform random predictions → log_loss should be ln(3) ≈ 1.099
567        let y_true = vec![0.0, 1.0, 2.0];
568        let y_prob = vec![
569            vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
570            vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
571            vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
572        ];
573        let ll = log_loss(&y_true, &y_prob);
574        assert!(ll > 0.5, "random log_loss should be positive, got {ll}");
575        assert!(
576            (ll - 3.0_f64.ln()).abs() < 1e-6,
577            "expected ~ln(3), got {ll}"
578        );
579    }
580
581    // -----------------------------------------------------------------------
582    // balanced_accuracy tests
583    // -----------------------------------------------------------------------
584
585    #[test]
586    fn test_balanced_accuracy_perfect() {
587        let ba = balanced_accuracy(&[0.0, 1.0, 2.0], &[0.0, 1.0, 2.0]);
588        assert!((ba - 1.0).abs() < 1e-10);
589    }
590
591    #[test]
592    fn test_balanced_accuracy_imbalanced() {
593        // 90 class-0, 10 class-1. Predict all as 0.
594        let mut y_true = vec![0.0; 90];
595        y_true.extend(vec![1.0; 10]);
596        let y_pred = vec![0.0; 100];
597
598        let raw = accuracy(&y_true, &y_pred);
599        let bal = balanced_accuracy(&y_true, &y_pred);
600
601        // Raw accuracy = 0.90, balanced = (1.0 + 0.0)/2 = 0.50
602        assert!((raw - 0.90).abs() < 1e-10);
603        assert!((bal - 0.50).abs() < 1e-10);
604        assert!(bal < raw, "balanced should be lower on imbalanced data");
605    }
606
607    // -----------------------------------------------------------------------
608    // cohen_kappa tests
609    // -----------------------------------------------------------------------
610
611    #[test]
612    fn test_cohen_kappa_perfect() {
613        let kappa = cohen_kappa_score(&[0.0, 1.0, 2.0, 0.0, 1.0], &[0.0, 1.0, 2.0, 0.0, 1.0]);
614        assert!(
615            (kappa - 1.0).abs() < 1e-10,
616            "perfect kappa should be 1.0, got {kappa}"
617        );
618    }
619
620    #[test]
621    fn test_cohen_kappa_chance() {
622        // All predict class 0 on balanced data → kappa ≈ 0
623        let y_true = vec![0.0, 0.0, 1.0, 1.0];
624        let y_pred = vec![0.0, 0.0, 0.0, 0.0];
625        let kappa = cohen_kappa_score(&y_true, &y_pred);
626        assert!(
627            kappa.abs() < 1e-10,
628            "chance kappa should be ~0, got {kappa}"
629        );
630    }
631
632    #[test]
633    fn test_cohen_kappa_partial() {
634        // Known: 3 agree out of 4 on binary
635        let y_true = vec![0.0, 0.0, 1.0, 1.0];
636        let y_pred = vec![0.0, 0.0, 0.0, 1.0];
637        let kappa = cohen_kappa_score(&y_true, &y_pred);
638        // p_o = 3/4 = 0.75, row/col sums: [2,2] x [3,1]
639        // p_e = (2*3)/(4*4) + (2*1)/(4*4) = 6/16 + 2/16 = 0.5
640        // kappa = (0.75 - 0.5) / (1.0 - 0.5) = 0.5
641        assert!(
642            (kappa - 0.5).abs() < 1e-10,
643            "expected kappa=0.5, got {kappa}"
644        );
645    }
646}