tensorlogic_train/loss/
diceloss_traits.rs1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::DiceLoss;
15
16impl Default for DiceLoss {
17 fn default() -> Self {
18 Self { smooth: 1.0 }
19 }
20}
21
22impl Loss for DiceLoss {
23 fn compute(
24 &self,
25 predictions: &ArrayView<f64, Ix2>,
26 targets: &ArrayView<f64, Ix2>,
27 ) -> TrainResult<f64> {
28 if predictions.shape() != targets.shape() {
29 return Err(TrainError::LossError(format!(
30 "Shape mismatch: predictions {:?} vs targets {:?}",
31 predictions.shape(),
32 targets.shape()
33 )));
34 }
35 let mut intersection = 0.0;
36 let mut pred_sum = 0.0;
37 let mut target_sum = 0.0;
38 for i in 0..predictions.nrows() {
39 for j in 0..predictions.ncols() {
40 let pred = predictions[[i, j]];
41 let target = targets[[i, j]];
42 intersection += pred * target;
43 pred_sum += pred;
44 target_sum += target;
45 }
46 }
47 let dice_coef = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth);
48 Ok(1.0 - dice_coef)
49 }
50 fn gradient(
51 &self,
52 predictions: &ArrayView<f64, Ix2>,
53 targets: &ArrayView<f64, Ix2>,
54 ) -> TrainResult<Array<f64, Ix2>> {
55 if predictions.shape() != targets.shape() {
56 return Err(TrainError::LossError(format!(
57 "Shape mismatch: predictions {:?} vs targets {:?}",
58 predictions.shape(),
59 targets.shape()
60 )));
61 }
62 let mut intersection = 0.0;
63 let mut pred_sum = 0.0;
64 let mut target_sum = 0.0;
65 for i in 0..predictions.nrows() {
66 for j in 0..predictions.ncols() {
67 intersection += predictions[[i, j]] * targets[[i, j]];
68 pred_sum += predictions[[i, j]];
69 target_sum += targets[[i, j]];
70 }
71 }
72 let denominator = pred_sum + target_sum + self.smooth;
73 let numerator = 2.0 * intersection + self.smooth;
74 let mut grad = Array::zeros(predictions.raw_dim());
75 for i in 0..predictions.nrows() {
76 for j in 0..predictions.ncols() {
77 let target = targets[[i, j]];
78 grad[[i, j]] =
79 -2.0 * (target * denominator - numerator) / (denominator * denominator);
80 }
81 }
82 Ok(grad)
83 }
84}