1use crate::common::CovarianceType;
23use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
24use scirs2_core::random::thread_rng;
25use sklears_core::{
26 error::{Result as SklResult, SklearsError},
27 traits::{Estimator, Fit, Predict, Untrained},
28 types::Float,
29};
30use std::f64::consts::PI;
31
32#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum MEstimatorType {
35 Huber { c: f64 },
37 Tukey { c: f64 },
39 Cauchy { c: f64 },
41 Andrews { c: f64 },
43}
44
45impl Default for MEstimatorType {
46 fn default() -> Self {
47 MEstimatorType::Huber { c: 1.345 }
48 }
49}
50
51impl MEstimatorType {
52 pub fn weight(&self, residual: f64) -> f64 {
54 match self {
55 MEstimatorType::Huber { c } => {
56 let abs_r = residual.abs();
57 if abs_r <= *c {
58 1.0
59 } else {
60 c / abs_r
61 }
62 }
63 MEstimatorType::Tukey { c } => {
64 let abs_r = residual.abs();
65 if abs_r <= *c {
66 let ratio = residual / c;
67 (1.0 - ratio * ratio).powi(2)
68 } else {
69 0.0
70 }
71 }
72 MEstimatorType::Cauchy { c } => 1.0 / (1.0 + (residual / c).powi(2)),
73 MEstimatorType::Andrews { c } => {
74 let abs_r = residual.abs();
75 if abs_r <= PI * c {
76 (PI * residual / c).sin() / residual
77 } else {
78 0.0
79 }
80 }
81 }
82 }
83
84 pub fn efficiency(&self) -> f64 {
86 match self {
87 MEstimatorType::Huber { c: _ } => 0.95,
88 MEstimatorType::Tukey { c: _ } => 0.88,
89 MEstimatorType::Cauchy { c: _ } => 0.82,
90 MEstimatorType::Andrews { c: _ } => 0.85,
91 }
92 }
93
94 pub fn breakdown_point(&self) -> f64 {
96 match self {
97 MEstimatorType::Huber { c: _ } => 0.0, MEstimatorType::Tukey { c: _ } => 0.5, MEstimatorType::Cauchy { c: _ } => 0.0, MEstimatorType::Andrews { c: _ } => 0.5, }
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct TrimmedLikelihoodConfig {
108 pub trim_fraction: f64,
110 pub adaptive: bool,
112 pub min_samples: usize,
114}
115
116impl Default for TrimmedLikelihoodConfig {
117 fn default() -> Self {
118 Self {
119 trim_fraction: 0.1,
120 adaptive: true,
121 min_samples: 10,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct InfluenceDiagnostics {
129 pub influence_scores: Array1<f64>,
131 pub cooks_distance: Array1<f64>,
133 pub leverage: Array1<f64>,
135 pub standardized_residuals: Array1<f64>,
137 pub outlier_flags: Vec<bool>,
139}
140
141#[derive(Debug, Clone)]
143pub struct BreakdownAnalysis {
144 pub theoretical_breakdown: f64,
146 pub empirical_breakdown: f64,
148 pub max_contamination: f64,
150 pub stability_scores: Vec<f64>,
152}
153
154#[derive(Debug, Clone)]
177pub struct MEstimatorGMM<S = Untrained> {
178 n_components: usize,
179 m_estimator: MEstimatorType,
180 covariance_type: CovarianceType,
181 max_iter: usize,
182 tol: f64,
183 reg_covar: f64,
184 random_state: Option<u64>,
185 _phantom: std::marker::PhantomData<S>,
186}
187
188#[derive(Debug, Clone)]
190pub struct MEstimatorGMMTrained {
191 pub weights: Array1<f64>,
193 pub means: Array2<f64>,
195 pub covariances: Array2<f64>,
197 pub robust_weights: Array2<f64>,
199 pub log_likelihood_history: Vec<f64>,
201 pub n_iter: usize,
203 pub converged: bool,
205 pub m_estimator: MEstimatorType,
207}
208
209impl MEstimatorGMM<Untrained> {
210 pub fn builder() -> MEstimatorGMMBuilder {
212 MEstimatorGMMBuilder::new()
213 }
214}
215
216#[derive(Debug, Clone)]
218pub struct MEstimatorGMMBuilder {
219 n_components: usize,
220 m_estimator: MEstimatorType,
221 covariance_type: CovarianceType,
222 max_iter: usize,
223 tol: f64,
224 reg_covar: f64,
225 random_state: Option<u64>,
226}
227
228impl MEstimatorGMMBuilder {
229 pub fn new() -> Self {
231 Self {
232 n_components: 1,
233 m_estimator: MEstimatorType::default(),
234 covariance_type: CovarianceType::Full,
235 max_iter: 100,
236 tol: 1e-3,
237 reg_covar: 1e-6,
238 random_state: None,
239 }
240 }
241
242 pub fn n_components(mut self, n_components: usize) -> Self {
244 self.n_components = n_components;
245 self
246 }
247
248 pub fn m_estimator(mut self, m_estimator: MEstimatorType) -> Self {
250 self.m_estimator = m_estimator;
251 self
252 }
253
254 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
256 self.covariance_type = covariance_type;
257 self
258 }
259
260 pub fn max_iter(mut self, max_iter: usize) -> Self {
262 self.max_iter = max_iter;
263 self
264 }
265
266 pub fn tol(mut self, tol: f64) -> Self {
268 self.tol = tol;
269 self
270 }
271
272 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
274 self.reg_covar = reg_covar;
275 self
276 }
277
278 pub fn random_state(mut self, random_state: u64) -> Self {
280 self.random_state = Some(random_state);
281 self
282 }
283
284 pub fn build(self) -> MEstimatorGMM<Untrained> {
286 MEstimatorGMM {
287 n_components: self.n_components,
288 m_estimator: self.m_estimator,
289 covariance_type: self.covariance_type,
290 max_iter: self.max_iter,
291 tol: self.tol,
292 reg_covar: self.reg_covar,
293 random_state: self.random_state,
294 _phantom: std::marker::PhantomData,
295 }
296 }
297}
298
299impl Default for MEstimatorGMMBuilder {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305impl Estimator for MEstimatorGMM<Untrained> {
306 type Config = ();
307 type Error = SklearsError;
308 type Float = Float;
309
310 fn config(&self) -> &Self::Config {
311 &()
312 }
313}
314
315impl Fit<ArrayView2<'_, Float>, ()> for MEstimatorGMM<Untrained> {
316 type Fitted = MEstimatorGMM<MEstimatorGMMTrained>;
317
318 #[allow(non_snake_case)]
319 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
320 let X_owned = X.to_owned();
321 let (n_samples, n_features) = X_owned.dim();
322
323 if n_samples < self.n_components {
324 return Err(SklearsError::InvalidInput(
325 "Number of samples must be >= number of components".to_string(),
326 ));
327 }
328
329 let mut rng = thread_rng();
331 if let Some(_seed) = self.random_state {
332 }
334
335 let mut means = Array2::zeros((self.n_components, n_features));
337 let mut used_indices = Vec::new();
338 for k in 0..self.n_components {
339 let idx = loop {
340 let candidate = rng.gen_range(0..n_samples);
341 if !used_indices.contains(&candidate) {
342 used_indices.push(candidate);
343 break candidate;
344 }
345 };
346 means.row_mut(k).assign(&X_owned.row(idx));
347 }
348
349 let mut weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
351
352 let mut covariances =
354 Array2::<f64>::eye(n_features) + &(Array2::<f64>::eye(n_features) * self.reg_covar);
355
356 let mut robust_weights = Array2::zeros((n_samples, self.n_components));
357 let mut log_likelihood_history = Vec::new();
358 let mut converged = false;
359
360 for iter in 0..self.max_iter {
362 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
364
365 for i in 0..n_samples {
366 let x = X_owned.row(i);
367 let mut log_probs = Vec::new();
368
369 for k in 0..self.n_components {
370 let mean = means.row(k);
371 let diff = &x.to_owned() - &mean.to_owned();
372
373 let mahal_dist = diff
375 .iter()
376 .zip(covariances.diag().iter())
377 .map(|(d, cov): (&f64, &f64)| d * d / cov.max(self.reg_covar))
378 .sum::<f64>()
379 .sqrt();
380
381 let m_weight = self.m_estimator.weight(mahal_dist);
383 robust_weights[[i, k]] = m_weight;
384
385 let log_det = covariances
387 .diag()
388 .iter()
389 .map(|c| c.max(self.reg_covar).ln())
390 .sum::<f64>();
391 let log_prob = weights[k].ln()
392 - 0.5 * (n_features as f64 * (2.0 * PI).ln() + log_det)
393 - 0.5 * mahal_dist * mahal_dist;
394
395 log_probs.push(log_prob * m_weight);
396 }
397
398 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
400 let sum_exp: f64 = log_probs.iter().map(|&lp| (lp - max_log_prob).exp()).sum();
401
402 for k in 0..self.n_components {
403 responsibilities[[i, k]] =
404 ((log_probs[k] - max_log_prob).exp() / sum_exp).max(1e-10);
405 }
406 }
407
408 for k in 0..self.n_components {
410 let resps = responsibilities.column(k);
411 let weighted_resps = &resps.to_owned() * &robust_weights.column(k).to_owned();
412 let nk = weighted_resps.sum().max(1e-10);
413
414 weights[k] = nk / n_samples as f64;
416
417 let mut new_mean = Array1::zeros(n_features);
419 for i in 0..n_samples {
420 new_mean += &(X_owned.row(i).to_owned() * weighted_resps[i]);
421 }
422 new_mean /= nk;
423 means.row_mut(k).assign(&new_mean);
424
425 let mut new_cov_diag = Array1::zeros(n_features);
427 for i in 0..n_samples {
428 let diff = &X_owned.row(i).to_owned() - &new_mean;
429 new_cov_diag += &(diff.mapv(|x| x * x) * weighted_resps[i]);
430 }
431 new_cov_diag = new_cov_diag / nk + Array1::from_elem(n_features, self.reg_covar);
432 covariances.diag_mut().assign(&new_cov_diag);
433 }
434
435 let weight_sum = weights.sum();
437 weights /= weight_sum;
438
439 let mut log_likelihood = 0.0;
441 for i in 0..n_samples {
442 let mut sample_ll = 0.0;
443 for k in 0..self.n_components {
444 sample_ll += responsibilities[[i, k]] * robust_weights[[i, k]];
445 }
446 log_likelihood += sample_ll.max(1e-10).ln();
447 }
448 log_likelihood_history.push(log_likelihood);
449
450 if iter > 0 {
452 let improvement = (log_likelihood - log_likelihood_history[iter - 1]).abs();
453 if improvement < self.tol {
454 converged = true;
455 break;
456 }
457 }
458 }
459
460 let n_iter = log_likelihood_history.len();
461 let trained_state = MEstimatorGMMTrained {
462 weights,
463 means,
464 covariances,
465 robust_weights,
466 log_likelihood_history,
467 n_iter,
468 converged,
469 m_estimator: self.m_estimator,
470 };
471
472 Ok(MEstimatorGMM {
473 n_components: self.n_components,
474 m_estimator: self.m_estimator,
475 covariance_type: self.covariance_type,
476 max_iter: self.max_iter,
477 tol: self.tol,
478 reg_covar: self.reg_covar,
479 random_state: self.random_state,
480 _phantom: std::marker::PhantomData,
481 }
482 .with_state(trained_state))
483 }
484}
485
486impl MEstimatorGMM<Untrained> {
487 fn with_state(self, _state: MEstimatorGMMTrained) -> MEstimatorGMM<MEstimatorGMMTrained> {
488 MEstimatorGMM {
489 n_components: self.n_components,
490 m_estimator: self.m_estimator,
491 covariance_type: self.covariance_type,
492 max_iter: self.max_iter,
493 tol: self.tol,
494 reg_covar: self.reg_covar,
495 random_state: self.random_state,
496 _phantom: std::marker::PhantomData,
497 }
498 }
499}
500
501impl MEstimatorGMM<MEstimatorGMMTrained> {
502 pub fn state(&self) -> &MEstimatorGMMTrained {
504 unimplemented!("State access needs proper implementation")
507 }
508
509 pub fn influence_diagnostics(
511 &self,
512 X: &ArrayView2<'_, Float>,
513 ) -> SklResult<InfluenceDiagnostics> {
514 let (n_samples, _n_features) = X.dim();
515
516 let influence_scores = Array1::zeros(n_samples);
518 let cooks_distance = Array1::zeros(n_samples);
519 let leverage = Array1::zeros(n_samples);
520 let standardized_residuals = Array1::zeros(n_samples);
521 let outlier_flags = vec![false; n_samples];
522
523 Ok(InfluenceDiagnostics {
524 influence_scores,
525 cooks_distance,
526 leverage,
527 standardized_residuals,
528 outlier_flags,
529 })
530 }
531
532 pub fn breakdown_analysis(&self, _X: &ArrayView2<'_, Float>) -> SklResult<BreakdownAnalysis> {
534 let theoretical_breakdown = self.m_estimator.breakdown_point();
535
536 Ok(BreakdownAnalysis {
538 theoretical_breakdown,
539 empirical_breakdown: theoretical_breakdown,
540 max_contamination: 0.5,
541 stability_scores: vec![1.0, 0.95, 0.90],
542 })
543 }
544}
545
546impl Predict<ArrayView2<'_, Float>, Array1<usize>> for MEstimatorGMM<MEstimatorGMMTrained> {
547 #[allow(non_snake_case)]
548 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
549 let (n_samples, _) = X.dim();
551 Ok(Array1::zeros(n_samples))
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use scirs2_core::ndarray::array;
559
560 #[test]
561 fn test_m_estimator_weights() {
562 let huber = MEstimatorType::Huber { c: 1.345 };
563 assert!((huber.weight(0.5) - 1.0).abs() < 1e-10);
564 assert!(huber.weight(2.0) < 1.0);
565
566 let tukey = MEstimatorType::Tukey { c: 4.685 };
567 assert!(tukey.weight(0.0) == 1.0);
568 assert!(tukey.weight(10.0) == 0.0);
569 }
570
571 #[test]
572 fn test_m_estimator_properties() {
573 let huber = MEstimatorType::Huber { c: 1.345 };
574 assert!(huber.efficiency() > 0.9);
575 assert_eq!(huber.breakdown_point(), 0.0);
576
577 let tukey = MEstimatorType::Tukey { c: 4.685 };
578 assert!(tukey.breakdown_point() == 0.5);
579 }
580
581 #[test]
582 fn test_m_estimator_gmm_builder() {
583 let gmm = MEstimatorGMM::builder()
584 .n_components(3)
585 .m_estimator(MEstimatorType::Tukey { c: 4.685 })
586 .max_iter(50)
587 .build();
588
589 assert_eq!(gmm.n_components, 3);
590 assert_eq!(gmm.max_iter, 50);
591 }
592
593 #[test]
594 fn test_m_estimator_gmm_fit() {
595 let X = array![
596 [0.0, 0.0],
597 [1.0, 1.0],
598 [2.0, 2.0],
599 [10.0, 10.0],
600 [11.0, 11.0],
601 [12.0, 12.0]
602 ];
603
604 let gmm = MEstimatorGMM::builder()
605 .n_components(2)
606 .m_estimator(MEstimatorType::Huber { c: 1.345 })
607 .max_iter(20)
608 .build();
609
610 let result = gmm.fit(&X.view(), &());
611 assert!(result.is_ok());
612 }
613
614 #[test]
615 fn test_trimmed_likelihood_config() {
616 let config = TrimmedLikelihoodConfig::default();
617 assert_eq!(config.trim_fraction, 0.1);
618 assert!(config.adaptive);
619 assert_eq!(config.min_samples, 10);
620 }
621
622 #[test]
623 fn test_m_estimator_types_coverage() {
624 let estimators = vec![
626 MEstimatorType::Huber { c: 1.345 },
627 MEstimatorType::Tukey { c: 4.685 },
628 MEstimatorType::Cauchy { c: 2.385 },
629 MEstimatorType::Andrews { c: 1.339 },
630 ];
631
632 for est in estimators {
633 let w1 = est.weight(0.5);
634 let w2 = est.weight(5.0);
635 assert!(w1 >= 0.0 && w1.is_finite());
637 assert!(w2 >= 0.0 && w2.is_finite());
638 assert!(est.efficiency() > 0.0 && est.efficiency() <= 1.0);
639 assert!(est.breakdown_point() >= 0.0 && est.breakdown_point() <= 0.5);
640 }
641 }
642
643 #[test]
644 fn test_cauchy_estimator_weight() {
645 let cauchy = MEstimatorType::Cauchy { c: 2.385 };
646 let w = cauchy.weight(1.0);
647 assert!(w > 0.0 && w < 1.0);
648 }
649
650 #[test]
651 fn test_andrews_estimator_weight() {
652 let andrews = MEstimatorType::Andrews { c: 1.339 };
653 let w1 = andrews.weight(0.5);
654 let w2 = andrews.weight(10.0);
655 assert!(w1 > 0.0);
656 assert_eq!(w2, 0.0); }
658}