Skip to main content

tensorlogic_train/loss/
mseloss_traits.rs

1//! # MseLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `MseLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Loss`
8
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, ArrayView, Ix2};
11
12use super::functions::Loss;
13use super::types::MseLoss;
14
15impl Loss for MseLoss {
16    fn compute(
17        &self,
18        predictions: &ArrayView<f64, Ix2>,
19        targets: &ArrayView<f64, Ix2>,
20    ) -> TrainResult<f64> {
21        if predictions.shape() != targets.shape() {
22            return Err(TrainError::LossError(format!(
23                "Shape mismatch: predictions {:?} vs targets {:?}",
24                predictions.shape(),
25                targets.shape()
26            )));
27        }
28        let n = predictions.len() as f64;
29        let mut total_loss = 0.0;
30        for i in 0..predictions.nrows() {
31            for j in 0..predictions.ncols() {
32                let diff = predictions[[i, j]] - targets[[i, j]];
33                total_loss += diff * diff;
34            }
35        }
36        Ok(total_loss / n)
37    }
38    fn gradient(
39        &self,
40        predictions: &ArrayView<f64, Ix2>,
41        targets: &ArrayView<f64, Ix2>,
42    ) -> TrainResult<Array<f64, Ix2>> {
43        if predictions.shape() != targets.shape() {
44            return Err(TrainError::LossError(format!(
45                "Shape mismatch: predictions {:?} vs targets {:?}",
46                predictions.shape(),
47                targets.shape()
48            )));
49        }
50        let n = predictions.len() as f64;
51        let mut grad = Array::zeros(predictions.raw_dim());
52        for i in 0..predictions.nrows() {
53            for j in 0..predictions.ncols() {
54                grad[[i, j]] = 2.0 * (predictions[[i, j]] - targets[[i, j]]) / n;
55            }
56        }
57        Ok(grad)
58    }
59}