rust_mlp/metrics.rs
1//! Metrics.
2//!
3//! Metrics are evaluation helpers (they do not participate in backprop).
4//!
5//! In this crate, metrics are computed sample-by-sample during evaluation/training
6//! without allocating per step.
7
8use crate::{Error, Result};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11/// Supported evaluation metrics.
12pub enum Metric {
13 /// Mean squared error.
14 Mse,
15 /// Mean absolute error.
16 Mae,
17 /// Classification accuracy.
18 ///
19 /// - For `output_dim == 1`: binary accuracy.
20 /// - For `output_dim > 1`: multiclass accuracy (argmax).
21 Accuracy,
22 /// Top-k accuracy for multiclass classification.
23 ///
24 /// This metric requires `output_dim > 1` and `k <= output_dim`.
25 TopKAccuracy { k: usize },
26}
27
28impl Metric {
29 /// Validate metric parameters.
30 pub fn validate(self) -> Result<()> {
31 match self {
32 Metric::TopKAccuracy { k } => {
33 if k == 0 {
34 return Err(Error::InvalidConfig(
35 "TopKAccuracy requires k > 0".to_owned(),
36 ));
37 }
38 }
39 Metric::Mse | Metric::Mae | Metric::Accuracy => {}
40 }
41 Ok(())
42 }
43}