tiny_solver/
loss_functions.rs1use core::f64;
2
3pub enum LossFunc {
4 HuberLoss,
5}
6
7pub trait Loss: Send + Sync {
8 fn evaluate(&self, s: f64) -> [f64; 3];
9}
10
11#[derive(Debug, Clone)]
12pub struct HuberLoss {
13 scale: f64,
14 scale2: f64,
15}
16impl HuberLoss {
17 pub fn new(scale: f64) -> Self {
18 if scale <= 0.0 {
19 panic!("scale needs to be larger than zero");
20 }
21 HuberLoss {
22 scale,
23 scale2: scale * scale,
24 }
25 }
26}
27
28impl Loss for HuberLoss {
29 fn evaluate(&self, s: f64) -> [f64; 3] {
30 if s > self.scale2 {
31 let r = s.sqrt();
34 let rho1 = (self.scale / r).max(f64::MIN);
35 [2.0 * self.scale * r - self.scale2, rho1, -rho1 / (2.0 * s)]
36 } else {
37 [s, 1.0, 0.0]
39 }
40 }
41}
42
43pub struct CauchyLoss {
44 scale2: f64,
45 c: f64,
46}
47impl CauchyLoss {
48 pub fn new(scale: f64) -> Self {
49 let scale2 = scale * scale;
50 CauchyLoss {
51 scale2,
52 c: 1.0 / scale2,
53 }
54 }
55}
56impl Loss for CauchyLoss {
57 fn evaluate(&self, s: f64) -> [f64; 3] {
58 let sum = 1.0 + s * self.c;
59 let inv = 1.0 / sum;
60 [
62 self.scale2 * sum.log2(),
63 inv.max(f64::MIN),
64 -self.c * (inv * inv),
65 ]
66 }
67}
68
69pub struct ArctanLoss {
70 tolerance: f64,
71 inv_of_squared_tolerance: f64,
72}
73
74impl ArctanLoss {
75 pub fn new(tolerance: f64) -> Self {
76 if tolerance <= 0.0 {
77 panic!("scale needs to be larger than zero");
78 }
79 ArctanLoss {
80 tolerance,
81 inv_of_squared_tolerance: 1.0 / (tolerance * tolerance),
82 }
83 }
84}
85
86impl Loss for ArctanLoss {
87 fn evaluate(&self, s: f64) -> [f64; 3] {
88 let sum = 1.0 + s * s * self.inv_of_squared_tolerance;
89 let inv = 1.0 / sum;
90
91 [
92 self.tolerance * s.atan2(self.tolerance),
93 inv.max(f64::MIN),
94 -2.0 * s * self.inv_of_squared_tolerance * (inv * inv),
95 ]
96 }
97}