1mod kde;
16mod variational;
17
18pub use kde::*;
19pub use variational::*;
20
21use crate::error::{StatsError, StatsResult};
22use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
23use scirs2_core::numeric::{Float, FromPrimitive, One, Zero};
24use scirs2_core::random::Rng;
25use scirs2_core::{simd_ops::SimdUnifiedOps, validation::*};
26use std::marker::PhantomData;
27
28pub struct GaussianMixtureModel<F> {
34 pub n_components: usize,
36 pub config: GMMConfig,
38 pub parameters: Option<GMMParameters<F>>,
40 pub convergence_history: Vec<F>,
42 _phantom: PhantomData<F>,
43}
44
45#[derive(Debug, Clone)]
47pub struct GMMConfig {
48 pub max_iter: usize,
50 pub tolerance: f64,
52 pub param_tolerance: f64,
54 pub covariance_type: CovarianceType,
56 pub reg_covar: f64,
58 pub init_method: InitializationMethod,
60 pub n_init: usize,
62 pub seed: Option<u64>,
64 pub parallel: bool,
66 pub use_simd: bool,
68 pub warm_start: bool,
70 pub robust_em: bool,
72 pub outlier_threshold: f64,
74 pub early_stopping: bool,
76 pub validation_fraction: f64,
78 pub patience: usize,
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub enum CovarianceType {
85 Full,
87 Diagonal,
89 Tied,
91 Spherical,
93 Factor {
95 n_factors: usize,
97 },
98 Constrained {
100 constraint: CovarianceConstraint,
102 },
103}
104
105#[derive(Debug, Clone, PartialEq)]
107pub enum CovarianceConstraint {
108 MinEigenvalue(f64),
110 MaxCondition(f64),
112 Sparse(Vec<(usize, usize)>),
114}
115
116#[derive(Debug, Clone, PartialEq)]
118pub enum InitializationMethod {
119 Random,
121 KMeansPlus,
123 KMeans {
125 n_runs: usize,
127 },
128 FurthestFirst,
130 Custom,
132 Quantile,
134 PCA,
136 Spectral,
138}
139
140#[derive(Debug, Clone)]
142pub struct GMMParameters<F> {
143 pub weights: Array1<F>,
145 pub means: Array2<F>,
147 pub covariances: Vec<Array2<F>>,
149 pub log_likelihood: F,
151 pub n_iter: usize,
153 pub converged: bool,
155 pub convergence_reason: ConvergenceReason,
157 pub model_selection: ModelSelectionCriteria<F>,
159 pub component_diagnostics: Vec<ComponentDiagnostics<F>>,
161 pub outlier_scores: Option<Array1<F>>,
163 pub responsibilities: Option<Array2<F>>,
165 pub parameter_history: Vec<ParameterSnapshot<F>>,
167}
168
169#[derive(Debug, Clone, PartialEq)]
171pub enum ConvergenceReason {
172 LogLikelihoodTolerance,
174 ParameterTolerance,
176 MaxIterations,
178 EarlyStopping,
180 NumericalInstability,
182}
183
184#[derive(Debug, Clone)]
186pub struct ModelSelectionCriteria<F> {
187 pub aic: F,
189 pub bic: F,
191 pub icl: F,
193 pub hqic: F,
195 pub cv_log_likelihood: Option<F>,
197 pub n_parameters: usize,
199}
200
201#[derive(Debug, Clone)]
203pub struct ComponentDiagnostics<F> {
204 pub effective_samplesize: F,
206 pub condition_number: F,
208 pub covariance_determinant: F,
210 pub component_separation: F,
212 pub weight_stability: F,
214}
215
216#[derive(Debug, Clone)]
218pub struct ParameterSnapshot<F> {
219 pub iteration: usize,
221 pub log_likelihood: F,
223 pub parameter_change: F,
225 pub weights: Array1<F>,
227}
228
229impl Default for GMMConfig {
230 fn default() -> Self {
231 Self {
232 max_iter: 100,
233 tolerance: 1e-3,
234 param_tolerance: 1e-4,
235 covariance_type: CovarianceType::Full,
236 reg_covar: 1e-6,
237 init_method: InitializationMethod::KMeansPlus,
238 n_init: 1,
239 seed: None,
240 parallel: true,
241 use_simd: true,
242 warm_start: false,
243 robust_em: false,
244 outlier_threshold: 0.01,
245 early_stopping: false,
246 validation_fraction: 0.1,
247 patience: 10,
248 }
249 }
250}
251
252pub trait GmmFloat:
258 Float
259 + Zero
260 + One
261 + Copy
262 + Send
263 + Sync
264 + SimdUnifiedOps
265 + FromPrimitive
266 + std::fmt::Display
267 + std::iter::Sum
268 + scirs2_core::ndarray::ScalarOperand
269{
270}
271
272impl<F> GmmFloat for F where
273 F: Float
274 + Zero
275 + One
276 + Copy
277 + Send
278 + Sync
279 + SimdUnifiedOps
280 + FromPrimitive
281 + std::fmt::Display
282 + std::iter::Sum
283 + scirs2_core::ndarray::ScalarOperand
284{
285}
286
287fn f64_to_f<F: Float + FromPrimitive>(v: f64, ctx: &str) -> StatsResult<F> {
292 F::from(v).ok_or_else(|| {
293 StatsError::ComputationError(format!("Failed to convert f64 ({v}) to float ({ctx})"))
294 })
295}
296
297impl<F: GmmFloat> GaussianMixtureModel<F> {
302 pub fn new(n_components: usize, config: GMMConfig) -> StatsResult<Self> {
304 check_positive(n_components, "n_components")?;
305
306 Ok(Self {
307 n_components,
308 config,
309 parameters: None,
310 convergence_history: Vec::new(),
311 _phantom: PhantomData,
312 })
313 }
314
315 pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&GMMParameters<F>> {
321 checkarray_finite(data, "data")?;
322
323 let (n_samples, n_features) = data.dim();
324
325 if n_samples < self.n_components {
326 return Err(StatsError::InvalidArgument(format!(
327 "Number of samples ({n_samples}) must be >= number of components ({})",
328 self.n_components
329 )));
330 }
331
332 let inv_k: F = f64_to_f(1.0 / self.n_components as f64, "inv_k")?;
333 let mut weights = Array1::from_elem(self.n_components, inv_k);
334 let mut means = self.initialize_means(data)?;
335 let mut covariances = self.initialize_covariances(data, &means)?;
336
337 let mut log_likelihood = F::neg_infinity();
338 let mut converged = false;
339 self.convergence_history.clear();
340
341 let n_iter_used;
342
343 for iter_idx in 0..self.config.max_iter {
344 let responsibilities = self.e_step(data, &weights, &means, &covariances)?;
345 let new_weights = self.m_step_weights(&responsibilities)?;
346 let new_means = self.m_step_means(data, &responsibilities)?;
347 let new_covariances = self.m_step_covariances(data, &responsibilities, &new_means)?;
348
349 let new_ll =
350 self.compute_log_likelihood(data, &new_weights, &new_means, &new_covariances)?;
351
352 self.convergence_history.push(new_ll);
353
354 let improvement = new_ll - log_likelihood;
355 let tol: F = f64_to_f(self.config.tolerance, "tolerance")?;
356 if improvement.abs() < tol && iter_idx > 0 {
357 converged = true;
358 }
359
360 weights = new_weights;
361 means = new_means;
362 covariances = new_covariances;
363 log_likelihood = new_ll;
364
365 if converged {
366 n_iter_used = iter_idx + 1;
367 self.store_parameters(
368 weights,
369 means,
370 covariances,
371 log_likelihood,
372 n_iter_used,
373 converged,
374 n_samples,
375 n_features,
376 data,
377 )?;
378 return self
379 .parameters
380 .as_ref()
381 .ok_or_else(|| StatsError::ComputationError("Parameters not stored".into()));
382 }
383 }
384
385 n_iter_used = self.config.max_iter;
386 self.store_parameters(
387 weights,
388 means,
389 covariances,
390 log_likelihood,
391 n_iter_used,
392 false,
393 n_samples,
394 n_features,
395 data,
396 )?;
397
398 self.parameters
399 .as_ref()
400 .ok_or_else(|| StatsError::ComputationError("Parameters not stored".into()))
401 }
402
403 pub fn predict(&self, data: &ArrayView2<F>) -> StatsResult<Array1<usize>> {
405 let params = self.require_fitted()?;
406 let responsibilities =
407 self.e_step(data, ¶ms.weights, ¶ms.means, ¶ms.covariances)?;
408
409 let mut predictions = Array1::zeros(data.nrows());
410 for i in 0..data.nrows() {
411 let mut max_resp = F::neg_infinity();
412 let mut best = 0usize;
413 for k in 0..self.n_components {
414 if responsibilities[[i, k]] > max_resp {
415 max_resp = responsibilities[[i, k]];
416 best = k;
417 }
418 }
419 predictions[i] = best;
420 }
421 Ok(predictions)
422 }
423
424 pub fn predict_proba(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
426 let params = self.require_fitted()?;
427 self.e_step(data, ¶ms.weights, ¶ms.means, ¶ms.covariances)
428 }
429
430 pub fn score(&self, data: &ArrayView2<F>) -> StatsResult<F> {
432 let params = self.require_fitted()?;
433 let total_ll =
434 self.compute_log_likelihood(data, ¶ms.weights, ¶ms.means, ¶ms.covariances)?;
435 let n: F = f64_to_f(data.nrows() as f64, "n_samples")?;
436 Ok(total_ll / n)
437 }
438
439 pub fn score_samples(&self, data: &ArrayView2<F>) -> StatsResult<Array1<F>> {
441 let params = self.require_fitted()?;
442 self.per_sample_log_likelihood(data, ¶ms.weights, ¶ms.means, ¶ms.covariances)
443 }
444
445 pub fn sample(&self, n: usize, seed: Option<u64>) -> StatsResult<Array2<F>> {
447 let params = self.require_fitted()?;
448 let n_features = params.means.ncols();
449
450 use scirs2_core::random::Random;
451 let mut init_rng = scirs2_core::random::thread_rng();
452 let mut rng = match seed {
453 Some(s) => Random::seed(s),
454 None => Random::seed(init_rng.random()),
455 };
456
457 let mut samples = Array2::zeros((n, n_features));
458
459 for i in 0..n {
460 let u: f64 = rng.random_f64();
462 let mut cumsum = 0.0;
463 let mut chosen_k = self.n_components - 1;
464 for k in 0..self.n_components {
465 let wk = params.weights[k].to_f64().ok_or_else(|| {
466 StatsError::ComputationError("Weight conversion failed".into())
467 })?;
468 cumsum += wk;
469 if u < cumsum {
470 chosen_k = k;
471 break;
472 }
473 }
474
475 let mean = params.means.row(chosen_k);
477 let cov = ¶ms.covariances[chosen_k];
478
479 let mut z = Array1::<f64>::zeros(n_features);
481 for j in (0..n_features).step_by(2) {
482 let u1: f64 = rng.random_f64().max(1e-300);
483 let u2: f64 = rng.random_f64();
484 let r = (-2.0 * u1.ln()).sqrt();
485 let theta = 2.0 * std::f64::consts::PI * u2;
486 z[j] = r * theta.cos();
487 if j + 1 < n_features {
488 z[j + 1] = r * theta.sin();
489 }
490 }
491
492 let cov_f64 = cov.mapv(|x| x.to_f64().unwrap_or(0.0));
493 let chol = cholesky_lower(&cov_f64)?;
494 let sampled = chol.dot(&z);
495 for j in 0..n_features {
496 let val: F = f64_to_f(sampled[j], "sample_val")?;
497 samples[[i, j]] = mean[j] + val;
498 }
499 }
500
501 Ok(samples)
502 }
503
504 pub fn bic(&self, _data: &ArrayView2<F>) -> StatsResult<F> {
506 let params = self.require_fitted()?;
507 Ok(params.model_selection.bic)
508 }
509
510 pub fn aic(&self, _data: &ArrayView2<F>) -> StatsResult<F> {
512 let params = self.require_fitted()?;
513 Ok(params.model_selection.aic)
514 }
515
516 pub fn n_parameters(&self) -> StatsResult<usize> {
518 let params = self.require_fitted()?;
519 Ok(params.model_selection.n_parameters)
520 }
521
522 fn require_fitted(&self) -> StatsResult<&GMMParameters<F>> {
527 self.parameters
528 .as_ref()
529 .ok_or_else(|| StatsError::InvalidArgument("Model must be fitted before use".into()))
530 }
531
532 #[allow(clippy::too_many_arguments)]
533 fn store_parameters(
534 &mut self,
535 weights: Array1<F>,
536 means: Array2<F>,
537 covariances: Vec<Array2<F>>,
538 log_likelihood: F,
539 n_iter: usize,
540 converged: bool,
541 n_samples: usize,
542 n_features: usize,
543 data: &ArrayView2<F>,
544 ) -> StatsResult<()> {
545 let n_params = self.compute_n_parameters(n_features);
546 let n_f: F = f64_to_f(n_samples as f64, "n_samples")?;
547 let p_f: F = f64_to_f(n_params as f64, "n_params")?;
548 let two: F = f64_to_f(2.0, "two")?;
549
550 let aic = -two * log_likelihood + two * p_f;
551 let bic = -two * log_likelihood + p_f * n_f.ln();
552 let hqic = -two * log_likelihood + two * p_f * n_f.ln().ln();
553
554 let responsibilities = self.e_step(data, &weights, &means, &covariances)?;
555 let entropy = self.responsibility_entropy(&responsibilities);
556 let icl = bic - two * entropy;
557
558 let mut diagnostics = Vec::with_capacity(self.n_components);
559 for k in 0..self.n_components {
560 let nk = responsibilities.column(k).sum();
561 let cov_f64 = covariances[k].mapv(|x| x.to_f64().unwrap_or(0.0));
562 let det = scirs2_linalg::det(&cov_f64.view(), None).unwrap_or(1.0);
563 let cond = self.estimate_condition_number(&cov_f64);
564 let sep = self.compute_component_separation(k, &means, &covariances);
565
566 diagnostics.push(ComponentDiagnostics {
567 effective_samplesize: nk,
568 condition_number: f64_to_f(cond, "cond").unwrap_or(F::one()),
569 covariance_determinant: f64_to_f(det.abs(), "det").unwrap_or(F::one()),
570 component_separation: sep,
571 weight_stability: F::zero(),
572 });
573 }
574
575 let parameters = GMMParameters {
576 weights,
577 means,
578 covariances,
579 log_likelihood,
580 n_iter,
581 converged,
582 convergence_reason: if converged {
583 ConvergenceReason::LogLikelihoodTolerance
584 } else {
585 ConvergenceReason::MaxIterations
586 },
587 model_selection: ModelSelectionCriteria {
588 aic,
589 bic,
590 icl,
591 hqic,
592 cv_log_likelihood: None,
593 n_parameters: n_params,
594 },
595 component_diagnostics: diagnostics,
596 outlier_scores: None,
597 responsibilities: Some(responsibilities),
598 parameter_history: Vec::new(),
599 };
600
601 self.parameters = Some(parameters);
602 Ok(())
603 }
604
605 fn compute_n_parameters(&self, d: usize) -> usize {
606 let k = self.n_components;
607 let weight_params = k - 1;
608 let mean_params = k * d;
609 let cov_params = match &self.config.covariance_type {
610 CovarianceType::Full => k * d * (d + 1) / 2,
611 CovarianceType::Diagonal => k * d,
612 CovarianceType::Tied => d * (d + 1) / 2,
613 CovarianceType::Spherical => k,
614 CovarianceType::Factor { n_factors } => k * (d * n_factors + d),
615 CovarianceType::Constrained { .. } => k * d * (d + 1) / 2,
616 };
617 weight_params + mean_params + cov_params
618 }
619
620 fn responsibility_entropy(&self, resp: &Array2<F>) -> F {
621 let mut entropy = F::zero();
622 let eps: F = f64_to_f(1e-300, "eps").unwrap_or(F::min_positive_value());
623 for row in resp.rows() {
624 for &r in row.iter() {
625 if r > eps {
626 entropy = entropy + r * r.ln();
627 }
628 }
629 }
630 entropy
631 }
632
633 fn estimate_condition_number(&self, cov: &Array2<f64>) -> f64 {
634 let diag: Vec<f64> = (0..cov.nrows()).map(|i| cov[[i, i]].abs()).collect();
635 let max_d = diag.iter().copied().fold(f64::NEG_INFINITY, f64::max);
636 let min_d = diag
637 .iter()
638 .copied()
639 .filter(|&v| v > 1e-300)
640 .fold(f64::INFINITY, f64::min);
641 if min_d > 0.0 {
642 max_d / min_d
643 } else {
644 f64::INFINITY
645 }
646 }
647
648 fn compute_component_separation(&self, k: usize, means: &Array2<F>, _covs: &[Array2<F>]) -> F {
649 let mut min_dist = F::infinity();
650 let mean_k = means.row(k);
651 for j in 0..self.n_components {
652 if j == k {
653 continue;
654 }
655 let mean_j = means.row(j);
656 let d: F = mean_k
657 .iter()
658 .zip(mean_j.iter())
659 .map(|(&a, &b)| (a - b) * (a - b))
660 .sum();
661 let d_sqrt = d.sqrt();
662 if d_sqrt < min_dist {
663 min_dist = d_sqrt;
664 }
665 }
666 min_dist
667 }
668
669 fn initialize_means(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
670 let (n_samples, n_features) = data.dim();
671 let mut means = Array2::zeros((self.n_components, n_features));
672
673 match self.config.init_method {
674 InitializationMethod::Random => {
675 use scirs2_core::random::Random;
676 let mut init_rng = scirs2_core::random::thread_rng();
677 let mut rng = match self.config.seed {
678 Some(seed) => Random::seed(seed),
679 None => Random::seed(init_rng.random()),
680 };
681 for i in 0..self.n_components {
682 let idx = rng.random_range(0..n_samples);
683 means.row_mut(i).assign(&data.row(idx));
684 }
685 }
686 InitializationMethod::KMeansPlus => {
687 means = self.kmeans_plus_plus_init(data)?;
688 }
689 InitializationMethod::FurthestFirst => {
690 means = self.furthest_first_init(data)?;
691 }
692 InitializationMethod::Quantile => {
693 means = self.quantile_init(data)?;
694 }
695 _ => {
696 means = self.kmeans_plus_plus_init(data)?;
697 }
698 }
699
700 Ok(means)
701 }
702
703 fn kmeans_plus_plus_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
704 use scirs2_core::random::Random;
705 let mut init_rng = scirs2_core::random::thread_rng();
706 let mut rng = match self.config.seed {
707 Some(seed) => Random::seed(seed),
708 None => Random::seed(init_rng.random()),
709 };
710
711 let (n_samples, n_features) = data.dim();
712 let mut means = Array2::zeros((self.n_components, n_features));
713 let first_idx = rng.random_range(0..n_samples);
714 means.row_mut(0).assign(&data.row(first_idx));
715
716 for i in 1..self.n_components {
717 let mut distances = Array1::zeros(n_samples);
718 for j in 0..n_samples {
719 let mut min_dist = F::infinity();
720 for k_idx in 0..i {
721 let dist = self.squared_distance(&data.row(j), &means.row(k_idx));
722 min_dist = min_dist.min(dist);
723 }
724 distances[j] = min_dist;
725 }
726
727 let total_dist: F = distances.sum();
728 if total_dist <= F::zero() {
729 let idx = rng.random_range(0..n_samples);
730 means.row_mut(i).assign(&data.row(idx));
731 continue;
732 }
733
734 let threshold_f64: f64 = rng.random_f64();
735 let threshold_ratio: F = F::from(threshold_f64)
736 .ok_or_else(|| StatsError::ComputationError("threshold conversion".into()))?;
737 let threshold: F = threshold_ratio * total_dist;
738 let mut cumsum = F::zero();
739 let mut picked = false;
740 for j in 0..n_samples {
741 cumsum = cumsum + distances[j];
742 if cumsum >= threshold {
743 means.row_mut(i).assign(&data.row(j));
744 picked = true;
745 break;
746 }
747 }
748 if !picked {
749 means
750 .row_mut(i)
751 .assign(&data.row(n_samples.saturating_sub(1)));
752 }
753 }
754
755 Ok(means)
756 }
757
758 fn furthest_first_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
759 use scirs2_core::random::Random;
760 let mut init_rng = scirs2_core::random::thread_rng();
761 let mut rng = match self.config.seed {
762 Some(s) => Random::seed(s),
763 None => Random::seed(init_rng.random()),
764 };
765
766 let (n_samples, n_features) = data.dim();
767 let mut means = Array2::zeros((self.n_components, n_features));
768 let first_idx = rng.random_range(0..n_samples);
769 means.row_mut(0).assign(&data.row(first_idx));
770
771 for i in 1..self.n_components {
772 let mut best_idx = 0;
773 let mut best_dist = F::neg_infinity();
774 for j in 0..n_samples {
775 let mut min_dist = F::infinity();
776 for k_idx in 0..i {
777 let d = self.squared_distance(&data.row(j), &means.row(k_idx));
778 min_dist = min_dist.min(d);
779 }
780 if min_dist > best_dist {
781 best_dist = min_dist;
782 best_idx = j;
783 }
784 }
785 means.row_mut(i).assign(&data.row(best_idx));
786 }
787
788 Ok(means)
789 }
790
791 fn quantile_init(&self, data: &ArrayView2<F>) -> StatsResult<Array2<F>> {
792 let (n_samples, n_features) = data.dim();
793 let mut means = Array2::zeros((self.n_components, n_features));
794 for i in 0..self.n_components {
795 let frac = (i as f64 + 0.5) / self.n_components as f64;
796 let idx = ((frac * n_samples as f64) as usize).min(n_samples.saturating_sub(1));
797 means.row_mut(i).assign(&data.row(idx));
798 }
799 Ok(means)
800 }
801
802 fn initialize_covariances(
803 &self,
804 data: &ArrayView2<F>,
805 _means: &Array2<F>,
806 ) -> StatsResult<Vec<Array2<F>>> {
807 let n_features = data.ncols();
808 let n_samples = data.nrows();
809 let mut covariances = Vec::with_capacity(self.n_components);
810 let n_f: F = f64_to_f(n_samples as f64, "n_samples_init")?;
811 let reg: F = f64_to_f(self.config.reg_covar, "reg_covar")?;
812
813 let mut data_var = Array1::zeros(n_features);
814 for j in 0..n_features {
815 let col_mean: F = data.column(j).sum() / n_f;
816 let var: F = data
817 .column(j)
818 .iter()
819 .map(|&x| (x - col_mean) * (x - col_mean))
820 .sum::<F>()
821 / n_f;
822 data_var[j] = if var > F::zero() { var } else { F::one() };
823 }
824
825 for _i in 0..self.n_components {
826 let mut cov = Array2::zeros((n_features, n_features));
827 for j in 0..n_features {
828 cov[[j, j]] = data_var[j] + reg;
829 }
830 covariances.push(cov);
831 }
832 Ok(covariances)
833 }
834
835 fn e_step(
836 &self,
837 data: &ArrayView2<F>,
838 weights: &Array1<F>,
839 means: &Array2<F>,
840 covariances: &[Array2<F>],
841 ) -> StatsResult<Array2<F>> {
842 let n_samples = data.shape()[0];
843 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
844
845 for i in 0..n_samples {
846 let sample = data.row(i);
847 let mut log_probs = Array1::zeros(self.n_components);
848
849 for k in 0..self.n_components {
850 let mean = means.row(k);
851 let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, &covariances[k])?;
852 log_probs[k] = weights[k].ln() + log_prob;
853 }
854
855 let max_lp = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
856 if max_lp == F::neg_infinity() {
857 let uni: F = f64_to_f(1.0 / self.n_components as f64, "uniform")?;
858 for k in 0..self.n_components {
859 responsibilities[[i, k]] = uni;
860 }
861 continue;
862 }
863 let log_sum_exp = (log_probs.mapv(|x| (x - max_lp).exp()).sum()).ln() + max_lp;
864
865 for k in 0..self.n_components {
866 responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
867 }
868 }
869 Ok(responsibilities)
870 }
871
872 fn m_step_weights(&self, responsibilities: &Array2<F>) -> StatsResult<Array1<F>> {
873 let n_f: F = f64_to_f(responsibilities.nrows() as f64, "n_samples_m")?;
874 let mut weights = Array1::zeros(self.n_components);
875 for k in 0..self.n_components {
876 weights[k] = responsibilities.column(k).sum() / n_f;
877 }
878 Ok(weights)
879 }
880
881 fn m_step_means(
882 &self,
883 data: &ArrayView2<F>,
884 responsibilities: &Array2<F>,
885 ) -> StatsResult<Array2<F>> {
886 let n_features = data.ncols();
887 let mut means = Array2::zeros((self.n_components, n_features));
888 let eps: F = f64_to_f(1e-10, "eps_m")?;
889
890 for k in 0..self.n_components {
891 let resp_sum = responsibilities.column(k).sum();
892 if resp_sum > eps {
893 for j in 0..n_features {
894 let weighted_sum: F = data
895 .column(j)
896 .iter()
897 .zip(responsibilities.column(k).iter())
898 .map(|(&x, &r)| x * r)
899 .sum();
900 means[[k, j]] = weighted_sum / resp_sum;
901 }
902 }
903 }
904 Ok(means)
905 }
906
907 fn m_step_covariances(
908 &self,
909 data: &ArrayView2<F>,
910 responsibilities: &Array2<F>,
911 means: &Array2<F>,
912 ) -> StatsResult<Vec<Array2<F>>> {
913 let n_features = data.ncols();
914 let mut covariances = Vec::with_capacity(self.n_components);
915 let eps: F = f64_to_f(1e-10, "eps_cov")?;
916 let reg: F = f64_to_f(self.config.reg_covar, "reg_covar")?;
917
918 for k in 0..self.n_components {
919 let resp_sum = responsibilities.column(k).sum();
920 let mean_k = means.row(k);
921 let mut cov = Array2::zeros((n_features, n_features));
922
923 if resp_sum > eps {
924 for i in 0..data.nrows() {
925 let diff = &data.row(i) - &mean_k;
926 let resp = responsibilities[[i, k]];
927 for j in 0..n_features {
928 for l in 0..n_features {
929 cov[[j, l]] = cov[[j, l]] + resp * diff[j] * diff[l];
930 }
931 }
932 }
933 cov = cov / resp_sum;
934 }
935
936 for i in 0..n_features {
937 cov[[i, i]] = cov[[i, i]] + reg;
938 }
939
940 match self.config.covariance_type {
941 CovarianceType::Diagonal => {
942 for i in 0..n_features {
943 for j in 0..n_features {
944 if i != j {
945 cov[[i, j]] = F::zero();
946 }
947 }
948 }
949 }
950 CovarianceType::Spherical => {
951 let n_feat_f: F = f64_to_f(n_features as f64, "n_feat")?;
952 let trace = cov.diag().sum() / n_feat_f;
953 cov = Array2::eye(n_features) * trace;
954 }
955 _ => {}
956 }
957
958 covariances.push(cov);
959 }
960 Ok(covariances)
961 }
962
963 fn log_multivariate_normal_pdf(
964 &self,
965 x: &ArrayView1<F>,
966 mean: &ArrayView1<F>,
967 cov: &Array2<F>,
968 ) -> StatsResult<F> {
969 let d = x.len();
970 let diff = x - mean;
971
972 let cov_f64 = cov.mapv(|v| v.to_f64().unwrap_or(0.0));
973 let det = scirs2_linalg::det(&cov_f64.view(), None).map_err(|e| {
974 StatsError::ComputationError(format!("Determinant computation failed: {e}"))
975 })?;
976
977 if det <= 0.0 {
978 return Ok(F::neg_infinity());
979 }
980
981 let log_det = det.ln();
982 let cov_inv = scirs2_linalg::inv(&cov_f64.view(), None)
983 .map_err(|e| StatsError::ComputationError(format!("Matrix inversion failed: {e}")))?;
984
985 let diff_f64 = diff.mapv(|v| v.to_f64().unwrap_or(0.0));
986 let quad_form = diff_f64.dot(&cov_inv.dot(&diff_f64));
987
988 let log_pdf = -0.5 * (d as f64 * (2.0 * std::f64::consts::PI).ln() + log_det + quad_form);
989 f64_to_f(log_pdf, "log_pdf")
990 }
991
992 fn compute_log_likelihood(
993 &self,
994 data: &ArrayView2<F>,
995 weights: &Array1<F>,
996 means: &Array2<F>,
997 covariances: &[Array2<F>],
998 ) -> StatsResult<F> {
999 let per_sample = self.per_sample_log_likelihood(data, weights, means, covariances)?;
1000 Ok(per_sample.sum())
1001 }
1002
1003 fn per_sample_log_likelihood(
1004 &self,
1005 data: &ArrayView2<F>,
1006 weights: &Array1<F>,
1007 means: &Array2<F>,
1008 covariances: &[Array2<F>],
1009 ) -> StatsResult<Array1<F>> {
1010 let n_samples = data.nrows();
1011 let mut scores = Array1::zeros(n_samples);
1012
1013 for i in 0..n_samples {
1014 let sample = data.row(i);
1015 let mut log_probs = Array1::zeros(self.n_components);
1016
1017 for k in 0..self.n_components {
1018 let mean = means.row(k);
1019 let log_prob = self.log_multivariate_normal_pdf(&sample, &mean, &covariances[k])?;
1020 log_probs[k] = weights[k].ln() + log_prob;
1021 }
1022
1023 let max_lp = log_probs.iter().copied().fold(F::neg_infinity(), F::max);
1024 let log_sum_exp = (log_probs.mapv(|x| (x - max_lp).exp()).sum()).ln() + max_lp;
1025 scores[i] = log_sum_exp;
1026 }
1027 Ok(scores)
1028 }
1029
1030 fn squared_distance(&self, a: &ArrayView1<F>, b: &ArrayView1<F>) -> F {
1031 a.iter()
1032 .zip(b.iter())
1033 .map(|(&x, &y)| (x - y) * (x - y))
1034 .sum()
1035 }
1036}
1037
1038fn cholesky_lower(a: &Array2<f64>) -> StatsResult<Array2<f64>> {
1043 let n = a.nrows();
1044 if n != a.ncols() {
1045 return Err(StatsError::DimensionMismatch(
1046 "Cholesky requires a square matrix".into(),
1047 ));
1048 }
1049 let mut l = Array2::<f64>::zeros((n, n));
1050 for i in 0..n {
1051 for j in 0..=i {
1052 let mut sum = 0.0;
1053 for k in 0..j {
1054 sum += l[[i, k]] * l[[j, k]];
1055 }
1056 if i == j {
1057 let diag = a[[i, i]] - sum;
1058 if diag <= 0.0 {
1059 l[[i, j]] = (diag.abs() + 1e-10).sqrt();
1060 } else {
1061 l[[i, j]] = diag.sqrt();
1062 }
1063 } else if l[[j, j]].abs() > 1e-300 {
1064 l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
1065 }
1066 }
1067 }
1068 Ok(l)
1069}
1070
1071pub fn gaussian_mixture_model<F: GmmFloat>(
1077 data: &ArrayView2<F>,
1078 n_components: usize,
1079 config: Option<GMMConfig>,
1080) -> StatsResult<GMMParameters<F>> {
1081 let config = config.unwrap_or_default();
1082 let mut gmm = GaussianMixtureModel::new(n_components, config)?;
1083 Ok(gmm.fit(data)?.clone())
1084}
1085
1086pub fn gmm_model_selection<F: GmmFloat>(
1088 data: &ArrayView2<F>,
1089 min_components: usize,
1090 max_components: usize,
1091 config: Option<GMMConfig>,
1092) -> StatsResult<(usize, GMMParameters<F>)> {
1093 let config = config.unwrap_or_default();
1094 let mut best_n = min_components;
1095 let mut best_bic = F::infinity();
1096 let mut best_params: Option<GMMParameters<F>> = None;
1097
1098 for n_comp in min_components..=max_components {
1099 let mut gmm = GaussianMixtureModel::new(n_comp, config.clone())?;
1100 let params = gmm.fit(data)?;
1101
1102 if params.model_selection.bic < best_bic {
1103 best_bic = params.model_selection.bic;
1104 best_n = n_comp;
1105 best_params = Some(params.clone());
1106 }
1107 }
1108
1109 let params = best_params.ok_or_else(|| {
1110 StatsError::ComputationError("No valid model found during selection".into())
1111 })?;
1112 Ok((best_n, params))
1113}
1114
1115pub fn select_n_components<F: GmmFloat>(
1119 data: &ArrayView2<F>,
1120 max_k: usize,
1121 criterion: &str,
1122) -> StatsResult<(usize, Vec<f64>)> {
1123 if max_k == 0 {
1124 return Err(StatsError::InvalidArgument("max_k must be >= 1".into()));
1125 }
1126
1127 let mut scores = Vec::with_capacity(max_k);
1128 let mut best_k = 1usize;
1129 let mut best_score = f64::INFINITY;
1130
1131 for k in 1..=max_k {
1132 let config = GMMConfig {
1133 max_iter: 100,
1134 ..Default::default()
1135 };
1136 let mut gmm = GaussianMixtureModel::<F>::new(k, config)?;
1137 let params = gmm.fit(data)?;
1138
1139 let score_f64 = match criterion {
1140 "aic" | "AIC" => params.model_selection.aic.to_f64().unwrap_or(f64::INFINITY),
1141 _ => params.model_selection.bic.to_f64().unwrap_or(f64::INFINITY),
1142 };
1143
1144 scores.push(score_f64);
1145
1146 if score_f64 < best_score {
1147 best_score = score_f64;
1148 best_k = k;
1149 }
1150 }
1151
1152 Ok((best_k, scores))
1153}
1154
1155pub struct RobustGMM<F> {
1161 pub gmm: GaussianMixtureModel<F>,
1163 pub outlier_threshold: F,
1165 pub contamination: F,
1167 _phantom: PhantomData<F>,
1168}
1169
1170impl<F: GmmFloat> RobustGMM<F> {
1171 pub fn new(
1173 n_components: usize,
1174 outlier_threshold: F,
1175 contamination: F,
1176 mut config: GMMConfig,
1177 ) -> StatsResult<Self> {
1178 config.robust_em = true;
1179 config.outlier_threshold = outlier_threshold.to_f64().unwrap_or(0.01);
1180
1181 let gmm = GaussianMixtureModel::new(n_components, config)?;
1182 Ok(Self {
1183 gmm,
1184 outlier_threshold,
1185 contamination,
1186 _phantom: PhantomData,
1187 })
1188 }
1189
1190 pub fn fit(&mut self, data: &ArrayView2<F>) -> StatsResult<&GMMParameters<F>> {
1192 self.gmm.fit(data)?;
1193 let outlier_scores = self.compute_outlier_scores(data)?;
1194
1195 if let Some(ref mut params) = self.gmm.parameters {
1196 params.outlier_scores = Some(outlier_scores);
1197 }
1198
1199 self.gmm.parameters.as_ref().ok_or_else(|| {
1200 StatsError::ComputationError("Parameters not stored after robust fit".into())
1201 })
1202 }
1203
1204 fn compute_outlier_scores(&self, data: &ArrayView2<F>) -> StatsResult<Array1<F>> {
1205 let params = self.gmm.require_fitted()?;
1206 let per_sample_ll = self.gmm.per_sample_log_likelihood(
1207 data,
1208 ¶ms.weights,
1209 ¶ms.means,
1210 ¶ms.covariances,
1211 )?;
1212 Ok(per_sample_ll.mapv(|x| -x))
1213 }
1214
1215 pub fn detect_outliers(&self, _data: &ArrayView2<F>) -> StatsResult<Array1<bool>> {
1217 let params = self.gmm.require_fitted()?;
1218
1219 let outlier_scores = params.outlier_scores.as_ref().ok_or_else(|| {
1220 StatsError::InvalidArgument("Robust EM must be enabled for outlier detection".into())
1221 })?;
1222
1223 let mut sorted: Vec<F> = outlier_scores.iter().copied().collect();
1224 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
1225
1226 let threshold_idx_f =
1227 (F::one() - self.contamination) * f64_to_f(sorted.len() as f64, "sorted_len")?;
1228 let threshold_idx = threshold_idx_f
1229 .to_usize()
1230 .unwrap_or(sorted.len().saturating_sub(1))
1231 .min(sorted.len().saturating_sub(1));
1232 let adaptive_threshold = sorted[threshold_idx];
1233
1234 let outliers = outlier_scores.mapv(|score| score > adaptive_threshold);
1235 Ok(outliers)
1236 }
1237}
1238
1239pub struct StreamingGMM<F> {
1245 pub gmm: GaussianMixtureModel<F>,
1247 pub learning_rate: F,
1249 pub decay_factor: F,
1251 pub n_samples_seen: usize,
1253 pub running_means: Option<Array2<F>>,
1255 pub running_covariances: Option<Vec<Array2<F>>>,
1256 pub running_weights: Option<Array1<F>>,
1257 _phantom: PhantomData<F>,
1258}
1259
1260impl<F: GmmFloat> StreamingGMM<F> {
1261 pub fn new(
1263 n_components: usize,
1264 learning_rate: F,
1265 decay_factor: F,
1266 config: GMMConfig,
1267 ) -> StatsResult<Self> {
1268 let gmm = GaussianMixtureModel::new(n_components, config)?;
1269 Ok(Self {
1270 gmm,
1271 learning_rate,
1272 decay_factor,
1273 n_samples_seen: 0,
1274 running_means: None,
1275 running_covariances: None,
1276 running_weights: None,
1277 _phantom: PhantomData,
1278 })
1279 }
1280
1281 pub fn partial_fit(&mut self, batch: &ArrayView2<F>) -> StatsResult<()> {
1283 let batchsize = batch.nrows();
1284
1285 if self.n_samples_seen == 0 {
1286 self.gmm.fit(batch)?;
1287 let params = self.gmm.require_fitted()?;
1288 self.running_means = Some(params.means.clone());
1289 self.running_covariances = Some(params.covariances.clone());
1290 self.running_weights = Some(params.weights.clone());
1291 } else {
1292 self.online_update(batch)?;
1293 }
1294
1295 self.n_samples_seen += batchsize;
1296 Ok(())
1297 }
1298
1299 fn online_update(&mut self, batch: &ArrayView2<F>) -> StatsResult<()> {
1300 let params = self.gmm.require_fitted()?;
1301
1302 let responsibilities =
1303 self.gmm
1304 .e_step(batch, ¶ms.weights, ¶ms.means, ¶ms.covariances)?;
1305
1306 let batch_weights = self.gmm.m_step_weights(&responsibilities)?;
1307 let batch_means = self.gmm.m_step_means(batch, &responsibilities)?;
1308
1309 let lr = self.learning_rate;
1310 let decay = self.decay_factor;
1311
1312 if let (Some(ref mut r_weights), Some(ref mut r_means)) =
1313 (&mut self.running_weights, &mut self.running_means)
1314 {
1315 *r_weights = r_weights.mapv(|x| x * decay) + batch_weights.mapv(|x| x * lr);
1316 let weight_sum = r_weights.sum();
1317 if weight_sum > F::zero() {
1318 *r_weights = r_weights.mapv(|x| x / weight_sum);
1319 }
1320 *r_means = r_means.mapv(|x| x * decay) + batch_means.mapv(|x| x * lr);
1321 }
1322
1323 if let Some(ref mut p) = self.gmm.parameters {
1324 if let Some(ref rw) = self.running_weights {
1325 p.weights = rw.clone();
1326 }
1327 if let Some(ref rm) = self.running_means {
1328 p.means = rm.clone();
1329 }
1330 }
1331
1332 Ok(())
1333 }
1334
1335 pub fn get_parameters(&self) -> Option<&GMMParameters<F>> {
1337 self.gmm.parameters.as_ref()
1338 }
1339}
1340
1341pub fn hierarchical_gmm_init<F: GmmFloat>(
1347 data: &ArrayView2<F>,
1348 n_components: usize,
1349 config: GMMConfig,
1350) -> StatsResult<GMMParameters<F>> {
1351 let mut init_config = config;
1352 init_config.init_method = InitializationMethod::FurthestFirst;
1353 gaussian_mixture_model(data, n_components, Some(init_config))
1354}
1355
1356pub fn gmm_cross_validation<F: GmmFloat>(
1362 data: &ArrayView2<F>,
1363 n_components: usize,
1364 n_folds: usize,
1365 config: GMMConfig,
1366) -> StatsResult<F> {
1367 let n_samples = data.nrows();
1368 if n_folds < 2 || n_folds > n_samples {
1369 return Err(StatsError::InvalidArgument(format!(
1370 "n_folds ({n_folds}) must be in [2, n_samples ({n_samples})]"
1371 )));
1372 }
1373 let foldsize = n_samples / n_folds;
1374 let mut cv_scores = Vec::with_capacity(n_folds);
1375
1376 for fold in 0..n_folds {
1377 let val_start = fold * foldsize;
1378 let val_end = if fold == n_folds - 1 {
1379 n_samples
1380 } else {
1381 (fold + 1) * foldsize
1382 };
1383
1384 let mut train_indices = Vec::new();
1385 for i in 0..n_samples {
1386 if i < val_start || i >= val_end {
1387 train_indices.push(i);
1388 }
1389 }
1390
1391 let traindata = Array2::from_shape_fn((train_indices.len(), data.ncols()), |(i, j)| {
1392 data[[train_indices[i], j]]
1393 });
1394 let valdata = data.slice(s![val_start..val_end, ..]);
1395
1396 let mut gmm = GaussianMixtureModel::new(n_components, config.clone())?;
1397 let params = gmm.fit(&traindata.view())?.clone();
1398
1399 let val_ll = gmm.compute_log_likelihood(
1400 &valdata,
1401 ¶ms.weights,
1402 ¶ms.means,
1403 ¶ms.covariances,
1404 )?;
1405 cv_scores.push(val_ll);
1406 }
1407
1408 let n_folds_f: F = f64_to_f(cv_scores.len() as f64, "cv_n")?;
1409 let avg_score: F = cv_scores.iter().copied().sum::<F>() / n_folds_f;
1410 Ok(avg_score)
1411}
1412
1413pub fn benchmark_mixture_models<F: GmmFloat>(
1419 data: &ArrayView2<F>,
1420 methods: &[(
1421 &str,
1422 Box<dyn Fn(&ArrayView2<F>) -> StatsResult<GMMParameters<F>>>,
1423 )],
1424) -> StatsResult<Vec<(String, std::time::Duration, F)>> {
1425 let mut results = Vec::new();
1426 for (name, method) in methods {
1427 let start_time = std::time::Instant::now();
1428 let params = method(data)?;
1429 let duration = start_time.elapsed();
1430 results.push((name.to_string(), duration, params.log_likelihood));
1431 }
1432 Ok(results)
1433}
1434
1435#[cfg(test)]
1440mod tests;