sklears_multioutput/
loss.rs

1//! Loss functions for neural network training
2//!
3//! This module provides various loss functions commonly used in neural network training,
4//! including Mean Squared Error for regression and Cross-Entropy for classification.
5
6// Use SciRS2-Core for arrays (SciRS2 Policy)
7use scirs2_core::ndarray::Array2;
8use sklears_core::types::Float;
9
10/// Loss functions for neural network training
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum LossFunction {
13    /// Mean squared error for regression
14    MeanSquaredError,
15    /// Cross-entropy for classification
16    CrossEntropy,
17    /// Binary cross-entropy for multi-label classification
18    BinaryCrossEntropy,
19}
20
21impl LossFunction {
22    /// Compute loss between predictions and targets
23    pub fn compute_loss(&self, y_pred: &Array2<Float>, y_true: &Array2<Float>) -> Float {
24        match self {
25            LossFunction::MeanSquaredError => {
26                let diff = y_pred - y_true;
27                diff.map(|x| x * x).mean().unwrap()
28            }
29            LossFunction::CrossEntropy => {
30                let mut total_loss = 0.0;
31                for i in 0..y_pred.nrows() {
32                    for j in 0..y_pred.ncols() {
33                        let pred = y_pred[[i, j]].clamp(1e-15, 1.0 - 1e-15); // Clip for numerical stability
34                        total_loss -= y_true[[i, j]] * pred.ln();
35                    }
36                }
37                total_loss / (y_pred.nrows() as Float)
38            }
39            LossFunction::BinaryCrossEntropy => {
40                let mut total_loss = 0.0;
41                for i in 0..y_pred.nrows() {
42                    for j in 0..y_pred.ncols() {
43                        let pred = y_pred[[i, j]].clamp(1e-15, 1.0 - 1e-15); // Clip for numerical stability
44                        total_loss -=
45                            y_true[[i, j]] * pred.ln() + (1.0 - y_true[[i, j]]) * (1.0 - pred).ln();
46                    }
47                }
48                total_loss / (y_pred.nrows() as Float * y_pred.ncols() as Float)
49            }
50        }
51    }
52}