Skip to main content

tensorlogic_train/loss/
crossentropyloss_traits.rs

1//! # CrossEntropyLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `CrossEntropyLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::CrossEntropyLoss;
15
16impl Default for CrossEntropyLoss {
17    fn default() -> Self {
18        Self { epsilon: 1e-10 }
19    }
20}
21
22impl Loss for CrossEntropyLoss {
23    fn compute(
24        &self,
25        predictions: &ArrayView<f64, Ix2>,
26        targets: &ArrayView<f64, Ix2>,
27    ) -> TrainResult<f64> {
28        if predictions.shape() != targets.shape() {
29            return Err(TrainError::LossError(format!(
30                "Shape mismatch: predictions {:?} vs targets {:?}",
31                predictions.shape(),
32                targets.shape()
33            )));
34        }
35        let n = predictions.nrows() as f64;
36        let mut total_loss = 0.0;
37        for i in 0..predictions.nrows() {
38            for j in 0..predictions.ncols() {
39                let pred = predictions[[i, j]]
40                    .max(self.epsilon)
41                    .min(1.0 - self.epsilon);
42                let target = targets[[i, j]];
43                total_loss -= target * pred.ln();
44            }
45        }
46        Ok(total_loss / n)
47    }
48    fn gradient(
49        &self,
50        predictions: &ArrayView<f64, Ix2>,
51        targets: &ArrayView<f64, Ix2>,
52    ) -> TrainResult<Array<f64, Ix2>> {
53        if predictions.shape() != targets.shape() {
54            return Err(TrainError::LossError(format!(
55                "Shape mismatch: predictions {:?} vs targets {:?}",
56                predictions.shape(),
57                targets.shape()
58            )));
59        }
60        let n = predictions.nrows() as f64;
61        let mut grad = Array::zeros(predictions.raw_dim());
62        for i in 0..predictions.nrows() {
63            for j in 0..predictions.ncols() {
64                let pred = predictions[[i, j]]
65                    .max(self.epsilon)
66                    .min(1.0 - self.epsilon);
67                let target = targets[[i, j]];
68                grad[[i, j]] = -(target / pred) / n;
69            }
70        }
71        Ok(grad)
72    }
73}