Skip to main content

tensorlogic_train/loss/
bcewithlogitsloss_traits.rs

1//! # BCEWithLogitsLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `BCEWithLogitsLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Loss`
8
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, ArrayView, Ix2};
11
12use super::functions::Loss;
13use super::types::BCEWithLogitsLoss;
14
15impl Loss for BCEWithLogitsLoss {
16    fn compute(
17        &self,
18        logits: &ArrayView<f64, Ix2>,
19        targets: &ArrayView<f64, Ix2>,
20    ) -> TrainResult<f64> {
21        if logits.shape() != targets.shape() {
22            return Err(TrainError::LossError(format!(
23                "Shape mismatch: logits {:?} vs targets {:?}",
24                logits.shape(),
25                targets.shape()
26            )));
27        }
28        let n = logits.len() as f64;
29        let mut total_loss = 0.0;
30        for i in 0..logits.nrows() {
31            for j in 0..logits.ncols() {
32                let logit = logits[[i, j]];
33                let target = targets[[i, j]];
34                let max_val = logit.max(0.0);
35                total_loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
36            }
37        }
38        Ok(total_loss / n)
39    }
40    fn gradient(
41        &self,
42        logits: &ArrayView<f64, Ix2>,
43        targets: &ArrayView<f64, Ix2>,
44    ) -> TrainResult<Array<f64, Ix2>> {
45        if logits.shape() != targets.shape() {
46            return Err(TrainError::LossError(format!(
47                "Shape mismatch: logits {:?} vs targets {:?}",
48                logits.shape(),
49                targets.shape()
50            )));
51        }
52        let n = logits.len() as f64;
53        let mut grad = Array::zeros(logits.raw_dim());
54        for i in 0..logits.nrows() {
55            for j in 0..logits.ncols() {
56                let logit = logits[[i, j]];
57                let target = targets[[i, j]];
58                let sigmoid = 1.0 / (1.0 + (-logit).exp());
59                grad[[i, j]] = (sigmoid - target) / n;
60            }
61        }
62        Ok(grad)
63    }
64}