1#![allow(non_snake_case)] use ndarray::{Array1, Array2};
18use serde::{Deserialize, Serialize};
19use statrs::distribution::{ContinuousCDF, Normal};
20
21use so_core::error::{Error, Result};
22use so_linalg::{inv, solve};
23use so_stats::median;
24
25#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
27pub enum LossFunction {
28 Huber { k: f64 },
30 Tukey { c: f64 },
32 Hampel { a: f64, b: f64, c: f64 },
34 Andrews { c: f64 },
36 LeastSquares,
38}
39
40impl LossFunction {
41 fn weight(&self, r: f64) -> f64 {
43 match self {
44 LossFunction::Huber { k } => {
45 if r.abs() <= *k {
46 1.0
47 } else {
48 k / r.abs()
49 }
50 }
51 LossFunction::Tukey { c } => {
52 if r.abs() <= *c {
53 let t = r / c;
54 (1.0 - t * t).powi(2)
55 } else {
56 0.0
57 }
58 }
59 LossFunction::Hampel { a, b, c } => {
60 let abs_r = r.abs();
61 if abs_r <= *a {
62 1.0
63 } else if abs_r <= *b {
64 a / abs_r
65 } else if abs_r <= *c {
66 a * (c - abs_r) / ((c - b) * abs_r)
67 } else {
68 0.0
69 }
70 }
71 LossFunction::Andrews { c } => {
72 if r.abs() <= *c * std::f64::consts::PI {
73 (c * r.sin() / r).max(0.0)
74 } else {
75 0.0
76 }
77 }
78 LossFunction::LeastSquares => 1.0,
79 }
80 }
81
82 fn psi(&self, r: f64) -> f64 {
84 match self {
85 LossFunction::Huber { k } => {
86 if r.abs() <= *k {
87 r
88 } else {
89 k * r.signum()
90 }
91 }
92 LossFunction::Tukey { c } => {
93 if r.abs() <= *c {
94 let t = r / c;
95 r * (1.0 - t * t).powi(2)
96 } else {
97 0.0
98 }
99 }
100 LossFunction::Hampel { a, b, c } => {
101 let abs_r = r.abs();
102 if abs_r <= *a {
103 r
104 } else if abs_r <= *b {
105 a * r.signum()
106 } else if abs_r <= *c {
107 a * (c - abs_r) / (c - b) * r.signum()
108 } else {
109 0.0
110 }
111 }
112 LossFunction::Andrews { c } => {
113 if r.abs() <= *c * std::f64::consts::PI {
114 c * r.sin()
115 } else {
116 0.0
117 }
118 }
119 LossFunction::LeastSquares => r,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct RobustRegressionResults {
127 pub coefficients: Array1<f64>,
129 pub standard_errors: Array1<f64>,
131 pub scale: f64,
133 pub iterations: usize,
135 pub weights: Array1<f64>,
137 pub breakdown_point: f64,
139 pub efficiency: f64,
141}
142
143#[derive(Clone)]
145pub struct MEstimator {
146 loss: LossFunction,
147 max_iter: usize,
148 tol: f64,
149 scale_est: ScaleEstimator,
150 tuning: TuningParameters,
151}
152
153#[derive(Debug, Clone, Copy)]
155pub enum ScaleEstimator {
156 MAD,
158 IQR,
160 SEstimate,
162 Fixed(f64),
164}
165
166#[derive(Debug, Clone, Copy)]
168pub struct TuningParameters {
169 pub breakdown_point: f64,
171 pub efficiency: f64,
173 pub delta: f64,
175}
176
177impl Default for TuningParameters {
178 fn default() -> Self {
179 Self {
180 breakdown_point: 0.5,
181 efficiency: 0.95,
182 delta: 1e-8,
183 }
184 }
185}
186
187impl MEstimator {
188 pub fn huber(k: f64) -> Self {
190 Self {
191 loss: LossFunction::Huber { k },
192 max_iter: 50,
193 tol: 1e-6,
194 scale_est: ScaleEstimator::MAD,
195 tuning: TuningParameters::default(),
196 }
197 }
198
199 pub fn tukey(c: f64) -> Self {
201 Self {
202 loss: LossFunction::Tukey { c },
203 max_iter: 50,
204 tol: 1e-6,
205 scale_est: ScaleEstimator::MAD,
206 tuning: TuningParameters::default(),
207 }
208 }
209
210 pub fn max_iterations(mut self, max_iter: usize) -> Self {
212 self.max_iter = max_iter;
213 self
214 }
215
216 pub fn tolerance(mut self, tol: f64) -> Self {
218 self.tol = tol;
219 self
220 }
221
222 pub fn scale_estimator(mut self, scale_est: ScaleEstimator) -> Self {
224 self.scale_est = scale_est;
225 self
226 }
227
228 pub fn tuning(mut self, tuning: TuningParameters) -> Self {
230 self.tuning = tuning;
231 self
232 }
233
234 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
236 let n = X.nrows();
237 let p = X.ncols();
238
239 if n <= p {
240 return Err(Error::DataError(
241 "Need more observations than predictors for robust regression".to_string(),
242 ));
243 }
244
245 let mut beta = self.initial_estimate(X, y)?;
247
248 let mut scale = self.initial_scale(X, y, &beta)?;
250
251 let mut iter = 0;
253 let mut converged = false;
254 let mut weights = Array1::ones(n);
255
256 while !converged && iter < self.max_iter {
257 iter += 1;
258
259 let beta_prev = beta.clone();
261
262 let residuals = y - X.dot(&beta);
264 let scaled_residuals = &residuals / scale;
265
266 for i in 0..n {
268 weights[i] = self.loss.weight(scaled_residuals[i]);
269 }
270
271 let W_sqrt = weights.mapv(|w| w.sqrt());
273 let X_weighted = X * W_sqrt.clone().insert_axis(ndarray::Axis(1));
274 let y_weighted = y * &W_sqrt;
275
276 beta = solve(
277 &X_weighted.t().dot(&X_weighted),
278 &X_weighted.t().dot(&y_weighted),
279 )
280 .map_err(|e| Error::LinearAlgebraError(format!("WLS solve failed: {}", e)))?;
281
282 if matches!(self.scale_est, ScaleEstimator::MAD | ScaleEstimator::IQR) {
284 scale = self.update_scale(&residuals, &weights);
285 }
286
287 let beta_diff = (&beta - &beta_prev).mapv(|x| x.abs());
289 let max_diff = beta_diff.iter().fold(0.0, |a, &b| f64::max(a, b));
290 converged = max_diff < self.tol;
291 }
292
293 let standard_errors = self.compute_standard_errors(X, y, &beta, scale, &weights)?;
295
296 let efficiency = self.compute_efficiency();
298 let breakdown_point = self.breakdown_point();
299
300 Ok(RobustRegressionResults {
301 coefficients: beta,
302 standard_errors,
303 scale,
304 iterations: iter,
305 weights,
306 breakdown_point,
307 efficiency,
308 })
309 }
310
311 fn initial_estimate(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<Array1<f64>> {
313 let lts = LeastTrimmedSquares::default();
315 lts.fit(X, y).map(|results| results.coefficients)
316 }
317
318 fn initial_scale(&self, X: &Array2<f64>, y: &Array1<f64>, beta: &Array1<f64>) -> Result<f64> {
320 match self.scale_est {
321 ScaleEstimator::MAD => {
322 let residuals = y - X.dot(beta);
323 Ok(self.mad(&residuals))
324 }
325 ScaleEstimator::IQR => {
326 let residuals = y - X.dot(beta);
327 Ok(self.iqr_scale(&residuals))
328 }
329 ScaleEstimator::SEstimate => {
330 let s_est = SEstimator::default();
332 s_est.fit(X, y).map(|results| results.scale)
333 }
334 ScaleEstimator::Fixed(scale) => Ok(scale),
335 }
336 }
337
338 fn update_scale(&self, residuals: &Array1<f64>, weights: &Array1<f64>) -> f64 {
340 let _n = residuals.len();
342 let sum_weights: f64 = weights.iter().sum();
343 let weighted_sse: f64 = residuals
344 .iter()
345 .zip(weights.iter())
346 .map(|(&r, &w)| r * r * w)
347 .sum();
348
349 (weighted_sse / sum_weights).sqrt()
350 }
351
352 fn mad(&self, data: &Array1<f64>) -> f64 {
354 let med = median(data).unwrap_or(0.0);
355 let abs_dev: Array1<f64> = data.mapv(|x| (x - med).abs());
356 let mad = median(&abs_dev).unwrap_or(0.0);
357 mad / 0.6745 }
359
360 fn iqr_scale(&self, data: &Array1<f64>) -> f64 {
362 use so_stats::quantile;
363 let q1 = quantile(data, 0.25).unwrap_or(0.0);
364 let q3 = quantile(data, 0.75).unwrap_or(0.0);
365 (q3 - q1) / 1.349 }
367
368 fn compute_standard_errors(
370 &self,
371 X: &Array2<f64>,
372 y: &Array1<f64>,
373 beta: &Array1<f64>,
374 scale: f64,
375 weights: &Array1<f64>,
376 ) -> Result<Array1<f64>> {
377 let n = X.nrows();
378 let p = X.ncols();
379
380 let W_sqrt = weights.mapv(|w| w.sqrt());
382 let X_weighted = X * W_sqrt.clone().insert_axis(ndarray::Axis(1));
383 let XtWX = X_weighted.t().dot(&X_weighted);
384
385 let XtWX_inv = inv(&XtWX)
386 .map_err(|e| Error::LinearAlgebraError(format!("Failed to invert X'WX: {}", e)))?;
387
388 let residuals = y - X.dot(beta);
390 let scaled_residuals = &residuals / scale;
391
392 let mut influence = Array1::<f64>::zeros(p);
394 for i in 0..n {
395 let psi = self.loss.psi(scaled_residuals[i]);
396 let xi = X.row(i);
397 influence = influence + xi.mapv(|x| x * psi);
398 }
399
400 let mut sandwich = Array2::zeros((p, p));
402 for i in 0..n {
403 let psi = self.loss.psi(scaled_residuals[i]);
404 let xi = X.row(i);
405 let outer = xi.t().dot(&xi).to_owned() * psi * psi;
406 sandwich += outer;
407 }
408
409 let cov = XtWX_inv.dot(&sandwich.dot(&XtWX_inv)) * scale * scale / n as f64;
410 let se = cov.diag().mapv(|x| x.sqrt());
411
412 Ok(se)
413 }
414
415 fn compute_efficiency(&self) -> f64 {
417 match self.loss {
419 LossFunction::Huber { k } => {
420 let normal = Normal::new(0.0, 1.0).unwrap();
421 let eff = 1.0 / (1.0 + 2.0 * (1.0 - normal.cdf(k)) / k.powi(2));
422 eff.min(1.0)
423 }
424 LossFunction::Tukey { c } => {
425 let _c2 = c * c;
427
428 if c >= 4.0 { 0.95 } else { 0.85 }
429 }
430 _ => 0.85, }
432 }
433
434 fn breakdown_point(&self) -> f64 {
436 match self.loss {
437 LossFunction::Huber { .. } => 0.0, LossFunction::Tukey { .. } => 0.5, LossFunction::Hampel { .. } => 0.5,
440 LossFunction::Andrews { .. } => 0.5,
441 LossFunction::LeastSquares => 0.0,
442 }
443 }
444}
445
446pub struct LeastTrimmedSquares {
448 coverage: f64,
449}
450
451impl Default for LeastTrimmedSquares {
452 fn default() -> Self {
453 Self { coverage: 0.5 }
454 }
455}
456
457impl LeastTrimmedSquares {
458 pub fn new(coverage: f64) -> Self {
460 Self { coverage }
461 }
462
463 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
465 let n = X.nrows();
466 let p = X.ncols();
467
468 if n <= p {
469 return Err(Error::DataError(
470 "Need more observations than predictors for LTS".to_string(),
471 ));
472 }
473
474 let h = (n as f64 * self.coverage).ceil() as usize;
475
476 let n_subsets = 500.min(n);
478 let mut best_sse = f64::INFINITY;
479 let mut best_beta = Array1::zeros(p);
480
481 let mut rng = rand::rng();
482
483 for _ in 0..n_subsets {
484 let subset_indices = rand::seq::index::sample(&mut rng, n, p + 1).into_vec();
486 let X_subset = X.select(ndarray::Axis(0), &subset_indices);
487 let y_subset = y.select(ndarray::Axis(0), &subset_indices);
488
489 if let Ok(beta) = solve(&X_subset.t().dot(&X_subset), &X_subset.t().dot(&y_subset)) {
491 let residuals = y - X.dot(&beta);
492 let mut squared_residuals: Vec<(f64, usize)> = residuals
493 .iter()
494 .enumerate()
495 .map(|(i, &r)| (r * r, i))
496 .collect();
497
498 squared_residuals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
499
500 let sse: f64 = squared_residuals[..h].iter().map(|(r2, _)| r2).sum();
501
502 if sse < best_sse {
503 best_sse = sse;
504 best_beta = beta;
505 }
506 }
507 }
508
509 let residuals = y - X.dot(&best_beta);
511 let mut squared_residuals: Vec<(f64, usize)> = residuals
512 .iter()
513 .enumerate()
514 .map(|(i, &r)| (r * r, i))
515 .collect();
516
517 squared_residuals.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
518
519 let best_indices: Vec<usize> = squared_residuals[..h].iter().map(|(_, i)| *i).collect();
520 let X_best = X.select(ndarray::Axis(0), &best_indices);
521 let y_best = y.select(ndarray::Axis(0), &best_indices);
522
523 let final_beta = solve(&X_best.t().dot(&X_best), &X_best.t().dot(&y_best))
524 .map_err(|e| Error::LinearAlgebraError(format!("LTS final fit failed: {}", e)))?;
525
526 let scale = (best_sse / h as f64).sqrt();
528
529 let mut weights = Array1::zeros(n);
531 for &idx in &best_indices {
532 weights[idx] = 1.0;
533 }
534
535 Ok(RobustRegressionResults {
536 coefficients: final_beta,
537 standard_errors: Array1::zeros(p), scale,
539 iterations: n_subsets,
540 weights,
541 breakdown_point: 1.0 - self.coverage,
542 efficiency: 0.7, })
544 }
545}
546
547#[allow(dead_code)]
549pub struct SEstimator {
550 breakdown_point: f64,
551 max_iter: usize,
552 tol: f64,
553}
554
555impl Default for SEstimator {
556 fn default() -> Self {
557 Self {
558 breakdown_point: 0.5,
559 max_iter: 100,
560 tol: 1e-6,
561 }
562 }
563}
564
565impl SEstimator {
566 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
568 let lts = LeastTrimmedSquares::new(self.breakdown_point);
570 lts.fit(X, y)
571 }
572}
573
574pub struct MMEstimator {
576 s_estimator: SEstimator,
577 m_estimator: MEstimator,
578}
579
580impl MMEstimator {
581 pub fn new() -> Self {
583 Self {
584 s_estimator: SEstimator::default(),
585 m_estimator: MEstimator::tukey(4.685),
586 }
587 }
588
589 pub fn fit(&self, X: &Array2<f64>, y: &Array1<f64>) -> Result<RobustRegressionResults> {
591 let s_results = self.s_estimator.fit(X, y)?;
593
594 let m_estimator = self
596 .m_estimator
597 .clone()
598 .scale_estimator(ScaleEstimator::Fixed(s_results.scale));
599
600 m_estimator.fit(X, y)
601 }
602}