1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::{Rng, SeedableRng};
9use std::f64::consts::PI;
10use std::fmt::Debug;
11use std::iter::Sum;
12
13use crate::error::{ClusteringError, Result};
14use crate::vq::kmeans_plus_plus;
15use statrs::statistics::Statistics;
16
17type GMMParams<F> = (Array1<F>, Array2<F>, Vec<Array2<F>>);
19
20type GMMFitResult<F> = (Array1<F>, Array2<F>, Vec<Array2<F>>, F, usize, bool);
22
23#[derive(Debug, Clone, Copy)]
25pub enum CovarianceType {
26 Full,
28 Diagonal,
30 Tied,
32 Spherical,
34}
35
36#[derive(Debug, Clone, Copy)]
38pub enum GMMInit {
39 KMeans,
41 Random,
43}
44
45#[derive(Debug, Clone)]
47pub struct GMMOptions<F: Float> {
48 pub n_components: usize,
50 pub covariance_type: CovarianceType,
52 pub tol: F,
54 pub max_iter: usize,
56 pub n_init: usize,
58 pub init_method: GMMInit,
60 pub random_seed: Option<u64>,
62 pub reg_covar: F,
64}
65
66impl<F: Float + FromPrimitive> Default for GMMOptions<F> {
67 fn default() -> Self {
68 Self {
69 n_components: 1,
70 covariance_type: CovarianceType::Full,
71 tol: F::from(1e-3).unwrap(),
72 max_iter: 100,
73 n_init: 1,
74 init_method: GMMInit::KMeans,
75 random_seed: None,
76 reg_covar: F::from(1e-6).unwrap(),
77 }
78 }
79}
80
81pub struct GaussianMixture<F: Float> {
83 options: GMMOptions<F>,
85 weights: Option<Array1<F>>,
87 means: Option<Array2<F>>,
89 covariances: Option<Vec<Array2<F>>>,
91 lower_bound: Option<F>,
93 n_iter: Option<usize>,
95 converged: bool,
97}
98
99impl<F: Float + FromPrimitive + Debug + ScalarOperand + Sum + std::borrow::Borrow<f64>>
100 GaussianMixture<F>
101{
102 pub fn new(options: GMMOptions<F>) -> Self {
104 Self {
105 options,
106 weights: None,
107 means: None,
108 covariances: None,
109 lower_bound: None,
110 n_iter: None,
111 converged: false,
112 }
113 }
114
115 pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
117 let n_samples = data.shape()[0];
118 let _n_features = data.shape()[1];
119
120 if n_samples < self.options.n_components {
121 return Err(ClusteringError::InvalidInput(
122 "Number of samples must be >= number of components".to_string(),
123 ));
124 }
125
126 let mut best_lower_bound = F::neg_infinity();
127 let mut best_params = None;
128
129 for _ in 0..self.options.n_init {
131 let (weights, means, covariances, lower_bound, n_iter, converged) =
132 self.fit_single(data)?;
133
134 if lower_bound > best_lower_bound {
135 best_lower_bound = lower_bound;
136 best_params = Some((weights, means, covariances, lower_bound, n_iter, converged));
137 }
138 }
139
140 if let Some((weights, means, covariances, lower_bound, n_iter, converged)) = best_params {
141 self.weights = Some(weights);
142 self.means = Some(means);
143 self.covariances = Some(covariances);
144 self.lower_bound = Some(lower_bound);
145 self.n_iter = Some(n_iter);
146 self.converged = converged;
147 }
148
149 Ok(())
150 }
151
152 fn fit_single(&self, data: ArrayView2<F>) -> Result<GMMFitResult<F>> {
154 let _n_samples = data.shape()[0];
155 let _n_features = data.shape()[1];
156 let _n_components = self.options.n_components;
157
158 let (mut weights, mut means, mut covariances) = self.initialize_params(data)?;
160
161 let mut lower_bound = F::neg_infinity();
162 let mut converged = false;
163
164 for iter in 0..self.options.max_iter {
165 let (resp_, new_lower_bound) = self.e_step(data, &weights, &means, &covariances)?;
167
168 let change = (new_lower_bound - lower_bound).abs();
170 if change < self.options.tol {
171 converged = true;
172 return Ok((
173 weights,
174 means,
175 covariances,
176 new_lower_bound,
177 iter + 1,
178 converged,
179 ));
180 }
181 lower_bound = new_lower_bound;
182
183 (weights, means, covariances) = self.m_step(data, resp_)?;
185 }
186
187 Ok((
188 weights,
189 means,
190 covariances,
191 lower_bound,
192 self.options.max_iter,
193 converged,
194 ))
195 }
196
197 fn initialize_params(&self, data: ArrayView2<F>) -> Result<GMMParams<F>> {
199 let n_samples = data.shape()[0];
200 let n_features = data.shape()[1];
201 let n_components = self.options.n_components;
202
203 let weights = Array1::from_elem(n_components, F::one() / F::from(n_components).unwrap());
205
206 let means = match self.options.init_method {
208 GMMInit::KMeans => {
209 kmeans_plus_plus(data, n_components, self.options.random_seed)?
211 }
212 GMMInit::Random => {
213 let mut rng = match self.options.random_seed {
215 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
216 None => scirs2_core::random::rngs::StdRng::seed_from_u64(
217 scirs2_core::random::rng().random::<u64>(),
218 ),
219 };
220
221 let mut means = Array2::zeros((n_components, n_features));
222 for i in 0..n_components {
223 let idx = rng.random_range(0..n_samples);
224 means.slice_mut(s![i, ..]).assign(&data.slice(s![idx, ..]));
225 }
226 means
227 }
228 };
229
230 let mut covariances = Vec::with_capacity(n_components);
232
233 let data_mean = data.mean_axis(Axis(0)).unwrap();
235 let mut variance = Array1::<F>::zeros(n_features);
236
237 for i in 0..n_samples {
238 let diff = &data.slice(s![i, ..]) - &data_mean;
239 variance = variance + &diff.mapv(|x| x * x);
240 }
241 variance = variance / F::from(n_samples - 1).unwrap();
242
243 match self.options.covariance_type {
244 CovarianceType::Spherical => {
245 let avg_variance = variance.sum() / F::from(variance.len()).unwrap();
246 for _ in 0..n_components {
247 let mut cov = Array2::<F>::zeros((n_features, n_features));
248 for i in 0..n_features {
249 cov[[i, i]] = avg_variance;
250 }
251 covariances.push(cov);
252 }
253 }
254 CovarianceType::Diagonal => {
255 for _ in 0..n_components {
256 let mut cov = Array2::<F>::zeros((n_features, n_features));
257 for i in 0..n_features {
258 cov[[i, i]] = variance[i];
259 }
260 covariances.push(cov);
261 }
262 }
263 CovarianceType::Full | CovarianceType::Tied => {
264 for _ in 0..n_components {
266 let mut cov = Array2::<F>::zeros((n_features, n_features));
267 for i in 0..n_features {
268 cov[[i, i]] = variance[i];
269 }
270 covariances.push(cov);
271 }
272 }
273 }
274
275 Ok((weights, means, covariances))
276 }
277
278 fn e_step(
280 &self,
281 data: ArrayView2<F>,
282 weights: &Array1<F>,
283 means: &Array2<F>,
284 covariances: &[Array2<F>],
285 ) -> Result<(Array2<F>, F)> {
286 let n_samples = data.shape()[0];
287 let n_components = self.options.n_components;
288
289 let mut log_prob = Array2::zeros((n_samples, n_components));
290
291 for (k, covariance) in covariances.iter().enumerate().take(n_components) {
293 let log_prob_k = self.log_multivariate_normal_density(
294 data,
295 means.slice(s![k, ..]).view(),
296 covariance,
297 )?;
298 log_prob.slice_mut(s![.., k]).assign(&log_prob_k);
299 }
300
301 for k in 0..n_components {
303 let log_weight = weights[k].ln();
304 log_prob
305 .slice_mut(s![.., k])
306 .mapv_inplace(|x| x + log_weight);
307 }
308
309 let log_prob_norm = self.logsumexp(log_prob.view(), Axis(1))?;
311
312 let mut resp_ = log_prob.clone();
314 for i in 0..n_samples {
315 for k in 0..n_components {
316 resp_[[i, k]] = (resp_[[i, k]] - log_prob_norm[i]).exp();
317 }
318 }
319
320 let lower_bound = log_prob_norm.sum() / F::from(log_prob_norm.len()).unwrap();
322
323 Ok((resp_, lower_bound))
324 }
325
326 fn m_step(&self, data: ArrayView2<F>, resp_: Array2<F>) -> Result<GMMParams<F>> {
328 let n_samples = data.shape()[0];
329 let n_features = data.shape()[1];
330 let n_components = self.options.n_components;
331
332 let nk = resp_.sum_axis(Axis(0));
334 let weights = &nk / F::from(n_samples).unwrap();
335
336 let mut means = Array2::zeros((n_components, n_features));
338 for k in 0..n_components {
339 let mut mean_k = Array1::zeros(n_features);
340 for i in 0..n_samples {
341 mean_k = mean_k + &data.slice(s![i, ..]) * resp_[[i, k]];
342 }
343 means.slice_mut(s![k, ..]).assign(&(&mean_k / nk[k]));
344 }
345
346 let mut covariances = Vec::with_capacity(n_components);
348
349 match self.options.covariance_type {
350 CovarianceType::Full => {
351 for k in 0..n_components {
352 let mean_k = means.slice(s![k, ..]);
353 let mut cov = Array2::zeros((n_features, n_features));
354
355 for i in 0..n_samples {
356 let diff = &data.slice(s![i, ..]) - &mean_k;
357 let outer = self.outer_product(diff.view(), diff.view());
358 cov = cov + &outer * resp_[[i, k]];
359 }
360
361 cov = cov / nk[k];
362 for i in 0..n_features {
364 cov[[i, i]] = cov[[i, i]] + self.options.reg_covar;
365 }
366
367 covariances.push(cov);
368 }
369 }
370 _ => {
371 for k in 0..n_components {
373 let mean_k = means.slice(s![k, ..]);
374 let mut cov = Array2::zeros((n_features, n_features));
375
376 for i in 0..n_samples {
377 let diff = &data.slice(s![i, ..]) - &mean_k;
378 for j in 0..n_features {
379 cov[[j, j]] = cov[[j, j]] + diff[j] * diff[j] * resp_[[i, k]];
380 }
381 }
382
383 for j in 0..n_features {
384 cov[[j, j]] = cov[[j, j]] / nk[k] + self.options.reg_covar;
385 }
386
387 covariances.push(cov);
388 }
389 }
390 }
391
392 Ok((weights, means, covariances))
393 }
394
395 fn log_multivariate_normal_density(
397 &self,
398 data: ArrayView2<F>,
399 mean: ArrayView1<F>,
400 covariance: &Array2<F>,
401 ) -> Result<Array1<F>> {
402 let n_samples = data.shape()[0];
403 let n_features = data.shape()[1];
404
405 let mut log_prob = Array1::zeros(n_samples);
407
408 let mut log_det = F::zero();
410 for i in 0..n_features {
411 log_det = log_det + covariance[[i, i]].ln();
412 }
413
414 let norm_const = F::from(n_features as f64 * (2.0 * PI).ln()).unwrap() + log_det;
415
416 for i in 0..n_samples {
417 let diff = &data.slice(s![i, ..]) - &mean;
418 let mut mahalanobis = F::zero();
419
420 for j in 0..n_features {
422 mahalanobis = mahalanobis + diff[j] * diff[j] / covariance[[j, j]];
423 }
424
425 log_prob[i] = F::from(-0.5).unwrap() * (norm_const + mahalanobis);
426 }
427
428 Ok(log_prob)
429 }
430
431 fn logsumexp(&self, arr: ArrayView2<F>, axis: Axis) -> Result<Array1<F>> {
433 let max_vals = arr.fold_axis(axis, F::neg_infinity(), |&a, &b| a.max(b));
434 let mut result = Array1::zeros(max_vals.len());
435
436 match axis {
437 Axis(1) => {
438 for i in 0..arr.shape()[0] {
439 let mut sum = F::zero();
440 for j in 0..arr.shape()[1] {
441 sum = sum + (arr[[i, j]] - max_vals[i]).exp();
442 }
443 result[i] = max_vals[i] + sum.ln();
444 }
445 }
446 _ => {
447 return Err(ClusteringError::InvalidInput(
448 "Only axis 1 is supported for logsumexp".to_string(),
449 ));
450 }
451 }
452
453 Ok(result)
454 }
455
456 fn outer_product(&self, a: ArrayView1<F>, b: ArrayView1<F>) -> Array2<F> {
458 let n = a.len();
459 let m = b.len();
460 let mut result = Array2::zeros((n, m));
461
462 for i in 0..n {
463 for j in 0..m {
464 result[[i, j]] = a[i] * b[j];
465 }
466 }
467
468 result
469 }
470
471 pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<i32>> {
473 if self.weights.is_none() || self.means.is_none() || self.covariances.is_none() {
474 return Err(ClusteringError::InvalidInput(
475 "Model has not been fitted yet".to_string(),
476 ));
477 }
478
479 let weights = self.weights.as_ref().unwrap();
480 let means = self.means.as_ref().unwrap();
481 let covariances = self.covariances.as_ref().unwrap();
482
483 let (resp__, _) = self.e_step(data, weights, means, covariances)?;
484
485 let mut labels = Array1::zeros(data.shape()[0]);
487 for i in 0..data.shape()[0] {
488 let mut max_resp_ = F::neg_infinity();
489 let mut best_k = 0;
490
491 for k in 0..self.options.n_components {
492 if resp__[[i, k]] > max_resp_ {
493 max_resp_ = resp__[[i, k]];
494 best_k = k;
495 }
496 }
497
498 labels[i] = best_k as i32;
499 }
500
501 Ok(labels)
502 }
503}
504
505#[allow(dead_code)]
539pub fn gaussian_mixture<F>(data: ArrayView2<F>, options: GMMOptions<F>) -> Result<Array1<i32>>
540where
541 F: Float + FromPrimitive + Debug + ScalarOperand + Sum + std::borrow::Borrow<f64>,
542{
543 let mut gmm = GaussianMixture::new(options);
544 gmm.fit(data)?;
545 gmm.predict(data)
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551 use scirs2_core::ndarray::Array2;
552
553 #[test]
554 fn test_gmm_simple() {
555 let data = Array2::from_shape_vec(
556 (6, 2),
557 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
558 )
559 .unwrap();
560
561 let options = GMMOptions {
562 n_components: 2,
563 max_iter: 10,
564 ..Default::default()
565 };
566
567 let result = gaussian_mixture(data.view(), options);
568 assert!(result.is_ok());
569
570 let labels = result.unwrap();
571 assert_eq!(labels.len(), 6);
572
573 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
575 assert!(unique_labels.len() <= 2);
576 }
577
578 #[test]
579 fn test_gmm_different_covariance_types() {
580 let data = Array2::from_shape_vec(
581 (8, 2),
582 vec![
583 1.0, 1.0, 1.1, 1.1, 0.9, 0.9, 1.2, 0.8, 5.0, 5.0, 5.1, 5.1, 4.9, 4.9, 5.2, 4.8,
584 ],
585 )
586 .unwrap();
587
588 let covariance_types = vec![
589 CovarianceType::Full,
590 CovarianceType::Diagonal,
591 CovarianceType::Spherical,
592 CovarianceType::Tied,
593 ];
594
595 for cov_type in covariance_types {
596 let options = GMMOptions {
597 n_components: 2,
598 covariance_type: cov_type,
599 max_iter: 50,
600 ..Default::default()
601 };
602
603 let result = gaussian_mixture(data.view(), options);
604 assert!(
605 result.is_ok(),
606 "Failed with covariance type: {:?}",
607 cov_type
608 );
609
610 let labels = result.unwrap();
611 assert_eq!(labels.len(), 8);
612 }
613 }
614
615 #[test]
616 fn test_gmm_initialization_methods() {
617 let data = Array2::from_shape_vec(
618 (6, 2),
619 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
620 )
621 .unwrap();
622
623 let init_methods = vec![GMMInit::KMeans, GMMInit::Random];
624
625 for init_method in init_methods {
626 let options = GMMOptions {
627 n_components: 2,
628 init_method,
629 random_seed: Some(42),
630 max_iter: 20,
631 ..Default::default()
632 };
633
634 let result = gaussian_mixture(data.view(), options);
635 assert!(result.is_ok(), "Failed with init method: {:?}", init_method);
636
637 let labels = result.unwrap();
638 assert_eq!(labels.len(), 6);
639 }
640 }
641
642 #[test]
643 fn test_gmm_parameter_validation() {
644 let data =
645 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0]).unwrap();
646
647 let options = GMMOptions {
649 n_components: 0,
650 ..Default::default()
651 };
652 let result = gaussian_mixture(data.view(), options);
653 assert!(result.is_err());
654
655 let options = GMMOptions {
657 n_components: 10,
658 max_iter: 5, ..Default::default()
660 };
661 let result = gaussian_mixture(data.view(), options);
662 let _result = result;
665 }
666
667 #[test]
668 fn test_gmm_convergence_criteria() {
669 let data = Array2::from_shape_vec(
670 (6, 2),
671 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
672 )
673 .unwrap();
674
675 let tolerances = vec![1e-3, 1e-6, 1e-9];
677
678 for tol in tolerances {
679 let options = GMMOptions {
680 n_components: 2,
681 tol,
682 max_iter: 100,
683 ..Default::default()
684 };
685
686 let result = gaussian_mixture(data.view(), options);
687 assert!(result.is_ok(), "Failed with tolerance: {}", tol);
688 }
689 }
690
691 #[test]
692 fn test_gmm_single_component() {
693 let data =
694 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 1.1, 2.1]).unwrap();
695
696 let options = GMMOptions {
697 n_components: 1,
698 max_iter: 20,
699 ..Default::default()
700 };
701
702 let result = gaussian_mixture(data.view(), options);
703 assert!(result.is_ok());
704
705 let labels = result.unwrap();
706 assert_eq!(labels.len(), 4);
707
708 assert!(labels.iter().all(|&l| l == 0));
710 }
711
712 #[test]
713 fn test_gmm_reproducibility_with_seed() {
714 let data = Array2::from_shape_vec(
715 (6, 2),
716 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
717 )
718 .unwrap();
719
720 let options1 = GMMOptions {
721 n_components: 2,
722 random_seed: Some(42),
723 max_iter: 50,
724 ..Default::default()
725 };
726
727 let options2 = GMMOptions {
728 n_components: 2,
729 random_seed: Some(42),
730 max_iter: 50,
731 ..Default::default()
732 };
733
734 let labels1 = gaussian_mixture(data.view(), options1).unwrap();
735 let labels2 = gaussian_mixture(data.view(), options2).unwrap();
736
737 assert_eq!(labels1.len(), labels2.len());
740
741 let unique1: std::collections::HashSet<_> = labels1.iter().cloned().collect();
743 let unique2: std::collections::HashSet<_> = labels2.iter().cloned().collect();
744 assert_eq!(unique1.len(), unique2.len());
745 }
746
747 #[test]
748 fn test_gmm_many_components() {
749 let data = Array2::from_shape_vec(
750 (10, 2),
751 vec![
752 1.0, 1.0, 1.1, 1.1, 1.2, 1.2, 3.0, 3.0, 3.1, 3.1, 3.2, 3.2, 5.0, 5.0, 5.1, 5.1,
753 5.2, 5.2, 7.0, 7.0,
754 ],
755 )
756 .unwrap();
757
758 let options = GMMOptions {
759 n_components: 3,
760 max_iter: 50,
761 ..Default::default()
762 };
763
764 let result = gaussian_mixture(data.view(), options);
765 assert!(result.is_ok());
766
767 let labels = result.unwrap();
768 assert_eq!(labels.len(), 10);
769
770 let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
772 assert!(unique_labels.len() <= 3);
773 assert!(!unique_labels.is_empty());
774 }
775
776 #[test]
777 fn test_gmm_regularization() {
778 let data = Array2::from_shape_vec(
779 (6, 2),
780 vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
781 )
782 .unwrap();
783
784 let reg_values = vec![1e-6, 1e-3, 1e-1];
786
787 for reg_covar in reg_values {
788 let options = GMMOptions {
789 n_components: 2,
790 reg_covar,
791 max_iter: 20,
792 ..Default::default()
793 };
794
795 let result = gaussian_mixture(data.view(), options);
796 assert!(result.is_ok(), "Failed with reg_covar: {}", reg_covar);
797 }
798 }
799
800 #[test]
801 fn test_gmm_fit_predict_workflow() {
802 let data = Array2::from_shape_vec(
803 (8, 2),
804 vec![
805 1.0, 1.0, 1.1, 1.1, 0.9, 0.9, 1.2, 0.8, 5.0, 5.0, 5.1, 5.1, 4.9, 4.9, 5.2, 4.8,
806 ],
807 )
808 .unwrap();
809
810 let options = GMMOptions {
811 n_components: 2,
812 max_iter: 50,
813 random_seed: Some(42),
814 ..Default::default()
815 };
816
817 let mut gmm = GaussianMixture::new(options);
819
820 let fit_result = gmm.fit(data.view());
822 assert!(fit_result.is_ok());
823
824 let predict_result = gmm.predict(data.view());
826 assert!(predict_result.is_ok());
827
828 let labels = predict_result.unwrap();
829 assert_eq!(labels.len(), 8);
830
831 let new_data = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 5.0, 5.0]).unwrap();
833
834 let new_labels = gmm.predict(new_data.view());
835 assert!(new_labels.is_ok());
836 assert_eq!(new_labels.unwrap().len(), 2);
837 }
838}