Skip to main content

tensorlogic_train/loss/
diceloss_traits.rs

1//! # DiceLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `DiceLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::DiceLoss;
15
16impl Default for DiceLoss {
17    fn default() -> Self {
18        Self { smooth: 1.0 }
19    }
20}
21
22impl Loss for DiceLoss {
23    fn compute(
24        &self,
25        predictions: &ArrayView<f64, Ix2>,
26        targets: &ArrayView<f64, Ix2>,
27    ) -> TrainResult<f64> {
28        if predictions.shape() != targets.shape() {
29            return Err(TrainError::LossError(format!(
30                "Shape mismatch: predictions {:?} vs targets {:?}",
31                predictions.shape(),
32                targets.shape()
33            )));
34        }
35        let mut intersection = 0.0;
36        let mut pred_sum = 0.0;
37        let mut target_sum = 0.0;
38        for i in 0..predictions.nrows() {
39            for j in 0..predictions.ncols() {
40                let pred = predictions[[i, j]];
41                let target = targets[[i, j]];
42                intersection += pred * target;
43                pred_sum += pred;
44                target_sum += target;
45            }
46        }
47        let dice_coef = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth);
48        Ok(1.0 - dice_coef)
49    }
50    fn gradient(
51        &self,
52        predictions: &ArrayView<f64, Ix2>,
53        targets: &ArrayView<f64, Ix2>,
54    ) -> TrainResult<Array<f64, Ix2>> {
55        if predictions.shape() != targets.shape() {
56            return Err(TrainError::LossError(format!(
57                "Shape mismatch: predictions {:?} vs targets {:?}",
58                predictions.shape(),
59                targets.shape()
60            )));
61        }
62        let mut intersection = 0.0;
63        let mut pred_sum = 0.0;
64        let mut target_sum = 0.0;
65        for i in 0..predictions.nrows() {
66            for j in 0..predictions.ncols() {
67                intersection += predictions[[i, j]] * targets[[i, j]];
68                pred_sum += predictions[[i, j]];
69                target_sum += targets[[i, j]];
70            }
71        }
72        let denominator = pred_sum + target_sum + self.smooth;
73        let numerator = 2.0 * intersection + self.smooth;
74        let mut grad = Array::zeros(predictions.raw_dim());
75        for i in 0..predictions.nrows() {
76            for j in 0..predictions.ncols() {
77                let target = targets[[i, j]];
78                grad[[i, j]] =
79                    -2.0 * (target * denominator - numerator) / (denominator * denominator);
80            }
81        }
82        Ok(grad)
83    }
84}