tiny_solver/
loss_functions.rs

1use core::f64;
2
3pub enum LossFunc {
4    HuberLoss,
5}
6
7pub trait Loss: Send + Sync {
8    fn evaluate(&self, s: f64) -> [f64; 3];
9}
10
11#[derive(Debug, Clone)]
12pub struct HuberLoss {
13    scale: f64,
14    scale2: f64,
15}
16impl HuberLoss {
17    pub fn new(scale: f64) -> Self {
18        if scale <= 0.0 {
19            panic!("scale needs to be larger than zero");
20        }
21        HuberLoss {
22            scale,
23            scale2: scale * scale,
24        }
25    }
26}
27
28impl Loss for HuberLoss {
29    fn evaluate(&self, s: f64) -> [f64; 3] {
30        if s > self.scale2 {
31            // Outlier region.
32            // 'r' is always positive.
33            let r = s.sqrt();
34            let rho1 = (self.scale / r).max(f64::MIN);
35            [2.0 * self.scale * r - self.scale2, rho1, -rho1 / (2.0 * s)]
36        } else {
37            // Inlier region.
38            [s, 1.0, 0.0]
39        }
40    }
41}
42
43pub struct CauchyLoss {
44    scale2: f64,
45    c: f64,
46}
47impl CauchyLoss {
48    pub fn new(scale: f64) -> Self {
49        let scale2 = scale * scale;
50        CauchyLoss {
51            scale2,
52            c: 1.0 / scale2,
53        }
54    }
55}
56impl Loss for CauchyLoss {
57    fn evaluate(&self, s: f64) -> [f64; 3] {
58        let sum = 1.0 + s * self.c;
59        let inv = 1.0 / sum;
60        // 'sum' and 'inv' are always positive, assuming that 's' is.
61        [
62            self.scale2 * sum.log2(),
63            inv.max(f64::MIN),
64            -self.c * (inv * inv),
65        ]
66    }
67}
68
69pub struct ArctanLoss {
70    tolerance: f64,
71    inv_of_squared_tolerance: f64,
72}
73
74impl ArctanLoss {
75    pub fn new(tolerance: f64) -> Self {
76        if tolerance <= 0.0 {
77            panic!("scale needs to be larger than zero");
78        }
79        ArctanLoss {
80            tolerance,
81            inv_of_squared_tolerance: 1.0 / (tolerance * tolerance),
82        }
83    }
84}
85
86impl Loss for ArctanLoss {
87    fn evaluate(&self, s: f64) -> [f64; 3] {
88        let sum = 1.0 + s * s * self.inv_of_squared_tolerance;
89        let inv = 1.0 / sum;
90
91        [
92            self.tolerance * s.atan2(self.tolerance),
93            inv.max(f64::MIN),
94            -2.0 * s * self.inv_of_squared_tolerance * (inv * inv),
95        ]
96    }
97}