tensorlogic_train/loss/
focalloss_traits.rs1use crate::{TrainError, TrainResult};
11use scirs2_core::ndarray::{Array, ArrayView, Ix2};
12
13use super::functions::Loss;
14use super::types::FocalLoss;
15
16impl Default for FocalLoss {
17 fn default() -> Self {
18 Self {
19 alpha: 0.25,
20 gamma: 2.0,
21 epsilon: 1e-10,
22 }
23 }
24}
25
26impl Loss for FocalLoss {
27 fn compute(
28 &self,
29 predictions: &ArrayView<f64, Ix2>,
30 targets: &ArrayView<f64, Ix2>,
31 ) -> TrainResult<f64> {
32 if predictions.shape() != targets.shape() {
33 return Err(TrainError::LossError(format!(
34 "Shape mismatch: predictions {:?} vs targets {:?}",
35 predictions.shape(),
36 targets.shape()
37 )));
38 }
39 let n = predictions.nrows() as f64;
40 let mut total_loss = 0.0;
41 for i in 0..predictions.nrows() {
42 for j in 0..predictions.ncols() {
43 let pred = predictions[[i, j]]
44 .max(self.epsilon)
45 .min(1.0 - self.epsilon);
46 let target = targets[[i, j]];
47 if target > 0.5 {
48 let focal_weight = (1.0 - pred).powf(self.gamma);
49 total_loss -= self.alpha * focal_weight * pred.ln();
50 } else {
51 let focal_weight = pred.powf(self.gamma);
52 total_loss -= (1.0 - self.alpha) * focal_weight * (1.0 - pred).ln();
53 }
54 }
55 }
56 Ok(total_loss / n)
57 }
58 fn gradient(
59 &self,
60 predictions: &ArrayView<f64, Ix2>,
61 targets: &ArrayView<f64, Ix2>,
62 ) -> TrainResult<Array<f64, Ix2>> {
63 if predictions.shape() != targets.shape() {
64 return Err(TrainError::LossError(format!(
65 "Shape mismatch: predictions {:?} vs targets {:?}",
66 predictions.shape(),
67 targets.shape()
68 )));
69 }
70 let n = predictions.nrows() as f64;
71 let mut grad = Array::zeros(predictions.raw_dim());
72 for i in 0..predictions.nrows() {
73 for j in 0..predictions.ncols() {
74 let pred = predictions[[i, j]]
75 .max(self.epsilon)
76 .min(1.0 - self.epsilon);
77 let target = targets[[i, j]];
78 if target > 0.5 {
79 let focal_weight = (1.0 - pred).powf(self.gamma);
80 let d_focal = self.gamma * (1.0 - pred).powf(self.gamma - 1.0);
81 grad[[i, j]] = -self.alpha * (focal_weight / pred - d_focal * pred.ln()) / n;
82 } else {
83 let focal_weight = pred.powf(self.gamma);
84 let d_focal = self.gamma * pred.powf(self.gamma - 1.0);
85 grad[[i, j]] = -(1.0 - self.alpha)
86 * (d_focal * (1.0 - pred).ln() - focal_weight / (1.0 - pred))
87 / n;
88 }
89 }
90 }
91 Ok(grad)
92 }
93}