Skip to main content

so_models/
robust.rs

1//! Robust statistical methods for StatOxide
2//!
3//! This module implements robust regression and estimation methods that are
4//! less sensitive to outliers and violations of classical assumptions.
5//!
6//! # Methods Implemented
7//!
8//! 1. **M-estimators**: Huber, Tukey's biweight, Hampel, Andrews
9//! 2. **S-estimators**: High breakdown point estimators
10//! 3. **MM-estimators**: Combine high breakdown and high efficiency
11//! 4. **LTS/LMS**: Least Trimmed Squares / Least Median of Squares
12//! 5. **Robust covariance estimation**: Minimum Covariance Determinant (MCD)
13//!
14
15#![allow(non_snake_case)] // Allow mathematical notation (X, W, etc.)
16
17use ndarray::{Array1, Array2};
18use serde::{Deserialize, Serialize};
19use statrs::distribution::{ContinuousCDF, Normal};
20
21use so_core::error::{Error, Result};
22use so_linalg::{inv, solve};
23use so_stats::median;
24
25/// Loss functions for M-estimation
26#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
27pub enum LossFunction {
28    /// Huber loss: quadratic near zero, linear in tails
29    Huber { k: f64 },
30    /// Tukey's biweight: redescending, completely rejects outliers
31    Tukey { c: f64 },
32    /// Hampel loss: piecewise linear with flat sections
33    Hampel { a: f64, b: f64, c: f64 },
34    /// Andrew's sine wave
35    Andrews { c: f64 },
36    /// Least squares (non-robust baseline)
37    LeastSquares,
38}
39
40impl LossFunction {
41    /// Compute weight for a standardized residual
42    fn weight(&self, r: f64) -> f64 {
43        match self {
44            LossFunction::Huber { k } => {
45                if r.abs() <= *k {
46                    1.0
47                } else {
48                    k / r.abs()
49                }
50            }
51            LossFunction::Tukey { c } => {
52                if r.abs() <= *c {
53                    let t = r / c;
54                    (1.0 - t * t).powi(2)
55                } else {
56                    0.0
57                }
58            }
59            LossFunction::Hampel { a, b, c } => {
60                let abs_r = r.abs();
61                if abs_r <= *a {
62                    1.0
63                } else if abs_r <= *b {
64                    a / abs_r
65                } else if abs_r <= *c {
66                    a * (c - abs_r) / ((c - b) * abs_r)
67                } else {
68                    0.0
69                }
70            }
71            LossFunction::Andrews { c } => {
72                if r.abs() <= *c * std::f64::consts::PI {
73                    (c * r.sin() / r).max(0.0)
74                } else {
75                    0.0
76                }
77            }
78            LossFunction::LeastSquares => 1.0,
79        }
80    }
81
82    /// Compute psi function (derivative of loss)
83    fn psi(&self, r: f64) -> f64 {
84        match self {
85            LossFunction::Huber { k } => {
86                if r.abs() <= *k {
87                    r
88                } else {
89                    k * r.signum()
90                }
91            }
92            LossFunction::Tukey { c } => {
93                if r.abs() <= *c {
94                    let t = r / c;
95                    r * (1.0 - t * t).powi(2)
96                } else {
97                    0.0
98                }
99            }
100            LossFunction::Hampel { a, b, c } => {
101                let abs_r = r.abs();
102                if abs_r <= *a {
103                    r
104                } else if abs_r <= *b {
105                    a * r.signum()
106                } else if abs_r <= *c {
107                    a * (c - abs_r) / (c - b) * r.signum()
108                } else {
109                    0.0
110                }
111            }
112            LossFunction::Andrews { c } => {
113                if r.abs() <= *c * std::f64::consts::PI {
114                    c * r.sin()
115                } else {
116                    0.0
117                }
118            }
119            LossFunction::LeastSquares => r,
120        }
121    }
122}
123
124/// Robust regression results
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct RobustRegressionResults {
127    /// Robust coefficients
128    pub coefficients: Array1<f64>,
129    /// Standard errors (robust)
130    pub standard_errors: Array1<f64>,
131    /// Robust scale estimate (MAD or similar)
132    pub scale: f64,
133    /// Number of iterations
134    pub iterations: usize,
135    /// Final weights (can be used to identify outliers)
136    pub weights: Array1<f64>,
137    /// Breakdown point achieved
138    pub breakdown_point: f64,
139    /// Efficiency relative to OLS
140    pub efficiency: f64,
141}
142
143/// M-estimator for robust regression
144#[derive(Clone)]
145pub struct MEstimator {
146    loss: LossFunction,
147    max_iter: usize,
148    tol: f64,
149    scale_est: ScaleEstimator,
150    tuning: TuningParameters,
151}
152
153/// Scale estimation methods
154#[derive(Debug, Clone, Copy)]
155pub enum ScaleEstimator {
156    /// Median Absolute Deviation (robust)
157    MAD,
158    /// Interquartile Range / 1.349
159    IQR,
160    /// S-estimator scale
161    SEstimate,
162    /// Fixed scale
163    Fixed(f64),
164}
165
166/// Tuning parameters for robust estimators
167#[derive(Debug, Clone, Copy)]
168pub struct TuningParameters {
169    /// Initial breakdown point for S-estimators
170    pub breakdown_point: f64,
171    /// Efficiency target for MM-estimators
172    pub efficiency: f64,
173    /// Numerical stability parameter
174    pub delta: f64,
175}
176
177impl Default for TuningParameters {
178    fn default() -> Self {
179        Self {
180            breakdown_point: 0.5,
181            efficiency: 0.95,
182            delta: 1e-8,
183        }
184    }
185}
186
187impl MEstimator {
188    /// Create a new M-estimator with Huber loss (k=1.345 gives 95% efficiency)
189    pub fn huber(k: f64) -> Self {
190        Self {
191            loss: LossFunction::Huber { k },
192            max_iter: 50,
193            tol: 1e-6,
194            scale_est: ScaleEstimator::MAD,
195            tuning: TuningParameters::default(),
196        }
197    }
198
199    /// Create a new M-estimator with Tukey's biweight (c=4.685 gives 95% efficiency)
200    pub fn tukey(c: f64) -> Self {
201        Self {
202            loss: LossFunction::Tukey { c },
203            max_iter: 50,
204            tol: 1e-6,
205            scale_est: ScaleEstimator::MAD,
206            tuning: TuningParameters::default(),
207        }
208    }
209
210    /// Set maximum iterations
211    pub fn max_iterations(mut self, max_iter: usize) -> Self {
212        self.max_iter = max_iter;
213        self
214    }
215
216    /// Set convergence tolerance
217    pub fn tolerance(mut self, tol: f64) -> Self {
218        self.tol = tol;
219        self
220    }
221
222    /// Set scale estimation method
223    pub fn scale_estimator(mut self, scale_est: ScaleEstimator) -> Self {
224        self.scale_est = scale_est;
225        self
226    }
227
228    /// Set tuning parameters
229    pub fn tuning(mut self, tuning: TuningParameters) -> Self {
230        self.tuning = tuning;
231        self
232    }
233
234    /// Fit robust regression using Iteratively Reweighted Least Squares (IRLS)
235    pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
236        let n = X.nrows();
237        let p = X.ncols();
238
239        if n <= p {
240            return Err(Error::DataError(
241                "Need more observations than predictors for robust regression".to_string(),
242            ));
243        }
244
245        // Initial OLS estimate
246        let mut beta = self.initial_estimate(X, y)?;
247
248        // Initial scale estimate
249        let mut scale = self.initial_scale(X, y, &beta)?;
250
251        // Iteratively reweighted least squares
252        let mut iter = 0;
253        let mut converged = false;
254        let mut weights = Array1::ones(n);
255
256        while !converged && iter < self.max_iter {
257            iter += 1;
258
259            // Store previous coefficients
260            let beta_prev = beta.clone();
261
262            // Compute standardized residuals
263            let residuals = y - X.dot(&beta);
264            let scaled_residuals = &residuals / scale;
265
266            // Compute weights based on loss function
267            for i in 0..n {
268                weights[i] = self.loss.weight(scaled_residuals[i]);
269            }
270
271            // Solve weighted least squares
272            let W_sqrt = weights.mapv(|w| w.sqrt());
273            let X_weighted = X * W_sqrt.clone().insert_axis(ndarray::Axis(1));
274            let y_weighted = y * &W_sqrt;
275
276            beta = solve(
277                &X_weighted.t().dot(&X_weighted),
278                &X_weighted.t().dot(&y_weighted),
279            )
280            .map_err(|e| Error::LinearAlgebraError(format!("WLS solve failed: {}", e)))?;
281
282            // Update scale estimate if needed
283            if matches!(self.scale_est, ScaleEstimator::MAD | ScaleEstimator::IQR) {
284                scale = self.update_scale(&residuals, &weights);
285            }
286
287            // Check convergence
288            let beta_diff = (&beta - &beta_prev).mapv(|x| x.abs());
289            let max_diff = beta_diff.iter().fold(0.0, |a, &b| f64::max(a, b));
290            converged = max_diff < self.tol;
291        }
292
293        // Compute robust standard errors
294        let standard_errors = self.compute_standard_errors(X, y, &beta, scale, &weights)?;
295
296        // Compute efficiency and breakdown point
297        let efficiency = self.compute_efficiency();
298        let breakdown_point = self.breakdown_point();
299
300        Ok(RobustRegressionResults {
301            coefficients: beta,
302            standard_errors,
303            scale,
304            iterations: iter,
305            weights,
306            breakdown_point,
307            efficiency,
308        })
309    }
310
311    /// Initial estimate (usually LTS or LMS for high breakdown)
312    fn initial_estimate(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<Array1<f64>> {
313        // For simplicity, use LTS with default coverage
314        let lts = LeastTrimmedSquares::default();
315        lts.fit(X, y).map(|results| results.coefficients)
316    }
317
318    /// Initial scale estimate
319    fn initial_scale(&self, X: &Array2<f64>, y: &Array1<f64>, beta: &Array1<f64>) -> Result<f64> {
320        match self.scale_est {
321            ScaleEstimator::MAD => {
322                let residuals = y - X.dot(beta);
323                Ok(self.mad(&residuals))
324            }
325            ScaleEstimator::IQR => {
326                let residuals = y - X.dot(beta);
327                Ok(self.iqr_scale(&residuals))
328            }
329            ScaleEstimator::SEstimate => {
330                // Use S-estimator for initial scale
331                let s_est = SEstimator::default();
332                s_est.fit(X, y).map(|results| results.scale)
333            }
334            ScaleEstimator::Fixed(scale) => Ok(scale),
335        }
336    }
337
338    /// Update scale estimate based on residuals and weights
339    fn update_scale(&self, residuals: &Array1<f64>, weights: &Array1<f64>) -> f64 {
340        // Weighted scale estimate
341        let _n = residuals.len();
342        let sum_weights: f64 = weights.iter().sum();
343        let weighted_sse: f64 = residuals
344            .iter()
345            .zip(weights.iter())
346            .map(|(&r, &w)| r * r * w)
347            .sum();
348
349        (weighted_sse / sum_weights).sqrt()
350    }
351
352    /// Compute Median Absolute Deviation
353    fn mad(&self, data: &Array1<f64>) -> f64 {
354        let med = median(data).unwrap_or(0.0);
355        let abs_dev: Array1<f64> = data.mapv(|x| (x - med).abs());
356        let mad = median(&abs_dev).unwrap_or(0.0);
357        mad / 0.6745 // Convert to consistent estimator for normal distribution
358    }
359
360    /// Compute IQR-based scale estimate
361    fn iqr_scale(&self, data: &Array1<f64>) -> f64 {
362        use so_stats::quantile;
363        let q1 = quantile(data, 0.25).unwrap_or(0.0);
364        let q3 = quantile(data, 0.75).unwrap_or(0.0);
365        (q3 - q1) / 1.349 // Convert to consistent estimator for normal distribution
366    }
367
368    /// Compute robust standard errors
369    fn compute_standard_errors(
370        &self,
371        X: &Array2<f64>,
372        y: &Array1<f64>,
373        beta: &Array1<f64>,
374        scale: f64,
375        weights: &Array1<f64>,
376    ) -> Result<Array1<f64>> {
377        let n = X.nrows();
378        let p = X.ncols();
379
380        // Compute weighted X'X inverse
381        let W_sqrt = weights.mapv(|w| w.sqrt());
382        let X_weighted = X * W_sqrt.clone().insert_axis(ndarray::Axis(1));
383        let XtWX = X_weighted.t().dot(&X_weighted);
384
385        let XtWX_inv = inv(&XtWX)
386            .map_err(|e| Error::LinearAlgebraError(format!("Failed to invert X'WX: {}", e)))?;
387
388        // Compute leverage-adjusted residuals
389        let residuals = y - X.dot(beta);
390        let scaled_residuals = &residuals / scale;
391
392        // Compute empirical influence function
393        let mut influence = Array1::<f64>::zeros(p);
394        for i in 0..n {
395            let psi = self.loss.psi(scaled_residuals[i]);
396            let xi = X.row(i);
397            influence = influence + xi.mapv(|x| x * psi);
398        }
399
400        // Compute sandwich variance estimator
401        let mut sandwich = Array2::zeros((p, p));
402        for i in 0..n {
403            let psi = self.loss.psi(scaled_residuals[i]);
404            let xi = X.row(i);
405            let outer = xi.t().dot(&xi).to_owned() * psi * psi;
406            sandwich += outer;
407        }
408
409        let cov = XtWX_inv.dot(&sandwich.dot(&XtWX_inv)) * scale * scale / n as f64;
410        let se = cov.diag().mapv(|x| x.sqrt());
411
412        Ok(se)
413    }
414
415    /// Compute asymptotic efficiency
416    fn compute_efficiency(&self) -> f64 {
417        // Asymptotic efficiency relative to OLS under normality
418        match self.loss {
419            LossFunction::Huber { k } => {
420                let normal = Normal::new(0.0, 1.0).unwrap();
421                let eff = 1.0 / (1.0 + 2.0 * (1.0 - normal.cdf(k)) / k.powi(2));
422                eff.min(1.0)
423            }
424            LossFunction::Tukey { c } => {
425                // Approximation for Tukey's efficiency
426                let _c2 = c * c;
427
428                if c >= 4.0 { 0.95 } else { 0.85 }
429            }
430            _ => 0.85, // Conservative estimate for other loss functions
431        }
432    }
433
434    /// Estimate breakdown point
435    fn breakdown_point(&self) -> f64 {
436        match self.loss {
437            LossFunction::Huber { .. } => 0.0, // M-estimators have 0 breakdown
438            LossFunction::Tukey { .. } => 0.5, // Redescending M-estimators can have high breakdown
439            LossFunction::Hampel { .. } => 0.5,
440            LossFunction::Andrews { .. } => 0.5,
441            LossFunction::LeastSquares => 0.0,
442        }
443    }
444}
445
446/// Least Trimmed Squares estimator (high breakdown)
447pub struct LeastTrimmedSquares {
448    coverage: f64,
449}
450
451impl Default for LeastTrimmedSquares {
452    fn default() -> Self {
453        Self { coverage: 0.5 }
454    }
455}
456
457impl LeastTrimmedSquares {
458    /// Create LTS with specified coverage
459    pub fn new(coverage: f64) -> Self {
460        Self { coverage }
461    }
462
463    /// Fit LTS regression
464    pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
465        let n = X.nrows();
466        let p = X.ncols();
467
468        if n <= p {
469            return Err(Error::DataError(
470                "Need more observations than predictors for LTS".to_string(),
471            ));
472        }
473
474        let h = (n as f64 * self.coverage).ceil() as usize;
475
476        // Simplified LTS: use random subsets (in practice, use fast algorithms)
477        let n_subsets = 500.min(n);
478        let mut best_sse = f64::INFINITY;
479        let mut best_beta = Array1::zeros(p);
480
481        let mut rng = rand::rng();
482
483        for _ in 0..n_subsets {
484            // Random subset of size p+1
485            let subset_indices = rand::seq::index::sample(&mut rng, n, p + 1).into_vec();
486            let X_subset = X.select(ndarray::Axis(0), &subset_indices);
487            let y_subset = y.select(ndarray::Axis(0), &subset_indices);
488
489            // Fit on subset
490            if let Ok(beta) = solve(&X_subset.t().dot(&X_subset), &X_subset.t().dot(&y_subset)) {
491                let residuals = y - X.dot(&beta);
492                let mut squared_residuals: Vec<(f64, usize)> = residuals
493                    .iter()
494                    .enumerate()
495                    .map(|(i, &r)| (r * r, i))
496                    .collect();
497
498                squared_residuals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
499
500                let sse: f64 = squared_residuals[..h].iter().map(|(r2, _)| r2).sum();
501
502                if sse < best_sse {
503                    best_sse = sse;
504                    best_beta = beta;
505                }
506            }
507        }
508
509        // Refit on best h points
510        let residuals = y - X.dot(&best_beta);
511        let mut squared_residuals: Vec<(f64, usize)> = residuals
512            .iter()
513            .enumerate()
514            .map(|(i, &r)| (r * r, i))
515            .collect();
516
517        squared_residuals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
518
519        let best_indices: Vec<usize> = squared_residuals[..h].iter().map(|(_, i)| *i).collect();
520        let X_best = X.select(ndarray::Axis(0), &best_indices);
521        let y_best = y.select(ndarray::Axis(0), &best_indices);
522
523        let final_beta = solve(&X_best.t().dot(&X_best), &X_best.t().dot(&y_best))
524            .map_err(|e| Error::LinearAlgebraError(format!("LTS final fit failed: {}", e)))?;
525
526        // Compute scale from trimmed residuals
527        let scale = (best_sse / h as f64).sqrt();
528
529        // Create weight vector (1 for inliers, 0 for outliers)
530        let mut weights = Array1::zeros(n);
531        for &idx in &best_indices {
532            weights[idx] = 1.0;
533        }
534
535        Ok(RobustRegressionResults {
536            coefficients: final_beta,
537            standard_errors: Array1::zeros(p), // Simplified
538            scale,
539            iterations: n_subsets,
540            weights,
541            breakdown_point: 1.0 - self.coverage,
542            efficiency: 0.7, // LTS has lower efficiency
543        })
544    }
545}
546
547/// S-estimator (high breakdown point)
548#[allow(dead_code)]
549pub struct SEstimator {
550    breakdown_point: f64,
551    max_iter: usize,
552    tol: f64,
553}
554
555impl Default for SEstimator {
556    fn default() -> Self {
557        Self {
558            breakdown_point: 0.5,
559            max_iter: 100,
560            tol: 1e-6,
561        }
562    }
563}
564
565impl SEstimator {
566    /// Fit S-estimator (simplified implementation)
567    pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
568        // Simplified: use LTS as starting point
569        let lts = LeastTrimmedSquares::new(self.breakdown_point);
570        lts.fit(X, y)
571    }
572}
573
574/// MM-estimator (combines high breakdown and high efficiency)
575pub struct MMEstimator {
576    s_estimator: SEstimator,
577    m_estimator: MEstimator,
578}
579
580impl MMEstimator {
581    /// Create new MM-estimator
582    pub fn new() -> Self {
583        Self {
584            s_estimator: SEstimator::default(),
585            m_estimator: MEstimator::tukey(4.685),
586        }
587    }
588
589    /// Fit MM-estimator
590    pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
591        // Step 1: S-estimator for high breakdown
592        let s_results = self.s_estimator.fit(X, y)?;
593
594        // Step 2: M-estimation with fixed scale from S-estimator
595        let m_estimator = self
596            .m_estimator
597            .clone()
598            .scale_estimator(ScaleEstimator::Fixed(s_results.scale));
599
600        m_estimator.fit(X, y)
601    }
602}