tensorlogic_train/loss/
tverskyloss_traits.rs1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::TverskyLoss;
15
16impl Default for TverskyLoss {
17 fn default() -> Self {
18 Self {
19 alpha: 0.5,
20 beta: 0.5,
21 smooth: 1.0,
22 }
23 }
24}
25
26impl Loss for TverskyLoss {
27 fn compute(
28 &self,
29 predictions: &ArrayView<f64, Ix2>,
30 targets: &ArrayView<f64, Ix2>,
31 ) -> TrainResult<f64> {
32 if predictions.shape() != targets.shape() {
33 return Err(TrainError::LossError(format!(
34 "Shape mismatch: predictions {:?} vs targets {:?}",
35 predictions.shape(),
36 targets.shape()
37 )));
38 }
39 let mut true_pos = 0.0;
40 let mut false_pos = 0.0;
41 let mut false_neg = 0.0;
42 for i in 0..predictions.nrows() {
43 for j in 0..predictions.ncols() {
44 let pred = predictions[[i, j]];
45 let target = targets[[i, j]];
46 true_pos += pred * target;
47 false_pos += pred * (1.0 - target);
48 false_neg += (1.0 - pred) * target;
49 }
50 }
51 let tversky_index = (true_pos + self.smooth)
52 / (true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth);
53 Ok(1.0 - tversky_index)
54 }
55 fn gradient(
56 &self,
57 predictions: &ArrayView<f64, Ix2>,
58 targets: &ArrayView<f64, Ix2>,
59 ) -> TrainResult<Array<f64, Ix2>> {
60 if predictions.shape() != targets.shape() {
61 return Err(TrainError::LossError(format!(
62 "Shape mismatch: predictions {:?} vs targets {:?}",
63 predictions.shape(),
64 targets.shape()
65 )));
66 }
67 let mut true_pos = 0.0;
68 let mut false_pos = 0.0;
69 let mut false_neg = 0.0;
70 for i in 0..predictions.nrows() {
71 for j in 0..predictions.ncols() {
72 let pred = predictions[[i, j]];
73 let target = targets[[i, j]];
74 true_pos += pred * target;
75 false_pos += pred * (1.0 - target);
76 false_neg += (1.0 - pred) * target;
77 }
78 }
79 let denominator = true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth;
80 let numerator = true_pos + self.smooth;
81 let mut grad = Array::zeros(predictions.raw_dim());
82 for i in 0..predictions.nrows() {
83 for j in 0..predictions.ncols() {
84 let target = targets[[i, j]];
85 let d_tp = target;
86 let d_fp = self.alpha * (1.0 - target);
87 let d_fn = -self.beta * target;
88 grad[[i, j]] = -(d_tp * denominator - numerator * (d_tp + d_fp + d_fn))
89 / (denominator * denominator);
90 }
91 }
92 Ok(grad)
93 }
94}