Skip to main content

tensorlogic_train/loss/
huberloss_traits.rs

1//! # HuberLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `HuberLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::HuberLoss;
15
16impl Default for HuberLoss {
17    fn default() -> Self {
18        Self { delta: 1.0 }
19    }
20}
21
22impl Loss for HuberLoss {
23    fn compute(
24        &self,
25        predictions: &ArrayView<f64, Ix2>,
26        targets: &ArrayView<f64, Ix2>,
27    ) -> TrainResult<f64> {
28        if predictions.shape() != targets.shape() {
29            return Err(TrainError::LossError(format!(
30                "Shape mismatch: predictions {:?} vs targets {:?}",
31                predictions.shape(),
32                targets.shape()
33            )));
34        }
35        let n = predictions.len() as f64;
36        let mut total_loss = 0.0;
37        for i in 0..predictions.nrows() {
38            for j in 0..predictions.ncols() {
39                let diff = (predictions[[i, j]] - targets[[i, j]]).abs();
40                if diff <= self.delta {
41                    total_loss += 0.5 * diff * diff;
42                } else {
43                    total_loss += self.delta * (diff - 0.5 * self.delta);
44                }
45            }
46        }
47        Ok(total_loss / n)
48    }
49    fn gradient(
50        &self,
51        predictions: &ArrayView<f64, Ix2>,
52        targets: &ArrayView<f64, Ix2>,
53    ) -> TrainResult<Array<f64, Ix2>> {
54        if predictions.shape() != targets.shape() {
55            return Err(TrainError::LossError(format!(
56                "Shape mismatch: predictions {:?} vs targets {:?}",
57                predictions.shape(),
58                targets.shape()
59            )));
60        }
61        let n = predictions.len() as f64;
62        let mut grad = Array::zeros(predictions.raw_dim());
63        for i in 0..predictions.nrows() {
64            for j in 0..predictions.ncols() {
65                let diff = predictions[[i, j]] - targets[[i, j]];
66                let abs_diff = diff.abs();
67                if abs_diff <= self.delta {
68                    grad[[i, j]] = diff / n;
69                } else {
70                    grad[[i, j]] = self.delta * diff.signum() / n;
71                }
72            }
73        }
74        Ok(grad)
75    }
76}