Skip to main content

tensorlogic_train/loss/
tverskyloss_traits.rs

1//! # TverskyLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `TverskyLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::TverskyLoss;
15
16impl Default for TverskyLoss {
17    fn default() -> Self {
18        Self {
19            alpha: 0.5,
20            beta: 0.5,
21            smooth: 1.0,
22        }
23    }
24}
25
26impl Loss for TverskyLoss {
27    fn compute(
28        &self,
29        predictions: &ArrayView<f64, Ix2>,
30        targets: &ArrayView<f64, Ix2>,
31    ) -> TrainResult<f64> {
32        if predictions.shape() != targets.shape() {
33            return Err(TrainError::LossError(format!(
34                "Shape mismatch: predictions {:?} vs targets {:?}",
35                predictions.shape(),
36                targets.shape()
37            )));
38        }
39        let mut true_pos = 0.0;
40        let mut false_pos = 0.0;
41        let mut false_neg = 0.0;
42        for i in 0..predictions.nrows() {
43            for j in 0..predictions.ncols() {
44                let pred = predictions[[i, j]];
45                let target = targets[[i, j]];
46                true_pos += pred * target;
47                false_pos += pred * (1.0 - target);
48                false_neg += (1.0 - pred) * target;
49            }
50        }
51        let tversky_index = (true_pos + self.smooth)
52            / (true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth);
53        Ok(1.0 - tversky_index)
54    }
55    fn gradient(
56        &self,
57        predictions: &ArrayView<f64, Ix2>,
58        targets: &ArrayView<f64, Ix2>,
59    ) -> TrainResult<Array<f64, Ix2>> {
60        if predictions.shape() != targets.shape() {
61            return Err(TrainError::LossError(format!(
62                "Shape mismatch: predictions {:?} vs targets {:?}",
63                predictions.shape(),
64                targets.shape()
65            )));
66        }
67        let mut true_pos = 0.0;
68        let mut false_pos = 0.0;
69        let mut false_neg = 0.0;
70        for i in 0..predictions.nrows() {
71            for j in 0..predictions.ncols() {
72                let pred = predictions[[i, j]];
73                let target = targets[[i, j]];
74                true_pos += pred * target;
75                false_pos += pred * (1.0 - target);
76                false_neg += (1.0 - pred) * target;
77            }
78        }
79        let denominator = true_pos + self.alpha * false_pos + self.beta * false_neg + self.smooth;
80        let numerator = true_pos + self.smooth;
81        let mut grad = Array::zeros(predictions.raw_dim());
82        for i in 0..predictions.nrows() {
83            for j in 0..predictions.ncols() {
84                let target = targets[[i, j]];
85                let d_tp = target;
86                let d_fp = self.alpha * (1.0 - target);
87                let d_fn = -self.beta * target;
88                grad[[i, j]] = -(d_tp * denominator - numerator * (d_tp + d_fp + d_fn))
89                    / (denominator * denominator);
90            }
91        }
92        Ok(grad)
93    }
94}