Skip to main content

tensorlogic_train/loss/
polyloss_traits.rs

1//! # PolyLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `PolyLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::PolyLoss;
15
16impl Default for PolyLoss {
17    fn default() -> Self {
18        Self {
19            epsilon: 1e-10,
20            poly_coeff: 1.0,
21        }
22    }
23}
24
25impl Loss for PolyLoss {
26    fn compute(
27        &self,
28        predictions: &ArrayView<f64, Ix2>,
29        targets: &ArrayView<f64, Ix2>,
30    ) -> TrainResult<f64> {
31        if predictions.shape() != targets.shape() {
32            return Err(TrainError::LossError(format!(
33                "Shape mismatch: predictions {:?} vs targets {:?}",
34                predictions.shape(),
35                targets.shape()
36            )));
37        }
38        let n = predictions.nrows() as f64;
39        let mut total_loss = 0.0;
40        for i in 0..predictions.nrows() {
41            for j in 0..predictions.ncols() {
42                let pred = predictions[[i, j]]
43                    .max(self.epsilon)
44                    .min(1.0 - self.epsilon);
45                let target = targets[[i, j]];
46                let ce = -target * pred.ln();
47                let poly_term = if target > 0.5 {
48                    self.poly_coeff * (1.0 - pred)
49                } else {
50                    self.poly_coeff * pred
51                };
52                total_loss += ce + poly_term;
53            }
54        }
55        Ok(total_loss / n)
56    }
57    fn gradient(
58        &self,
59        predictions: &ArrayView<f64, Ix2>,
60        targets: &ArrayView<f64, Ix2>,
61    ) -> TrainResult<Array<f64, Ix2>> {
62        let n = predictions.nrows() as f64;
63        let mut grad = Array::zeros(predictions.raw_dim());
64        for i in 0..predictions.nrows() {
65            for j in 0..predictions.ncols() {
66                let pred = predictions[[i, j]]
67                    .max(self.epsilon)
68                    .min(1.0 - self.epsilon);
69                let target = targets[[i, j]];
70                let ce_grad = -target / pred;
71                let poly_grad = if target > 0.5 {
72                    -self.poly_coeff
73                } else {
74                    self.poly_coeff
75                };
76                grad[[i, j]] = (ce_grad + poly_grad) / n;
77            }
78        }
79        Ok(grad)
80    }
81    fn name(&self) -> &str {
82        "poly_loss"
83    }
84}