scirs2_neural/utils/
metrics.rs

1//! Evaluation metrics for neural networks
2
3use crate::error::{NeuralError, Result};
4use ndarray::{Array, Ix1, Ix2, IxDyn, Zip};
5use num_traits::Float;
6use std::fmt::Debug;
7
8/// Trait for metrics that evaluate model performance
9pub trait Metric<F: Float> {
10    /// Compute the metric for the given predictions and targets
11    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F>;
12}
13
14/// Mean squared error metric
15pub struct MeanSquaredError;
16
17impl<F: Float + Debug> Metric<F> for MeanSquaredError {
18    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
19        mean_squared_error(predictions, targets)
20    }
21}
22
23/// Binary accuracy metric for classification
24pub struct BinaryAccuracy {
25    /// Threshold for classifying a prediction as positive
26    pub threshold: f64,
27}
28
29impl BinaryAccuracy {
30    /// Create a new binary accuracy metric with the given threshold
31    pub fn new(threshold: f64) -> Self {
32        Self { threshold }
33    }
34}
35
36impl Default for BinaryAccuracy {
37    fn default() -> Self {
38        Self { threshold: 0.5 }
39    }
40}
41
42impl<F: Float> Metric<F> for BinaryAccuracy {
43    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
44        if predictions.shape() != targets.shape() {
45            return Err(NeuralError::InferenceError(format!(
46                "Predictions shape {:?} does not match targets shape {:?}",
47                predictions.shape(),
48                targets.shape()
49            )));
50        }
51
52        let threshold = F::from(self.threshold).ok_or_else(|| {
53            NeuralError::Other("Could not convert threshold to the required float type".to_string())
54        })?;
55
56        let mut correct = 0;
57        let n_elements = predictions.len();
58
59        for (pred, target) in predictions.iter().zip(targets.iter()) {
60            let pred_class = if *pred >= threshold {
61                F::one()
62            } else {
63                F::zero()
64            };
65            if pred_class == *target {
66                correct += 1;
67            }
68        }
69
70        Ok(F::from(correct).unwrap_or(F::zero()) / F::from(n_elements).unwrap_or(F::one()))
71    }
72}
73
74/// Categorical accuracy metric for multi-class classification
75pub struct CategoricalAccuracy;
76
77impl<F: Float + Debug> Metric<F> for CategoricalAccuracy {
78    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
79        if predictions.ndim() >= 2 && targets.ndim() >= 2 {
80            categorical_accuracy(
81                &predictions.to_owned().into_dimensionality::<Ix2>().unwrap(),
82                &targets.to_owned().into_dimensionality::<Ix2>().unwrap(),
83            )
84        } else {
85            Err(NeuralError::Other(
86                "Predictions and targets must have at least 2 dimensions for categorical accuracy"
87                    .to_string(),
88            ))
89        }
90    }
91}
92
93/// Coefficient of determination (R²) metric
94pub struct R2Score;
95
96impl<F: Float> Metric<F> for R2Score {
97    fn compute(&self, predictions: &Array<F, IxDyn>, targets: &Array<F, IxDyn>) -> Result<F> {
98        if predictions.shape() != targets.shape() {
99            return Err(NeuralError::InferenceError(format!(
100                "Predictions shape {:?} does not match targets shape {:?}",
101                predictions.shape(),
102                targets.shape()
103            )));
104        }
105
106        let n_elements = F::from(targets.len()).unwrap_or(F::one());
107
108        // Calculate mean of targets
109        let target_mean = targets.iter().fold(F::zero(), |acc, &x| acc + x) / n_elements;
110
111        // Calculate total sum of squares
112        let mut ss_tot = F::zero();
113        for target in targets.iter() {
114            let diff = *target - target_mean;
115            ss_tot = ss_tot + diff * diff;
116        }
117
118        // Calculate residual sum of squares
119        let mut ss_res = F::zero();
120        for (pred, target) in predictions.iter().zip(targets.iter()) {
121            let diff = *target - *pred;
122            ss_res = ss_res + diff * diff;
123        }
124
125        // Calculate R²
126        let r2 = F::one() - ss_res / ss_tot;
127
128        // R² can be negative if the model is worse than predicting the mean
129        Ok(r2)
130    }
131}
132
133/// Calculate the mean squared error between predictions and targets
134///
135/// # Arguments
136///
137/// * `predictions` - Predicted values
138/// * `targets` - Target values
139///
140/// # Returns
141///
142/// * The mean squared error
143///
144/// # Examples
145///
146/// ```
147/// use scirs2_neural::utils::mean_squared_error;
148/// use ndarray::arr1;
149///
150/// let predictions = arr1(&[1.0f64, 2.0, 3.0]).into_dyn();
151/// let targets = arr1(&[1.5f64, 1.8, 2.5]).into_dyn();
152///
153/// let mse = mean_squared_error(&predictions, &targets).unwrap();
154/// assert!(mse > 0.0f64);
155/// ```
156pub fn mean_squared_error<F: Float + Debug>(
157    predictions: &Array<F, IxDyn>,
158    targets: &Array<F, IxDyn>,
159) -> Result<F> {
160    if predictions.shape() != targets.shape() {
161        return Err(NeuralError::InferenceError(format!(
162            "Shape mismatch in mean_squared_error: predictions {:?} vs targets {:?}",
163            predictions.shape(),
164            targets.shape()
165        )));
166    }
167
168    let n = F::from(predictions.len())
169        .ok_or_else(|| NeuralError::Other("Could not convert array length to float".to_string()))?;
170
171    let mut sum_squared_diff = F::zero();
172
173    for (p, t) in predictions.iter().zip(targets.iter()) {
174        let diff = *p - *t;
175        sum_squared_diff = sum_squared_diff + diff * diff;
176    }
177
178    Ok(sum_squared_diff / n)
179}
180
181/// Calculate the binary accuracy between predictions and targets
182///
183/// # Arguments
184///
185/// * `predictions` - Predicted values (should be between 0 and 1)
186/// * `targets` - Target values (should be either 0 or 1)
187/// * `threshold` - Threshold value for binary classification (default: 0.5)
188///
189/// # Returns
190///
191/// * The accuracy (proportion of correct predictions)
192///
193/// # Examples
194///
195/// ```
196/// use scirs2_neural::utils::binary_accuracy;
197/// use ndarray::arr1;
198///
199/// let predictions = arr1(&[0.7f64, 0.3, 0.8, 0.2]);
200/// let targets = arr1(&[1.0f64, 0.0, 1.0, 0.0]);
201///
202/// let accuracy = binary_accuracy(&predictions, &targets, 0.5f64).unwrap();
203/// assert_eq!(accuracy, 1.0f64); // All predictions are correct
204/// ```
205pub fn binary_accuracy<F: Float + Debug>(
206    predictions: &Array<F, Ix1>,
207    targets: &Array<F, Ix1>,
208    threshold: F,
209) -> Result<F> {
210    if predictions.shape() != targets.shape() {
211        return Err(NeuralError::InferenceError(format!(
212            "Shape mismatch in binary_accuracy: predictions {:?} vs targets {:?}",
213            predictions.shape(),
214            targets.shape()
215        )));
216    }
217
218    let n = F::from(predictions.len()).ok_or_else(|| {
219        NeuralError::InferenceError("Could not convert array length to float".to_string())
220    })?;
221
222    let mut correct = F::zero();
223
224    Zip::from(predictions).and(targets).for_each(|&p, &t| {
225        let pred_class = if p >= threshold { F::one() } else { F::zero() };
226        if pred_class == t {
227            correct = correct + F::one();
228        }
229    });
230
231    Ok(correct / n)
232}
233
234/// Calculate the categorical accuracy between predictions and targets
235///
236/// # Arguments
237///
238/// * `predictions` - Predicted class probabilities (each row sums to 1)
239/// * `targets` - One-hot encoded target classes (each row has a single 1)
240///
241/// # Returns
242///
243/// * The accuracy (proportion of correct predictions)
244///
245/// # Examples
246///
247/// ```
248/// use scirs2_neural::utils::categorical_accuracy;
249/// use ndarray::arr2;
250///
251/// let predictions = arr2(&[
252///     [0.7f64, 0.2, 0.1],  // Predicted class: 0
253///     [0.3f64, 0.6, 0.1],  // Predicted class: 1
254///     [0.2f64, 0.3, 0.5]   // Predicted class: 2
255/// ]);
256///
257/// let targets = arr2(&[
258///     [1.0f64, 0.0, 0.0],  // True class: 0
259///     [0.0f64, 1.0, 0.0],  // True class: 1
260///     [0.0f64, 0.0, 1.0]   // True class: 2
261/// ]);
262///
263/// let accuracy = categorical_accuracy(&predictions, &targets).unwrap();
264/// assert_eq!(accuracy, 1.0f64); // All predictions are correct
265/// ```
266pub fn categorical_accuracy<F: Float + Debug>(
267    predictions: &Array<F, Ix2>,
268    targets: &Array<F, Ix2>,
269) -> Result<F> {
270    if predictions.shape() != targets.shape() {
271        return Err(NeuralError::InferenceError(format!(
272            "Shape mismatch in categorical_accuracy: predictions {:?} vs targets {:?}",
273            predictions.shape(),
274            targets.shape()
275        )));
276    }
277
278    let n = F::from(predictions.shape()[0]).ok_or_else(|| {
279        NeuralError::InferenceError("Could not convert sample count to float".to_string())
280    })?;
281
282    let mut correct = F::zero();
283
284    for i in 0..predictions.shape()[0] {
285        // Find predicted class (index of max value)
286        let mut pred_class = 0;
287        let mut max_prob = predictions[[i, 0]];
288
289        for j in 1..predictions.shape()[1] {
290            if predictions[[i, j]] > max_prob {
291                max_prob = predictions[[i, j]];
292                pred_class = j;
293            }
294        }
295
296        // Find true class (index of 1 in one-hot encoding)
297        let mut true_class = 0;
298
299        for j in 0..targets.shape()[1] {
300            if targets[[i, j]] == F::one() {
301                true_class = j;
302                break;
303            }
304        }
305
306        // Check if prediction is correct
307        if pred_class == true_class {
308            correct = correct + F::one();
309        }
310    }
311
312    Ok(correct / n)
313}