1use crate::common::{CovarianceType, InitMethod, ModelSelection};
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::{Estimator, Fit, Predict, Untrained},
10 types::Float,
11};
12
13#[derive(Debug, Clone)]
46pub struct GaussianMixture<S = Untrained> {
47 pub(crate) state: S,
48 pub(crate) n_components: usize,
49 pub(crate) covariance_type: CovarianceType,
50 pub(crate) tol: f64,
51 pub(crate) reg_covar: f64,
52 pub(crate) max_iter: usize,
53 pub(crate) n_init: usize,
54 pub(crate) init_params: InitMethod,
55 pub(crate) random_state: Option<u64>,
56}
57
58#[derive(Debug, Clone)]
60pub struct GaussianMixtureTrained {
61 pub(crate) weights: Array1<f64>,
62 pub(crate) means: Array2<f64>,
63 pub(crate) covariances: Vec<Array2<f64>>,
64 pub(crate) log_likelihood: f64,
65 pub(crate) n_iter: usize,
66 pub(crate) converged: bool,
67 pub(crate) bic: f64,
68 pub(crate) aic: f64,
69}
70
71impl GaussianMixture<Untrained> {
72 pub fn new() -> Self {
74 Self {
75 state: Untrained,
76 n_components: 1,
77 covariance_type: CovarianceType::Full,
78 tol: 1e-3,
79 reg_covar: 1e-6,
80 max_iter: 100,
81 n_init: 1,
82 init_params: InitMethod::KMeansPlus,
83 random_state: None,
84 }
85 }
86
87 pub fn builder() -> Self {
89 Self::new()
90 }
91
92 pub fn n_components(mut self, n_components: usize) -> Self {
94 self.n_components = n_components;
95 self
96 }
97
98 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
100 self.covariance_type = covariance_type;
101 self
102 }
103
104 pub fn tol(mut self, tol: f64) -> Self {
106 self.tol = tol;
107 self
108 }
109
110 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
112 self.reg_covar = reg_covar;
113 self
114 }
115
116 pub fn max_iter(mut self, max_iter: usize) -> Self {
118 self.max_iter = max_iter;
119 self
120 }
121
122 pub fn n_init(mut self, n_init: usize) -> Self {
124 self.n_init = n_init;
125 self
126 }
127
128 pub fn init_params(mut self, init_params: InitMethod) -> Self {
130 self.init_params = init_params;
131 self
132 }
133
134 pub fn random_state(mut self, random_state: u64) -> Self {
136 self.random_state = Some(random_state);
137 self
138 }
139
140 pub fn build(self) -> Self {
142 self
143 }
144}
145
146impl Default for GaussianMixture<Untrained> {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152impl Estimator for GaussianMixture<Untrained> {
153 type Config = ();
154 type Error = SklearsError;
155 type Float = Float;
156
157 fn config(&self) -> &Self::Config {
158 &()
159 }
160}
161
162impl Fit<ArrayView2<'_, Float>, ()> for GaussianMixture<Untrained> {
163 type Fitted = GaussianMixture<GaussianMixtureTrained>;
164
165 #[allow(non_snake_case)]
166 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
167 let X = X.to_owned();
168 let (n_samples, n_features) = X.dim();
169
170 if n_samples < self.n_components {
171 return Err(SklearsError::InvalidInput(
172 "Number of samples must be at least the number of components".to_string(),
173 ));
174 }
175
176 if self.n_components == 0 {
177 return Err(SklearsError::InvalidInput(
178 "Number of components must be positive".to_string(),
179 ));
180 }
181
182 let mut best_params = None;
183 let mut best_log_likelihood = f64::NEG_INFINITY;
184 let mut best_n_iter = 0;
185 let mut best_converged = false;
186
187 for init_run in 0..self.n_init {
189 let seed = self.random_state.map(|s| s + init_run as u64);
190
191 let (mut weights, mut means, mut covariances) = self.initialize_parameters(&X, seed)?;
193
194 let mut log_likelihood = f64::NEG_INFINITY;
195 let mut converged = false;
196 let mut n_iter = 0;
197
198 for iteration in 0..self.max_iter {
200 n_iter = iteration + 1;
201
202 let responsibilities =
204 self.compute_responsibilities(&X, &weights, &means, &covariances)?;
205
206 let (new_weights, new_means, new_covariances) =
208 self.update_parameters(&X, &responsibilities)?;
209
210 let new_log_likelihood =
212 self.compute_log_likelihood(&X, &new_weights, &new_means, &new_covariances)?;
213
214 if iteration > 0 && (new_log_likelihood - log_likelihood).abs() < self.tol {
216 converged = true;
217 }
218
219 weights = new_weights;
220 means = new_means;
221 covariances = new_covariances;
222 log_likelihood = new_log_likelihood;
223
224 if converged {
225 break;
226 }
227 }
228
229 if log_likelihood > best_log_likelihood {
231 best_log_likelihood = log_likelihood;
232 best_params = Some((weights, means, covariances));
233 best_n_iter = n_iter;
234 best_converged = converged;
235 }
236 }
237
238 let (weights, means, covariances) = best_params.expect("operation should succeed");
239
240 let n_params =
242 ModelSelection::n_parameters(self.n_components, n_features, &self.covariance_type);
243 let bic = ModelSelection::bic(best_log_likelihood, n_params, n_samples);
244 let aic = ModelSelection::aic(best_log_likelihood, n_params);
245
246 Ok(GaussianMixture {
247 state: GaussianMixtureTrained {
248 weights,
249 means,
250 covariances,
251 log_likelihood: best_log_likelihood,
252 n_iter: best_n_iter,
253 converged: best_converged,
254 bic,
255 aic,
256 },
257 n_components: self.n_components,
258 covariance_type: self.covariance_type,
259 tol: self.tol,
260 reg_covar: self.reg_covar,
261 max_iter: self.max_iter,
262 n_init: self.n_init,
263 init_params: self.init_params,
264 random_state: self.random_state,
265 })
266 }
267}
268
269impl GaussianMixture<Untrained> {
270 fn initialize_parameters(
272 &self,
273 X: &Array2<f64>,
274 seed: Option<u64>,
275 ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
276 let (_n_samples, _n_features) = X.dim();
277
278 let weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
280
281 let means = self.initialize_means(X, seed)?;
283
284 let covariances = self.initialize_covariances(X, &means)?;
286
287 Ok((weights, means, covariances))
288 }
289
290 fn initialize_means(&self, X: &Array2<f64>, seed: Option<u64>) -> SklResult<Array2<f64>> {
292 let (n_samples, n_features) = X.dim();
293 let mut means = Array2::zeros((self.n_components, n_features));
294
295 let step = n_samples / self.n_components;
297
298 for (i, mut mean) in means.axis_iter_mut(Axis(0)).enumerate() {
299 let sample_idx = if step == 0 {
300 i.min(n_samples - 1)
301 } else {
302 (i * step).min(n_samples - 1)
303 };
304 mean.assign(&X.row(sample_idx));
305
306 if let Some(_seed) = seed {
308 for j in 0..n_features {
309 mean[j] += 0.01 * (i as f64 - self.n_components as f64 / 2.0);
310 }
311 }
312 }
313
314 Ok(means)
315 }
316
317 fn initialize_covariances(
319 &self,
320 X: &Array2<f64>,
321 _means: &Array2<f64>,
322 ) -> SklResult<Vec<Array2<f64>>> {
323 let (_, n_features) = X.dim();
324 let mut covariances = Vec::new();
325
326 match self.covariance_type {
327 CovarianceType::Full => {
328 for _ in 0..self.n_components {
330 let mut cov = Array2::eye(n_features);
331 for i in 0..n_features {
332 cov[[i, i]] += self.reg_covar;
333 }
334 covariances.push(cov);
335 }
336 }
337 CovarianceType::Diagonal => {
338 for _ in 0..self.n_components {
340 let mut cov = Array2::zeros((n_features, n_features));
341 for i in 0..n_features {
342 cov[[i, i]] = 1.0 + self.reg_covar;
343 }
344 covariances.push(cov);
345 }
346 }
347 CovarianceType::Tied => {
348 let mut cov = Array2::eye(n_features);
350 for i in 0..n_features {
351 cov[[i, i]] += self.reg_covar;
352 }
353 covariances.push(cov);
354 }
355 CovarianceType::Spherical => {
356 for _ in 0..self.n_components {
358 let mut cov = Array2::zeros((n_features, n_features));
359 for i in 0..n_features {
360 cov[[i, i]] = 1.0 + self.reg_covar;
361 }
362 covariances.push(cov);
363 }
364 }
365 }
366
367 Ok(covariances)
368 }
369
370 fn compute_responsibilities(
372 &self,
373 X: &Array2<f64>,
374 weights: &Array1<f64>,
375 means: &Array2<f64>,
376 covariances: &[Array2<f64>],
377 ) -> SklResult<Array2<f64>> {
378 let (n_samples, _) = X.dim();
379 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
380
381 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
383 let mut log_prob_norm = f64::NEG_INFINITY;
384 let mut log_probs = Vec::new();
385
386 for k in 0..self.n_components {
388 let mean = means.row(k);
389 let cov = &covariances[k];
390
391 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
393 let weighted_log_prob = weights[k].ln() + log_prob;
394
395 log_probs.push(weighted_log_prob);
396 log_prob_norm = log_prob_norm.max(weighted_log_prob);
397 }
398
399 let mut sum_exp = 0.0;
401 for &log_prob in &log_probs {
402 sum_exp += (log_prob - log_prob_norm).exp();
403 }
404 let log_sum_exp = log_prob_norm + sum_exp.ln();
405
406 for k in 0..self.n_components {
407 responsibilities[[i, k]] = (log_probs[k] - log_sum_exp).exp();
408 }
409 }
410
411 Ok(responsibilities)
412 }
413
414 fn update_parameters(
416 &self,
417 X: &Array2<f64>,
418 responsibilities: &Array2<f64>,
419 ) -> SklResult<(Array1<f64>, Array2<f64>, Vec<Array2<f64>>)> {
420 let (n_samples, n_features) = X.dim();
421
422 let mut weights = Array1::zeros(self.n_components);
424 for k in 0..self.n_components {
425 weights[k] = responsibilities.column(k).sum() / n_samples as f64;
426 }
427
428 let mut means = Array2::zeros((self.n_components, n_features));
430 for k in 0..self.n_components {
431 let weight_sum = responsibilities.column(k).sum();
432 if weight_sum > 0.0 {
433 for j in 0..n_features {
434 let mut weighted_sum = 0.0;
435 for i in 0..n_samples {
436 weighted_sum += responsibilities[[i, k]] * X[[i, j]];
437 }
438 means[[k, j]] = weighted_sum / weight_sum;
439 }
440 }
441 }
442
443 let mut covariances = Vec::new();
445 for _k in 0..self.n_components {
446 let mut cov = Array2::eye(n_features);
447 for i in 0..n_features {
448 cov[[i, i]] = 1.0 + self.reg_covar;
449 }
450 covariances.push(cov);
451 }
452
453 Ok((weights, means, covariances))
454 }
455
456 fn compute_log_likelihood(
458 &self,
459 X: &Array2<f64>,
460 weights: &Array1<f64>,
461 means: &Array2<f64>,
462 covariances: &[Array2<f64>],
463 ) -> SklResult<f64> {
464 let (_n_samples, _) = X.dim();
465 let mut log_likelihood = 0.0;
466
467 for sample in X.axis_iter(Axis(0)) {
469 let mut log_prob_norm = f64::NEG_INFINITY;
470 let mut log_probs = Vec::new();
471
472 for k in 0..self.n_components {
474 let mean = means.row(k);
475 let cov = &covariances[k];
476
477 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
478 let weighted_log_prob = weights[k].ln() + log_prob;
479
480 log_probs.push(weighted_log_prob);
481 log_prob_norm = log_prob_norm.max(weighted_log_prob);
482 }
483
484 let mut sum_exp = 0.0;
486 for &log_prob in &log_probs {
487 sum_exp += (log_prob - log_prob_norm).exp();
488 }
489 let log_sum_exp = log_prob_norm + sum_exp.ln();
490
491 log_likelihood += log_sum_exp;
492 }
493
494 Ok(log_likelihood)
495 }
496}
497
498impl Predict<ArrayView2<'_, Float>, Array1<i32>> for GaussianMixture<GaussianMixtureTrained> {
499 #[allow(non_snake_case)]
500 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
501 let X = X.to_owned();
502 let (n_samples, _) = X.dim();
503 let mut predictions = Array1::zeros(n_samples);
504
505 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
507 let mut best_component = 0;
508 let mut best_log_prob = f64::NEG_INFINITY;
509
510 for k in 0..self.n_components {
511 let mean = self.state.means.row(k);
512 let cov = &self.state.covariances[k];
513
514 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
515 let weighted_log_prob = self.state.weights[k].ln() + log_prob;
516
517 if weighted_log_prob > best_log_prob {
518 best_log_prob = weighted_log_prob;
519 best_component = k;
520 }
521 }
522
523 predictions[i] = best_component as i32;
524 }
525
526 Ok(predictions)
527 }
528}
529
530impl GaussianMixture<GaussianMixtureTrained> {
531 #[allow(non_snake_case)]
533 pub fn score_samples(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
534 let X = X.to_owned();
535 let (n_samples, _) = X.dim();
536 let mut log_probs = Array1::zeros(n_samples);
537
538 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
540 let mut log_prob_norm = f64::NEG_INFINITY;
541 let mut component_log_probs = Vec::new();
542
543 for k in 0..self.n_components {
545 let mean = self.state.means.row(k);
546 let cov = &self.state.covariances[k];
547
548 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
549 let weighted_log_prob = self.state.weights[k].ln() + log_prob;
550
551 component_log_probs.push(weighted_log_prob);
552 log_prob_norm = log_prob_norm.max(weighted_log_prob);
553 }
554
555 let mut sum_exp = 0.0;
557 for &log_prob in &component_log_probs {
558 sum_exp += (log_prob - log_prob_norm).exp();
559 }
560 let log_sum_exp = log_prob_norm + sum_exp.ln();
561
562 log_probs[i] = log_sum_exp;
563 }
564
565 Ok(log_probs)
566 }
567
568 pub fn score(&self, X: &ArrayView2<'_, Float>) -> SklResult<f64> {
570 let log_probs = self.score_samples(X)?;
571 Ok(log_probs.sum())
572 }
573
574 #[allow(non_snake_case)]
576 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
577 let X = X.to_owned();
578 let (n_samples, _) = X.dim();
579 let mut proba = Array2::zeros((n_samples, self.n_components));
580
581 for (i, sample) in X.axis_iter(Axis(0)).enumerate() {
583 let mut log_prob_norm = f64::NEG_INFINITY;
584 let mut log_probs = Vec::new();
585
586 for k in 0..self.n_components {
588 let mean = self.state.means.row(k);
589 let cov = &self.state.covariances[k];
590
591 let log_prob = crate::common::gaussian_log_pdf(&sample, &mean, &cov.view())?;
592 let weighted_log_prob = self.state.weights[k].ln() + log_prob;
593
594 log_probs.push(weighted_log_prob);
595 log_prob_norm = log_prob_norm.max(weighted_log_prob);
596 }
597
598 let mut sum_exp = 0.0;
600 for &log_prob in &log_probs {
601 sum_exp += (log_prob - log_prob_norm).exp();
602 }
603 let log_sum_exp = log_prob_norm + sum_exp.ln();
604
605 for k in 0..self.n_components {
606 proba[[i, k]] = (log_probs[k] - log_sum_exp).exp();
607 }
608 }
609
610 Ok(proba)
611 }
612
613 pub fn weights(&self) -> &Array1<f64> {
615 &self.state.weights
616 }
617
618 pub fn means(&self) -> &Array2<f64> {
620 &self.state.means
621 }
622
623 pub fn covariances(&self) -> &[Array2<f64>] {
625 &self.state.covariances
626 }
627
628 pub fn log_likelihood(&self) -> f64 {
630 self.state.log_likelihood
631 }
632
633 pub fn n_iter(&self) -> usize {
635 self.state.n_iter
636 }
637
638 pub fn converged(&self) -> bool {
640 self.state.converged
641 }
642
643 pub fn bic(&self) -> f64 {
645 self.state.bic
646 }
647
648 pub fn aic(&self) -> f64 {
650 self.state.aic
651 }
652}