Skip to main content

sidereon_core/astro/math/
robust.rs

1//! Robust M-estimation primitives for iteratively reweighted least squares.
2//!
3//! These are the per-outer-iteration reweighting pieces of a Huber IRLS loop:
4//! a median-absolute-deviation scale estimate and the Huber weight function.
5//! They are deliberately pure `f64` arithmetic (abs, compare, divide, sort by
6//! [`f64::total_cmp`]) with no fused-multiply-add and no contraction, so the
7//! per-iteration weight vector is bit-reproducible against an explicit
8//! outer-loop reference recipe. The trust-region linear-algebra step that
9//! consumes the weights is BLAS-bound and is NOT a 0-ULP target.
10
11/// The default Huber tuning constant. Residuals scaled below this (in units of
12/// the robust scale) keep full weight; larger ones are down-weighted as
13/// `k / |u|`. `1.345` gives ~95% efficiency at the Gaussian model.
14pub const HUBER_K: f64 = 1.345;
15
16/// The MAD-to-sigma consistency constant for a normal distribution,
17/// `1 / Phi^-1(3/4)`. Multiplying the median absolute deviation by this makes
18/// it a consistent estimator of the standard deviation under normality.
19pub const MAD_NORMAL_CONST: f64 = 1.4826;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
22pub enum RobustError {
23    #[error("invalid robust statistic {field}: {reason}")]
24    InvalidInput {
25        field: &'static str,
26        reason: &'static str,
27    },
28}
29
30impl RobustError {
31    pub const fn field(&self) -> &'static str {
32        match self {
33            Self::InvalidInput { field, .. } => field,
34        }
35    }
36
37    pub const fn reason(&self) -> &'static str {
38        match self {
39            Self::InvalidInput { reason, .. } => reason,
40        }
41    }
42}
43
44/// The median of `values`, computed on a `total_cmp` sort so the order (and
45/// thus the result for an even count, which averages the two central values)
46/// is deterministic. An empty slice yields `0.0`. The averaging of the two
47/// central elements is a single `(a + b) / 2.0`, no FMA.
48pub fn median(values: &[f64]) -> Result<f64, RobustError> {
49    validate_finite_slice(values, "values")?;
50    if values.is_empty() {
51        return Ok(0.0);
52    }
53    let mut v: Vec<f64> = values.to_vec();
54    v.sort_by(|a, b| a.total_cmp(b));
55    let n = v.len();
56    if n % 2 == 1 {
57        Ok(v[n / 2])
58    } else {
59        Ok((v[n / 2 - 1] + v[n / 2]) / 2.0)
60    }
61}
62
63/// The median-absolute-deviation scale of `residuals`, scaled to a normal-sigma
64/// estimate and floored at `scale_floor`.
65///
66/// `s = max(scale_floor, MAD_NORMAL_CONST * median(|r_i - median(r)|))`. The
67/// floor prevents a near-perfect fit (MAD approaching zero) from blowing up the
68/// scaled residuals `u_i = r_i / s` and spuriously down-weighting every
69/// observation. Both medians use [`median`]'s `total_cmp` sort.
70pub fn mad_scale(residuals: &[f64], scale_floor: f64) -> Result<f64, RobustError> {
71    validate_finite_positive(scale_floor, "scale_floor")?;
72    let med = median(residuals)?;
73    let abs_dev: Vec<f64> = residuals.iter().map(|r| (r - med).abs()).collect();
74    let mad = median(&abs_dev)?;
75    let scaled = MAD_NORMAL_CONST * mad;
76    if scaled > scale_floor {
77        Ok(scaled)
78    } else {
79        Ok(scale_floor)
80    }
81}
82
83/// The Huber weight for a scaled residual `u = r / s`.
84///
85/// `w(u) = 1` for `|u| <= k` and `w(u) = k / |u|` otherwise (the Huber
86/// `psi(u) / u` form, always in `(0, 1]`). At `u == 0` the weight is `1`. This
87/// is the multiplier applied on top of any base (elevation) weight to obtain the
88/// effective per-observation weight of the current outer iteration.
89pub fn huber_weight(u: f64, k: f64) -> f64 {
90    let a = u.abs();
91    if a <= k {
92        1.0
93    } else {
94        k / a
95    }
96}
97
98fn validate_finite_slice(values: &[f64], field: &'static str) -> Result<(), RobustError> {
99    if values.iter().all(|value| value.is_finite()) {
100        Ok(())
101    } else {
102        Err(invalid_input(field, "not finite"))
103    }
104}
105
106fn validate_finite_positive(value: f64, field: &'static str) -> Result<(), RobustError> {
107    if !value.is_finite() {
108        Err(invalid_input(field, "not finite"))
109    } else if value <= 0.0 {
110        Err(invalid_input(field, "not positive"))
111    } else {
112        Ok(())
113    }
114}
115
116fn invalid_input(field: &'static str, reason: &'static str) -> RobustError {
117    RobustError::InvalidInput { field, reason }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn median_odd_even() {
126        assert_eq!(median(&[3.0, 1.0, 2.0]).unwrap(), 2.0);
127        assert_eq!(median(&[1.0, 2.0, 3.0, 4.0]).unwrap(), 2.5);
128        assert_eq!(median(&[]).unwrap(), 0.0);
129    }
130
131    #[test]
132    fn median_rejects_nonfinite_sample() {
133        assert_eq!(
134            median(&[1.0, f64::NAN]),
135            Err(RobustError::InvalidInput {
136                field: "values",
137                reason: "not finite"
138            })
139        );
140    }
141
142    #[test]
143    fn huber_weight_breaks_at_k() {
144        assert_eq!(huber_weight(0.0, HUBER_K), 1.0);
145        assert_eq!(huber_weight(HUBER_K, HUBER_K), 1.0);
146        let w = huber_weight(2.0 * HUBER_K, HUBER_K);
147        assert!((w - 0.5).abs() < 1e-15);
148    }
149
150    #[test]
151    fn mad_scale_floored() {
152        // All-equal residuals give MAD 0, so the floor governs.
153        assert_eq!(mad_scale(&[5.0, 5.0, 5.0], 0.25).unwrap(), 0.25);
154    }
155
156    #[test]
157    fn mad_scale_rejects_nonfinite_sample() {
158        assert_eq!(
159            mad_scale(&[5.0, f64::INFINITY], 0.25),
160            Err(RobustError::InvalidInput {
161                field: "values",
162                reason: "not finite"
163            })
164        );
165    }
166}