tensorlogic_train/loss/
bcewithlogitsloss_traits.rs1use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, ArrayView, Ix2};
11
12use super::functions::Loss;
13use super::types::BCEWithLogitsLoss;
14
15impl Loss for BCEWithLogitsLoss {
16 fn compute(
17 &self,
18 logits: &ArrayView<f64, Ix2>,
19 targets: &ArrayView<f64, Ix2>,
20 ) -> TrainResult<f64> {
21 if logits.shape() != targets.shape() {
22 return Err(TrainError::LossError(format!(
23 "Shape mismatch: logits {:?} vs targets {:?}",
24 logits.shape(),
25 targets.shape()
26 )));
27 }
28 let n = logits.len() as f64;
29 let mut total_loss = 0.0;
30 for i in 0..logits.nrows() {
31 for j in 0..logits.ncols() {
32 let logit = logits[[i, j]];
33 let target = targets[[i, j]];
34 let max_val = logit.max(0.0);
35 total_loss += max_val - logit * target + (1.0 + (-logit.abs()).exp()).ln();
36 }
37 }
38 Ok(total_loss / n)
39 }
40 fn gradient(
41 &self,
42 logits: &ArrayView<f64, Ix2>,
43 targets: &ArrayView<f64, Ix2>,
44 ) -> TrainResult<Array<f64, Ix2>> {
45 if logits.shape() != targets.shape() {
46 return Err(TrainError::LossError(format!(
47 "Shape mismatch: logits {:?} vs targets {:?}",
48 logits.shape(),
49 targets.shape()
50 )));
51 }
52 let n = logits.len() as f64;
53 let mut grad = Array::zeros(logits.raw_dim());
54 for i in 0..logits.nrows() {
55 for j in 0..logits.ncols() {
56 let logit = logits[[i, j]];
57 let target = targets[[i, j]];
58 let sigmoid = 1.0 / (1.0 + (-logit).exp());
59 grad[[i, j]] = (sigmoid - target) / n;
60 }
61 }
62 Ok(grad)
63 }
64}