sidereon_core/astro/math/
robust.rs1pub const HUBER_K: f64 = 1.345;
15
16pub const MAD_NORMAL_CONST: f64 = 1.4826;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
22pub enum RobustError {
23 #[error("invalid robust statistic {field}: {reason}")]
24 InvalidInput {
25 field: &'static str,
26 reason: &'static str,
27 },
28}
29
30impl RobustError {
31 pub const fn field(&self) -> &'static str {
32 match self {
33 Self::InvalidInput { field, .. } => field,
34 }
35 }
36
37 pub const fn reason(&self) -> &'static str {
38 match self {
39 Self::InvalidInput { reason, .. } => reason,
40 }
41 }
42}
43
44pub fn median(values: &[f64]) -> Result<f64, RobustError> {
49 validate_finite_slice(values, "values")?;
50 if values.is_empty() {
51 return Ok(0.0);
52 }
53 let mut v: Vec<f64> = values.to_vec();
54 Ok(median_sorting_in_place(&mut v).unwrap_or(0.0))
55}
56
57pub(crate) fn median_sorting_in_place(values: &mut [f64]) -> Option<f64> {
62 if values.is_empty() {
63 return None;
64 }
65 values.sort_by(|a, b| a.total_cmp(b));
66 let n = values.len();
67 if n % 2 == 1 {
68 Some(values[n / 2])
69 } else {
70 Some((values[n / 2 - 1] + values[n / 2]) / 2.0)
71 }
72}
73
74pub fn mad_scale(residuals: &[f64], scale_floor: f64) -> Result<f64, RobustError> {
82 validate_finite_positive(scale_floor, "scale_floor")?;
83 let med = median(residuals)?;
84 let abs_dev: Vec<f64> = residuals.iter().map(|r| (r - med).abs()).collect();
85 let mad = median(&abs_dev)?;
86 let scaled = MAD_NORMAL_CONST * mad;
87 if scaled > scale_floor {
88 Ok(scaled)
89 } else {
90 Ok(scale_floor)
91 }
92}
93
94pub fn huber_weight(u: f64, k: f64) -> f64 {
101 let a = u.abs();
102 if a <= k {
103 1.0
104 } else {
105 k / a
106 }
107}
108
109fn validate_finite_slice(values: &[f64], field: &'static str) -> Result<(), RobustError> {
110 if values.iter().all(|value| value.is_finite()) {
111 Ok(())
112 } else {
113 Err(invalid_input(field, "not finite"))
114 }
115}
116
117fn validate_finite_positive(value: f64, field: &'static str) -> Result<(), RobustError> {
118 if !value.is_finite() {
119 Err(invalid_input(field, "not finite"))
120 } else if value <= 0.0 {
121 Err(invalid_input(field, "not positive"))
122 } else {
123 Ok(())
124 }
125}
126
127fn invalid_input(field: &'static str, reason: &'static str) -> RobustError {
128 RobustError::InvalidInput { field, reason }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn median_odd_even() {
137 assert_eq!(median(&[3.0, 1.0, 2.0]).unwrap(), 2.0);
138 assert_eq!(median(&[1.0, 2.0, 3.0, 4.0]).unwrap(), 2.5);
139 assert_eq!(median(&[]).unwrap(), 0.0);
140 }
141
142 #[test]
143 fn median_rejects_nonfinite_sample() {
144 assert_eq!(
145 median(&[1.0, f64::NAN]),
146 Err(RobustError::InvalidInput {
147 field: "values",
148 reason: "not finite"
149 })
150 );
151 }
152
153 #[test]
154 fn huber_weight_breaks_at_k() {
155 assert_eq!(huber_weight(0.0, HUBER_K), 1.0);
156 assert_eq!(huber_weight(HUBER_K, HUBER_K), 1.0);
157 let w = huber_weight(2.0 * HUBER_K, HUBER_K);
158 assert!((w - 0.5).abs() < 1e-15);
159 }
160
161 #[test]
162 fn mad_scale_floored() {
163 assert_eq!(mad_scale(&[5.0, 5.0, 5.0], 0.25).unwrap(), 0.25);
165 }
166
167 #[test]
168 fn mad_scale_rejects_nonfinite_sample() {
169 assert_eq!(
170 mad_scale(&[5.0, f64::INFINITY], 0.25),
171 Err(RobustError::InvalidInput {
172 field: "values",
173 reason: "not finite"
174 })
175 );
176 }
177}