scirs2_neural/utils/
metrics.rs

1//! Evaluation metrics for neural networks
2
3use crate::error::{NeuralError, Result};
4use scirs2_core::ndarray::{Array, Ix1, Ix2, IxDyn, Zip};
5use scirs2_core::numeric::Float;
6use std::fmt::Debug;
7
8/// Trait for metrics that evaluate model performance
9pub trait Metric<F: Float> {
10    /// Compute the metric for the given predictions and targets
11    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F>;
12}
13
14/// Mean squared error metric
15pub struct MeanSquaredError;
16
17impl<F: Float + Debug> Metric<F> for MeanSquaredError {
18    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
19        mean_squared_error(predictions, targets)
20    }
21}
22
23/// Binary accuracy metric for classification
24pub struct BinaryAccuracy {
25    /// Threshold for classifying a prediction as positive
26    pub threshold: f64,
27}
28
29impl BinaryAccuracy {
30    /// Create a new binary accuracy metric with the given threshold
31    pub fn new(threshold: f64) -> Self {
32        Self { threshold }
33    }
34}
35
36impl Default for BinaryAccuracy {
37    fn default() -> Self {
38        Self { threshold: 0.5 }
39    }
40}
41
42impl<F: Float> Metric<F> for BinaryAccuracy {
43    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
44        if predictions.shape() != targets.shape() {
45            return Err(NeuralError::InferenceError(format!(
46                "Predictions shape {:?} does not match targets shape {:?}",
47                predictions.shape(),
48                targets.shape()
49            )));
50        }
51
52        let threshold = F::from(self.threshold).ok_or_else(|| {
53            NeuralError::Other("Could not convert threshold to the required float type".to_string())
54        })?;
55
56        let mut correct = 0;
57        let n_elements = predictions.len();
58        for (pred, target) in predictions.iter().zip(targets.iter()) {
59            let pred_class = if *pred >= threshold {
60                F::one()
61            } else {
62                F::zero()
63            };
64            if pred_class == *target {
65                correct += 1;
66            }
67        }
68
69        Ok(F::from(correct).unwrap_or(F::zero()) / F::from(n_elements).unwrap_or(F::one()))
70    }
71}
72
73/// Categorical accuracy metric for multi-class classification
74pub struct CategoricalAccuracy;
75
76impl<F: Float + Debug> Metric<F> for CategoricalAccuracy {
77    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
78        if predictions.ndim() >= 2 && targets.ndim() >= 2 {
79            categorical_accuracy(
80                &predictions.to_owned().into_dimensionality::<Ix2>().unwrap(),
81                &targets.to_owned().into_dimensionality::<Ix2>().unwrap(),
82            )
83        } else {
84            Err(NeuralError::Other(
85                "Predictions and targets must have at least 2 dimensions for categorical accuracy"
86                    .to_string(),
87            ))
88        }
89    }
90}
91
92/// Coefficient of determination (R²) metric
93pub struct R2Score;
94
95impl<F: Float> Metric<F> for R2Score {
96    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
97        let n_elements = F::from(targets.len()).unwrap_or(F::one());
98
99        // Calculate mean of targets
100        let target_mean = targets.iter().fold(F::zero(), |acc, &x| acc + x) / n_elements;
101
102        // Calculate total sum of squares
103        let mut ss_tot = F::zero();
104        for target in targets.iter() {
105            let diff = *target - target_mean;
106            ss_tot = ss_tot + diff * diff;
107        }
108
109        // Calculate residual sum of squares
110        let mut ss_res = F::zero();
111        for (pred, target) in predictions.iter().zip(targets.iter()) {
112            let diff = *target - *pred;
113            ss_res = ss_res + diff * diff;
114        }
115
116        // Calculate R²
117        let r2 = F::one() - ss_res / ss_tot;
118
119        // R² can be negative if the model is worse than predicting the mean
120        Ok(r2)
121    }
122}
123
124/// Calculate the mean squared error between predictions and targets
125///
126/// # Arguments
127/// * `predictions` - Predicted values
128/// * `targets` - Target values
129///
130/// # Returns
131/// * The mean squared error
132///
133/// # Examples
134/// ```
135/// use scirs2_neural::utils::mean_squared_error;
136/// use scirs2_core::ndarray::arr1;
137///
138/// let predictions = arr1(&[1.0f64, 2.0, 3.0]).into_dyn();
139/// let targets = arr1(&[1.5f64, 1.8, 2.5]).into_dyn();
140/// let mse = mean_squared_error(&predictions, &targets).unwrap();
141/// assert!(mse > 0.0f64);
142/// ```
143#[allow(dead_code)]
144pub fn mean_squared_error<F: Float + Debug>(
145    predictions: &Array<F, IxDyn>,
146    targets: &Array<F, IxDyn>,
147) -> Result<F> {
148    if predictions.shape() != targets.shape() {
149        return Err(NeuralError::InferenceError(format!(
150            "Shape mismatch in mean_squared_error: predictions {:?} vs targets {:?}",
151            predictions.shape(),
152            targets.shape()
153        )));
154    }
155
156    let n = F::from(predictions.len())
157        .ok_or_else(|| NeuralError::Other("Could not convert array length to float".to_string()))?;
158
159    let mut sum_squared_diff = F::zero();
160    for (p, t) in predictions.iter().zip(targets.iter()) {
161        let diff = *p - *t;
162        sum_squared_diff = sum_squared_diff + diff * diff;
163    }
164
165    Ok(sum_squared_diff / n)
166}
167
168/// Calculate the binary accuracy between predictions and targets
169///
170/// # Arguments
171/// * `predictions` - Predicted values (should be between 0 and 1)
172/// * `targets` - Target values (should be either 0 or 1)
173/// * `threshold` - Threshold value for binary classification (default: 0.5)
174///
175/// # Returns
176/// * The accuracy (proportion of correct predictions)
177///
178/// # Examples
179/// ```
180/// use scirs2_neural::utils::binary_accuracy;
181/// use scirs2_core::ndarray::arr1;
182///
183/// let predictions = arr1(&[0.7f64, 0.3, 0.8, 0.2]);
184/// let targets = arr1(&[1.0f64, 0.0, 1.0, 0.0]);
185/// let accuracy = binary_accuracy(&predictions, &targets, 0.5f64).unwrap();
186/// assert_eq!(accuracy, 1.0f64); // All predictions are correct
187/// ```
188#[allow(dead_code)]
189pub fn binary_accuracy<F: Float + Debug>(
190    predictions: &Array<F, Ix1>,
191    targets: &Array<F, Ix1>,
192    threshold: F,
193) -> Result<F> {
194    if predictions.shape() != targets.shape() {
195        return Err(NeuralError::InferenceError(format!(
196            "Shape mismatch in binary_accuracy: predictions {:?} vs targets {:?}",
197            predictions.shape(),
198            targets.shape()
199        )));
200    }
201
202    let n = F::from(predictions.len()).ok_or_else(|| {
203        NeuralError::InferenceError("Could not convert array length to float".to_string())
204    })?;
205
206    let mut correct = F::zero();
207    Zip::from(predictions).and(targets).for_each(|&p, &t| {
208        let pred_class = if p >= threshold { F::one() } else { F::zero() };
209        if pred_class == t {
210            correct = correct + F::one();
211        }
212    });
213
214    Ok(correct / n)
215}
216
217/// Calculate the categorical accuracy between predictions and targets
218///
219/// # Arguments
220/// * `predictions` - Predicted class probabilities (each row sums to 1)
221/// * `targets` - One-hot encoded target classes (each row has a single 1)
222///
223/// # Returns
224/// * The accuracy (proportion of correct predictions)
225///
226/// # Examples
227/// ```
228/// use scirs2_neural::utils::categorical_accuracy;
229/// use scirs2_core::ndarray::arr2;
230///
231/// let predictions = arr2(&[
232///     [0.7f64, 0.2, 0.1],  // Predicted class: 0
233///     [0.3f64, 0.6, 0.1],  // Predicted class: 1
234///     [0.2f64, 0.3, 0.5]   // Predicted class: 2
235/// ]);
236/// let targets = arr2(&[
237///     [1.0f64, 0.0, 0.0],  // True class: 0
238///     [0.0f64, 1.0, 0.0],  // True class: 1
239///     [0.0f64, 0.0, 1.0]   // True class: 2
240/// ]);
241/// let accuracy = categorical_accuracy(&predictions, &targets).unwrap();
242/// assert_eq!(accuracy, 1.0f64); // All predictions are correct
243/// ```
244#[allow(dead_code)]
245pub fn categorical_accuracy<F: Float + Debug>(
246    predictions: &Array<F, Ix2>,
247    targets: &Array<F, Ix2>,
248) -> Result<F> {
249    if predictions.shape() != targets.shape() {
250        return Err(NeuralError::InferenceError(format!(
251            "Shape mismatch in categorical_accuracy: predictions {:?} vs targets {:?}",
252            predictions.shape(),
253            targets.shape()
254        )));
255    }
256
257    let n = F::from(predictions.shape()[0]).ok_or_else(|| {
258        NeuralError::InferenceError("Could not convert sample count to float".to_string())
259    })?;
260
261    let mut correct = F::zero();
262    for i in 0..predictions.shape()[0] {
263        // Find predicted class (index of max value)
264        let mut pred_class = 0;
265        let mut max_prob = predictions[[i, 0]];
266        for j in 1..predictions.shape()[1] {
267            if predictions[[i, j]] > max_prob {
268                max_prob = predictions[[i, j]];
269                pred_class = j;
270            }
271        }
272
273        // Find true class (index of 1 in one-hot encoding)
274        let mut true_class = 0;
275        for j in 0..targets.shape()[1] {
276            if targets[[i, j]] == F::one() {
277                true_class = j;
278                break;
279            }
280        }
281
282        // Check if prediction is correct
283        if pred_class == true_class {
284            correct = correct + F::one();
285        }
286    }
287
288    Ok(correct / n)
289}