Skip to main content

tensorlogic_train/loss/
focalloss_traits.rs

1//! # FocalLoss - Trait Implementations
2//!
3//! This module contains trait implementations for `FocalLoss`.
4//!
5//! ## Implemented Traits
6//!
7//! - `Default`
8//! - `Loss`
9
10use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::FocalLoss;
15
16impl Default for FocalLoss {
17    fn default() -> Self {
18        Self {
19            alpha: 0.25,
20            gamma: 2.0,
21            epsilon: 1e-10,
22        }
23    }
24}
25
26impl Loss for FocalLoss {
27    fn compute(
28        &self,
29        predictions: &ArrayView<f64, Ix2>,
30        targets: &ArrayView<f64, Ix2>,
31    ) -> TrainResult<f64> {
32        if predictions.shape() != targets.shape() {
33            return Err(TrainError::LossError(format!(
34                "Shape mismatch: predictions {:?} vs targets {:?}",
35                predictions.shape(),
36                targets.shape()
37            )));
38        }
39        let n = predictions.nrows() as f64;
40        let mut total_loss = 0.0;
41        for i in 0..predictions.nrows() {
42            for j in 0..predictions.ncols() {
43                let pred = predictions[[i, j]]
44                    .max(self.epsilon)
45                    .min(1.0 - self.epsilon);
46                let target = targets[[i, j]];
47                if target > 0.5 {
48                    let focal_weight = (1.0 - pred).powf(self.gamma);
49                    total_loss -= self.alpha * focal_weight * pred.ln();
50                } else {
51                    let focal_weight = pred.powf(self.gamma);
52                    total_loss -= (1.0 - self.alpha) * focal_weight * (1.0 - pred).ln();
53                }
54            }
55        }
56        Ok(total_loss / n)
57    }
58    fn gradient(
59        &self,
60        predictions: &ArrayView<f64, Ix2>,
61        targets: &ArrayView<f64, Ix2>,
62    ) -> TrainResult<Array<f64, Ix2>> {
63        if predictions.shape() != targets.shape() {
64            return Err(TrainError::LossError(format!(
65                "Shape mismatch: predictions {:?} vs targets {:?}",
66                predictions.shape(),
67                targets.shape()
68            )));
69        }
70        let n = predictions.nrows() as f64;
71        let mut grad = Array::zeros(predictions.raw_dim());
72        for i in 0..predictions.nrows() {
73            for j in 0..predictions.ncols() {
74                let pred = predictions[[i, j]]
75                    .max(self.epsilon)
76                    .min(1.0 - self.epsilon);
77                let target = targets[[i, j]];
78                if target > 0.5 {
79                    let focal_weight = (1.0 - pred).powf(self.gamma);
80                    let d_focal = self.gamma * (1.0 - pred).powf(self.gamma - 1.0);
81                    grad[[i, j]] = -self.alpha * (focal_weight / pred - d_focal * pred.ln()) / n;
82                } else {
83                    let focal_weight = pred.powf(self.gamma);
84                    let d_focal = self.gamma * pred.powf(self.gamma - 1.0);
85                    grad[[i, j]] = -(1.0 - self.alpha)
86                        * (d_focal * (1.0 - pred).ln() - focal_weight / (1.0 - pred))
87                        / n;
88                }
89            }
90        }
91        Ok(grad)
92    }
93}