Skip to main content

rust_mlp/
metrics.rs

1//! Metrics.
2//!
3//! Metrics are evaluation helpers (they do not participate in backprop).
4//!
5//! In this crate, metrics are computed sample-by-sample during evaluation/training
6//! without allocating per step.
7
8use crate::{Error, Result};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11/// Supported evaluation metrics.
12pub enum Metric {
13    /// Mean squared error.
14    Mse,
15    /// Mean absolute error.
16    Mae,
17    /// Classification accuracy.
18    ///
19    /// - For `output_dim == 1`: binary accuracy.
20    /// - For `output_dim > 1`: multiclass accuracy (argmax).
21    Accuracy,
22    /// Top-k accuracy for multiclass classification.
23    ///
24    /// This metric requires `output_dim > 1` and `k <= output_dim`.
25    TopKAccuracy { k: usize },
26}
27
28impl Metric {
29    /// Validate metric parameters.
30    pub fn validate(self) -> Result<()> {
31        match self {
32            Metric::TopKAccuracy { k } => {
33                if k == 0 {
34                    return Err(Error::InvalidConfig(
35                        "TopKAccuracy requires k > 0".to_owned(),
36                    ));
37                }
38            }
39            Metric::Mse | Metric::Mae | Metric::Accuracy => {}
40        }
41        Ok(())
42    }
43}