sklears_mixture/
robust.rs

1//! Robust Gaussian Mixture Models
2//!
3//! This module implements robust Gaussian mixture models that are resistant to outliers.
4//! The implementation uses trimmed likelihood estimation and outlier detection to provide
5//! robust parameter estimates even in the presence of outliers.
6
7use crate::common::{CovarianceType, ModelSelection};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use sklears_core::{
10    error::{Result as SklResult, SklearsError},
11    traits::{Estimator, Fit, Predict, Untrained},
12    types::Float,
13};
14use std::f64::consts::PI;
15
16/// Utility function for log-sum-exp computation
17fn log_sum_exp(a: f64, b: f64) -> f64 {
18    let max_val = a.max(b);
19    if max_val.is_finite() {
20        max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
21    } else {
22        max_val
23    }
24}
25
26/// Robust Gaussian Mixture Model
27///
28/// A robust version of Gaussian mixture model that is resistant to outliers.
29/// This implementation uses trimmed likelihood estimation and outlier detection
30/// to provide robust parameter estimates even in the presence of outliers.
31///
32/// # Parameters
33///
34/// * `n_components` - Number of mixture components
35/// * `covariance_type` - Type of covariance parameters
36/// * `tol` - Convergence threshold
37/// * `reg_covar` - Regularization added to the diagonal of covariance
38/// * `max_iter` - Maximum number of EM iterations
39/// * `n_init` - Number of initializations to perform
40/// * `outlier_fraction` - Expected fraction of outliers in the data (0.0 to 0.5)
41/// * `outlier_threshold` - Threshold for outlier detection (in standard deviations)
42/// * `robust_covariance` - Whether to use robust covariance estimation
43/// * `random_state` - Random state for reproducibility
44///
45/// # Examples
46///
47/// ```
48/// use sklears_mixture::{RobustGaussianMixture, CovarianceType};
49/// use sklears_core::traits::{Predict, Fit};
50/// use scirs2_core::ndarray::array;
51///
52/// let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [100.0, 100.0], [11.0, 11.0], [12.0, 12.0]];
53///
54/// let rgmm = RobustGaussianMixture::new()
55///     .n_components(2)
56///     .outlier_fraction(0.15)
57///     .covariance_type(CovarianceType::Diagonal)
58///     .max_iter(100);
59/// let fitted = rgmm.fit(&X.view(), &()).unwrap();
60/// let labels = fitted.predict(&X.view()).unwrap();
61/// ```
62#[derive(Debug, Clone)]
63pub struct RobustGaussianMixture<S = Untrained> {
64    pub(crate) state: S,
65    n_components: usize,
66    covariance_type: CovarianceType,
67    tol: f64,
68    reg_covar: f64,
69    max_iter: usize,
70    n_init: usize,
71    outlier_fraction: f64,
72    outlier_threshold: f64,
73    robust_covariance: bool,
74    random_state: Option<u64>,
75}
76
77impl RobustGaussianMixture<Untrained> {
78    /// Create a new RobustGaussianMixture instance
79    pub fn new() -> Self {
80        Self {
81            state: Untrained,
82            n_components: 1,
83            covariance_type: CovarianceType::Full,
84            tol: 1e-3,
85            reg_covar: 1e-6,
86            max_iter: 100,
87            n_init: 1,
88            outlier_fraction: 0.1,
89            outlier_threshold: 3.0,
90            robust_covariance: true,
91            random_state: None,
92        }
93    }
94
95    /// Set the number of components
96    pub fn n_components(mut self, n_components: usize) -> Self {
97        self.n_components = n_components;
98        self
99    }
100
101    /// Set the covariance type
102    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
103        self.covariance_type = covariance_type;
104        self
105    }
106
107    /// Set the convergence tolerance
108    pub fn tol(mut self, tol: f64) -> Self {
109        self.tol = tol;
110        self
111    }
112
113    /// Set the regularization parameter
114    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
115        self.reg_covar = reg_covar;
116        self
117    }
118
119    /// Set the maximum number of iterations
120    pub fn max_iter(mut self, max_iter: usize) -> Self {
121        self.max_iter = max_iter;
122        self
123    }
124
125    /// Set the number of initializations
126    pub fn n_init(mut self, n_init: usize) -> Self {
127        self.n_init = n_init;
128        self
129    }
130
131    /// Set the expected fraction of outliers
132    pub fn outlier_fraction(mut self, outlier_fraction: f64) -> Self {
133        self.outlier_fraction = outlier_fraction.clamp(0.0, 0.5);
134        self
135    }
136
137    /// Set the outlier detection threshold (in standard deviations)
138    pub fn outlier_threshold(mut self, outlier_threshold: f64) -> Self {
139        self.outlier_threshold = outlier_threshold;
140        self
141    }
142
143    /// Set whether to use robust covariance estimation
144    pub fn robust_covariance(mut self, robust_covariance: bool) -> Self {
145        self.robust_covariance = robust_covariance;
146        self
147    }
148
149    /// Set the random state
150    pub fn random_state(mut self, random_state: u64) -> Self {
151        self.random_state = Some(random_state);
152        self
153    }
154}
155
156impl Default for RobustGaussianMixture<Untrained> {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl Estimator for RobustGaussianMixture<Untrained> {
163    type Config = ();
164    type Error = SklearsError;
165    type Float = Float;
166
167    fn config(&self) -> &Self::Config {
168        &()
169    }
170}
171
172impl Fit<ArrayView2<'_, Float>, ()> for RobustGaussianMixture<Untrained> {
173    type Fitted = RobustGaussianMixture<RobustGaussianMixtureTrained>;
174
175    #[allow(non_snake_case)]
176    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
177        let X = X.to_owned();
178        let (n_samples, n_features) = X.dim();
179
180        if n_samples < self.n_components {
181            return Err(SklearsError::InvalidInput(
182                "Number of samples must be at least the number of components".to_string(),
183            ));
184        }
185
186        if self.n_components == 0 {
187            return Err(SklearsError::InvalidInput(
188                "Number of components must be positive".to_string(),
189            ));
190        }
191
192        let mut best_params = None;
193        let mut best_log_likelihood = f64::NEG_INFINITY;
194        let mut best_n_iter = 0;
195        let mut best_converged = false;
196        let mut best_outlier_mask = None;
197
198        // Run multiple initializations and keep the best
199        for init_run in 0..self.n_init {
200            let seed = self.random_state.map(|s| s + init_run as u64);
201
202            // Initialize parameters
203            let (mut weights, mut means, mut covariances) = self.initialize_parameters(&X, seed)?;
204
205            let mut log_likelihood = f64::NEG_INFINITY;
206            let mut converged = false;
207            let mut n_iter = 0;
208            let mut outlier_mask = Array1::from_elem(n_samples, false);
209
210            // Robust EM iterations
211            for iteration in 0..self.max_iter {
212                n_iter = iteration + 1;
213
214                // E-step: Compute responsibilities with outlier detection
215                let responsibilities = self.compute_robust_responsibilities(
216                    &X,
217                    &weights,
218                    &means,
219                    &covariances,
220                    &mut outlier_mask,
221                )?;
222
223                // M-step: Update parameters with outlier weighting
224                let (new_weights, new_means, new_covariances) =
225                    self.update_robust_parameters(&X, &responsibilities, &outlier_mask)?;
226
227                // Compute trimmed log-likelihood (excluding outliers)
228                let new_log_likelihood = self.compute_trimmed_log_likelihood(
229                    &X,
230                    &new_weights,
231                    &new_means,
232                    &new_covariances,
233                    &outlier_mask,
234                )?;
235
236                // Check convergence
237                if iteration > 0 && (new_log_likelihood - log_likelihood).abs() < self.tol {
238                    converged = true;
239                }
240
241                weights = new_weights;
242                means = new_means;
243                covariances = new_covariances;
244                log_likelihood = new_log_likelihood;
245
246                if converged {
247                    break;
248                }
249            }
250
251            // Keep track of best parameters
252            if log_likelihood > best_log_likelihood {
253                best_log_likelihood = log_likelihood;
254                best_params = Some((weights, means, covariances));
255                best_n_iter = n_iter;
256                best_converged = converged;
257                best_outlier_mask = Some(outlier_mask);
258            }
259        }
260
261        let (weights, means, covariances) = best_params.unwrap();
262        let outlier_mask = best_outlier_mask.unwrap();
263
264        // Calculate model selection criteria
265        let n_params =
266            ModelSelection::n_parameters(self.n_components, n_features, &self.covariance_type);
267        let bic = ModelSelection::bic(best_log_likelihood, n_params, n_samples);
268        let aic = ModelSelection::aic(best_log_likelihood, n_params);
269
270        // Count detected outliers
271        let n_outliers = outlier_mask.iter().filter(|&&x| x).count();
272
273        Ok(RobustGaussianMixture {
274            state: RobustGaussianMixtureTrained {
275                weights,
276                means,
277                covariances,
278                log_likelihood: best_log_likelihood,
279                n_iter: best_n_iter,
280                converged: best_converged,
281                bic,
282                aic,
283                outlier_mask,
284                n_outliers,
285            },
286            n_components: self.n_components,
287            covariance_type: self.covariance_type,
288            tol: self.tol,
289            reg_covar: self.reg_covar,
290            max_iter: self.max_iter,
291            n_init: self.n_init,
292            outlier_fraction: self.outlier_fraction,
293            outlier_threshold: self.outlier_threshold,
294            robust_covariance: self.robust_covariance,
295            random_state: self.random_state,
296        })
297    }
298}
299
300impl RobustGaussianMixture<Untrained> {
301    /// Initialize parameters for robust EM algorithm
302    fn initialize_parameters(
303        &self,
304        X: &Array2<f64>,
305        seed: Option<u64>,
306    ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
307        // Initialize weights (uniform)
308        let weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
309
310        // Initialize means using robust initialization
311        let means = self.initialize_robust_means(X, seed)?;
312
313        // Initialize covariances
314        let covariances = self.initialize_robust_covariances(X, &means)?;
315
316        Ok((weights, means, covariances))
317    }
318
319    /// Initialize means using robust k-means++ style initialization
320    fn initialize_robust_means(
321        &self,
322        X: &Array2<f64>,
323        seed: Option<u64>,
324    ) -> SklResult<Array2<f64>> {
325        let (n_samples, n_features) = X.dim();
326        let mut means = Array2::zeros((self.n_components, n_features));
327
328        // Use median-based initialization for robustness
329        for i in 0..self.n_components {
330            let step = n_samples / self.n_components;
331            let sample_idx = if step == 0 {
332                i.min(n_samples - 1)
333            } else {
334                (i * step).min(n_samples - 1)
335            };
336
337            let mut mean = means.row_mut(i);
338            mean.assign(&X.row(sample_idx));
339
340            // Add small perturbation if seed is provided
341            if let Some(_seed) = seed {
342                for j in 0..n_features {
343                    mean[j] += 0.01 * (i as f64 - self.n_components as f64 / 2.0);
344                }
345            }
346        }
347
348        Ok(means)
349    }
350
351    /// Initialize covariances with robust estimation
352    fn initialize_robust_covariances(
353        &self,
354        X: &Array2<f64>,
355        _means: &Array2<f64>,
356    ) -> SklResult<Vec<Array2<f64>>> {
357        let (_, n_features) = X.dim();
358        let mut covariances = Vec::new();
359
360        // Use robust scale estimation
361        let robust_scale = if self.robust_covariance {
362            self.estimate_robust_scale(X)?
363        } else {
364            1.0
365        };
366
367        match self.covariance_type {
368            CovarianceType::Full => {
369                for _ in 0..self.n_components {
370                    let mut cov = Array2::eye(n_features);
371                    for i in 0..n_features {
372                        cov[[i, i]] = robust_scale + self.reg_covar;
373                    }
374                    covariances.push(cov);
375                }
376            }
377            CovarianceType::Diagonal => {
378                for _ in 0..self.n_components {
379                    let mut cov = Array2::zeros((n_features, n_features));
380                    for i in 0..n_features {
381                        cov[[i, i]] = robust_scale + self.reg_covar;
382                    }
383                    covariances.push(cov);
384                }
385            }
386            CovarianceType::Tied => {
387                let mut cov = Array2::eye(n_features);
388                for i in 0..n_features {
389                    cov[[i, i]] = robust_scale + self.reg_covar;
390                }
391                for _ in 0..self.n_components {
392                    covariances.push(cov.clone());
393                }
394            }
395            CovarianceType::Spherical => {
396                for _ in 0..self.n_components {
397                    let mut cov = Array2::zeros((n_features, n_features));
398                    for i in 0..n_features {
399                        cov[[i, i]] = robust_scale + self.reg_covar;
400                    }
401                    covariances.push(cov);
402                }
403            }
404        }
405
406        Ok(covariances)
407    }
408
409    /// Estimate robust scale using median absolute deviation
410    fn estimate_robust_scale(&self, X: &Array2<f64>) -> SklResult<f64> {
411        let (n_samples, n_features) = X.dim();
412        let mut all_deviations = Vec::new();
413
414        // Calculate median for each feature
415        for j in 0..n_features {
416            let mut feature_values: Vec<f64> = X.column(j).to_vec();
417            feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
418            let median = if n_samples % 2 == 0 {
419                (feature_values[n_samples / 2 - 1] + feature_values[n_samples / 2]) / 2.0
420            } else {
421                feature_values[n_samples / 2]
422            };
423
424            // Calculate absolute deviations from median
425            for i in 0..n_samples {
426                all_deviations.push((X[[i, j]] - median).abs());
427            }
428        }
429
430        // Calculate median absolute deviation
431        all_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
432        let mad = if all_deviations.len() % 2 == 0 {
433            (all_deviations[all_deviations.len() / 2 - 1]
434                + all_deviations[all_deviations.len() / 2])
435                / 2.0
436        } else {
437            all_deviations[all_deviations.len() / 2]
438        };
439
440        // Scale MAD to approximate standard deviation (factor 1.4826 for normal distribution)
441        Ok((mad * 1.4826).max(1e-6))
442    }
443
444    /// Compute responsibilities with outlier detection (robust E-step)
445    fn compute_robust_responsibilities(
446        &self,
447        X: &Array2<f64>,
448        weights: &Array1<f64>,
449        means: &Array2<f64>,
450        covariances: &[Array2<f64>],
451        outlier_mask: &mut Array1<bool>,
452    ) -> SklResult<Array2<f64>> {
453        let (n_samples, _) = X.dim();
454        let mut responsibilities = Array2::zeros((n_samples, self.n_components));
455        let mut sample_likelihoods = Array1::zeros(n_samples);
456
457        // Compute standard responsibilities first
458        for i in 0..n_samples {
459            let sample = X.row(i);
460            let mut log_prob_sum = f64::NEG_INFINITY;
461            let mut log_probs = Vec::new();
462
463            for k in 0..self.n_components {
464                let mean = means.row(k);
465                let cov = &covariances[k];
466                let log_weight = weights[k].ln();
467                let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
468                let log_prob = log_weight + log_likelihood;
469                log_probs.push(log_prob);
470                log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
471            }
472
473            sample_likelihoods[i] = log_prob_sum;
474
475            // Normalize to get responsibilities
476            for k in 0..self.n_components {
477                responsibilities[[i, k]] = (log_probs[k] - log_prob_sum).exp();
478            }
479        }
480
481        // Detect outliers based on likelihood threshold
482        self.detect_outliers(&sample_likelihoods, outlier_mask)?;
483
484        // Down-weight outliers in responsibilities
485        for i in 0..n_samples {
486            if outlier_mask[i] {
487                // Reduce responsibility for outliers
488                for k in 0..self.n_components {
489                    responsibilities[[i, k]] *= 0.1; // Down-weight factor
490                }
491            }
492        }
493
494        Ok(responsibilities)
495    }
496
497    /// Detect outliers based on likelihood values
498    fn detect_outliers(
499        &self,
500        sample_likelihoods: &Array1<f64>,
501        outlier_mask: &mut Array1<bool>,
502    ) -> SklResult<()> {
503        let n_samples = sample_likelihoods.len();
504
505        // Calculate robust threshold using percentile
506        let mut sorted_likelihoods = sample_likelihoods.to_vec();
507        sorted_likelihoods.sort_by(|a, b| a.partial_cmp(b).unwrap());
508
509        // Use lower percentile as threshold
510        let threshold_idx = ((1.0 - self.outlier_fraction) * n_samples as f64) as usize;
511        let threshold_idx = threshold_idx.min(n_samples - 1);
512        let likelihood_threshold = sorted_likelihoods[threshold_idx];
513
514        // Mark samples with likelihood below threshold as outliers
515        for i in 0..n_samples {
516            outlier_mask[i] = sample_likelihoods[i] < likelihood_threshold;
517        }
518
519        Ok(())
520    }
521
522    /// Update parameters with outlier weighting (robust M-step)
523    fn update_robust_parameters(
524        &self,
525        X: &Array2<f64>,
526        responsibilities: &Array2<f64>,
527        outlier_mask: &Array1<bool>,
528    ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
529        let (n_samples, n_features) = X.dim();
530
531        // Calculate effective sample weights (down-weight outliers)
532        let mut effective_responsibilities = responsibilities.clone();
533        for i in 0..n_samples {
534            if outlier_mask[i] {
535                for k in 0..self.n_components {
536                    effective_responsibilities[[i, k]] *= 0.1;
537                }
538            }
539        }
540
541        // Update weights
542        let n_k: Array1<f64> = effective_responsibilities.sum_axis(Axis(0));
543        let total_weight = n_k.sum();
544        let weights = &n_k / total_weight;
545
546        // Update means
547        let mut means = Array2::zeros((self.n_components, n_features));
548        for k in 0..self.n_components {
549            if n_k[k] > 1e-10 {
550                for i in 0..n_samples {
551                    for j in 0..n_features {
552                        means[[k, j]] += effective_responsibilities[[i, k]] * X[[i, j]];
553                    }
554                }
555                for j in 0..n_features {
556                    means[[k, j]] /= n_k[k];
557                }
558            }
559        }
560
561        // Update covariances with robust estimation
562        let covariances =
563            self.update_robust_covariances(X, &effective_responsibilities, &means, &n_k)?;
564
565        Ok((weights, means, covariances))
566    }
567
568    /// Update covariances with robust estimation
569    fn update_robust_covariances(
570        &self,
571        X: &Array2<f64>,
572        responsibilities: &Array2<f64>,
573        means: &Array2<f64>,
574        n_k: &Array1<f64>,
575    ) -> SklResult<Vec<Array2<f64>>> {
576        let (n_samples, n_features) = X.dim();
577        let mut covariances = Vec::new();
578
579        match self.covariance_type {
580            CovarianceType::Full => {
581                for k in 0..self.n_components {
582                    let mut cov = Array2::zeros((n_features, n_features));
583
584                    if n_k[k] > 1e-10 {
585                        let mean_k = means.row(k);
586
587                        for i in 0..n_samples {
588                            let sample = X.row(i);
589                            let diff = &sample - &mean_k;
590
591                            for d1 in 0..n_features {
592                                for d2 in 0..n_features {
593                                    cov[[d1, d2]] += responsibilities[[i, k]] * diff[d1] * diff[d2];
594                                }
595                            }
596                        }
597
598                        for d1 in 0..n_features {
599                            for d2 in 0..n_features {
600                                cov[[d1, d2]] /= n_k[k];
601                            }
602                        }
603
604                        // Add robust regularization
605                        let robust_reg = if self.robust_covariance {
606                            self.reg_covar * 10.0 // Stronger regularization for robustness
607                        } else {
608                            self.reg_covar
609                        };
610
611                        for d in 0..n_features {
612                            cov[[d, d]] += robust_reg;
613                        }
614                    } else {
615                        // Empty component: use robust identity
616                        for d in 0..n_features {
617                            cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
618                        }
619                    }
620
621                    covariances.push(cov);
622                }
623            }
624            CovarianceType::Diagonal => {
625                for k in 0..self.n_components {
626                    let mut cov = Array2::zeros((n_features, n_features));
627
628                    if n_k[k] > 1e-10 {
629                        let mean_k = means.row(k);
630
631                        for d in 0..n_features {
632                            let mut var = 0.0;
633                            for i in 0..n_samples {
634                                let diff = X[[i, d]] - mean_k[d];
635                                var += responsibilities[[i, k]] * diff * diff;
636                            }
637                            var /= n_k[k];
638
639                            let robust_reg = if self.robust_covariance {
640                                self.reg_covar * 10.0
641                            } else {
642                                self.reg_covar
643                            };
644
645                            cov[[d, d]] = var + robust_reg;
646                        }
647                    } else {
648                        for d in 0..n_features {
649                            cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
650                        }
651                    }
652
653                    covariances.push(cov);
654                }
655            }
656            CovarianceType::Tied => {
657                let mut cov = Array2::zeros((n_features, n_features));
658                let total_responsibility: f64 = n_k.sum();
659
660                if total_responsibility > 1e-10 {
661                    for k in 0..self.n_components {
662                        let mean_k = means.row(k);
663
664                        for i in 0..n_samples {
665                            let sample = X.row(i);
666                            let diff = &sample - &mean_k;
667
668                            for d1 in 0..n_features {
669                                for d2 in 0..n_features {
670                                    cov[[d1, d2]] += responsibilities[[i, k]] * diff[d1] * diff[d2];
671                                }
672                            }
673                        }
674                    }
675
676                    for d1 in 0..n_features {
677                        for d2 in 0..n_features {
678                            cov[[d1, d2]] /= total_responsibility;
679                        }
680                    }
681
682                    let robust_reg = if self.robust_covariance {
683                        self.reg_covar * 10.0
684                    } else {
685                        self.reg_covar
686                    };
687
688                    for d in 0..n_features {
689                        cov[[d, d]] += robust_reg;
690                    }
691                } else {
692                    for d in 0..n_features {
693                        cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
694                    }
695                }
696
697                for _ in 0..self.n_components {
698                    covariances.push(cov.clone());
699                }
700            }
701            CovarianceType::Spherical => {
702                for k in 0..self.n_components {
703                    let mut cov = Array2::zeros((n_features, n_features));
704
705                    if n_k[k] > 1e-10 {
706                        let mean_k = means.row(k);
707                        let mut total_var = 0.0;
708
709                        for i in 0..n_samples {
710                            for d in 0..n_features {
711                                let diff = X[[i, d]] - mean_k[d];
712                                total_var += responsibilities[[i, k]] * diff * diff;
713                            }
714                        }
715
716                        total_var /= n_k[k] * n_features as f64;
717
718                        let robust_reg = if self.robust_covariance {
719                            self.reg_covar * 10.0
720                        } else {
721                            self.reg_covar
722                        };
723
724                        let variance = total_var + robust_reg;
725
726                        for d in 0..n_features {
727                            cov[[d, d]] = variance;
728                        }
729                    } else {
730                        for d in 0..n_features {
731                            cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
732                        }
733                    }
734
735                    covariances.push(cov);
736                }
737            }
738        }
739
740        Ok(covariances)
741    }
742
743    /// Compute trimmed log-likelihood (excluding outliers)
744    fn compute_trimmed_log_likelihood(
745        &self,
746        X: &Array2<f64>,
747        weights: &Array1<f64>,
748        means: &Array2<f64>,
749        covariances: &[Array2<f64>],
750        outlier_mask: &Array1<bool>,
751    ) -> SklResult<f64> {
752        let (n_samples, _) = X.dim();
753        let mut total_log_likelihood = 0.0;
754        let mut n_included = 0;
755
756        for i in 0..n_samples {
757            if !outlier_mask[i] {
758                // Only include non-outliers
759                let sample = X.row(i);
760                let mut log_prob_sum = f64::NEG_INFINITY;
761
762                for k in 0..self.n_components {
763                    let mean = means.row(k);
764                    let cov = &covariances[k];
765                    let log_weight = weights[k].ln();
766                    let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
767                    let log_prob = log_weight + log_likelihood;
768                    log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
769                }
770
771                total_log_likelihood += log_prob_sum;
772                n_included += 1;
773            }
774        }
775
776        // Return average log-likelihood of non-outliers
777        if n_included > 0 {
778            Ok(total_log_likelihood / n_included as f64)
779        } else {
780            Ok(f64::NEG_INFINITY)
781        }
782    }
783
784    /// Compute multivariate normal log probability density function
785    fn multivariate_normal_log_pdf(
786        &self,
787        x: &ArrayView1<f64>,
788        mean: &ArrayView1<f64>,
789        cov: &Array2<f64>,
790    ) -> SklResult<f64> {
791        let d = x.len() as f64;
792        let diff: Array1<f64> = x - mean;
793
794        match self.covariance_type {
795            CovarianceType::Full => {
796                let mut log_det = 0.0;
797                let mut quad_form = 0.0;
798
799                for i in 0..cov.nrows() {
800                    if cov[[i, i]] <= 0.0 {
801                        return Err(SklearsError::InvalidInput(
802                            "Covariance matrix has non-positive diagonal elements".to_string(),
803                        ));
804                    }
805                    log_det += cov[[i, i]].ln();
806                    quad_form += diff[i] * diff[i] / cov[[i, i]];
807                }
808
809                let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
810                Ok(log_pdf)
811            }
812            CovarianceType::Diagonal | CovarianceType::Tied | CovarianceType::Spherical => {
813                let mut log_det = 0.0;
814                let mut quad_form = 0.0;
815
816                for i in 0..diff.len() {
817                    if cov[[i, i]] <= 0.0 {
818                        return Err(SklearsError::InvalidInput(
819                            "Covariance matrix has non-positive diagonal elements".to_string(),
820                        ));
821                    }
822                    log_det += cov[[i, i]].ln();
823                    quad_form += diff[i] * diff[i] / cov[[i, i]];
824                }
825
826                let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
827                Ok(log_pdf)
828            }
829        }
830    }
831}
832
833/// Trained state for RobustGaussianMixture
834#[derive(Debug, Clone)]
835pub struct RobustGaussianMixtureTrained {
836    /// Mixture component weights
837    pub weights: Array1<f64>,
838    /// Component means
839    pub means: Array2<f64>,
840    /// Component covariance matrices or parameters
841    pub covariances: Vec<Array2<f64>>,
842    /// Log likelihood of the fitted model
843    pub log_likelihood: f64,
844    /// Number of iterations performed
845    pub n_iter: usize,
846    /// Whether the algorithm converged
847    pub converged: bool,
848    /// Bayesian Information Criterion
849    pub bic: f64,
850    /// Akaike Information Criterion
851    pub aic: f64,
852    /// Mask indicating which samples are detected as outliers
853    pub outlier_mask: Array1<bool>,
854    /// Number of detected outliers
855    pub n_outliers: usize,
856}
857
858impl Predict<ArrayView2<'_, Float>, Array1<i32>>
859    for RobustGaussianMixture<RobustGaussianMixtureTrained>
860{
861    #[allow(non_snake_case)]
862    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
863        let X = X.to_owned();
864        let (n_samples, _) = X.dim();
865        let mut predictions = Array1::zeros(n_samples);
866
867        for i in 0..n_samples {
868            let sample = X.row(i);
869            let mut max_log_prob = f64::NEG_INFINITY;
870            let mut best_component = 0;
871
872            for k in 0..self.n_components {
873                let mean = self.state.means.row(k);
874                let cov = &self.state.covariances[k];
875                let log_weight = self.state.weights[k].ln();
876
877                if let Ok(log_likelihood) = self.multivariate_normal_log_pdf(&sample, &mean, cov) {
878                    let log_prob = log_weight + log_likelihood;
879                    if log_prob > max_log_prob {
880                        max_log_prob = log_prob;
881                        best_component = k;
882                    }
883                }
884            }
885
886            predictions[i] = best_component as i32;
887        }
888
889        Ok(predictions)
890    }
891}
892
893impl RobustGaussianMixture<RobustGaussianMixtureTrained> {
894    /// Get the mixture weights
895    pub fn weights(&self) -> &Array1<f64> {
896        &self.state.weights
897    }
898
899    /// Get the component means
900    pub fn means(&self) -> &Array2<f64> {
901        &self.state.means
902    }
903
904    /// Get the component covariances
905    pub fn covariances(&self) -> &[Array2<f64>] {
906        &self.state.covariances
907    }
908
909    /// Get the log likelihood of the fitted model
910    pub fn log_likelihood(&self) -> f64 {
911        self.state.log_likelihood
912    }
913
914    /// Get the number of iterations performed
915    pub fn n_iter(&self) -> usize {
916        self.state.n_iter
917    }
918
919    /// Check if the algorithm converged
920    pub fn converged(&self) -> bool {
921        self.state.converged
922    }
923
924    /// Get the Bayesian Information Criterion
925    pub fn bic(&self) -> f64 {
926        self.state.bic
927    }
928
929    /// Get the Akaike Information Criterion
930    pub fn aic(&self) -> f64 {
931        self.state.aic
932    }
933
934    /// Get the outlier mask
935    pub fn outlier_mask(&self) -> &Array1<bool> {
936        &self.state.outlier_mask
937    }
938
939    /// Get the number of detected outliers
940    pub fn n_outliers(&self) -> usize {
941        self.state.n_outliers
942    }
943
944    /// Predict probabilities for each component
945    #[allow(non_snake_case)]
946    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
947        let X = X.to_owned();
948        let (n_samples, _) = X.dim();
949        let mut probabilities = Array2::zeros((n_samples, self.n_components));
950
951        for i in 0..n_samples {
952            let sample = X.row(i);
953            let mut log_prob_sum = f64::NEG_INFINITY;
954            let mut log_probs = Vec::new();
955
956            // Compute log probabilities for each component
957            for k in 0..self.n_components {
958                let mean = self.state.means.row(k);
959                let cov = &self.state.covariances[k];
960                let log_weight = self.state.weights[k].ln();
961                let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
962                let log_prob = log_weight + log_likelihood;
963                log_probs.push(log_prob);
964                log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
965            }
966
967            // Normalize to get probabilities
968            for k in 0..self.n_components {
969                probabilities[[i, k]] = (log_probs[k] - log_prob_sum).exp();
970            }
971        }
972
973        Ok(probabilities)
974    }
975
976    /// Compute the per-sample log-likelihood
977    #[allow(non_snake_case)]
978    pub fn score_samples(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
979        let X = X.to_owned();
980        let (n_samples, _) = X.dim();
981        let mut scores = Array1::zeros(n_samples);
982
983        for i in 0..n_samples {
984            let sample = X.row(i);
985            let mut log_prob_sum = f64::NEG_INFINITY;
986
987            for k in 0..self.n_components {
988                let mean = self.state.means.row(k);
989                let cov = &self.state.covariances[k];
990                let log_weight = self.state.weights[k].ln();
991                let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
992                let log_prob = log_weight + log_likelihood;
993                log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
994            }
995
996            scores[i] = log_prob_sum;
997        }
998
999        Ok(scores)
1000    }
1001
1002    /// Compute the average log-likelihood
1003    pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
1004        let scores = self.score_samples(X)?;
1005        Ok(scores.mean().unwrap_or(0.0))
1006    }
1007
1008    /// Detect outliers in new data
1009    pub fn detect_outliers(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<bool>> {
1010        let scores = self.score_samples(X)?;
1011        let n_samples = scores.len();
1012        let mut outlier_mask = Array1::from_elem(n_samples, false);
1013
1014        // Use same outlier detection logic as training
1015        let mut sorted_scores = scores.to_vec();
1016        sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
1017
1018        let threshold_idx = ((1.0 - self.outlier_fraction) * n_samples as f64) as usize;
1019        let threshold_idx = threshold_idx.min(n_samples - 1);
1020        let score_threshold = sorted_scores[threshold_idx];
1021
1022        for i in 0..n_samples {
1023            outlier_mask[i] = scores[i] < score_threshold;
1024        }
1025
1026        Ok(outlier_mask)
1027    }
1028
1029    fn multivariate_normal_log_pdf(
1030        &self,
1031        x: &ArrayView1<f64>,
1032        mean: &ArrayView1<f64>,
1033        cov: &Array2<f64>,
1034    ) -> SklResult<f64> {
1035        let d = x.len() as f64;
1036        let diff: Array1<f64> = x - mean;
1037
1038        match self.covariance_type {
1039            CovarianceType::Full => {
1040                let mut log_det = 0.0;
1041                let mut quad_form = 0.0;
1042
1043                for i in 0..cov.nrows() {
1044                    if cov[[i, i]] <= 0.0 {
1045                        return Err(SklearsError::InvalidInput(
1046                            "Covariance matrix has non-positive diagonal elements".to_string(),
1047                        ));
1048                    }
1049                    log_det += cov[[i, i]].ln();
1050                    quad_form += diff[i] * diff[i] / cov[[i, i]];
1051                }
1052
1053                let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
1054                Ok(log_pdf)
1055            }
1056            CovarianceType::Diagonal | CovarianceType::Tied | CovarianceType::Spherical => {
1057                let mut log_det = 0.0;
1058                let mut quad_form = 0.0;
1059
1060                for i in 0..diff.len() {
1061                    if cov[[i, i]] <= 0.0 {
1062                        return Err(SklearsError::InvalidInput(
1063                            "Covariance matrix has non-positive diagonal elements".to_string(),
1064                        ));
1065                    }
1066                    log_det += cov[[i, i]].ln();
1067                    quad_form += diff[i] * diff[i] / cov[[i, i]];
1068                }
1069
1070                let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
1071                Ok(log_pdf)
1072            }
1073        }
1074    }
1075}