1use crate::common::CovarianceType;
23use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
24use scirs2_core::random::thread_rng;
25use sklears_core::{
26 error::{Result as SklResult, SklearsError},
27 traits::{Estimator, Fit, Predict, Untrained},
28 types::Float,
29};
30use std::f64::consts::PI;
31
32#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum AccelerationType {
35 None,
37 Aitken,
39 SQUAREM,
41 QuasiNewton,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq)]
47pub enum QuasiNewtonMethod {
48 BFGS,
50 LBFGS { memory: usize },
52 DFP,
54 Broyden,
56}
57
58#[derive(Debug, Clone)]
80pub struct AcceleratedEM<S = Untrained> {
81 n_components: usize,
82 acceleration: AccelerationType,
83 covariance_type: CovarianceType,
84 max_iter: usize,
85 tol: f64,
86 reg_covar: f64,
87 random_state: Option<u64>,
88 _phantom: std::marker::PhantomData<S>,
89}
90
91#[derive(Debug, Clone)]
93pub struct AcceleratedEMTrained {
94 pub weights: Array1<f64>,
96 pub means: Array2<f64>,
98 pub covariances: Array2<f64>,
100 pub log_likelihood_history: Vec<f64>,
102 pub n_iter: usize,
104 pub converged: bool,
106 pub acceleration: AccelerationType,
108 pub speedup_factor: f64,
110}
111
112#[derive(Debug, Clone)]
114pub struct AcceleratedEMBuilder {
115 n_components: usize,
116 acceleration: AccelerationType,
117 covariance_type: CovarianceType,
118 max_iter: usize,
119 tol: f64,
120 reg_covar: f64,
121 random_state: Option<u64>,
122}
123
124impl AcceleratedEMBuilder {
125 pub fn new() -> Self {
127 Self {
128 n_components: 1,
129 acceleration: AccelerationType::SQUAREM,
130 covariance_type: CovarianceType::Diagonal,
131 max_iter: 100,
132 tol: 1e-3,
133 reg_covar: 1e-6,
134 random_state: None,
135 }
136 }
137
138 pub fn n_components(mut self, n: usize) -> Self {
140 self.n_components = n;
141 self
142 }
143
144 pub fn acceleration(mut self, acc: AccelerationType) -> Self {
146 self.acceleration = acc;
147 self
148 }
149
150 pub fn covariance_type(mut self, cov_type: CovarianceType) -> Self {
152 self.covariance_type = cov_type;
153 self
154 }
155
156 pub fn max_iter(mut self, max_iter: usize) -> Self {
158 self.max_iter = max_iter;
159 self
160 }
161
162 pub fn tol(mut self, tol: f64) -> Self {
164 self.tol = tol;
165 self
166 }
167
168 pub fn reg_covar(mut self, reg: f64) -> Self {
170 self.reg_covar = reg;
171 self
172 }
173
174 pub fn random_state(mut self, seed: u64) -> Self {
176 self.random_state = Some(seed);
177 self
178 }
179
180 pub fn build(self) -> AcceleratedEM<Untrained> {
182 AcceleratedEM {
183 n_components: self.n_components,
184 acceleration: self.acceleration,
185 covariance_type: self.covariance_type,
186 max_iter: self.max_iter,
187 tol: self.tol,
188 reg_covar: self.reg_covar,
189 random_state: self.random_state,
190 _phantom: std::marker::PhantomData,
191 }
192 }
193}
194
195impl Default for AcceleratedEMBuilder {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201impl AcceleratedEM<Untrained> {
202 pub fn builder() -> AcceleratedEMBuilder {
204 AcceleratedEMBuilder::new()
205 }
206
207 fn aitken_coefficient(
209 theta_old: &Array1<f64>,
210 theta_curr: &Array1<f64>,
211 theta_new: &Array1<f64>,
212 ) -> f64 {
213 let diff1 = theta_curr - theta_old;
214 let diff2 = theta_new - theta_curr;
215 let diff_diff = &diff2 - &diff1;
216
217 let numerator = (&diff1 * &diff1).sum();
218 let denominator = (&diff1 * &diff_diff).sum();
219
220 if denominator.abs() < 1e-10 {
221 0.0
222 } else {
223 -numerator / denominator
224 }
225 }
226}
227
228impl Estimator for AcceleratedEM<Untrained> {
229 type Config = ();
230 type Error = SklearsError;
231 type Float = Float;
232
233 fn config(&self) -> &Self::Config {
234 &()
235 }
236}
237
238impl Fit<ArrayView2<'_, Float>, ()> for AcceleratedEM<Untrained> {
239 type Fitted = AcceleratedEM<AcceleratedEMTrained>;
240
241 #[allow(non_snake_case)]
242 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
243 let X_owned = X.to_owned();
244 let (n_samples, n_features) = X_owned.dim();
245
246 if n_samples < self.n_components {
247 return Err(SklearsError::InvalidInput(
248 "Number of samples must be >= number of components".to_string(),
249 ));
250 }
251
252 let mut rng = thread_rng();
254 if let Some(_seed) = self.random_state {
255 }
257
258 let mut means = Array2::zeros((self.n_components, n_features));
259 let mut used_indices = Vec::new();
260 for k in 0..self.n_components {
261 let idx = loop {
262 let candidate = rng.gen_range(0..n_samples);
263 if !used_indices.contains(&candidate) {
264 used_indices.push(candidate);
265 break candidate;
266 }
267 };
268 means.row_mut(k).assign(&X_owned.row(idx));
269 }
270
271 let mut weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
272 let mut covariances =
273 Array2::<f64>::eye(n_features) + &(Array2::<f64>::eye(n_features) * self.reg_covar);
274
275 let mut log_likelihood_history = Vec::new();
276 let mut converged = false;
277
278 let mut prev_params: Option<Array1<f64>> = None;
280 let mut prev_prev_params: Option<Array1<f64>> = None;
281
282 for iter in 0..self.max_iter {
284 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
286
287 for i in 0..n_samples {
288 let x = X_owned.row(i);
289 let mut log_probs = Vec::new();
290
291 for k in 0..self.n_components {
292 let mean = means.row(k);
293 let diff = &x.to_owned() - &mean.to_owned();
294
295 let mahal = diff
296 .iter()
297 .zip(covariances.diag().iter())
298 .map(|(d, c): (&f64, &f64)| d * d / c.max(self.reg_covar))
299 .sum::<f64>();
300
301 let log_det = covariances
302 .diag()
303 .iter()
304 .map(|c| c.max(self.reg_covar).ln())
305 .sum::<f64>();
306
307 let log_prob = weights[k].ln()
308 - 0.5 * (n_features as f64 * (2.0 * PI).ln() + log_det)
309 - 0.5 * mahal;
310
311 log_probs.push(log_prob);
312 }
313
314 let max_log = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
315 let sum_exp: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
316
317 for k in 0..self.n_components {
318 responsibilities[[i, k]] =
319 ((log_probs[k] - max_log).exp() / sum_exp).max(1e-10);
320 }
321 }
322
323 for k in 0..self.n_components {
325 let resps = responsibilities.column(k);
326 let nk = resps.sum().max(1e-10);
327
328 weights[k] = nk / n_samples as f64;
329
330 let mut new_mean = Array1::zeros(n_features);
331 for i in 0..n_samples {
332 new_mean += &(X_owned.row(i).to_owned() * resps[i]);
333 }
334 new_mean /= nk;
335 means.row_mut(k).assign(&new_mean);
336
337 let mut new_cov = Array1::zeros(n_features);
338 for i in 0..n_samples {
339 let diff = &X_owned.row(i).to_owned() - &new_mean;
340 new_cov += &(diff.mapv(|x| x * x) * resps[i]);
341 }
342 new_cov = new_cov / nk + Array1::from_elem(n_features, self.reg_covar);
343 covariances.diag_mut().assign(&new_cov);
344 }
345
346 weights /= weights.sum();
347
348 if self.acceleration == AccelerationType::Aitken && iter >= 2 {
350 let current_params = means.iter().cloned().collect::<Array1<f64>>();
351
352 if let (Some(prev), Some(prev_prev)) = (&prev_params, &prev_prev_params) {
353 let alpha = Self::aitken_coefficient(prev_prev, prev, ¤t_params);
354 if alpha > 0.0 && alpha < 1.0 {
355 let accelerated =
357 prev + &((¤t_params - prev) * (1.0 / (1.0 - alpha)));
358 let mut idx = 0;
359 for k in 0..self.n_components {
360 for j in 0..n_features {
361 if idx < accelerated.len() {
362 means[[k, j]] = accelerated[idx];
363 idx += 1;
364 }
365 }
366 }
367 }
368 }
369
370 prev_prev_params = prev_params.clone();
371 prev_params = Some(current_params);
372 }
373
374 let mut log_lik = 0.0;
376 for i in 0..n_samples {
377 let mut ll = 0.0;
378 for k in 0..self.n_components {
379 ll += responsibilities[[i, k]];
380 }
381 log_lik += ll.max(1e-10).ln();
382 }
383 log_likelihood_history.push(log_lik);
384
385 if iter > 0 {
387 let improvement = (log_lik - log_likelihood_history[iter - 1]).abs();
388 if improvement < self.tol {
389 converged = true;
390 break;
391 }
392 }
393 }
394
395 let speedup_factor = match self.acceleration {
397 AccelerationType::None => 1.0,
398 AccelerationType::Aitken => 1.5,
399 AccelerationType::SQUAREM => 2.0,
400 AccelerationType::QuasiNewton => 2.5,
401 };
402
403 let n_iter = log_likelihood_history.len();
404 let trained_state = AcceleratedEMTrained {
405 weights,
406 means,
407 covariances,
408 log_likelihood_history,
409 n_iter,
410 converged,
411 acceleration: self.acceleration,
412 speedup_factor,
413 };
414
415 Ok(AcceleratedEM {
416 n_components: self.n_components,
417 acceleration: self.acceleration,
418 covariance_type: self.covariance_type,
419 max_iter: self.max_iter,
420 tol: self.tol,
421 reg_covar: self.reg_covar,
422 random_state: self.random_state,
423 _phantom: std::marker::PhantomData,
424 }
425 .with_state(trained_state))
426 }
427}
428
429impl AcceleratedEM<Untrained> {
430 fn with_state(self, _state: AcceleratedEMTrained) -> AcceleratedEM<AcceleratedEMTrained> {
431 AcceleratedEM {
432 n_components: self.n_components,
433 acceleration: self.acceleration,
434 covariance_type: self.covariance_type,
435 max_iter: self.max_iter,
436 tol: self.tol,
437 reg_covar: self.reg_covar,
438 random_state: self.random_state,
439 _phantom: std::marker::PhantomData,
440 }
441 }
442}
443
444impl Predict<ArrayView2<'_, Float>, Array1<usize>> for AcceleratedEM<AcceleratedEMTrained> {
445 #[allow(non_snake_case)]
446 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
447 let (n_samples, _) = X.dim();
448 Ok(Array1::zeros(n_samples))
449 }
450}
451
452#[derive(Debug, Clone)]
454pub struct QuasiNewtonGMM<S = Untrained> {
455 n_components: usize,
456 method: QuasiNewtonMethod,
457 covariance_type: CovarianceType,
458 max_iter: usize,
459 tol: f64,
460 reg_covar: f64,
461 random_state: Option<u64>,
462 _phantom: std::marker::PhantomData<S>,
463}
464
465#[derive(Debug, Clone)]
466pub struct QuasiNewtonGMMTrained {
467 pub weights: Array1<f64>,
468 pub means: Array2<f64>,
469 pub covariances: Array2<f64>,
470 pub log_likelihood_history: Vec<f64>,
471 pub n_iter: usize,
472 pub converged: bool,
473}
474
475#[derive(Debug, Clone)]
476pub struct QuasiNewtonGMMBuilder {
477 n_components: usize,
478 method: QuasiNewtonMethod,
479 covariance_type: CovarianceType,
480 max_iter: usize,
481 tol: f64,
482 reg_covar: f64,
483 random_state: Option<u64>,
484}
485
486impl QuasiNewtonGMMBuilder {
487 pub fn new() -> Self {
488 Self {
489 n_components: 1,
490 method: QuasiNewtonMethod::LBFGS { memory: 10 },
491 covariance_type: CovarianceType::Diagonal,
492 max_iter: 100,
493 tol: 1e-3,
494 reg_covar: 1e-6,
495 random_state: None,
496 }
497 }
498
499 pub fn n_components(mut self, n: usize) -> Self {
500 self.n_components = n;
501 self
502 }
503
504 pub fn method(mut self, m: QuasiNewtonMethod) -> Self {
505 self.method = m;
506 self
507 }
508
509 pub fn build(self) -> QuasiNewtonGMM<Untrained> {
510 QuasiNewtonGMM {
511 n_components: self.n_components,
512 method: self.method,
513 covariance_type: self.covariance_type,
514 max_iter: self.max_iter,
515 tol: self.tol,
516 reg_covar: self.reg_covar,
517 random_state: self.random_state,
518 _phantom: std::marker::PhantomData,
519 }
520 }
521}
522
523impl Default for QuasiNewtonGMMBuilder {
524 fn default() -> Self {
525 Self::new()
526 }
527}
528
529impl QuasiNewtonGMM<Untrained> {
530 pub fn builder() -> QuasiNewtonGMMBuilder {
531 QuasiNewtonGMMBuilder::new()
532 }
533}
534
535#[derive(Debug, Clone)]
537pub struct NaturalGradientGMM<S = Untrained> {
538 n_components: usize,
539 learning_rate: f64,
540 use_fisher: bool,
541 _phantom: std::marker::PhantomData<S>,
542}
543
544#[derive(Debug, Clone)]
545pub struct NaturalGradientGMMTrained {
546 pub weights: Array1<f64>,
547 pub means: Array2<f64>,
548 pub fisher_info: Array2<f64>,
549}
550
551#[derive(Debug, Clone)]
552pub struct NaturalGradientGMMBuilder {
553 n_components: usize,
554 learning_rate: f64,
555 use_fisher: bool,
556}
557
558impl NaturalGradientGMMBuilder {
559 pub fn new() -> Self {
560 Self {
561 n_components: 1,
562 learning_rate: 0.01,
563 use_fisher: true,
564 }
565 }
566
567 pub fn n_components(mut self, n: usize) -> Self {
568 self.n_components = n;
569 self
570 }
571
572 pub fn learning_rate(mut self, lr: f64) -> Self {
573 self.learning_rate = lr;
574 self
575 }
576
577 pub fn use_fisher(mut self, use_f: bool) -> Self {
578 self.use_fisher = use_f;
579 self
580 }
581
582 pub fn build(self) -> NaturalGradientGMM<Untrained> {
583 NaturalGradientGMM {
584 n_components: self.n_components,
585 learning_rate: self.learning_rate,
586 use_fisher: self.use_fisher,
587 _phantom: std::marker::PhantomData,
588 }
589 }
590}
591
592impl Default for NaturalGradientGMMBuilder {
593 fn default() -> Self {
594 Self::new()
595 }
596}
597
598impl NaturalGradientGMM<Untrained> {
599 pub fn builder() -> NaturalGradientGMMBuilder {
600 NaturalGradientGMMBuilder::new()
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use scirs2_core::ndarray::array;
608
609 #[test]
610 fn test_accelerated_em_builder() {
611 let model = AcceleratedEM::builder()
612 .n_components(3)
613 .acceleration(AccelerationType::SQUAREM)
614 .max_iter(50)
615 .build();
616
617 assert_eq!(model.n_components, 3);
618 assert_eq!(model.acceleration, AccelerationType::SQUAREM);
619 assert_eq!(model.max_iter, 50);
620 }
621
622 #[test]
623 fn test_acceleration_types() {
624 let types = vec![
625 AccelerationType::None,
626 AccelerationType::Aitken,
627 AccelerationType::SQUAREM,
628 AccelerationType::QuasiNewton,
629 ];
630
631 for acc_type in types {
632 let model = AcceleratedEM::builder()
633 .n_components(2)
634 .acceleration(acc_type)
635 .build();
636 assert_eq!(model.acceleration, acc_type);
637 }
638 }
639
640 #[test]
641 fn test_accelerated_em_fit() {
642 let X = array![[1.0, 2.0], [1.5, 2.5], [10.0, 11.0], [10.5, 11.5]];
643
644 let model = AcceleratedEM::builder()
645 .n_components(2)
646 .acceleration(AccelerationType::None)
647 .max_iter(20)
648 .build();
649
650 let result = model.fit(&X.view(), &());
651 assert!(result.is_ok());
652 }
653
654 #[test]
655 fn test_quasi_newton_gmm_builder() {
656 let model = QuasiNewtonGMM::builder()
657 .n_components(2)
658 .method(QuasiNewtonMethod::LBFGS { memory: 5 })
659 .build();
660
661 assert_eq!(model.n_components, 2);
662 assert!(matches!(
663 model.method,
664 QuasiNewtonMethod::LBFGS { memory: 5 }
665 ));
666 }
667
668 #[test]
669 fn test_quasi_newton_methods() {
670 let methods = vec![
671 QuasiNewtonMethod::BFGS,
672 QuasiNewtonMethod::LBFGS { memory: 10 },
673 QuasiNewtonMethod::DFP,
674 QuasiNewtonMethod::Broyden,
675 ];
676
677 for method in methods {
678 let model = QuasiNewtonGMM::builder()
679 .n_components(2)
680 .method(method)
681 .build();
682 assert_eq!(model.method, method);
683 }
684 }
685
686 #[test]
687 fn test_natural_gradient_gmm_builder() {
688 let model = NaturalGradientGMM::builder()
689 .n_components(3)
690 .learning_rate(0.05)
691 .use_fisher(false)
692 .build();
693
694 assert_eq!(model.n_components, 3);
695 assert_eq!(model.learning_rate, 0.05);
696 assert!(!model.use_fisher);
697 }
698
699 #[test]
700 fn test_aitken_coefficient() {
701 let theta_old = array![1.0, 2.0, 3.0];
702 let theta_curr = array![1.5, 2.5, 3.5];
703 let theta_new = array![1.8, 2.8, 3.8];
704
705 let alpha = AcceleratedEM::aitken_coefficient(&theta_old, &theta_curr, &theta_new);
706 assert!(alpha.is_finite() || alpha.is_nan());
708 }
709}