scirs2_cluster/
gmm.rs

1//! Gaussian Mixture Models (GMM) for clustering
2//!
3//! This module implements Gaussian Mixture Models, a probabilistic model that assumes
4//! data is generated from a mixture of a finite number of Gaussian distributions.
5
6use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis, ScalarOperand};
7use scirs2_core::numeric::{Float, FromPrimitive};
8use scirs2_core::random::{Rng, SeedableRng};
9use std::f64::consts::PI;
10use std::fmt::Debug;
11use std::iter::Sum;
12
13use crate::error::{ClusteringError, Result};
14use crate::vq::kmeans_plus_plus;
15use statrs::statistics::Statistics;
16
17/// Type alias for GMM parameters
18type GMMParams<F> = (Array1<F>, Array2<F>, Vec<Array2<F>>);
19
20/// Type alias for GMM fit result
21type GMMFitResult<F> = (Array1<F>, Array2<F>, Vec<Array2<F>>, F, usize, bool);
22
23/// Covariance type for GMM
24#[derive(Debug, Clone, Copy)]
25pub enum CovarianceType {
26    /// Each component has its own general covariance matrix
27    Full,
28    /// Each component has its own diagonal covariance matrix
29    Diagonal,
30    /// All components share the same general covariance matrix
31    Tied,
32    /// All components share the same diagonal covariance matrix (spherical)
33    Spherical,
34}
35
36/// GMM initialization method
37#[derive(Debug, Clone, Copy)]
38pub enum GMMInit {
39    /// Initialize using K-means++
40    KMeans,
41    /// Random initialization
42    Random,
43}
44
45/// Options for Gaussian Mixture Model
46#[derive(Debug, Clone)]
47pub struct GMMOptions<F: Float> {
48    /// Number of mixture components
49    pub n_components: usize,
50    /// Type of covariance parameters
51    pub covariance_type: CovarianceType,
52    /// Convergence threshold
53    pub tol: F,
54    /// Maximum number of iterations
55    pub max_iter: usize,
56    /// Number of initializations to perform
57    pub n_init: usize,
58    /// Initialization method
59    pub init_method: GMMInit,
60    /// Random seed
61    pub random_seed: Option<u64>,
62    /// Regularization added to the diagonal of covariance matrices
63    pub reg_covar: F,
64}
65
66impl<F: Float + FromPrimitive> Default for GMMOptions<F> {
67    fn default() -> Self {
68        Self {
69            n_components: 1,
70            covariance_type: CovarianceType::Full,
71            tol: F::from(1e-3).unwrap(),
72            max_iter: 100,
73            n_init: 1,
74            init_method: GMMInit::KMeans,
75            random_seed: None,
76            reg_covar: F::from(1e-6).unwrap(),
77        }
78    }
79}
80
81/// Gaussian Mixture Model
82pub struct GaussianMixture<F: Float> {
83    /// Options
84    options: GMMOptions<F>,
85    /// Weights of each mixture component
86    weights: Option<Array1<F>>,
87    /// Means of each mixture component
88    means: Option<Array2<F>>,
89    /// Covariances of each mixture component
90    covariances: Option<Vec<Array2<F>>>,
91    /// Lower bound value (log-likelihood)
92    lower_bound: Option<F>,
93    /// Number of iterations run
94    n_iter: Option<usize>,
95    /// Whether the model has converged
96    converged: bool,
97}
98
99impl<F: Float + FromPrimitive + Debug + ScalarOperand + Sum + std::borrow::Borrow<f64>>
100    GaussianMixture<F>
101{
102    /// Create a new GMM instance
103    pub fn new(options: GMMOptions<F>) -> Self {
104        Self {
105            options,
106            weights: None,
107            means: None,
108            covariances: None,
109            lower_bound: None,
110            n_iter: None,
111            converged: false,
112        }
113    }
114
115    /// Fit the Gaussian Mixture Model to data
116    pub fn fit(&mut self, data: ArrayView2<F>) -> Result<()> {
117        let n_samples = data.shape()[0];
118        let _n_features = data.shape()[1];
119
120        if n_samples < self.options.n_components {
121            return Err(ClusteringError::InvalidInput(
122                "Number of samples must be >= number of components".to_string(),
123            ));
124        }
125
126        let mut best_lower_bound = F::neg_infinity();
127        let mut best_params = None;
128
129        // Try multiple initializations
130        for _ in 0..self.options.n_init {
131            let (weights, means, covariances, lower_bound, n_iter, converged) =
132                self.fit_single(data)?;
133
134            if lower_bound > best_lower_bound {
135                best_lower_bound = lower_bound;
136                best_params = Some((weights, means, covariances, lower_bound, n_iter, converged));
137            }
138        }
139
140        if let Some((weights, means, covariances, lower_bound, n_iter, converged)) = best_params {
141            self.weights = Some(weights);
142            self.means = Some(means);
143            self.covariances = Some(covariances);
144            self.lower_bound = Some(lower_bound);
145            self.n_iter = Some(n_iter);
146            self.converged = converged;
147        }
148
149        Ok(())
150    }
151
152    /// Single run of EM algorithm
153    fn fit_single(&self, data: ArrayView2<F>) -> Result<GMMFitResult<F>> {
154        let _n_samples = data.shape()[0];
155        let _n_features = data.shape()[1];
156        let _n_components = self.options.n_components;
157
158        // Initialize parameters
159        let (mut weights, mut means, mut covariances) = self.initialize_params(data)?;
160
161        let mut lower_bound = F::neg_infinity();
162        let mut converged = false;
163
164        for iter in 0..self.options.max_iter {
165            // E-step: compute resp_onsibilities
166            let (resp_, new_lower_bound) = self.e_step(data, &weights, &means, &covariances)?;
167
168            // Check convergence
169            let change = (new_lower_bound - lower_bound).abs();
170            if change < self.options.tol {
171                converged = true;
172                return Ok((
173                    weights,
174                    means,
175                    covariances,
176                    new_lower_bound,
177                    iter + 1,
178                    converged,
179                ));
180            }
181            lower_bound = new_lower_bound;
182
183            // M-step: update parameters
184            (weights, means, covariances) = self.m_step(data, resp_)?;
185        }
186
187        Ok((
188            weights,
189            means,
190            covariances,
191            lower_bound,
192            self.options.max_iter,
193            converged,
194        ))
195    }
196
197    /// Initialize GMM parameters
198    fn initialize_params(&self, data: ArrayView2<F>) -> Result<GMMParams<F>> {
199        let n_samples = data.shape()[0];
200        let n_features = data.shape()[1];
201        let n_components = self.options.n_components;
202
203        // Initialize weights uniformly
204        let weights = Array1::from_elem(n_components, F::one() / F::from(n_components).unwrap());
205
206        // Initialize means
207        let means = match self.options.init_method {
208            GMMInit::KMeans => {
209                // Use k-means++ initialization
210                kmeans_plus_plus(data, n_components, self.options.random_seed)?
211            }
212            GMMInit::Random => {
213                // Random selection from data points
214                let mut rng = match self.options.random_seed {
215                    Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
216                    None => scirs2_core::random::rngs::StdRng::seed_from_u64(
217                        scirs2_core::random::rng().random::<u64>(),
218                    ),
219                };
220
221                let mut means = Array2::zeros((n_components, n_features));
222                for i in 0..n_components {
223                    let idx = rng.random_range(0..n_samples);
224                    means.slice_mut(s![i, ..]).assign(&data.slice(s![idx, ..]));
225                }
226                means
227            }
228        };
229
230        // Initialize covariances based on data variance
231        let mut covariances = Vec::with_capacity(n_components);
232
233        // Compute data variance for initialization
234        let data_mean = data.mean_axis(Axis(0)).unwrap();
235        let mut variance = Array1::<F>::zeros(n_features);
236
237        for i in 0..n_samples {
238            let diff = &data.slice(s![i, ..]) - &data_mean;
239            variance = variance + &diff.mapv(|x| x * x);
240        }
241        variance = variance / F::from(n_samples - 1).unwrap();
242
243        match self.options.covariance_type {
244            CovarianceType::Spherical => {
245                let avg_variance = variance.sum() / F::from(variance.len()).unwrap();
246                for _ in 0..n_components {
247                    let mut cov = Array2::<F>::zeros((n_features, n_features));
248                    for i in 0..n_features {
249                        cov[[i, i]] = avg_variance;
250                    }
251                    covariances.push(cov);
252                }
253            }
254            CovarianceType::Diagonal => {
255                for _ in 0..n_components {
256                    let mut cov = Array2::<F>::zeros((n_features, n_features));
257                    for i in 0..n_features {
258                        cov[[i, i]] = variance[i];
259                    }
260                    covariances.push(cov);
261                }
262            }
263            CovarianceType::Full | CovarianceType::Tied => {
264                // Initialize with diagonal covariance
265                for _ in 0..n_components {
266                    let mut cov = Array2::<F>::zeros((n_features, n_features));
267                    for i in 0..n_features {
268                        cov[[i, i]] = variance[i];
269                    }
270                    covariances.push(cov);
271                }
272            }
273        }
274
275        Ok((weights, means, covariances))
276    }
277
278    /// E-step: compute resp_onsibilities
279    fn e_step(
280        &self,
281        data: ArrayView2<F>,
282        weights: &Array1<F>,
283        means: &Array2<F>,
284        covariances: &[Array2<F>],
285    ) -> Result<(Array2<F>, F)> {
286        let n_samples = data.shape()[0];
287        let n_components = self.options.n_components;
288
289        let mut log_prob = Array2::zeros((n_samples, n_components));
290
291        // Compute log probabilities for each component
292        for (k, covariance) in covariances.iter().enumerate().take(n_components) {
293            let log_prob_k = self.log_multivariate_normal_density(
294                data,
295                means.slice(s![k, ..]).view(),
296                covariance,
297            )?;
298            log_prob.slice_mut(s![.., k]).assign(&log_prob_k);
299        }
300
301        // Add log weights
302        for k in 0..n_components {
303            let log_weight = weights[k].ln();
304            log_prob
305                .slice_mut(s![.., k])
306                .mapv_inplace(|x| x + log_weight);
307        }
308
309        // Compute log normalization
310        let log_prob_norm = self.logsumexp(log_prob.view(), Axis(1))?;
311
312        // Compute resp_onsibilities
313        let mut resp_ = log_prob.clone();
314        for i in 0..n_samples {
315            for k in 0..n_components {
316                resp_[[i, k]] = (resp_[[i, k]] - log_prob_norm[i]).exp();
317            }
318        }
319
320        // Compute lower bound
321        let lower_bound = log_prob_norm.sum() / F::from(log_prob_norm.len()).unwrap();
322
323        Ok((resp_, lower_bound))
324    }
325
326    /// M-step: update parameters
327    fn m_step(&self, data: ArrayView2<F>, resp_: Array2<F>) -> Result<GMMParams<F>> {
328        let n_samples = data.shape()[0];
329        let n_features = data.shape()[1];
330        let n_components = self.options.n_components;
331
332        // Compute weights
333        let nk = resp_.sum_axis(Axis(0));
334        let weights = &nk / F::from(n_samples).unwrap();
335
336        // Compute means
337        let mut means = Array2::zeros((n_components, n_features));
338        for k in 0..n_components {
339            let mut mean_k = Array1::zeros(n_features);
340            for i in 0..n_samples {
341                mean_k = mean_k + &data.slice(s![i, ..]) * resp_[[i, k]];
342            }
343            means.slice_mut(s![k, ..]).assign(&(&mean_k / nk[k]));
344        }
345
346        // Compute covariances
347        let mut covariances = Vec::with_capacity(n_components);
348
349        match self.options.covariance_type {
350            CovarianceType::Full => {
351                for k in 0..n_components {
352                    let mean_k = means.slice(s![k, ..]);
353                    let mut cov = Array2::zeros((n_features, n_features));
354
355                    for i in 0..n_samples {
356                        let diff = &data.slice(s![i, ..]) - &mean_k;
357                        let outer = self.outer_product(diff.view(), diff.view());
358                        cov = cov + &outer * resp_[[i, k]];
359                    }
360
361                    cov = cov / nk[k];
362                    // Add regularization
363                    for i in 0..n_features {
364                        cov[[i, i]] = cov[[i, i]] + self.options.reg_covar;
365                    }
366
367                    covariances.push(cov);
368                }
369            }
370            _ => {
371                // Simplified: use diagonal covariances for other types
372                for k in 0..n_components {
373                    let mean_k = means.slice(s![k, ..]);
374                    let mut cov = Array2::zeros((n_features, n_features));
375
376                    for i in 0..n_samples {
377                        let diff = &data.slice(s![i, ..]) - &mean_k;
378                        for j in 0..n_features {
379                            cov[[j, j]] = cov[[j, j]] + diff[j] * diff[j] * resp_[[i, k]];
380                        }
381                    }
382
383                    for j in 0..n_features {
384                        cov[[j, j]] = cov[[j, j]] / nk[k] + self.options.reg_covar;
385                    }
386
387                    covariances.push(cov);
388                }
389            }
390        }
391
392        Ok((weights, means, covariances))
393    }
394
395    /// Compute log probability under a multivariate Gaussian distribution
396    fn log_multivariate_normal_density(
397        &self,
398        data: ArrayView2<F>,
399        mean: ArrayView1<F>,
400        covariance: &Array2<F>,
401    ) -> Result<Array1<F>> {
402        let n_samples = data.shape()[0];
403        let n_features = data.shape()[1];
404
405        // For simplicity, assume diagonal covariance
406        let mut log_prob = Array1::zeros(n_samples);
407
408        // Compute determinant (product of diagonal elements for diagonal matrix)
409        let mut log_det = F::zero();
410        for i in 0..n_features {
411            log_det = log_det + covariance[[i, i]].ln();
412        }
413
414        let norm_const = F::from(n_features as f64 * (2.0 * PI).ln()).unwrap() + log_det;
415
416        for i in 0..n_samples {
417            let diff = &data.slice(s![i, ..]) - &mean;
418            let mut mahalanobis = F::zero();
419
420            // For diagonal covariance, this simplifies
421            for j in 0..n_features {
422                mahalanobis = mahalanobis + diff[j] * diff[j] / covariance[[j, j]];
423            }
424
425            log_prob[i] = F::from(-0.5).unwrap() * (norm_const + mahalanobis);
426        }
427
428        Ok(log_prob)
429    }
430
431    /// Compute log-sum-exp along an axis
432    fn logsumexp(&self, arr: ArrayView2<F>, axis: Axis) -> Result<Array1<F>> {
433        let max_vals = arr.fold_axis(axis, F::neg_infinity(), |&a, &b| a.max(b));
434        let mut result = Array1::zeros(max_vals.len());
435
436        match axis {
437            Axis(1) => {
438                for i in 0..arr.shape()[0] {
439                    let mut sum = F::zero();
440                    for j in 0..arr.shape()[1] {
441                        sum = sum + (arr[[i, j]] - max_vals[i]).exp();
442                    }
443                    result[i] = max_vals[i] + sum.ln();
444                }
445            }
446            _ => {
447                return Err(ClusteringError::InvalidInput(
448                    "Only axis 1 is supported for logsumexp".to_string(),
449                ));
450            }
451        }
452
453        Ok(result)
454    }
455
456    /// Compute outer product of two vectors
457    fn outer_product(&self, a: ArrayView1<F>, b: ArrayView1<F>) -> Array2<F> {
458        let n = a.len();
459        let m = b.len();
460        let mut result = Array2::zeros((n, m));
461
462        for i in 0..n {
463            for j in 0..m {
464                result[[i, j]] = a[i] * b[j];
465            }
466        }
467
468        result
469    }
470
471    /// Predict cluster labels
472    pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<i32>> {
473        if self.weights.is_none() || self.means.is_none() || self.covariances.is_none() {
474            return Err(ClusteringError::InvalidInput(
475                "Model has not been fitted yet".to_string(),
476            ));
477        }
478
479        let weights = self.weights.as_ref().unwrap();
480        let means = self.means.as_ref().unwrap();
481        let covariances = self.covariances.as_ref().unwrap();
482
483        let (resp__, _) = self.e_step(data, weights, means, covariances)?;
484
485        // Assign to component with highest resp_onsibility
486        let mut labels = Array1::zeros(data.shape()[0]);
487        for i in 0..data.shape()[0] {
488            let mut max_resp_ = F::neg_infinity();
489            let mut best_k = 0;
490
491            for k in 0..self.options.n_components {
492                if resp__[[i, k]] > max_resp_ {
493                    max_resp_ = resp__[[i, k]];
494                    best_k = k;
495                }
496            }
497
498            labels[i] = best_k as i32;
499        }
500
501        Ok(labels)
502    }
503}
504
505/// Fit a Gaussian Mixture Model
506///
507/// # Arguments
508///
509/// * `data` - Input data (n_samples × n_features)
510/// * `options` - GMM options
511///
512/// # Returns
513///
514/// * Array of cluster labels
515///
516/// # Example
517///
518/// ```
519/// use scirs2_core::ndarray::Array2;
520/// use scirs2_cluster::gmm::{gaussian_mixture, GMMOptions};
521///
522/// let data = Array2::from_shape_vec((6, 2), vec![
523///     1.0, 2.0,
524///     1.2, 1.8,
525///     0.8, 1.9,
526///     4.0, 5.0,
527///     4.2, 4.8,
528///     3.9, 5.1,
529/// ]).unwrap();
530///
531/// let options = GMMOptions {
532///     n_components: 2,
533///     ..Default::default()
534/// };
535///
536/// let labels = gaussian_mixture(data.view(), options).unwrap();
537/// ```
538#[allow(dead_code)]
539pub fn gaussian_mixture<F>(data: ArrayView2<F>, options: GMMOptions<F>) -> Result<Array1<i32>>
540where
541    F: Float + FromPrimitive + Debug + ScalarOperand + Sum + std::borrow::Borrow<f64>,
542{
543    let mut gmm = GaussianMixture::new(options);
544    gmm.fit(data)?;
545    gmm.predict(data)
546}
547
548#[cfg(test)]
549mod tests {
550    use super::*;
551    use scirs2_core::ndarray::Array2;
552
553    #[test]
554    fn test_gmm_simple() {
555        let data = Array2::from_shape_vec(
556            (6, 2),
557            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
558        )
559        .unwrap();
560
561        let options = GMMOptions {
562            n_components: 2,
563            max_iter: 10,
564            ..Default::default()
565        };
566
567        let result = gaussian_mixture(data.view(), options);
568        assert!(result.is_ok());
569
570        let labels = result.unwrap();
571        assert_eq!(labels.len(), 6);
572
573        // Check that we have 2 clusters
574        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
575        assert!(unique_labels.len() <= 2);
576    }
577
578    #[test]
579    fn test_gmm_different_covariance_types() {
580        let data = Array2::from_shape_vec(
581            (8, 2),
582            vec![
583                1.0, 1.0, 1.1, 1.1, 0.9, 0.9, 1.2, 0.8, 5.0, 5.0, 5.1, 5.1, 4.9, 4.9, 5.2, 4.8,
584            ],
585        )
586        .unwrap();
587
588        let covariance_types = vec![
589            CovarianceType::Full,
590            CovarianceType::Diagonal,
591            CovarianceType::Spherical,
592            CovarianceType::Tied,
593        ];
594
595        for cov_type in covariance_types {
596            let options = GMMOptions {
597                n_components: 2,
598                covariance_type: cov_type,
599                max_iter: 50,
600                ..Default::default()
601            };
602
603            let result = gaussian_mixture(data.view(), options);
604            assert!(
605                result.is_ok(),
606                "Failed with covariance type: {:?}",
607                cov_type
608            );
609
610            let labels = result.unwrap();
611            assert_eq!(labels.len(), 8);
612        }
613    }
614
615    #[test]
616    fn test_gmm_initialization_methods() {
617        let data = Array2::from_shape_vec(
618            (6, 2),
619            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
620        )
621        .unwrap();
622
623        let init_methods = vec![GMMInit::KMeans, GMMInit::Random];
624
625        for init_method in init_methods {
626            let options = GMMOptions {
627                n_components: 2,
628                init_method,
629                random_seed: Some(42),
630                max_iter: 20,
631                ..Default::default()
632            };
633
634            let result = gaussian_mixture(data.view(), options);
635            assert!(result.is_ok(), "Failed with init method: {:?}", init_method);
636
637            let labels = result.unwrap();
638            assert_eq!(labels.len(), 6);
639        }
640    }
641
642    #[test]
643    fn test_gmm_parameter_validation() {
644        let data =
645            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0]).unwrap();
646
647        // Test with n_components = 0 (invalid)
648        let options = GMMOptions {
649            n_components: 0,
650            ..Default::default()
651        };
652        let result = gaussian_mixture(data.view(), options);
653        assert!(result.is_err());
654
655        // Test with n_components > n_samples (questionable but should work)
656        let options = GMMOptions {
657            n_components: 10,
658            max_iter: 5, // Keep low to avoid long convergence
659            ..Default::default()
660        };
661        let result = gaussian_mixture(data.view(), options);
662        // This might succeed or fail depending on implementation
663        // Just check it doesn't panic
664        let _result = result;
665    }
666
667    #[test]
668    fn test_gmm_convergence_criteria() {
669        let data = Array2::from_shape_vec(
670            (6, 2),
671            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
672        )
673        .unwrap();
674
675        // Test with different tolerance values
676        let tolerances = vec![1e-3, 1e-6, 1e-9];
677
678        for tol in tolerances {
679            let options = GMMOptions {
680                n_components: 2,
681                tol,
682                max_iter: 100,
683                ..Default::default()
684            };
685
686            let result = gaussian_mixture(data.view(), options);
687            assert!(result.is_ok(), "Failed with tolerance: {}", tol);
688        }
689    }
690
691    #[test]
692    fn test_gmm_single_component() {
693        let data =
694            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 1.1, 2.1]).unwrap();
695
696        let options = GMMOptions {
697            n_components: 1,
698            max_iter: 20,
699            ..Default::default()
700        };
701
702        let result = gaussian_mixture(data.view(), options);
703        assert!(result.is_ok());
704
705        let labels = result.unwrap();
706        assert_eq!(labels.len(), 4);
707
708        // All labels should be 0 for single component
709        assert!(labels.iter().all(|&l| l == 0));
710    }
711
712    #[test]
713    fn test_gmm_reproducibility_with_seed() {
714        let data = Array2::from_shape_vec(
715            (6, 2),
716            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
717        )
718        .unwrap();
719
720        let options1 = GMMOptions {
721            n_components: 2,
722            random_seed: Some(42),
723            max_iter: 50,
724            ..Default::default()
725        };
726
727        let options2 = GMMOptions {
728            n_components: 2,
729            random_seed: Some(42),
730            max_iter: 50,
731            ..Default::default()
732        };
733
734        let labels1 = gaussian_mixture(data.view(), options1).unwrap();
735        let labels2 = gaussian_mixture(data.view(), options2).unwrap();
736
737        // With same seed, results should be consistent in clustering structure
738        // Note: cluster labels might be swapped (0->1, 1->0) but the clustering should be the same
739        assert_eq!(labels1.len(), labels2.len());
740
741        // Check that the number of unique clusters is the same
742        let unique1: std::collections::HashSet<_> = labels1.iter().cloned().collect();
743        let unique2: std::collections::HashSet<_> = labels2.iter().cloned().collect();
744        assert_eq!(unique1.len(), unique2.len());
745    }
746
747    #[test]
748    fn test_gmm_many_components() {
749        let data = Array2::from_shape_vec(
750            (10, 2),
751            vec![
752                1.0, 1.0, 1.1, 1.1, 1.2, 1.2, 3.0, 3.0, 3.1, 3.1, 3.2, 3.2, 5.0, 5.0, 5.1, 5.1,
753                5.2, 5.2, 7.0, 7.0,
754            ],
755        )
756        .unwrap();
757
758        let options = GMMOptions {
759            n_components: 3,
760            max_iter: 50,
761            ..Default::default()
762        };
763
764        let result = gaussian_mixture(data.view(), options);
765        assert!(result.is_ok());
766
767        let labels = result.unwrap();
768        assert_eq!(labels.len(), 10);
769
770        // Should find up to 3 clusters
771        let unique_labels: std::collections::HashSet<_> = labels.iter().cloned().collect();
772        assert!(unique_labels.len() <= 3);
773        assert!(!unique_labels.is_empty());
774    }
775
776    #[test]
777    fn test_gmm_regularization() {
778        let data = Array2::from_shape_vec(
779            (6, 2),
780            vec![1.0, 2.0, 1.2, 1.8, 0.8, 1.9, 4.0, 5.0, 4.2, 4.8, 3.9, 5.1],
781        )
782        .unwrap();
783
784        // Test with different regularization values
785        let reg_values = vec![1e-6, 1e-3, 1e-1];
786
787        for reg_covar in reg_values {
788            let options = GMMOptions {
789                n_components: 2,
790                reg_covar,
791                max_iter: 20,
792                ..Default::default()
793            };
794
795            let result = gaussian_mixture(data.view(), options);
796            assert!(result.is_ok(), "Failed with reg_covar: {}", reg_covar);
797        }
798    }
799
800    #[test]
801    fn test_gmm_fit_predict_workflow() {
802        let data = Array2::from_shape_vec(
803            (8, 2),
804            vec![
805                1.0, 1.0, 1.1, 1.1, 0.9, 0.9, 1.2, 0.8, 5.0, 5.0, 5.1, 5.1, 4.9, 4.9, 5.2, 4.8,
806            ],
807        )
808        .unwrap();
809
810        let options = GMMOptions {
811            n_components: 2,
812            max_iter: 50,
813            random_seed: Some(42),
814            ..Default::default()
815        };
816
817        // Test the fit-predict workflow using the struct directly
818        let mut gmm = GaussianMixture::new(options);
819
820        // Fit the model
821        let fit_result = gmm.fit(data.view());
822        assert!(fit_result.is_ok());
823
824        // Predict on the same data
825        let predict_result = gmm.predict(data.view());
826        assert!(predict_result.is_ok());
827
828        let labels = predict_result.unwrap();
829        assert_eq!(labels.len(), 8);
830
831        // Predict on new data (should work after fitting)
832        let new_data = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 5.0, 5.0]).unwrap();
833
834        let new_labels = gmm.predict(new_data.view());
835        assert!(new_labels.is_ok());
836        assert_eq!(new_labels.unwrap().len(), 2);
837    }
838}