1#![allow(non_snake_case)] use ndarray::{Array1, Array2};
18use serde::{Deserialize, Serialize};
19use statrs::distribution::{ContinuousCDF, Normal};
20
21use so_core::error::{Error, Result};
22use so_linalg::{inv, solve};
23use so_stats::median;
24
25#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
27pub enum LossFunction {
28 Huber { k: f64 },
30 Tukey { c: f64 },
32 Hampel { a: f64, b: f64, c: f64 },
34 Andrews { c: f64 },
36 LeastSquares,
38}
39
40impl LossFunction {
41 fn weight(&self, r: f64) -> f64 {
43 match self {
44 LossFunction::Huber { k } => {
45 if r.abs() <= *k {
46 1.0
47 } else {
48 k / r.abs()
49 }
50 }
51 LossFunction::Tukey { c } => {
52 if r.abs() <= *c {
53 let t = r / c;
54 (1.0 - t * t).powi(2)
55 } else {
56 0.0
57 }
58 }
59 LossFunction::Hampel { a, b, c } => {
60 let abs_r = r.abs();
61 if abs_r <= *a {
62 1.0
63 } else if abs_r <= *b {
64 a / abs_r
65 } else if abs_r <= *c {
66 a * (c - abs_r) / ((c - b) * abs_r)
67 } else {
68 0.0
69 }
70 }
71 LossFunction::Andrews { c } => {
72 let abs_r = r.abs();
73 if abs_r <= *c * std::f64::consts::PI {
74 if abs_r < 1e-12 {
75 1.0 } else {
77 (c * r.sin() / r).max(0.0)
78 }
79 } else {
80 0.0
81 }
82 }
83 LossFunction::LeastSquares => 1.0,
84 }
85 }
86
87 fn psi(&self, r: f64) -> f64 {
89 match self {
90 LossFunction::Huber { k } => {
91 if r.abs() <= *k {
92 r
93 } else {
94 k * r.signum()
95 }
96 }
97 LossFunction::Tukey { c } => {
98 if r.abs() <= *c {
99 let t = r / c;
100 r * (1.0 - t * t).powi(2)
101 } else {
102 0.0
103 }
104 }
105 LossFunction::Hampel { a, b, c } => {
106 let abs_r = r.abs();
107 if abs_r <= *a {
108 r
109 } else if abs_r <= *b {
110 a * r.signum()
111 } else if abs_r <= *c {
112 a * (c - abs_r) / (c - b) * r.signum()
113 } else {
114 0.0
115 }
116 }
117 LossFunction::Andrews { c } => {
118 if r.abs() <= *c * std::f64::consts::PI {
119 c * r.sin()
120 } else {
121 0.0
122 }
123 }
124 LossFunction::LeastSquares => r,
125 }
126 }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RobustRegressionResults {
132 pub coefficients: Array1<f64>,
134 pub standard_errors: Array1<f64>,
136 pub scale: f64,
138 pub iterations: usize,
140 pub weights: Array1<f64>,
142 pub breakdown_point: f64,
144 pub efficiency: f64,
146}
147
148#[derive(Clone)]
150pub struct MEstimator {
151 loss: LossFunction,
152 max_iter: usize,
153 tol: f64,
154 scale_est: ScaleEstimator,
155 tuning: TuningParameters,
156}
157
158#[derive(Debug, Clone, Copy)]
160pub enum ScaleEstimator {
161 MAD,
163 IQR,
165 SEstimate,
167 Fixed(f64),
169}
170
171#[derive(Debug, Clone, Copy)]
173pub struct TuningParameters {
174 pub breakdown_point: f64,
176 pub efficiency: f64,
178 pub delta: f64,
180}
181
182impl Default for TuningParameters {
183 fn default() -> Self {
184 Self {
185 breakdown_point: 0.5,
186 efficiency: 0.95,
187 delta: 1e-8,
188 }
189 }
190}
191
192impl MEstimator {
193 pub fn huber(k: f64) -> Self {
195 Self {
196 loss: LossFunction::Huber { k },
197 max_iter: 50,
198 tol: 1e-6,
199 scale_est: ScaleEstimator::MAD,
200 tuning: TuningParameters::default(),
201 }
202 }
203
204 pub fn tukey(c: f64) -> Self {
206 Self {
207 loss: LossFunction::Tukey { c },
208 max_iter: 50,
209 tol: 1e-6,
210 scale_est: ScaleEstimator::MAD,
211 tuning: TuningParameters::default(),
212 }
213 }
214
215 pub fn max_iterations(mut self, max_iter: usize) -> Self {
217 self.max_iter = max_iter;
218 self
219 }
220
221 pub fn tolerance(mut self, tol: f64) -> Self {
223 self.tol = tol;
224 self
225 }
226
227 pub fn scale_estimator(mut self, scale_est: ScaleEstimator) -> Self {
229 self.scale_est = scale_est;
230 self
231 }
232
233 pub fn tuning(mut self, tuning: TuningParameters) -> Self {
235 self.tuning = tuning;
236 self
237 }
238
239 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
241 let n = X.nrows();
242 let p = X.ncols();
243
244 if n <= p {
245 return Err(Error::DataError(
246 "Need more observations than predictors for robust regression".to_string(),
247 ));
248 }
249
250 let mut beta = self.initial_estimate(X, y)?;
252
253 let mut scale = self.initial_scale(X, y, &beta)?;
255
256 let mut iter = 0;
258 let mut converged = false;
259 let mut weights = Array1::ones(n);
260
261 while !converged && iter < self.max_iter {
262 iter += 1;
263
264 let beta_prev = beta.clone();
266
267 let residuals = y - X.dot(&beta);
269 let scaled_residuals = &residuals / scale;
270
271 for i in 0..n {
273 weights[i] = self.loss.weight(scaled_residuals[i]);
274 }
275
276 let W_sqrt = weights.mapv(|w| w.sqrt());
278 let X_weighted = X * W_sqrt.clone().insert_axis(ndarray::Axis(1));
279 let y_weighted = y * &W_sqrt;
280
281 beta = solve(
282 &X_weighted.t().dot(&X_weighted),
283 &X_weighted.t().dot(&y_weighted),
284 )
285 .map_err(|e| Error::LinearAlgebraError(format!("WLS solve failed: {}", e)))?;
286
287 if matches!(self.scale_est, ScaleEstimator::MAD | ScaleEstimator::IQR) {
289 scale = self.update_scale(&residuals, &weights);
290 }
291
292 let beta_diff = (&beta - &beta_prev).mapv(|x| x.abs());
294 let max_diff = beta_diff.iter().fold(0.0, |a, &b| f64::max(a, b));
295 converged = max_diff < self.tol;
296 }
297
298 let standard_errors = self.compute_standard_errors(X, y, &beta, scale, &weights)?;
300
301 let efficiency = self.compute_efficiency();
303 let breakdown_point = self.breakdown_point();
304
305 Ok(RobustRegressionResults {
306 coefficients: beta,
307 standard_errors,
308 scale,
309 iterations: iter,
310 weights,
311 breakdown_point,
312 efficiency,
313 })
314 }
315
316 fn initial_estimate(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<Array1<f64>> {
318 let lts = LeastTrimmedSquares::default();
320 match lts.fit(X, y) {
321 Ok(results) => Ok(results.coefficients),
322 Err(_) => {
323 solve(&X.t().dot(X), &X.t().dot(y)).map_err(|e| {
325 Error::LinearAlgebraError(format!("Initial estimate failed: {}", e))
326 })
327 }
328 }
329 }
330
331 fn initial_scale(&self, X: &Array2<f64>, y: &Array1<f64>, beta: &Array1<f64>) -> Result<f64> {
333 match self.scale_est {
334 ScaleEstimator::MAD => {
335 let residuals = y - X.dot(beta);
336 Ok(self.mad(&residuals))
337 }
338 ScaleEstimator::IQR => {
339 let residuals = y - X.dot(beta);
340 Ok(self.iqr_scale(&residuals))
341 }
342 ScaleEstimator::SEstimate => {
343 let s_est = SEstimator::default();
345 s_est.fit(X, y).map(|results| results.scale)
346 }
347 ScaleEstimator::Fixed(scale) => Ok(scale),
348 }
349 }
350
351 fn update_scale(&self, residuals: &Array1<f64>, weights: &Array1<f64>) -> f64 {
353 let sum_weights: f64 = weights.iter().sum();
355 if sum_weights < 1e-12 {
356 return self.mad(residuals); }
358
359 let weighted_sse: f64 = residuals
360 .iter()
361 .zip(weights.iter())
362 .map(|(&r, &w)| r * r * w)
363 .sum();
364
365 let scale = (weighted_sse / sum_weights).sqrt();
366 if scale < 1e-12 {
367 self.mad(residuals) } else {
369 scale
370 }
371 }
372
373 fn mad(&self, data: &Array1<f64>) -> f64 {
375 let med = median(data).unwrap_or(0.0);
376 let abs_dev: Array1<f64> = data.mapv(|x| (x - med).abs());
377 let mad = median(&abs_dev).unwrap_or(0.0);
378 let scale = mad / 0.6745; if scale < 1e-12 {
380 1.0 } else {
382 scale
383 }
384 }
385
386 fn iqr_scale(&self, data: &Array1<f64>) -> f64 {
388 use so_stats::quantile;
389 let q1 = quantile(data, 0.25).unwrap_or(0.0);
390 let q3 = quantile(data, 0.75).unwrap_or(0.0);
391 (q3 - q1) / 1.349 }
393
394 fn compute_standard_errors(
396 &self,
397 X: &Array2<f64>,
398 y: &Array1<f64>,
399 beta: &Array1<f64>,
400 scale: f64,
401 weights: &Array1<f64>,
402 ) -> Result<Array1<f64>> {
403 let n = X.nrows();
404 let p = X.ncols();
405
406 let W_sqrt = weights.mapv(|w| w.sqrt());
408 let X_weighted = X * W_sqrt.clone().insert_axis(ndarray::Axis(1));
409 let XtWX = X_weighted.t().dot(&X_weighted);
410
411 let XtWX_inv = inv(&XtWX)
412 .map_err(|e| Error::LinearAlgebraError(format!("Failed to invert X'WX: {}", e)))?;
413
414 let residuals = y - X.dot(beta);
416 let scaled_residuals = &residuals / scale;
417
418 let mut influence = Array1::<f64>::zeros(p);
420 for i in 0..n {
421 let psi = self.loss.psi(scaled_residuals[i]);
422 let xi = X.row(i);
423 influence = influence + xi.mapv(|x| x * psi);
424 }
425
426 let mut sandwich = Array2::zeros((p, p));
428 for i in 0..n {
429 let psi = self.loss.psi(scaled_residuals[i]);
430 let xi = X.row(i);
431 let outer = xi.t().dot(&xi).to_owned() * psi * psi;
432 sandwich += outer;
433 }
434
435 let cov = XtWX_inv.dot(&sandwich.dot(&XtWX_inv)) * scale * scale / n as f64;
436 let se = cov.diag().mapv(|x| x.sqrt());
437
438 Ok(se)
439 }
440
441 fn compute_efficiency(&self) -> f64 {
443 match self.loss {
445 LossFunction::Huber { k } => {
446 let normal = Normal::new(0.0, 1.0).unwrap();
447 let eff = 1.0 / (1.0 + 2.0 * (1.0 - normal.cdf(k)) / k.powi(2));
448 eff.min(1.0)
449 }
450 LossFunction::Tukey { c } => {
451 let _c2 = c * c;
453
454 if c >= 4.0 { 0.95 } else { 0.85 }
455 }
456 _ => 0.85, }
458 }
459
460 fn breakdown_point(&self) -> f64 {
462 match self.loss {
463 LossFunction::Huber { .. } => 0.0, LossFunction::Tukey { .. } => 0.5, LossFunction::Hampel { .. } => 0.5,
466 LossFunction::Andrews { .. } => 0.5,
467 LossFunction::LeastSquares => 0.0,
468 }
469 }
470}
471
472pub struct LeastTrimmedSquares {
474 coverage: f64,
475}
476
477impl Default for LeastTrimmedSquares {
478 fn default() -> Self {
479 Self { coverage: 0.5 }
480 }
481}
482
483impl LeastTrimmedSquares {
484 pub fn new(coverage: f64) -> Self {
486 Self { coverage }
487 }
488
489 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
491 let n = X.nrows();
492 let p = X.ncols();
493
494 if n <= p {
495 return Err(Error::DataError(
496 "Need more observations than predictors for LTS".to_string(),
497 ));
498 }
499
500 let h = (n as f64 * self.coverage).ceil() as usize;
501
502 let n_subsets = 500.min(n);
504 let mut best_sse = f64::INFINITY;
505 let mut best_beta = Array1::zeros(p);
506
507 let mut rng = rand::rng();
508
509 for _ in 0..n_subsets {
510 let subset_indices = rand::seq::index::sample(&mut rng, n, p + 1).into_vec();
512 let X_subset = X.select(ndarray::Axis(0), &subset_indices);
513 let y_subset = y.select(ndarray::Axis(0), &subset_indices);
514
515 if let Ok(beta) = solve(&X_subset.t().dot(&X_subset), &X_subset.t().dot(&y_subset)) {
517 let residuals = y - X.dot(&beta);
518 let mut squared_residuals: Vec<(f64, usize)> = residuals
519 .iter()
520 .enumerate()
521 .map(|(i, &r)| (r * r, i))
522 .collect();
523
524 squared_residuals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
525
526 let sse: f64 = squared_residuals[..h].iter().map(|(r2, _)| r2).sum();
527
528 if sse < best_sse {
529 best_sse = sse;
530 best_beta = beta;
531 }
532 }
533 }
534
535 let residuals = y - X.dot(&best_beta);
537 let mut squared_residuals: Vec<(f64, usize)> = residuals
538 .iter()
539 .enumerate()
540 .map(|(i, &r)| (r * r, i))
541 .collect();
542
543 squared_residuals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
544
545 let best_indices: Vec<usize> = squared_residuals[..h].iter().map(|(_, i)| *i).collect();
546 let X_best = X.select(ndarray::Axis(0), &best_indices);
547 let y_best = y.select(ndarray::Axis(0), &best_indices);
548
549 let final_beta = solve(&X_best.t().dot(&X_best), &X_best.t().dot(&y_best))
550 .map_err(|e| Error::LinearAlgebraError(format!("LTS final fit failed: {}", e)))?;
551
552 let scale = (best_sse / h as f64).sqrt();
554
555 let mut weights = Array1::zeros(n);
557 for &idx in &best_indices {
558 weights[idx] = 1.0;
559 }
560
561 Ok(RobustRegressionResults {
562 coefficients: final_beta,
563 standard_errors: Array1::zeros(p), scale,
565 iterations: n_subsets,
566 weights,
567 breakdown_point: 1.0 - self.coverage,
568 efficiency: 0.7, })
570 }
571}
572
573#[allow(dead_code)]
575pub struct SEstimator {
576 breakdown_point: f64,
577 max_iter: usize,
578 tol: f64,
579}
580
581impl Default for SEstimator {
582 fn default() -> Self {
583 Self {
584 breakdown_point: 0.5,
585 max_iter: 100,
586 tol: 1e-6,
587 }
588 }
589}
590
591impl SEstimator {
592 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
594 let lts = LeastTrimmedSquares::new(self.breakdown_point);
596 lts.fit(X, y)
597 }
598}
599
600pub struct MMEstimator {
602 s_estimator: SEstimator,
603 m_estimator: MEstimator,
604}
605
606impl MMEstimator {
607 pub fn new() -> Self {
609 Self {
610 s_estimator: SEstimator::default(),
611 m_estimator: MEstimator::tukey(4.685),
612 }
613 }
614
615 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
617 let s_results = self.s_estimator.fit(X, y)?;
619
620 let m_estimator = self
622 .m_estimator
623 .clone()
624 .scale_estimator(ScaleEstimator::Fixed(s_results.scale));
625
626 m_estimator.fit(X, y)
627 }
628}
629
630#[cfg(test)]
635mod tests {
636 use super::*;
637 use ndarray::{array, Array1, Array2};
638
639 #[test]
640 fn test_loss_functions() {
641 let huber = LossFunction::Huber { k: 1.345 };
642 let tukey = LossFunction::Tukey { c: 4.685 };
643 let hampel = LossFunction::Hampel {
644 a: 1.0,
645 b: 2.0,
646 c: 3.0,
647 };
648 let andrews = LossFunction::Andrews { c: 1.339 };
649 let ls = LossFunction::LeastSquares;
650
651 assert_eq!(huber.weight(0.0), 1.0);
653 assert_eq!(tukey.weight(0.0), 1.0);
654 assert_eq!(hampel.weight(0.0), 1.0);
655 assert_eq!(andrews.weight(0.0), 1.0);
656 assert_eq!(ls.weight(0.0), 1.0);
657
658 assert_eq!(huber.psi(0.0), 0.0);
660 assert_eq!(tukey.psi(0.0), 0.0);
661 assert_eq!(hampel.psi(0.0), 0.0);
662 assert_eq!(andrews.psi(0.0), 0.0);
663 assert_eq!(ls.psi(0.0), 0.0);
664 }
665
666 #[test]
667 fn test_mad_and_iqr() {
668 let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
669 let estimator = MEstimator::huber(1.345);
670
671 let mad = estimator.mad(&data);
672 let iqr_scale = estimator.iqr_scale(&data);
673
674 assert!(mad > 0.0);
675 assert!(iqr_scale > 0.0);
676 }
677
678 #[test]
679 fn test_huber_regression() {
680 let X = array![
682 [1.0],
683 [2.0],
684 [3.0],
685 [4.0],
686 [5.0],
687 [6.0], ];
689 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 20.0]); let huber = MEstimator::huber(1.345);
692 let result = huber.fit(&X, &y);
693
694 assert!(result.is_ok());
696 let results = result.unwrap();
697 assert_eq!(results.coefficients.len(), 1);
698 assert!(results.scale > 0.0);
699 assert!(results.iterations > 0);
700 assert!(results.weights.len() == 6);
701
702 for w in results.weights.iter() {
705 assert!(*w >= 0.0 && *w <= 1.0);
706 }
707 }
708
709 #[test]
710 fn test_tukey_regression() {
711 let X = array![
712 [1.0, 1.0],
713 [1.0, 2.0],
714 [1.0, 3.0],
715 [1.0, 4.0],
716 [1.0, 5.0],
717 ];
718 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 10.0]);
719
720 let tukey = MEstimator::tukey(4.685);
721 let result = tukey.fit(&X, &y);
722
723 assert!(result.is_ok());
724 let results = result.unwrap();
725 assert_eq!(results.coefficients.len(), 2);
726 assert!(results.breakdown_point > 0.0);
727 }
728
729 #[test]
730 fn test_lts_regression() {
731 let X = array![
732 [1.0, 1.0],
733 [1.0, 2.0],
734 [1.0, 3.0],
735 [1.0, 4.0],
736 [1.0, 5.0],
737 ];
738 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 100.0]); let lts = LeastTrimmedSquares::new(0.5);
741 let result = lts.fit(&X, &y);
742
743 assert!(result.is_ok());
744 let results = result.unwrap();
745 assert_eq!(results.coefficients.len(), 2);
746 assert!(results.breakdown_point >= 0.5); assert!(results.weights[4] == 0.0); }
749
750 #[test]
751 fn test_mm_estimator() {
752 let X = array![
753 [1.0, 1.0],
754 [1.0, 2.0],
755 [1.0, 3.0],
756 [1.0, 4.0],
757 [1.0, 5.0],
758 ];
759 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0, 100.0]); let mm = MMEstimator::new();
762 let result = mm.fit(&X, &y);
763
764 assert!(result.is_ok());
765 let results = result.unwrap();
766 assert_eq!(results.coefficients.len(), 2);
767 assert!(results.breakdown_point > 0.0);
768 assert!(results.efficiency > 0.8); }
770
771 #[test]
772 fn test_insufficient_data() {
773 let X = array![[1.0]]; let y = Array1::from_vec(vec![1.0]);
775
776 let huber = MEstimator::huber(1.345);
777 let result = huber.fit(&X, &y);
778
779 assert!(result.is_err());
781 }
782
783 }