Skip to main content

sidereon_core/astro/math/
robust.rs

1//! Robust M-estimation primitives for iteratively reweighted least squares.
2//!
3//! These are the per-outer-iteration reweighting pieces of a Huber IRLS loop:
4//! a median-absolute-deviation scale estimate and the Huber weight function.
5//! They are deliberately pure `f64` arithmetic (abs, compare, divide, sort by
6//! [`f64::total_cmp`]) with no fused-multiply-add and no contraction, so the
7//! per-iteration weight vector is bit-reproducible against an explicit
8//! outer-loop reference recipe. The trust-region linear-algebra step that
9//! consumes the weights is BLAS-bound and is NOT a 0-ULP target.
10
11/// The default Huber tuning constant. Residuals scaled below this (in units of
12/// the robust scale) keep full weight; larger ones are down-weighted as
13/// `k / |u|`. `1.345` gives ~95% efficiency at the Gaussian model.
14pub const HUBER_K: f64 = 1.345;
15
16/// The MAD-to-sigma consistency constant for a normal distribution,
17/// `1 / Phi^-1(3/4)`. Multiplying the median absolute deviation by this makes
18/// it a consistent estimator of the standard deviation under normality.
19pub const MAD_NORMAL_CONST: f64 = 1.4826;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
22pub enum RobustError {
23    #[error("invalid robust statistic {field}: {reason}")]
24    InvalidInput {
25        field: &'static str,
26        reason: &'static str,
27    },
28}
29
30impl RobustError {
31    pub const fn field(&self) -> &'static str {
32        match self {
33            Self::InvalidInput { field, .. } => field,
34        }
35    }
36
37    pub const fn reason(&self) -> &'static str {
38        match self {
39            Self::InvalidInput { reason, .. } => reason,
40        }
41    }
42}
43
44/// The median of `values`, computed on a `total_cmp` sort so the order (and
45/// thus the result for an even count, which averages the two central values)
46/// is deterministic. An empty slice yields `0.0`. The averaging of the two
47/// central elements is a single `(a + b) / 2.0`, no FMA.
48pub fn median(values: &[f64]) -> Result<f64, RobustError> {
49    validate_finite_slice(values, "values")?;
50    if values.is_empty() {
51        return Ok(0.0);
52    }
53    let mut v: Vec<f64> = values.to_vec();
54    Ok(median_sorting_in_place(&mut v).unwrap_or(0.0))
55}
56
57/// The shared median kernel: sorts `values` in place by `total_cmp` and
58/// returns the middle element (odd count) or the `(a + b) / 2.0` average of
59/// the two central elements (even count, no FMA). `None` for an empty slice.
60/// No finiteness validation; callers own their input contracts.
61pub(crate) fn median_sorting_in_place(values: &mut [f64]) -> Option<f64> {
62    if values.is_empty() {
63        return None;
64    }
65    values.sort_by(|a, b| a.total_cmp(b));
66    let n = values.len();
67    if n % 2 == 1 {
68        Some(values[n / 2])
69    } else {
70        Some((values[n / 2 - 1] + values[n / 2]) / 2.0)
71    }
72}
73
74/// The median-absolute-deviation scale of `residuals`, scaled to a normal-sigma
75/// estimate and floored at `scale_floor`.
76///
77/// `s = max(scale_floor, MAD_NORMAL_CONST * median(|r_i - median(r)|))`. The
78/// floor prevents a near-perfect fit (MAD approaching zero) from blowing up the
79/// scaled residuals `u_i = r_i / s` and spuriously down-weighting every
80/// observation. Both medians use [`median`]'s `total_cmp` sort.
81pub fn mad_scale(residuals: &[f64], scale_floor: f64) -> Result<f64, RobustError> {
82    validate_finite_positive(scale_floor, "scale_floor")?;
83    let med = median(residuals)?;
84    let abs_dev: Vec<f64> = residuals.iter().map(|r| (r - med).abs()).collect();
85    let mad = median(&abs_dev)?;
86    let scaled = MAD_NORMAL_CONST * mad;
87    if scaled > scale_floor {
88        Ok(scaled)
89    } else {
90        Ok(scale_floor)
91    }
92}
93
94/// The Huber weight for a scaled residual `u = r / s`.
95///
96/// `w(u) = 1` for `|u| <= k` and `w(u) = k / |u|` otherwise (the Huber
97/// `psi(u) / u` form, always in `(0, 1]`). At `u == 0` the weight is `1`. This
98/// is the multiplier applied on top of any base (elevation) weight to obtain the
99/// effective per-observation weight of the current outer iteration.
100pub fn huber_weight(u: f64, k: f64) -> f64 {
101    let a = u.abs();
102    if a <= k {
103        1.0
104    } else {
105        k / a
106    }
107}
108
109fn validate_finite_slice(values: &[f64], field: &'static str) -> Result<(), RobustError> {
110    if values.iter().all(|value| value.is_finite()) {
111        Ok(())
112    } else {
113        Err(invalid_input(field, "not finite"))
114    }
115}
116
117fn validate_finite_positive(value: f64, field: &'static str) -> Result<(), RobustError> {
118    if !value.is_finite() {
119        Err(invalid_input(field, "not finite"))
120    } else if value <= 0.0 {
121        Err(invalid_input(field, "not positive"))
122    } else {
123        Ok(())
124    }
125}
126
127fn invalid_input(field: &'static str, reason: &'static str) -> RobustError {
128    RobustError::InvalidInput { field, reason }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn median_odd_even() {
137        assert_eq!(median(&[3.0, 1.0, 2.0]).unwrap(), 2.0);
138        assert_eq!(median(&[1.0, 2.0, 3.0, 4.0]).unwrap(), 2.5);
139        assert_eq!(median(&[]).unwrap(), 0.0);
140    }
141
142    #[test]
143    fn median_rejects_nonfinite_sample() {
144        assert_eq!(
145            median(&[1.0, f64::NAN]),
146            Err(RobustError::InvalidInput {
147                field: "values",
148                reason: "not finite"
149            })
150        );
151    }
152
153    #[test]
154    fn huber_weight_breaks_at_k() {
155        assert_eq!(huber_weight(0.0, HUBER_K), 1.0);
156        assert_eq!(huber_weight(HUBER_K, HUBER_K), 1.0);
157        let w = huber_weight(2.0 * HUBER_K, HUBER_K);
158        assert!((w - 0.5).abs() < 1e-15);
159    }
160
161    #[test]
162    fn mad_scale_floored() {
163        // All-equal residuals give MAD 0, so the floor governs.
164        assert_eq!(mad_scale(&[5.0, 5.0, 5.0], 0.25).unwrap(), 0.25);
165    }
166
167    #[test]
168    fn mad_scale_rejects_nonfinite_sample() {
169        assert_eq!(
170            mad_scale(&[5.0, f64::INFINITY], 0.25),
171            Err(RobustError::InvalidInput {
172                field: "values",
173                reason: "not finite"
174            })
175        );
176    }
177}