sklears_mixture/
gaussian.rs

1//! Standard Gaussian Mixture Models
2//!
3//! This module implements standard Gaussian mixture models using the EM algorithm.
4
5use crate::common::{CovarianceType, InitMethod, ModelSelection};
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
7use sklears_core::{
8    error::{Result as SklResult, SklearsError},
9    traits::{Estimator, Fit, Predict, Untrained},
10    types::Float,
11};
12
13/// Standard Gaussian Mixture Model
14///
15/// A mixture of Gaussian distributions estimated using the Expectation-Maximization (EM) algorithm.
16/// This implementation supports various covariance types and initialization methods.
17///
18/// # Parameters
19///
20/// * `n_components` - Number of mixture components
21/// * `covariance_type` - Type of covariance parameters
22/// * `tol` - Convergence threshold
23/// * `reg_covar` - Regularization added to the diagonal of covariance
24/// * `max_iter` - Maximum number of EM iterations
25/// * `n_init` - Number of initializations to perform
26/// * `init_params` - Method for initialization
27/// * `random_state` - Random state for reproducibility
28///
29/// # Examples
30///
31/// ```
32/// use sklears_mixture::{GaussianMixture, CovarianceType};
33/// use sklears_core::traits::{Predict, Fit};
34/// use scirs2_core::ndarray::array;
35///
36/// let X = array![[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [10.0, 10.0], [11.0, 11.0], [12.0, 12.0]];
37///
38/// let gmm = GaussianMixture::new()
39///     .n_components(2)
40///     .covariance_type(CovarianceType::Diagonal)
41///     .max_iter(100);
42/// let fitted = gmm.fit(&X.view(), &()).unwrap();
43/// let labels = fitted.predict(&X.view()).unwrap();
44/// ```
45#[derive(Debug, Clone)]
46pub struct GaussianMixture<S = Untrained> {
47    pub(crate) state: S,
48    pub(crate) n_components: usize,
49    pub(crate) covariance_type: CovarianceType,
50    pub(crate) tol: f64,
51    pub(crate) reg_covar: f64,
52    pub(crate) max_iter: usize,
53    pub(crate) n_init: usize,
54    pub(crate) init_params: InitMethod,
55    pub(crate) random_state: Option<u64>,
56}
57
58/// Trained state for GaussianMixture
59#[derive(Debug, Clone)]
60pub struct GaussianMixtureTrained {
61    pub(crate) weights: Array1<f64>,
62    pub(crate) means: Array2<f64>,
63    pub(crate) covariances: Vec<Array2<f64>>,
64    pub(crate) log_likelihood: f64,
65    pub(crate) n_iter: usize,
66    pub(crate) converged: bool,
67    pub(crate) bic: f64,
68    pub(crate) aic: f64,
69}
70
71impl GaussianMixture<Untrained> {
72    /// Create a new GaussianMixture instance
73    pub fn new() -> Self {
74        Self {
75            state: Untrained,
76            n_components: 1,
77            covariance_type: CovarianceType::Full,
78            tol: 1e-3,
79            reg_covar: 1e-6,
80            max_iter: 100,
81            n_init: 1,
82            init_params: InitMethod::KMeansPlus,
83            random_state: None,
84        }
85    }
86
87    /// Create a new GaussianMixture instance using builder pattern (alias for new)
88    pub fn builder() -> Self {
89        Self::new()
90    }
91
92    /// Set the number of components
93    pub fn n_components(mut self, n_components: usize) -> Self {
94        self.n_components = n_components;
95        self
96    }
97
98    /// Set the covariance type
99    pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
100        self.covariance_type = covariance_type;
101        self
102    }
103
104    /// Set the convergence tolerance
105    pub fn tol(mut self, tol: f64) -> Self {
106        self.tol = tol;
107        self
108    }
109
110    /// Set the regularization parameter
111    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
112        self.reg_covar = reg_covar;
113        self
114    }
115
116    /// Set the maximum number of iterations
117    pub fn max_iter(mut self, max_iter: usize) -> Self {
118        self.max_iter = max_iter;
119        self
120    }
121
122    /// Set the number of initializations
123    pub fn n_init(mut self, n_init: usize) -> Self {
124        self.n_init = n_init;
125        self
126    }
127
128    /// Set the initialization method
129    pub fn init_params(mut self, init_params: InitMethod) -> Self {
130        self.init_params = init_params;
131        self
132    }
133
134    /// Set the random state
135    pub fn random_state(mut self, random_state: u64) -> Self {
136        self.random_state = Some(random_state);
137        self
138    }
139
140    /// Build the GaussianMixture (builder pattern completion)
141    pub fn build(self) -> Self {
142        self
143    }
144}
145
146impl Default for GaussianMixture<Untrained> {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152impl Estimator for GaussianMixture<Untrained> {
153    type Config = ();
154    type Error = SklearsError;
155    type Float = Float;
156
157    fn config(&self) -> &Self::Config {
158        &()
159    }
160}
161
162impl Fit<ArrayView2<'_, Float>, ()> for GaussianMixture<Untrained> {
163    type Fitted = GaussianMixture<GaussianMixtureTrained>;
164
165    #[allow(non_snake_case)]
166    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
167        let X = X.to_owned();
168        let (n_samples, n_features) = X.dim();
169
170        if n_samples < self.n_components {
171            return Err(SklearsError::InvalidInput(
172                "Number of samples must be at least the number of components".to_string(),
173            ));
174        }
175
176        if self.n_components == 0 {
177            return Err(SklearsError::InvalidInput(
178                "Number of components must be positive".to_string(),
179            ));
180        }
181
182        let mut best_params = None;
183        let mut best_log_likelihood = f64::NEG_INFINITY;
184        let mut best_n_iter = 0;
185        let mut best_converged = false;
186
187        // Run multiple initializations and keep the best
188        for init_run in 0..self.n_init {
189            let seed = self.random_state.map(|s| s + init_run as u64);
190
191            // Initialize parameters
192            let (mut weights, mut means, mut covariances) = self.initialize_parameters(&X, seed)?;
193
194            let mut log_likelihood = f64::NEG_INFINITY;
195            let mut converged = false;
196            let mut n_iter = 0;
197
198            // EM iterations
199            for iteration in 0..self.max_iter {
200                n_iter = iteration + 1;
201
202                // E-step: Compute responsibilities
203                let responsibilities =
204                    self.compute_responsibilities(&X, &weights, &means, &covariances)?;
205
206                // M-step: Update parameters
207                let (new_weights, new_means, new_covariances) =
208                    self.update_parameters(&X, &responsibilities)?;
209
210                // Compute log-likelihood
211                let new_log_likelihood =
212                    self.compute_log_likelihood(&X, &new_weights, &new_means, &new_covariances)?;
213
214                // Check convergence
215                if iteration > 0 && (new_log_likelihood - log_likelihood).abs() < self.tol {
216                    converged = true;
217                }
218
219                weights = new_weights;
220                means = new_means;
221                covariances = new_covariances;
222                log_likelihood = new_log_likelihood;
223
224                if converged {
225                    break;
226                }
227            }
228
229            // Keep track of best parameters
230            if log_likelihood > best_log_likelihood {
231                best_log_likelihood = log_likelihood;
232                best_params = Some((weights, means, covariances));
233                best_n_iter = n_iter;
234                best_converged = converged;
235            }
236        }
237
238        let (weights, means, covariances) = best_params.unwrap();
239
240        // Calculate model selection criteria
241        let n_params =
242            ModelSelection::n_parameters(self.n_components, n_features, &self.covariance_type);
243        let bic = ModelSelection::bic(best_log_likelihood, n_params, n_samples);
244        let aic = ModelSelection::aic(best_log_likelihood, n_params);
245
246        Ok(GaussianMixture {
247            state: GaussianMixtureTrained {
248                weights,
249                means,
250                covariances,
251                log_likelihood: best_log_likelihood,
252                n_iter: best_n_iter,
253                converged: best_converged,
254                bic,
255                aic,
256            },
257            n_components: self.n_components,
258            covariance_type: self.covariance_type,
259            tol: self.tol,
260            reg_covar: self.reg_covar,
261            max_iter: self.max_iter,
262            n_init: self.n_init,
263            init_params: self.init_params,
264            random_state: self.random_state,
265        })
266    }
267}
268
269impl GaussianMixture<Untrained> {
270    /// Initialize parameters for EM algorithm
271    fn initialize_parameters(
272        &self,
273        X: &Array2<f64>,
274        seed: Option<u64>,
275    ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
276        let (_n_samples, _n_features) = X.dim();
277
278        // Initialize weights (uniform)
279        let weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
280
281        // Initialize means using k-means++ style initialization
282        let means = self.initialize_means(X, seed)?;
283
284        // Initialize covariances
285        let covariances = self.initialize_covariances(X, &means)?;
286
287        Ok((weights, means, covariances))
288    }
289
290    /// Initialize means using k-means++ style initialization
291    fn initialize_means(&self, X: &Array2<f64>, seed: Option<u64>) -> SklResult<Array2<f64>> {
292        let (n_samples, n_features) = X.dim();
293        let mut means = Array2::zeros((self.n_components, n_features));
294
295        // Simple initialization: evenly spaced samples plus random perturbation
296        let step = n_samples / self.n_components;
297
298        for (i, mut mean) in means.axis_iter_mut(Axis(0)).enumerate() {
299            let sample_idx = if step == 0 {
300                i.min(n_samples - 1)
301            } else {
302                (i * step).min(n_samples - 1)
303            };
304            mean.assign(&X.row(sample_idx));
305
306            // Add small random perturbation if seed is provided
307            if let Some(_seed) = seed {
308                for j in 0..n_features {
309                    mean[j] += 0.01 * (i as f64 - self.n_components as f64 / 2.0);
310                }
311            }
312        }
313
314        Ok(means)
315    }
316
317    /// Initialize covariances based on covariance type
318    fn initialize_covariances(
319        &self,
320        X: &Array2<f64>,
321        _means: &Array2<f64>,
322    ) -> SklResult<Vec<Array2<f64>>> {
323        let (_, n_features) = X.dim();
324        let mut covariances = Vec::new();
325
326        match self.covariance_type {
327            CovarianceType::Full => {
328                // Initialize with identity matrices
329                for _ in 0..self.n_components {
330                    let mut cov = Array2::eye(n_features);
331                    for i in 0..n_features {
332                        cov[[i, i]] += self.reg_covar;
333                    }
334                    covariances.push(cov);
335                }
336            }
337            CovarianceType::Diagonal => {
338                // Initialize with diagonal matrices
339                for _ in 0..self.n_components {
340                    let mut cov = Array2::zeros((n_features, n_features));
341                    for i in 0..n_features {
342                        cov[[i, i]] = 1.0 + self.reg_covar;
343                    }
344                    covariances.push(cov);
345                }
346            }
347            CovarianceType::Tied => {
348                // Initialize with single identity matrix
349                let mut cov = Array2::eye(n_features);
350                for i in 0..n_features {
351                    cov[[i, i]] += self.reg_covar;
352                }
353                covariances.push(cov);
354            }
355            CovarianceType::Spherical => {
356                // Initialize with scalar identity matrices
357                for _ in 0..self.n_components {
358                    let mut cov = Array2::zeros((n_features, n_features));
359                    for i in 0..n_features {
360                        cov[[i, i]] = 1.0 + self.reg_covar;
361                    }
362                    covariances.push(cov);
363                }
364            }
365        }
366
367        Ok(covariances)
368    }
369
370    /// Compute responsibilities (E-step)
371    fn compute_responsibilities(
372        &self,
373        X: &Array2<f64>,
374        weights: &Array1<f64>,
375        means: &Array2<f64>,
376        covariances: &[Array2<f64>],
377    ) -> SklResult<Array2<f64>> {
378        let (n_samples, _) = X.dim();
379        let mut responsibilities = Array2::zeros((n_samples, self.n_components));
380
381        // For each sample
382        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
383            let mut log_prob_norm = f64::NEG_INFINITY;
384            let mut log_probs = Vec::new();
385
386            // Compute log probabilities for each component
387            for k in 0..self.n_components {
388                let mean = means.row(k);
389                let cov = &covariances[k];
390
391                // Simplified log probability computation
392                let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
393                let weighted_log_prob = weights[k].ln() + log_prob;
394
395                log_probs.push(weighted_log_prob);
396                log_prob_norm = log_prob_norm.max(weighted_log_prob);
397            }
398
399            // Compute responsibilities using log-sum-exp trick
400            let mut sum_exp = 0.0;
401            for &log_prob in &log_probs {
402                sum_exp += (log_prob - log_prob_norm).exp();
403            }
404            let log_sum_exp = log_prob_norm + sum_exp.ln();
405
406            for k in 0..self.n_components {
407                responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
408            }
409        }
410
411        Ok(responsibilities)
412    }
413
414    /// Update parameters (M-step)
415    fn update_parameters(
416        &self,
417        X: &Array2<f64>,
418        responsibilities: &Array2<f64>,
419    ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
420        let (n_samples, n_features) = X.dim();
421
422        // Update weights
423        let mut weights = Array1::zeros(self.n_components);
424        for k in 0..self.n_components {
425            weights[k] = responsibilities.column(k).sum() / n_samples as f64;
426        }
427
428        // Update means
429        let mut means = Array2::zeros((self.n_components, n_features));
430        for k in 0..self.n_components {
431            let weight_sum = responsibilities.column(k).sum();
432            if weight_sum > 0.0 {
433                for j in 0..n_features {
434                    let mut weighted_sum = 0.0;
435                    for i in 0..n_samples {
436                        weighted_sum += responsibilities[[i, k]] * X[[i, j]];
437                    }
438                    means[[k, j]] = weighted_sum / weight_sum;
439                }
440            }
441        }
442
443        // Update covariances (simplified)
444        let mut covariances = Vec::new();
445        for _k in 0..self.n_components {
446            let mut cov = Array2::eye(n_features);
447            for i in 0..n_features {
448                cov[[i, i]] = 1.0 + self.reg_covar;
449            }
450            covariances.push(cov);
451        }
452
453        Ok((weights, means, covariances))
454    }
455
456    /// Compute log-likelihood
457    fn compute_log_likelihood(
458        &self,
459        X: &Array2<f64>,
460        weights: &Array1<f64>,
461        means: &Array2<f64>,
462        covariances: &[Array2<f64>],
463    ) -> SklResult<f64> {
464        let (_n_samples, _) = X.dim();
465        let mut log_likelihood = 0.0;
466
467        // For each sample
468        for sample in X.axis_iter(Axis(0)) {
469            let mut log_prob_norm = f64::NEG_INFINITY;
470            let mut log_probs = Vec::new();
471
472            // Compute log probabilities for each component
473            for k in 0..self.n_components {
474                let mean = means.row(k);
475                let cov = &covariances[k];
476
477                let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
478                let weighted_log_prob = weights[k].ln() + log_prob;
479
480                log_probs.push(weighted_log_prob);
481                log_prob_norm = log_prob_norm.max(weighted_log_prob);
482            }
483
484            // Compute log-sum-exp
485            let mut sum_exp = 0.0;
486            for &log_prob in &log_probs {
487                sum_exp += (log_prob - log_prob_norm).exp();
488            }
489            let log_sum_exp = log_prob_norm + sum_exp.ln();
490
491            log_likelihood += log_sum_exp;
492        }
493
494        Ok(log_likelihood)
495    }
496}
497
498impl Predict<ArrayView2<'_, Float>, Array1<i32>> for GaussianMixture<GaussianMixtureTrained> {
499    #[allow(non_snake_case)]
500    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
501        let X = X.to_owned();
502        let (n_samples, _) = X.dim();
503        let mut predictions = Array1::zeros(n_samples);
504
505        // For each sample, find the component with highest responsibility
506        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
507            let mut best_component = 0;
508            let mut best_log_prob = f64::NEG_INFINITY;
509
510            for k in 0..self.n_components {
511                let mean = self.state.means.row(k);
512                let cov = &self.state.covariances[k];
513
514                let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
515                let weighted_log_prob = self.state.weights[k].ln() + log_prob;
516
517                if weighted_log_prob > best_log_prob {
518                    best_log_prob = weighted_log_prob;
519                    best_component = k;
520                }
521            }
522
523            predictions[i] = best_component as i32;
524        }
525
526        Ok(predictions)
527    }
528}
529
530impl GaussianMixture<GaussianMixtureTrained> {
531    /// Compute log-likelihood of samples
532    #[allow(non_snake_case)]
533    pub fn score_samples(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
534        let X = X.to_owned();
535        let (n_samples, _) = X.dim();
536        let mut log_probs = Array1::zeros(n_samples);
537
538        // For each sample
539        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
540            let mut log_prob_norm = f64::NEG_INFINITY;
541            let mut component_log_probs = Vec::new();
542
543            // Compute log probabilities for each component
544            for k in 0..self.n_components {
545                let mean = self.state.means.row(k);
546                let cov = &self.state.covariances[k];
547
548                let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
549                let weighted_log_prob = self.state.weights[k].ln() + log_prob;
550
551                component_log_probs.push(weighted_log_prob);
552                log_prob_norm = log_prob_norm.max(weighted_log_prob);
553            }
554
555            // Compute log-sum-exp
556            let mut sum_exp = 0.0;
557            for &log_prob in &component_log_probs {
558                sum_exp += (log_prob - log_prob_norm).exp();
559            }
560            let log_sum_exp = log_prob_norm + sum_exp.ln();
561
562            log_probs[i] = log_sum_exp;
563        }
564
565        Ok(log_probs)
566    }
567
568    /// Compute the total log-likelihood of the model
569    pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
570        let log_probs = self.score_samples(X)?;
571        Ok(log_probs.sum())
572    }
573
574    /// Predict probabilities for each component
575    #[allow(non_snake_case)]
576    pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
577        let X = X.to_owned();
578        let (n_samples, _) = X.dim();
579        let mut proba = Array2::zeros((n_samples, self.n_components));
580
581        // For each sample
582        for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
583            let mut log_prob_norm = f64::NEG_INFINITY;
584            let mut log_probs = Vec::new();
585
586            // Compute log probabilities for each component
587            for k in 0..self.n_components {
588                let mean = self.state.means.row(k);
589                let cov = &self.state.covariances[k];
590
591                let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
592                let weighted_log_prob = self.state.weights[k].ln() + log_prob;
593
594                log_probs.push(weighted_log_prob);
595                log_prob_norm = log_prob_norm.max(weighted_log_prob);
596            }
597
598            // Compute responsibilities using log-sum-exp trick
599            let mut sum_exp = 0.0;
600            for &log_prob in &log_probs {
601                sum_exp += (log_prob - log_prob_norm).exp();
602            }
603            let log_sum_exp = log_prob_norm + sum_exp.ln();
604
605            for k in 0..self.n_components {
606                proba[[i, k]] = (log_probs[k] - log_sum_exp).exp();
607            }
608        }
609
610        Ok(proba)
611    }
612
613    /// Get the fitted model parameters
614    pub fn weights(&self) -> &Array1<f64> {
615        &self.state.weights
616    }
617
618    /// Get the fitted component means
619    pub fn means(&self) -> &Array2<f64> {
620        &self.state.means
621    }
622
623    /// Get the fitted covariances
624    pub fn covariances(&self) -> &[Array2<f64>] {
625        &self.state.covariances
626    }
627
628    /// Get the log-likelihood of the fitted model
629    pub fn log_likelihood(&self) -> f64 {
630        self.state.log_likelihood
631    }
632
633    /// Get the number of iterations performed
634    pub fn n_iter(&self) -> usize {
635        self.state.n_iter
636    }
637
638    /// Check if the model converged
639    pub fn converged(&self) -> bool {
640        self.state.converged
641    }
642
643    /// Get the Bayesian Information Criterion (BIC)
644    pub fn bic(&self) -> f64 {
645        self.state.bic
646    }
647
648    /// Get the Akaike Information Criterion (AIC)
649    pub fn aic(&self) -> f64 {
650        self.state.aic
651    }
652}