1use crate::common::{CovarianceType, ModelSelection};
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14use std::f64::consts::PI;
15
16fn log_sum_exp(a: f64, b: f64) -> f64 {
18 let max_val = a.max(b);
19 if max_val.is_finite() {
20 max_val + ((a - max_val).exp() + (b - max_val).exp()).ln()
21 } else {
22 max_val
23 }
24}
25
26#[derive(Debug, Clone)]
63pub struct RobustGaussianMixture<S = Untrained> {
64 pub(crate) state: S,
65 n_components: usize,
66 covariance_type: CovarianceType,
67 tol: f64,
68 reg_covar: f64,
69 max_iter: usize,
70 n_init: usize,
71 outlier_fraction: f64,
72 outlier_threshold: f64,
73 robust_covariance: bool,
74 random_state: Option<u64>,
75}
76
77impl RobustGaussianMixture<Untrained> {
78 pub fn new() -> Self {
80 Self {
81 state: Untrained,
82 n_components: 1,
83 covariance_type: CovarianceType::Full,
84 tol: 1e-3,
85 reg_covar: 1e-6,
86 max_iter: 100,
87 n_init: 1,
88 outlier_fraction: 0.1,
89 outlier_threshold: 3.0,
90 robust_covariance: true,
91 random_state: None,
92 }
93 }
94
95 pub fn n_components(mut self, n_components: usize) -> Self {
97 self.n_components = n_components;
98 self
99 }
100
101 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
103 self.covariance_type = covariance_type;
104 self
105 }
106
107 pub fn tol(mut self, tol: f64) -> Self {
109 self.tol = tol;
110 self
111 }
112
113 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
115 self.reg_covar = reg_covar;
116 self
117 }
118
119 pub fn max_iter(mut self, max_iter: usize) -> Self {
121 self.max_iter = max_iter;
122 self
123 }
124
125 pub fn n_init(mut self, n_init: usize) -> Self {
127 self.n_init = n_init;
128 self
129 }
130
131 pub fn outlier_fraction(mut self, outlier_fraction: f64) -> Self {
133 self.outlier_fraction = outlier_fraction.clamp(0.0, 0.5);
134 self
135 }
136
137 pub fn outlier_threshold(mut self, outlier_threshold: f64) -> Self {
139 self.outlier_threshold = outlier_threshold;
140 self
141 }
142
143 pub fn robust_covariance(mut self, robust_covariance: bool) -> Self {
145 self.robust_covariance = robust_covariance;
146 self
147 }
148
149 pub fn random_state(mut self, random_state: u64) -> Self {
151 self.random_state = Some(random_state);
152 self
153 }
154}
155
156impl Default for RobustGaussianMixture<Untrained> {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl Estimator for RobustGaussianMixture<Untrained> {
163 type Config = ();
164 type Error = SklearsError;
165 type Float = Float;
166
167 fn config(&self) -> &Self::Config {
168 &()
169 }
170}
171
172impl Fit<ArrayView2<'_, Float>, ()> for RobustGaussianMixture<Untrained> {
173 type Fitted = RobustGaussianMixture<RobustGaussianMixtureTrained>;
174
175 #[allow(non_snake_case)]
176 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
177 let X = X.to_owned();
178 let (n_samples, n_features) = X.dim();
179
180 if n_samples < self.n_components {
181 return Err(SklearsError::InvalidInput(
182 "Number of samples must be at least the number of components".to_string(),
183 ));
184 }
185
186 if self.n_components == 0 {
187 return Err(SklearsError::InvalidInput(
188 "Number of components must be positive".to_string(),
189 ));
190 }
191
192 let mut best_params = None;
193 let mut best_log_likelihood = f64::NEG_INFINITY;
194 let mut best_n_iter = 0;
195 let mut best_converged = false;
196 let mut best_outlier_mask = None;
197
198 for init_run in 0..self.n_init {
200 let seed = self.random_state.map(|s| s + init_run as u64);
201
202 let (mut weights, mut means, mut covariances) = self.initialize_parameters(&X, seed)?;
204
205 let mut log_likelihood = f64::NEG_INFINITY;
206 let mut converged = false;
207 let mut n_iter = 0;
208 let mut outlier_mask = Array1::from_elem(n_samples, false);
209
210 for iteration in 0..self.max_iter {
212 n_iter = iteration + 1;
213
214 let responsibilities = self.compute_robust_responsibilities(
216 &X,
217 &weights,
218 &means,
219 &covariances,
220 &mut outlier_mask,
221 )?;
222
223 let (new_weights, new_means, new_covariances) =
225 self.update_robust_parameters(&X, &responsibilities, &outlier_mask)?;
226
227 let new_log_likelihood = self.compute_trimmed_log_likelihood(
229 &X,
230 &new_weights,
231 &new_means,
232 &new_covariances,
233 &outlier_mask,
234 )?;
235
236 if iteration > 0 && (new_log_likelihood - log_likelihood).abs() < self.tol {
238 converged = true;
239 }
240
241 weights = new_weights;
242 means = new_means;
243 covariances = new_covariances;
244 log_likelihood = new_log_likelihood;
245
246 if converged {
247 break;
248 }
249 }
250
251 if log_likelihood > best_log_likelihood {
253 best_log_likelihood = log_likelihood;
254 best_params = Some((weights, means, covariances));
255 best_n_iter = n_iter;
256 best_converged = converged;
257 best_outlier_mask = Some(outlier_mask);
258 }
259 }
260
261 let (weights, means, covariances) = best_params.unwrap();
262 let outlier_mask = best_outlier_mask.unwrap();
263
264 let n_params =
266 ModelSelection::n_parameters(self.n_components, n_features, &self.covariance_type);
267 let bic = ModelSelection::bic(best_log_likelihood, n_params, n_samples);
268 let aic = ModelSelection::aic(best_log_likelihood, n_params);
269
270 let n_outliers = outlier_mask.iter().filter(|&&x| x).count();
272
273 Ok(RobustGaussianMixture {
274 state: RobustGaussianMixtureTrained {
275 weights,
276 means,
277 covariances,
278 log_likelihood: best_log_likelihood,
279 n_iter: best_n_iter,
280 converged: best_converged,
281 bic,
282 aic,
283 outlier_mask,
284 n_outliers,
285 },
286 n_components: self.n_components,
287 covariance_type: self.covariance_type,
288 tol: self.tol,
289 reg_covar: self.reg_covar,
290 max_iter: self.max_iter,
291 n_init: self.n_init,
292 outlier_fraction: self.outlier_fraction,
293 outlier_threshold: self.outlier_threshold,
294 robust_covariance: self.robust_covariance,
295 random_state: self.random_state,
296 })
297 }
298}
299
300impl RobustGaussianMixture<Untrained> {
301 fn initialize_parameters(
303 &self,
304 X: &Array2<f64>,
305 seed: Option<u64>,
306 ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
307 let weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
309
310 let means = self.initialize_robust_means(X, seed)?;
312
313 let covariances = self.initialize_robust_covariances(X, &means)?;
315
316 Ok((weights, means, covariances))
317 }
318
319 fn initialize_robust_means(
321 &self,
322 X: &Array2<f64>,
323 seed: Option<u64>,
324 ) -> SklResult<Array2<f64>> {
325 let (n_samples, n_features) = X.dim();
326 let mut means = Array2::zeros((self.n_components, n_features));
327
328 for i in 0..self.n_components {
330 let step = n_samples / self.n_components;
331 let sample_idx = if step == 0 {
332 i.min(n_samples - 1)
333 } else {
334 (i * step).min(n_samples - 1)
335 };
336
337 let mut mean = means.row_mut(i);
338 mean.assign(&X.row(sample_idx));
339
340 if let Some(_seed) = seed {
342 for j in 0..n_features {
343 mean[j] += 0.01 * (i as f64 - self.n_components as f64 / 2.0);
344 }
345 }
346 }
347
348 Ok(means)
349 }
350
351 fn initialize_robust_covariances(
353 &self,
354 X: &Array2<f64>,
355 _means: &Array2<f64>,
356 ) -> SklResult<Vec<Array2<f64>>> {
357 let (_, n_features) = X.dim();
358 let mut covariances = Vec::new();
359
360 let robust_scale = if self.robust_covariance {
362 self.estimate_robust_scale(X)?
363 } else {
364 1.0
365 };
366
367 match self.covariance_type {
368 CovarianceType::Full => {
369 for _ in 0..self.n_components {
370 let mut cov = Array2::eye(n_features);
371 for i in 0..n_features {
372 cov[[i, i]] = robust_scale + self.reg_covar;
373 }
374 covariances.push(cov);
375 }
376 }
377 CovarianceType::Diagonal => {
378 for _ in 0..self.n_components {
379 let mut cov = Array2::zeros((n_features, n_features));
380 for i in 0..n_features {
381 cov[[i, i]] = robust_scale + self.reg_covar;
382 }
383 covariances.push(cov);
384 }
385 }
386 CovarianceType::Tied => {
387 let mut cov = Array2::eye(n_features);
388 for i in 0..n_features {
389 cov[[i, i]] = robust_scale + self.reg_covar;
390 }
391 for _ in 0..self.n_components {
392 covariances.push(cov.clone());
393 }
394 }
395 CovarianceType::Spherical => {
396 for _ in 0..self.n_components {
397 let mut cov = Array2::zeros((n_features, n_features));
398 for i in 0..n_features {
399 cov[[i, i]] = robust_scale + self.reg_covar;
400 }
401 covariances.push(cov);
402 }
403 }
404 }
405
406 Ok(covariances)
407 }
408
409 fn estimate_robust_scale(&self, X: &Array2<f64>) -> SklResult<f64> {
411 let (n_samples, n_features) = X.dim();
412 let mut all_deviations = Vec::new();
413
414 for j in 0..n_features {
416 let mut feature_values: Vec<f64> = X.column(j).to_vec();
417 feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
418 let median = if n_samples % 2 == 0 {
419 (feature_values[n_samples / 2 - 1] + feature_values[n_samples / 2]) / 2.0
420 } else {
421 feature_values[n_samples / 2]
422 };
423
424 for i in 0..n_samples {
426 all_deviations.push((X[[i, j]] - median).abs());
427 }
428 }
429
430 all_deviations.sort_by(|a, b| a.partial_cmp(b).unwrap());
432 let mad = if all_deviations.len() % 2 == 0 {
433 (all_deviations[all_deviations.len() / 2 - 1]
434 + all_deviations[all_deviations.len() / 2])
435 / 2.0
436 } else {
437 all_deviations[all_deviations.len() / 2]
438 };
439
440 Ok((mad * 1.4826).max(1e-6))
442 }
443
444 fn compute_robust_responsibilities(
446 &self,
447 X: &Array2<f64>,
448 weights: &Array1<f64>,
449 means: &Array2<f64>,
450 covariances: &[Array2<f64>],
451 outlier_mask: &mut Array1<bool>,
452 ) -> SklResult<Array2<f64>> {
453 let (n_samples, _) = X.dim();
454 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
455 let mut sample_likelihoods = Array1::zeros(n_samples);
456
457 for i in 0..n_samples {
459 let sample = X.row(i);
460 let mut log_prob_sum = f64::NEG_INFINITY;
461 let mut log_probs = Vec::new();
462
463 for k in 0..self.n_components {
464 let mean = means.row(k);
465 let cov = &covariances[k];
466 let log_weight = weights[k].ln();
467 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
468 let log_prob = log_weight + log_likelihood;
469 log_probs.push(log_prob);
470 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
471 }
472
473 sample_likelihoods[i] = log_prob_sum;
474
475 for k in 0..self.n_components {
477 responsibilities[[i, k]] = (log_probs[k] - log_prob_sum).exp();
478 }
479 }
480
481 self.detect_outliers(&sample_likelihoods, outlier_mask)?;
483
484 for i in 0..n_samples {
486 if outlier_mask[i] {
487 for k in 0..self.n_components {
489 responsibilities[[i, k]] *= 0.1; }
491 }
492 }
493
494 Ok(responsibilities)
495 }
496
497 fn detect_outliers(
499 &self,
500 sample_likelihoods: &Array1<f64>,
501 outlier_mask: &mut Array1<bool>,
502 ) -> SklResult<()> {
503 let n_samples = sample_likelihoods.len();
504
505 let mut sorted_likelihoods = sample_likelihoods.to_vec();
507 sorted_likelihoods.sort_by(|a, b| a.partial_cmp(b).unwrap());
508
509 let threshold_idx = ((1.0 - self.outlier_fraction) * n_samples as f64) as usize;
511 let threshold_idx = threshold_idx.min(n_samples - 1);
512 let likelihood_threshold = sorted_likelihoods[threshold_idx];
513
514 for i in 0..n_samples {
516 outlier_mask[i] = sample_likelihoods[i] < likelihood_threshold;
517 }
518
519 Ok(())
520 }
521
522 fn update_robust_parameters(
524 &self,
525 X: &Array2<f64>,
526 responsibilities: &Array2<f64>,
527 outlier_mask: &Array1<bool>,
528 ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
529 let (n_samples, n_features) = X.dim();
530
531 let mut effective_responsibilities = responsibilities.clone();
533 for i in 0..n_samples {
534 if outlier_mask[i] {
535 for k in 0..self.n_components {
536 effective_responsibilities[[i, k]] *= 0.1;
537 }
538 }
539 }
540
541 let n_k: Array1<f64> = effective_responsibilities.sum_axis(Axis(0));
543 let total_weight = n_k.sum();
544 let weights = &n_k / total_weight;
545
546 let mut means = Array2::zeros((self.n_components, n_features));
548 for k in 0..self.n_components {
549 if n_k[k] > 1e-10 {
550 for i in 0..n_samples {
551 for j in 0..n_features {
552 means[[k, j]] += effective_responsibilities[[i, k]] * X[[i, j]];
553 }
554 }
555 for j in 0..n_features {
556 means[[k, j]] /= n_k[k];
557 }
558 }
559 }
560
561 let covariances =
563 self.update_robust_covariances(X, &effective_responsibilities, &means, &n_k)?;
564
565 Ok((weights, means, covariances))
566 }
567
568 fn update_robust_covariances(
570 &self,
571 X: &Array2<f64>,
572 responsibilities: &Array2<f64>,
573 means: &Array2<f64>,
574 n_k: &Array1<f64>,
575 ) -> SklResult<Vec<Array2<f64>>> {
576 let (n_samples, n_features) = X.dim();
577 let mut covariances = Vec::new();
578
579 match self.covariance_type {
580 CovarianceType::Full => {
581 for k in 0..self.n_components {
582 let mut cov = Array2::zeros((n_features, n_features));
583
584 if n_k[k] > 1e-10 {
585 let mean_k = means.row(k);
586
587 for i in 0..n_samples {
588 let sample = X.row(i);
589 let diff = &sample - &mean_k;
590
591 for d1 in 0..n_features {
592 for d2 in 0..n_features {
593 cov[[d1, d2]] += responsibilities[[i, k]] * diff[d1] * diff[d2];
594 }
595 }
596 }
597
598 for d1 in 0..n_features {
599 for d2 in 0..n_features {
600 cov[[d1, d2]] /= n_k[k];
601 }
602 }
603
604 let robust_reg = if self.robust_covariance {
606 self.reg_covar * 10.0 } else {
608 self.reg_covar
609 };
610
611 for d in 0..n_features {
612 cov[[d, d]] += robust_reg;
613 }
614 } else {
615 for d in 0..n_features {
617 cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
618 }
619 }
620
621 covariances.push(cov);
622 }
623 }
624 CovarianceType::Diagonal => {
625 for k in 0..self.n_components {
626 let mut cov = Array2::zeros((n_features, n_features));
627
628 if n_k[k] > 1e-10 {
629 let mean_k = means.row(k);
630
631 for d in 0..n_features {
632 let mut var = 0.0;
633 for i in 0..n_samples {
634 let diff = X[[i, d]] - mean_k[d];
635 var += responsibilities[[i, k]] * diff * diff;
636 }
637 var /= n_k[k];
638
639 let robust_reg = if self.robust_covariance {
640 self.reg_covar * 10.0
641 } else {
642 self.reg_covar
643 };
644
645 cov[[d, d]] = var + robust_reg;
646 }
647 } else {
648 for d in 0..n_features {
649 cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
650 }
651 }
652
653 covariances.push(cov);
654 }
655 }
656 CovarianceType::Tied => {
657 let mut cov = Array2::zeros((n_features, n_features));
658 let total_responsibility: f64 = n_k.sum();
659
660 if total_responsibility > 1e-10 {
661 for k in 0..self.n_components {
662 let mean_k = means.row(k);
663
664 for i in 0..n_samples {
665 let sample = X.row(i);
666 let diff = &sample - &mean_k;
667
668 for d1 in 0..n_features {
669 for d2 in 0..n_features {
670 cov[[d1, d2]] += responsibilities[[i, k]] * diff[d1] * diff[d2];
671 }
672 }
673 }
674 }
675
676 for d1 in 0..n_features {
677 for d2 in 0..n_features {
678 cov[[d1, d2]] /= total_responsibility;
679 }
680 }
681
682 let robust_reg = if self.robust_covariance {
683 self.reg_covar * 10.0
684 } else {
685 self.reg_covar
686 };
687
688 for d in 0..n_features {
689 cov[[d, d]] += robust_reg;
690 }
691 } else {
692 for d in 0..n_features {
693 cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
694 }
695 }
696
697 for _ in 0..self.n_components {
698 covariances.push(cov.clone());
699 }
700 }
701 CovarianceType::Spherical => {
702 for k in 0..self.n_components {
703 let mut cov = Array2::zeros((n_features, n_features));
704
705 if n_k[k] > 1e-10 {
706 let mean_k = means.row(k);
707 let mut total_var = 0.0;
708
709 for i in 0..n_samples {
710 for d in 0..n_features {
711 let diff = X[[i, d]] - mean_k[d];
712 total_var += responsibilities[[i, k]] * diff * diff;
713 }
714 }
715
716 total_var /= n_k[k] * n_features as f64;
717
718 let robust_reg = if self.robust_covariance {
719 self.reg_covar * 10.0
720 } else {
721 self.reg_covar
722 };
723
724 let variance = total_var + robust_reg;
725
726 for d in 0..n_features {
727 cov[[d, d]] = variance;
728 }
729 } else {
730 for d in 0..n_features {
731 cov[[d, d]] = 1.0 + self.reg_covar * 10.0;
732 }
733 }
734
735 covariances.push(cov);
736 }
737 }
738 }
739
740 Ok(covariances)
741 }
742
743 fn compute_trimmed_log_likelihood(
745 &self,
746 X: &Array2<f64>,
747 weights: &Array1<f64>,
748 means: &Array2<f64>,
749 covariances: &[Array2<f64>],
750 outlier_mask: &Array1<bool>,
751 ) -> SklResult<f64> {
752 let (n_samples, _) = X.dim();
753 let mut total_log_likelihood = 0.0;
754 let mut n_included = 0;
755
756 for i in 0..n_samples {
757 if !outlier_mask[i] {
758 let sample = X.row(i);
760 let mut log_prob_sum = f64::NEG_INFINITY;
761
762 for k in 0..self.n_components {
763 let mean = means.row(k);
764 let cov = &covariances[k];
765 let log_weight = weights[k].ln();
766 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
767 let log_prob = log_weight + log_likelihood;
768 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
769 }
770
771 total_log_likelihood += log_prob_sum;
772 n_included += 1;
773 }
774 }
775
776 if n_included > 0 {
778 Ok(total_log_likelihood / n_included as f64)
779 } else {
780 Ok(f64::NEG_INFINITY)
781 }
782 }
783
784 fn multivariate_normal_log_pdf(
786 &self,
787 x: &ArrayView1<f64>,
788 mean: &ArrayView1<f64>,
789 cov: &Array2<f64>,
790 ) -> SklResult<f64> {
791 let d = x.len() as f64;
792 let diff: Array1<f64> = x - mean;
793
794 match self.covariance_type {
795 CovarianceType::Full => {
796 let mut log_det = 0.0;
797 let mut quad_form = 0.0;
798
799 for i in 0..cov.nrows() {
800 if cov[[i, i]] <= 0.0 {
801 return Err(SklearsError::InvalidInput(
802 "Covariance matrix has non-positive diagonal elements".to_string(),
803 ));
804 }
805 log_det += cov[[i, i]].ln();
806 quad_form += diff[i] * diff[i] / cov[[i, i]];
807 }
808
809 let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
810 Ok(log_pdf)
811 }
812 CovarianceType::Diagonal | CovarianceType::Tied | CovarianceType::Spherical => {
813 let mut log_det = 0.0;
814 let mut quad_form = 0.0;
815
816 for i in 0..diff.len() {
817 if cov[[i, i]] <= 0.0 {
818 return Err(SklearsError::InvalidInput(
819 "Covariance matrix has non-positive diagonal elements".to_string(),
820 ));
821 }
822 log_det += cov[[i, i]].ln();
823 quad_form += diff[i] * diff[i] / cov[[i, i]];
824 }
825
826 let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
827 Ok(log_pdf)
828 }
829 }
830 }
831}
832
833#[derive(Debug, Clone)]
835pub struct RobustGaussianMixtureTrained {
836 pub weights: Array1<f64>,
838 pub means: Array2<f64>,
840 pub covariances: Vec<Array2<f64>>,
842 pub log_likelihood: f64,
844 pub n_iter: usize,
846 pub converged: bool,
848 pub bic: f64,
850 pub aic: f64,
852 pub outlier_mask: Array1<bool>,
854 pub n_outliers: usize,
856}
857
858impl Predict<ArrayView2<'_, Float>, Array1<i32>>
859 for RobustGaussianMixture<RobustGaussianMixtureTrained>
860{
861 #[allow(non_snake_case)]
862 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
863 let X = X.to_owned();
864 let (n_samples, _) = X.dim();
865 let mut predictions = Array1::zeros(n_samples);
866
867 for i in 0..n_samples {
868 let sample = X.row(i);
869 let mut max_log_prob = f64::NEG_INFINITY;
870 let mut best_component = 0;
871
872 for k in 0..self.n_components {
873 let mean = self.state.means.row(k);
874 let cov = &self.state.covariances[k];
875 let log_weight = self.state.weights[k].ln();
876
877 if let Ok(log_likelihood) = self.multivariate_normal_log_pdf(&sample, &mean, cov) {
878 let log_prob = log_weight + log_likelihood;
879 if log_prob > max_log_prob {
880 max_log_prob = log_prob;
881 best_component = k;
882 }
883 }
884 }
885
886 predictions[i] = best_component as i32;
887 }
888
889 Ok(predictions)
890 }
891}
892
893impl RobustGaussianMixture<RobustGaussianMixtureTrained> {
894 pub fn weights(&self) -> &Array1<f64> {
896 &self.state.weights
897 }
898
899 pub fn means(&self) -> &Array2<f64> {
901 &self.state.means
902 }
903
904 pub fn covariances(&self) -> &[Array2<f64>] {
906 &self.state.covariances
907 }
908
909 pub fn log_likelihood(&self) -> f64 {
911 self.state.log_likelihood
912 }
913
914 pub fn n_iter(&self) -> usize {
916 self.state.n_iter
917 }
918
919 pub fn converged(&self) -> bool {
921 self.state.converged
922 }
923
924 pub fn bic(&self) -> f64 {
926 self.state.bic
927 }
928
929 pub fn aic(&self) -> f64 {
931 self.state.aic
932 }
933
934 pub fn outlier_mask(&self) -> &Array1<bool> {
936 &self.state.outlier_mask
937 }
938
939 pub fn n_outliers(&self) -> usize {
941 self.state.n_outliers
942 }
943
944 #[allow(non_snake_case)]
946 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
947 let X = X.to_owned();
948 let (n_samples, _) = X.dim();
949 let mut probabilities = Array2::zeros((n_samples, self.n_components));
950
951 for i in 0..n_samples {
952 let sample = X.row(i);
953 let mut log_prob_sum = f64::NEG_INFINITY;
954 let mut log_probs = Vec::new();
955
956 for k in 0..self.n_components {
958 let mean = self.state.means.row(k);
959 let cov = &self.state.covariances[k];
960 let log_weight = self.state.weights[k].ln();
961 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
962 let log_prob = log_weight + log_likelihood;
963 log_probs.push(log_prob);
964 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
965 }
966
967 for k in 0..self.n_components {
969 probabilities[[i, k]] = (log_probs[k] - log_prob_sum).exp();
970 }
971 }
972
973 Ok(probabilities)
974 }
975
976 #[allow(non_snake_case)]
978 pub fn score_samples(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
979 let X = X.to_owned();
980 let (n_samples, _) = X.dim();
981 let mut scores = Array1::zeros(n_samples);
982
983 for i in 0..n_samples {
984 let sample = X.row(i);
985 let mut log_prob_sum = f64::NEG_INFINITY;
986
987 for k in 0..self.n_components {
988 let mean = self.state.means.row(k);
989 let cov = &self.state.covariances[k];
990 let log_weight = self.state.weights[k].ln();
991 let log_likelihood = self.multivariate_normal_log_pdf(&sample, &mean, cov)?;
992 let log_prob = log_weight + log_likelihood;
993 log_prob_sum = log_sum_exp(log_prob_sum, log_prob);
994 }
995
996 scores[i] = log_prob_sum;
997 }
998
999 Ok(scores)
1000 }
1001
1002 pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
1004 let scores = self.score_samples(X)?;
1005 Ok(scores.mean().unwrap_or(0.0))
1006 }
1007
1008 pub fn detect_outliers(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<bool>> {
1010 let scores = self.score_samples(X)?;
1011 let n_samples = scores.len();
1012 let mut outlier_mask = Array1::from_elem(n_samples, false);
1013
1014 let mut sorted_scores = scores.to_vec();
1016 sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
1017
1018 let threshold_idx = ((1.0 - self.outlier_fraction) * n_samples as f64) as usize;
1019 let threshold_idx = threshold_idx.min(n_samples - 1);
1020 let score_threshold = sorted_scores[threshold_idx];
1021
1022 for i in 0..n_samples {
1023 outlier_mask[i] = scores[i] < score_threshold;
1024 }
1025
1026 Ok(outlier_mask)
1027 }
1028
1029 fn multivariate_normal_log_pdf(
1030 &self,
1031 x: &ArrayView1<f64>,
1032 mean: &ArrayView1<f64>,
1033 cov: &Array2<f64>,
1034 ) -> SklResult<f64> {
1035 let d = x.len() as f64;
1036 let diff: Array1<f64> = x - mean;
1037
1038 match self.covariance_type {
1039 CovarianceType::Full => {
1040 let mut log_det = 0.0;
1041 let mut quad_form = 0.0;
1042
1043 for i in 0..cov.nrows() {
1044 if cov[[i, i]] <= 0.0 {
1045 return Err(SklearsError::InvalidInput(
1046 "Covariance matrix has non-positive diagonal elements".to_string(),
1047 ));
1048 }
1049 log_det += cov[[i, i]].ln();
1050 quad_form += diff[i] * diff[i] / cov[[i, i]];
1051 }
1052
1053 let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
1054 Ok(log_pdf)
1055 }
1056 CovarianceType::Diagonal | CovarianceType::Tied | CovarianceType::Spherical => {
1057 let mut log_det = 0.0;
1058 let mut quad_form = 0.0;
1059
1060 for i in 0..diff.len() {
1061 if cov[[i, i]] <= 0.0 {
1062 return Err(SklearsError::InvalidInput(
1063 "Covariance matrix has non-positive diagonal elements".to_string(),
1064 ));
1065 }
1066 log_det += cov[[i, i]].ln();
1067 quad_form += diff[i] * diff[i] / cov[[i, i]];
1068 }
1069
1070 let log_pdf = -0.5 * (d * (2.0 * PI).ln() + log_det + quad_form);
1071 Ok(log_pdf)
1072 }
1073 }
1074 }
1075}