tensorrs/loss/
binarycrossentropy.rs

1use crate::linalg::Matrix;
2use crate::loss::Loss;
3use crate::Float;
4
5pub struct BCE<T: Float>(T);
6
7impl<T: Float> BCE<T> {
8    pub fn new(_data_type: T) -> Self {
9        Self(_data_type)
10    }
11}
12
13impl<T: Float> Loss<T> for BCE<T> {
14    fn call(&self, output: &Matrix<T>, target: &Matrix<T>) -> T {
15        if output.shape() != target.shape() {
16            panic!("!!!Size of output matrix and target must be equal!!!\nOutput size:{:?} Target size: {:?}", output.shape(), target.shape())
17        }
18        let n = output.data.len();
19        if n == 0 {
20            return T::default();
21        }
22
23        let epsilon = T::f32_f64(1e-7, 1e-15);
24        let output_clamped = output.max(epsilon).min(T::one() - epsilon);
25
26        let a = target & &output_clamped.ln();
27        let b = target.map(|x| T::one() - x) & &output_clamped.map(|z| T::one() - z).ln();
28        let loss = -(a + &b);
29        loss.sum() / T::from_usize(n)
30    }
31
32    fn gradient(&self, output: &Matrix<T>, target: &Matrix<T>) -> Matrix<T> {
33        let grads = output.zip_with(target, |y_pred, y_true| {
34            y_true / y_pred - (T::one() - y_true) / (T::one() - y_pred)
35        });
36        let n = grads.data.len();
37        grads * (T::one() / T::from_usize(n))
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use crate::linalg::Matrix;
44    use crate::loss::{Loss, BCE};
45    use crate::{matrix, DataType};
46
47    #[test]
48    fn call() {
49        let a = matrix![[1.0, 0.0]];
50        let b = matrix![[0.5, 0.5]];
51
52        let bce = BCE::new(DataType::f64());
53        println!("{}", bce.call(&b, &a));
54    }
55
56    #[test]
57    fn grad() {
58        let a = matrix![[1.0, 0.0]];
59        let b = matrix![[0.5, 0.5]];
60
61        let bce = BCE::new(DataType::f64());
62        println!("{}", bce.gradient(&b, &a));
63    }
64
65    #[test]
66    fn help() {
67        let tar = matrix![[1.0, 0.0, 1.0, 0.0]];
68        let out = matrix![[0.9, 0.1, 0.8, 0.2]];
69        let bce = BCE::new(DataType::f32());
70        println!("{}", bce.call(&out, &tar));
71
72        println!("{}", bce.gradient(&out, &tar));
73    }
74}