sklears_multioutput/
loss.rs1use scirs2_core::ndarray::Array2;
8use sklears_core::types::Float;
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum LossFunction {
13 MeanSquaredError,
15 CrossEntropy,
17 BinaryCrossEntropy,
19}
20
21impl LossFunction {
22 pub fn compute_loss(&self, y_pred: &Array2<Float>, y_true: &Array2<Float>) -> Float {
24 match self {
25 LossFunction::MeanSquaredError => {
26 let diff = y_pred - y_true;
27 diff.map(|x| x * x).mean().unwrap()
28 }
29 LossFunction::CrossEntropy => {
30 let mut total_loss = 0.0;
31 for i in 0..y_pred.nrows() {
32 for j in 0..y_pred.ncols() {
33 let pred = y_pred[[i, j]].clamp(1e-15, 1.0 - 1e-15); total_loss -= y_true[[i, j]] * pred.ln();
35 }
36 }
37 total_loss / (y_pred.nrows() as Float)
38 }
39 LossFunction::BinaryCrossEntropy => {
40 let mut total_loss = 0.0;
41 for i in 0..y_pred.nrows() {
42 for j in 0..y_pred.ncols() {
43 let pred = y_pred[[i, j]].clamp(1e-15, 1.0 - 1e-15); total_loss -=
45 y_true[[i, j]] * pred.ln() + (1.0 - y_true[[i, j]]) * (1.0 - pred).ln();
46 }
47 }
48 total_loss / (y_pred.nrows() as Float * y_pred.ncols() as Float)
49 }
50 }
51 }
52}