1use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9use scirs2_core::random::{thread_rng, Rng, SeedableRng};
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, Trained, Untrained},
13};
14use std::f64::consts::PI;
15
16use crate::common::{CovarianceType, InitMethod, ModelSelection};
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum StructuredFamily {
21 WeightAssignment,
23 MeanPrecision,
25 ComponentWise,
27 BlockDiagonal,
29}
30
31#[derive(Debug, Clone)]
58pub struct StructuredVariationalGMM<S = Untrained> {
59 state: S,
60 n_components: usize,
62 structured_family: StructuredFamily,
64 covariance_type: CovarianceType,
66 tol: f64,
68 max_iter: usize,
70 random_state: Option<u64>,
72 reg_covar: f64,
74 weight_concentration: f64,
76 mean_precision: f64,
78 degrees_of_freedom: f64,
80 init_method: InitMethod,
82 n_init: usize,
84 max_coord_steps: usize,
86 damping: f64,
88}
89
90#[derive(Debug, Clone)]
92pub struct StructuredVariationalGMMTrained {
93 n_components: usize,
95 structured_family: StructuredFamily,
97 covariance_type: CovarianceType,
99 weight_concentration: Array1<f64>,
101 mean_precision: Array1<f64>,
103 mean_values: Array2<f64>,
105 precision_values: Array3<f64>,
107 degrees_of_freedom: Array1<f64>,
109 scale_matrices: Array3<f64>,
111 structured_cov: Array3<f64>,
113 n_samples: usize,
115 n_features: usize,
117 lower_bound: f64,
119 responsibilities: Array2<f64>,
121 model_selection: ModelSelection,
123}
124
125impl StructuredVariationalGMM<Untrained> {
126 pub fn new() -> Self {
128 Self {
129 state: Untrained,
130 n_components: 2,
131 structured_family: StructuredFamily::MeanPrecision,
132 covariance_type: CovarianceType::Full,
133 tol: 1e-3,
134 max_iter: 100,
135 random_state: None,
136 reg_covar: 1e-6,
137 weight_concentration: 1.0,
138 mean_precision: 1.0,
139 degrees_of_freedom: 1.0,
140 init_method: InitMethod::KMeansPlus,
141 n_init: 1,
142 max_coord_steps: 10,
143 damping: 0.5,
144 }
145 }
146
147 pub fn n_components(mut self, n_components: usize) -> Self {
149 self.n_components = n_components;
150 self
151 }
152
153 pub fn structured_family(mut self, family: StructuredFamily) -> Self {
155 self.structured_family = family;
156 self
157 }
158
159 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
161 self.covariance_type = covariance_type;
162 self
163 }
164
165 pub fn tol(mut self, tol: f64) -> Self {
167 self.tol = tol;
168 self
169 }
170
171 pub fn max_iter(mut self, max_iter: usize) -> Self {
173 self.max_iter = max_iter;
174 self
175 }
176
177 pub fn random_state(mut self, random_state: u64) -> Self {
179 self.random_state = Some(random_state);
180 self
181 }
182
183 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
185 self.reg_covar = reg_covar;
186 self
187 }
188
189 pub fn weight_concentration(mut self, weight_concentration: f64) -> Self {
191 self.weight_concentration = weight_concentration;
192 self
193 }
194
195 pub fn mean_precision(mut self, mean_precision: f64) -> Self {
197 self.mean_precision = mean_precision;
198 self
199 }
200
201 pub fn degrees_of_freedom(mut self, degrees_of_freedom: f64) -> Self {
203 self.degrees_of_freedom = degrees_of_freedom;
204 self
205 }
206
207 pub fn init_method(mut self, init_method: InitMethod) -> Self {
209 self.init_method = init_method;
210 self
211 }
212
213 pub fn n_init(mut self, n_init: usize) -> Self {
215 self.n_init = n_init;
216 self
217 }
218
219 pub fn max_coord_steps(mut self, max_coord_steps: usize) -> Self {
221 self.max_coord_steps = max_coord_steps;
222 self
223 }
224
225 pub fn damping(mut self, damping: f64) -> Self {
227 self.damping = damping;
228 self
229 }
230}
231
232impl Default for StructuredVariationalGMM<Untrained> {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238impl Estimator<Untrained> for StructuredVariationalGMM<Untrained> {
239 type Config = ();
240 type Error = SklearsError;
241 type Float = f64;
242
243 fn config(&self) -> &Self::Config {
244 &()
245 }
246}
247
248impl Fit<ArrayView2<'_, f64>, ()> for StructuredVariationalGMM<Untrained> {
249 type Fitted = StructuredVariationalGMMTrained;
250
251 fn fit(self, X: &ArrayView2<f64>, _y: &()) -> SklResult<Self::Fitted> {
252 let (n_samples, _n_features) = X.dim();
253
254 if n_samples < self.n_components {
255 return Err(SklearsError::InvalidInput(
256 "Number of samples must be greater than number of components".to_string(),
257 ));
258 }
259
260 let mut rng = match self.random_state {
262 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
263 None => scirs2_core::random::rngs::StdRng::from_rng(&mut thread_rng()),
264 };
265
266 let mut best_model = None;
267 let mut best_lower_bound = f64::NEG_INFINITY;
268
269 for _ in 0..self.n_init {
270 let (
272 weight_concentration,
273 mean_precision,
274 mean_values,
275 precision_values,
276 degrees_of_freedom,
277 scale_matrices,
278 structured_cov,
279 ) = self.initialize_parameters(X, &mut rng)?;
280
281 let result = self.run_structured_inference(
283 X,
284 weight_concentration,
285 mean_precision,
286 mean_values,
287 precision_values,
288 degrees_of_freedom,
289 scale_matrices,
290 structured_cov,
291 &mut rng,
292 )?;
293
294 if result.lower_bound > best_lower_bound {
295 best_lower_bound = result.lower_bound;
296 best_model = Some(result);
297 }
298 }
299
300 match best_model {
301 Some(model) => Ok(model),
302 None => Err(SklearsError::ConvergenceError {
303 iterations: self.max_iter,
304 }),
305 }
306 }
307}
308
309impl StructuredVariationalGMM<Untrained> {
310 fn initialize_parameters(
312 &self,
313 X: &ArrayView2<f64>,
314 rng: &mut scirs2_core::random::rngs::StdRng,
315 ) -> SklResult<(
316 Array1<f64>,
317 Array1<f64>,
318 Array2<f64>,
319 Array3<f64>,
320 Array1<f64>,
321 Array3<f64>,
322 Array3<f64>,
323 )> {
324 let (_n_samples, n_features) = X.dim();
325
326 let weight_concentration = Array1::from_elem(self.n_components, self.weight_concentration);
328
329 let mean_precision = Array1::from_elem(self.n_components, self.mean_precision);
331
332 let mean_values = self.initialize_means(X, rng)?;
334
335 let precision_values = self.initialize_precisions(X, n_features)?;
337
338 let degrees_of_freedom = Array1::from_elem(
340 self.n_components,
341 self.degrees_of_freedom + n_features as f64,
342 );
343
344 let scale_matrices = self.initialize_scale_matrices(X, n_features)?;
346
347 let structured_cov = self.initialize_structured_covariance(n_features)?;
349
350 Ok((
351 weight_concentration,
352 mean_precision,
353 mean_values,
354 precision_values,
355 degrees_of_freedom,
356 scale_matrices,
357 structured_cov,
358 ))
359 }
360
361 fn initialize_means(
363 &self,
364 X: &ArrayView2<f64>,
365 rng: &mut scirs2_core::random::rngs::StdRng,
366 ) -> SklResult<Array2<f64>> {
367 let (n_samples, n_features) = X.dim();
368 let mut means = Array2::zeros((self.n_components, n_features));
369
370 let first_idx = rng.gen_range(0..n_samples);
372 means
373 .slice_mut(s![0, ..])
374 .assign(&X.slice(s![first_idx, ..]));
375
376 for k in 1..self.n_components {
378 let mut distances = Array1::zeros(n_samples);
379
380 for i in 0..n_samples {
381 let mut min_dist = f64::INFINITY;
382 for j in 0..k {
383 let dist = self.squared_distance(&X.slice(s![i, ..]), &means.slice(s![j, ..]));
384 if dist < min_dist {
385 min_dist = dist;
386 }
387 }
388 distances[i] = min_dist;
389 }
390
391 let total_dist: f64 = distances.sum();
393 let mut prob = rng.gen::<f64>() * total_dist;
394 let mut chosen_idx = 0;
395
396 for i in 0..n_samples {
397 prob -= distances[i];
398 if prob <= 0.0 {
399 chosen_idx = i;
400 break;
401 }
402 }
403
404 means
405 .slice_mut(s![k, ..])
406 .assign(&X.slice(s![chosen_idx, ..]));
407 }
408
409 Ok(means)
410 }
411
412 fn initialize_precisions(
414 &self,
415 X: &ArrayView2<f64>,
416 n_features: usize,
417 ) -> SklResult<Array3<f64>> {
418 let mut precisions = Array3::zeros((self.n_components, n_features, n_features));
419
420 let data_var = X.var_axis(Axis(0), 0.0);
422 let avg_var = data_var.mean().unwrap_or(1.0);
423
424 for k in 0..self.n_components {
425 let mut precision = Array2::eye(n_features);
426 precision *= 1.0 / (avg_var + self.reg_covar);
427 precisions.slice_mut(s![k, .., ..]).assign(&precision);
428 }
429
430 Ok(precisions)
431 }
432
433 fn initialize_scale_matrices(
435 &self,
436 X: &ArrayView2<f64>,
437 n_features: usize,
438 ) -> SklResult<Array3<f64>> {
439 let mut scale_matrices = Array3::zeros((self.n_components, n_features, n_features));
440
441 let cov = self.compute_empirical_covariance(X)?;
443
444 for k in 0..self.n_components {
445 scale_matrices.slice_mut(s![k, .., ..]).assign(&cov);
446 }
447
448 Ok(scale_matrices)
449 }
450
451 fn initialize_structured_covariance(&self, n_features: usize) -> SklResult<Array3<f64>> {
453 let size = match self.structured_family {
454 StructuredFamily::WeightAssignment => self.n_components + 1,
455 StructuredFamily::MeanPrecision => n_features + n_features * n_features,
456 StructuredFamily::ComponentWise => 1 + n_features + n_features * n_features,
457 StructuredFamily::BlockDiagonal => 2 * n_features,
458 };
459
460 let mut structured_cov = Array3::zeros((self.n_components, size, size));
461
462 for k in 0..self.n_components {
464 let mut cov = Array2::eye(size);
465 cov *= 0.1; structured_cov.slice_mut(s![k, .., ..]).assign(&cov);
467 }
468
469 Ok(structured_cov)
470 }
471
472 fn compute_empirical_covariance(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
474 let (n_samples, n_features) = X.dim();
475
476 let mean = X.mean_axis(Axis(0)).unwrap();
478
479 let mut cov = Array2::zeros((n_features, n_features));
481 for i in 0..n_samples {
482 let diff = &X.slice(s![i, ..]) - &mean;
483 for j in 0..n_features {
484 for k in 0..n_features {
485 cov[[j, k]] += diff[j] * diff[k];
486 }
487 }
488 }
489
490 cov /= n_samples as f64;
491
492 for i in 0..n_features {
494 cov[[i, i]] += self.reg_covar;
495 }
496
497 Ok(cov)
498 }
499
500 fn squared_distance(&self, a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
502 let diff = a - b;
503 diff.dot(&diff)
504 }
505
506 fn run_structured_inference(
508 &self,
509 X: &ArrayView2<f64>,
510 mut weight_concentration: Array1<f64>,
511 mut mean_precision: Array1<f64>,
512 mut mean_values: Array2<f64>,
513 mut precision_values: Array3<f64>,
514 mut degrees_of_freedom: Array1<f64>,
515 mut scale_matrices: Array3<f64>,
516 mut structured_cov: Array3<f64>,
517 rng: &mut scirs2_core::random::rngs::StdRng,
518 ) -> SklResult<StructuredVariationalGMMTrained> {
519 let (n_samples, n_features) = X.dim();
520 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
521
522 let mut prev_lower_bound = f64::NEG_INFINITY;
523 let mut lower_bound = f64::NEG_INFINITY;
524
525 for _iter in 0..self.max_iter {
526 self.structured_e_step(
528 X,
529 &weight_concentration,
530 &mean_values,
531 &precision_values,
532 °rees_of_freedom,
533 &scale_matrices,
534 &structured_cov,
535 &mut responsibilities,
536 )?;
537
538 self.structured_m_step(
540 X,
541 &responsibilities,
542 &mut weight_concentration,
543 &mut mean_precision,
544 &mut mean_values,
545 &mut precision_values,
546 &mut degrees_of_freedom,
547 &mut scale_matrices,
548 &mut structured_cov,
549 rng,
550 )?;
551
552 lower_bound = self.compute_structured_lower_bound(
554 X,
555 &responsibilities,
556 &weight_concentration,
557 &mean_precision,
558 &mean_values,
559 &precision_values,
560 °rees_of_freedom,
561 &scale_matrices,
562 &structured_cov,
563 )?;
564
565 if (lower_bound - prev_lower_bound).abs() < self.tol {
567 break;
568 }
569
570 prev_lower_bound = lower_bound;
571 }
572
573 let n_params = self.count_parameters(n_features);
575 let model_selection = ModelSelection {
576 aic: -2.0 * lower_bound + 2.0 * n_params as f64,
577 bic: -2.0 * lower_bound + (n_params as f64) * (n_samples as f64).ln(),
578 log_likelihood: lower_bound,
579 n_parameters: n_params,
580 };
581
582 Ok(StructuredVariationalGMMTrained {
583 n_components: self.n_components,
584 structured_family: self.structured_family,
585 covariance_type: self.covariance_type.clone(),
586 weight_concentration,
587 mean_precision,
588 mean_values,
589 precision_values,
590 degrees_of_freedom,
591 scale_matrices,
592 structured_cov,
593 n_samples,
594 n_features,
595 lower_bound,
596 responsibilities,
597 model_selection,
598 })
599 }
600
601 fn structured_e_step(
603 &self,
604 X: &ArrayView2<f64>,
605 weight_concentration: &Array1<f64>,
606 mean_values: &Array2<f64>,
607 precision_values: &Array3<f64>,
608 degrees_of_freedom: &Array1<f64>,
609 scale_matrices: &Array3<f64>,
610 structured_cov: &Array3<f64>,
611 responsibilities: &mut Array2<f64>,
612 ) -> SklResult<()> {
613 let (n_samples, _) = X.dim();
614
615 let expected_log_weights = self.compute_expected_log_weights(weight_concentration)?;
617
618 for i in 0..n_samples {
620 let mut log_resp = Array1::zeros(self.n_components);
621
622 for k in 0..self.n_components {
623 let expected_log_likelihood = self.compute_expected_log_likelihood(
624 &X.slice(s![i, ..]),
625 &mean_values.slice(s![k, ..]),
626 &precision_values.slice(s![k, .., ..]),
627 °rees_of_freedom[k],
628 &scale_matrices.slice(s![k, .., ..]),
629 &structured_cov.slice(s![k, .., ..]),
630 )?;
631
632 log_resp[k] = expected_log_weights[k] + expected_log_likelihood;
633 }
634
635 let log_prob_norm = self.log_sum_exp_array(&log_resp);
637 for k in 0..self.n_components {
638 responsibilities[[i, k]] = (log_resp[k] - log_prob_norm).exp();
639 }
640 }
641
642 Ok(())
643 }
644
645 fn structured_m_step(
647 &self,
648 X: &ArrayView2<f64>,
649 responsibilities: &Array2<f64>,
650 weight_concentration: &mut Array1<f64>,
651 mean_precision: &mut Array1<f64>,
652 mean_values: &mut Array2<f64>,
653 precision_values: &mut Array3<f64>,
654 degrees_of_freedom: &mut Array1<f64>,
655 scale_matrices: &mut Array3<f64>,
656 structured_cov: &mut Array3<f64>,
657 rng: &mut scirs2_core::random::rngs::StdRng,
658 ) -> SklResult<()> {
659 let (_n_samples, _n_features) = X.dim();
660
661 let n_k = responsibilities.sum_axis(Axis(0));
663
664 for k in 0..self.n_components {
666 weight_concentration[k] = self.weight_concentration + n_k[k];
667 }
668
669 for k in 0..self.n_components {
671 for _ in 0..self.max_coord_steps {
673 self.update_structured_parameters(
674 X,
675 responsibilities,
676 k,
677 &n_k,
678 mean_precision,
679 mean_values,
680 precision_values,
681 degrees_of_freedom,
682 scale_matrices,
683 structured_cov,
684 rng,
685 )?;
686 }
687 }
688
689 Ok(())
690 }
691
692 fn update_structured_parameters(
694 &self,
695 X: &ArrayView2<f64>,
696 responsibilities: &Array2<f64>,
697 k: usize,
698 n_k: &Array1<f64>,
699 mean_precision: &mut Array1<f64>,
700 mean_values: &mut Array2<f64>,
701 precision_values: &mut Array3<f64>,
702 degrees_of_freedom: &mut Array1<f64>,
703 scale_matrices: &mut Array3<f64>,
704 structured_cov: &mut Array3<f64>,
705 rng: &mut scirs2_core::random::rngs::StdRng,
706 ) -> SklResult<()> {
707 let (_n_samples, _n_features) = X.dim();
708
709 match self.structured_family {
710 StructuredFamily::WeightAssignment => {
711 self.update_weight_assignment_parameters(
713 X,
714 responsibilities,
715 k,
716 n_k,
717 mean_precision,
718 mean_values,
719 precision_values,
720 degrees_of_freedom,
721 scale_matrices,
722 structured_cov,
723 rng,
724 )?;
725 }
726 StructuredFamily::MeanPrecision => {
727 self.update_mean_precision_parameters(
729 X,
730 responsibilities,
731 k,
732 n_k,
733 mean_precision,
734 mean_values,
735 precision_values,
736 degrees_of_freedom,
737 scale_matrices,
738 structured_cov,
739 rng,
740 )?;
741 }
742 StructuredFamily::ComponentWise => {
743 self.update_component_wise_parameters(
745 X,
746 responsibilities,
747 k,
748 n_k,
749 mean_precision,
750 mean_values,
751 precision_values,
752 degrees_of_freedom,
753 scale_matrices,
754 structured_cov,
755 rng,
756 )?;
757 }
758 StructuredFamily::BlockDiagonal => {
759 self.update_block_diagonal_parameters(
761 X,
762 responsibilities,
763 k,
764 n_k,
765 mean_precision,
766 mean_values,
767 precision_values,
768 degrees_of_freedom,
769 scale_matrices,
770 structured_cov,
771 rng,
772 )?;
773 }
774 }
775
776 Ok(())
777 }
778
779 fn update_weight_assignment_parameters(
781 &self,
782 X: &ArrayView2<f64>,
783 responsibilities: &Array2<f64>,
784 k: usize,
785 n_k: &Array1<f64>,
786 mean_precision: &mut Array1<f64>,
787 mean_values: &mut Array2<f64>,
788 _precision_values: &mut Array3<f64>,
789 degrees_of_freedom: &mut Array1<f64>,
790 scale_matrices: &mut Array3<f64>,
791 _structured_cov: &mut Array3<f64>,
792 _rng: &mut scirs2_core::random::rngs::StdRng,
793 ) -> SklResult<()> {
794 let (n_samples, n_features) = X.dim();
795
796 let mut weighted_mean = Array1::zeros(n_features);
798 for i in 0..n_samples {
799 let weight = responsibilities[[i, k]];
800 for j in 0..n_features {
801 weighted_mean[j] += weight * X[[i, j]];
802 }
803 }
804
805 if n_k[k] > 0.0 {
806 weighted_mean /= n_k[k];
807 }
808
809 for j in 0..n_features {
811 let old_mean = mean_values[[k, j]];
812 let new_mean = (self.mean_precision * 0.0 + n_k[k] * weighted_mean[j])
813 / (self.mean_precision + n_k[k]);
814 mean_values[[k, j]] = (1.0 - self.damping) * old_mean + self.damping * new_mean;
815 }
816
817 mean_precision[k] = self.mean_precision + n_k[k];
819
820 degrees_of_freedom[k] = self.degrees_of_freedom + n_k[k];
822
823 let mut scale_update = Array2::zeros((n_features, n_features));
825 for i in 0..n_samples {
826 let weight = responsibilities[[i, k]];
827 let diff = &X.slice(s![i, ..]) - &mean_values.slice(s![k, ..]);
828 for j in 0..n_features {
829 for l in 0..n_features {
830 scale_update[[j, l]] += weight * diff[j] * diff[l];
831 }
832 }
833 }
834
835 let mut current_scale = scale_matrices.slice(s![k, .., ..]).to_owned();
837 current_scale = (1.0 - self.damping) * current_scale + self.damping * scale_update;
838 scale_matrices
839 .slice_mut(s![k, .., ..])
840 .assign(¤t_scale);
841
842 Ok(())
843 }
844
845 fn update_mean_precision_parameters(
847 &self,
848 X: &ArrayView2<f64>,
849 responsibilities: &Array2<f64>,
850 k: usize,
851 n_k: &Array1<f64>,
852 mean_precision: &mut Array1<f64>,
853 mean_values: &mut Array2<f64>,
854 _precision_values: &mut Array3<f64>,
855 degrees_of_freedom: &mut Array1<f64>,
856 scale_matrices: &mut Array3<f64>,
857 structured_cov: &mut Array3<f64>,
858 _rng: &mut scirs2_core::random::rngs::StdRng,
859 ) -> SklResult<()> {
860 let (n_samples, n_features) = X.dim();
861
862 let mut weighted_mean = Array1::zeros(n_features);
864 for i in 0..n_samples {
865 let weight = responsibilities[[i, k]];
866 for j in 0..n_features {
867 weighted_mean[j] += weight * X[[i, j]];
868 }
869 }
870
871 if n_k[k] > 0.0 {
872 weighted_mean /= n_k[k];
873 }
874
875 let structured_factor = structured_cov[[k, 0, 0]]; let correlation_adjustment = 1.0 + structured_factor.abs() * 0.1;
878
879 for j in 0..n_features {
881 let old_mean = mean_values[[k, j]];
882 let new_mean = (self.mean_precision * 0.0
883 + n_k[k] * weighted_mean[j] * correlation_adjustment)
884 / (self.mean_precision + n_k[k] * correlation_adjustment);
885 mean_values[[k, j]] = (1.0 - self.damping) * old_mean + self.damping * new_mean;
886 }
887
888 mean_precision[k] = (self.mean_precision + n_k[k]) * correlation_adjustment;
890
891 degrees_of_freedom[k] = self.degrees_of_freedom + n_k[k];
893
894 let mut scale_update = Array2::zeros((n_features, n_features));
896 for i in 0..n_samples {
897 let weight = responsibilities[[i, k]];
898 let diff = &X.slice(s![i, ..]) - &mean_values.slice(s![k, ..]);
899 for j in 0..n_features {
900 for l in 0..n_features {
901 scale_update[[j, l]] += weight * diff[j] * diff[l] * correlation_adjustment;
902 }
903 }
904 }
905
906 let mut current_scale = scale_matrices.slice(s![k, .., ..]).to_owned();
908 current_scale = (1.0 - self.damping) * current_scale + self.damping * scale_update;
909 scale_matrices
910 .slice_mut(s![k, .., ..])
911 .assign(¤t_scale);
912
913 Ok(())
914 }
915
916 fn update_component_wise_parameters(
918 &self,
919 X: &ArrayView2<f64>,
920 responsibilities: &Array2<f64>,
921 k: usize,
922 n_k: &Array1<f64>,
923 mean_precision: &mut Array1<f64>,
924 mean_values: &mut Array2<f64>,
925 _precision_values: &mut Array3<f64>,
926 degrees_of_freedom: &mut Array1<f64>,
927 scale_matrices: &mut Array3<f64>,
928 structured_cov: &mut Array3<f64>,
929 _rng: &mut scirs2_core::random::rngs::StdRng,
930 ) -> SklResult<()> {
931 let (n_samples, n_features) = X.dim();
932
933 let mut weighted_mean = Array1::zeros(n_features);
935 for i in 0..n_samples {
936 let weight = responsibilities[[i, k]];
937 for j in 0..n_features {
938 weighted_mean[j] += weight * X[[i, j]];
939 }
940 }
941
942 if n_k[k] > 0.0 {
943 weighted_mean /= n_k[k];
944 }
945
946 let structured_factor = structured_cov[[k, 0, 0]].abs() * 0.1;
948 let weight_factor = 1.0 + structured_factor;
949 let mean_factor = 1.0 + structured_factor * 0.5;
950 let precision_factor = 1.0 + structured_factor * 0.3;
951
952 for j in 0..n_features {
954 let old_mean = mean_values[[k, j]];
955 let new_mean = (self.mean_precision * 0.0 + n_k[k] * weighted_mean[j] * mean_factor)
956 / (self.mean_precision + n_k[k] * mean_factor);
957 mean_values[[k, j]] = (1.0 - self.damping) * old_mean + self.damping * new_mean;
958 }
959
960 mean_precision[k] = (self.mean_precision + n_k[k]) * precision_factor;
962
963 degrees_of_freedom[k] = (self.degrees_of_freedom + n_k[k]) * weight_factor;
965
966 let mut scale_update = Array2::zeros((n_features, n_features));
968 for i in 0..n_samples {
969 let weight = responsibilities[[i, k]];
970 let diff = &X.slice(s![i, ..]) - &mean_values.slice(s![k, ..]);
971 for j in 0..n_features {
972 for l in 0..n_features {
973 scale_update[[j, l]] += weight * diff[j] * diff[l] * mean_factor;
974 }
975 }
976 }
977
978 let mut current_scale = scale_matrices.slice(s![k, .., ..]).to_owned();
980 current_scale = (1.0 - self.damping) * current_scale + self.damping * scale_update;
981 scale_matrices
982 .slice_mut(s![k, .., ..])
983 .assign(¤t_scale);
984
985 Ok(())
986 }
987
988 fn update_block_diagonal_parameters(
990 &self,
991 X: &ArrayView2<f64>,
992 responsibilities: &Array2<f64>,
993 k: usize,
994 n_k: &Array1<f64>,
995 mean_precision: &mut Array1<f64>,
996 mean_values: &mut Array2<f64>,
997 _precision_values: &mut Array3<f64>,
998 degrees_of_freedom: &mut Array1<f64>,
999 scale_matrices: &mut Array3<f64>,
1000 structured_cov: &mut Array3<f64>,
1001 _rng: &mut scirs2_core::random::rngs::StdRng,
1002 ) -> SklResult<()> {
1003 let (n_samples, n_features) = X.dim();
1004
1005 let mut weighted_mean = Array1::zeros(n_features);
1007 for i in 0..n_samples {
1008 let weight = responsibilities[[i, k]];
1009 for j in 0..n_features {
1010 weighted_mean[j] += weight * X[[i, j]];
1011 }
1012 }
1013
1014 if n_k[k] > 0.0 {
1015 weighted_mean /= n_k[k];
1016 }
1017
1018 let block_size = (n_features / 2).max(1);
1020 for block_start in (0..n_features).step_by(block_size) {
1021 let block_end = (block_start + block_size).min(n_features);
1022
1023 let block_factor = structured_cov[[
1025 k,
1026 block_start % structured_cov.len_of(Axis(1)),
1027 block_start % structured_cov.len_of(Axis(2)),
1028 ]]
1029 .abs()
1030 * 0.1;
1031 let correlation_factor = 1.0 + block_factor;
1032
1033 for j in block_start..block_end {
1035 let old_mean = mean_values[[k, j]];
1036 let new_mean = (self.mean_precision * 0.0
1037 + n_k[k] * weighted_mean[j] * correlation_factor)
1038 / (self.mean_precision + n_k[k] * correlation_factor);
1039 mean_values[[k, j]] = (1.0 - self.damping) * old_mean + self.damping * new_mean;
1040 }
1041 }
1042
1043 mean_precision[k] = self.mean_precision + n_k[k];
1045
1046 degrees_of_freedom[k] = self.degrees_of_freedom + n_k[k];
1048
1049 let mut scale_update = Array2::zeros((n_features, n_features));
1051 for i in 0..n_samples {
1052 let weight = responsibilities[[i, k]];
1053 let diff = &X.slice(s![i, ..]) - &mean_values.slice(s![k, ..]);
1054 for j in 0..n_features {
1055 for l in 0..n_features {
1056 scale_update[[j, l]] += weight * diff[j] * diff[l];
1057 }
1058 }
1059 }
1060
1061 let mut current_scale = scale_matrices.slice(s![k, .., ..]).to_owned();
1063 current_scale = (1.0 - self.damping) * current_scale + self.damping * scale_update;
1064 scale_matrices
1065 .slice_mut(s![k, .., ..])
1066 .assign(¤t_scale);
1067
1068 Ok(())
1069 }
1070
1071 fn compute_expected_log_weights(
1073 &self,
1074 weight_concentration: &Array1<f64>,
1075 ) -> SklResult<Array1<f64>> {
1076 let concentration_sum: f64 = weight_concentration.sum();
1077 let mut expected_log_weights = Array1::zeros(self.n_components);
1078
1079 for k in 0..self.n_components {
1080 expected_log_weights[k] =
1082 Self::digamma(weight_concentration[k]) - Self::digamma(concentration_sum);
1083 }
1084
1085 Ok(expected_log_weights)
1086 }
1087
1088 fn compute_expected_log_likelihood(
1090 &self,
1091 x: &ArrayView1<f64>,
1092 mean: &ArrayView1<f64>,
1093 precision: &ArrayView2<f64>,
1094 degrees_of_freedom: &f64,
1095 _scale_matrix: &ArrayView2<f64>,
1096 structured_cov: &ArrayView2<f64>,
1097 ) -> SklResult<f64> {
1098 let n_features = x.len();
1099 let diff = x - mean;
1100
1101 let mut expected_log_det = 0.0;
1103 for i in 0..n_features {
1104 expected_log_det += Self::digamma((degrees_of_freedom + 1.0 - i as f64) / 2.0);
1105 }
1106 expected_log_det += n_features as f64 * (2.0_f64).ln();
1107
1108 let structured_correction = structured_cov[[0, 0]].abs() * 0.01;
1110 expected_log_det += structured_correction;
1111
1112 let mut expected_quad_form = 0.0;
1114 for i in 0..n_features {
1115 for j in 0..n_features {
1116 expected_quad_form += diff[i] * precision[[i, j]] * diff[j];
1117 }
1118 }
1119 expected_quad_form *= degrees_of_freedom / (degrees_of_freedom - 2.0);
1120
1121 expected_quad_form += structured_correction * expected_quad_form.abs() * 0.01;
1123
1124 let log_likelihood = 0.5 * expected_log_det
1125 - 0.5 * expected_quad_form
1126 - 0.5 * n_features as f64 * (2.0 * PI).ln();
1127
1128 Ok(log_likelihood)
1129 }
1130
1131 fn compute_structured_lower_bound(
1133 &self,
1134 X: &ArrayView2<f64>,
1135 responsibilities: &Array2<f64>,
1136 weight_concentration: &Array1<f64>,
1137 _mean_precision: &Array1<f64>,
1138 mean_values: &Array2<f64>,
1139 precision_values: &Array3<f64>,
1140 degrees_of_freedom: &Array1<f64>,
1141 scale_matrices: &Array3<f64>,
1142 structured_cov: &Array3<f64>,
1143 ) -> SklResult<f64> {
1144 let (n_samples, _n_features) = X.dim();
1145 let mut lower_bound = 0.0;
1146
1147 let expected_log_weights = self.compute_expected_log_weights(weight_concentration)?;
1149
1150 for i in 0..n_samples {
1151 for k in 0..self.n_components {
1152 let responsibility = responsibilities[[i, k]];
1153 if responsibility > 1e-10 {
1154 let expected_log_likelihood = self.compute_expected_log_likelihood(
1155 &X.slice(s![i, ..]),
1156 &mean_values.slice(s![k, ..]),
1157 &precision_values.slice(s![k, .., ..]),
1158 °rees_of_freedom[k],
1159 &scale_matrices.slice(s![k, .., ..]),
1160 &structured_cov.slice(s![k, .., ..]),
1161 )?;
1162
1163 lower_bound +=
1164 responsibility * (expected_log_weights[k] + expected_log_likelihood);
1165 }
1166 }
1167 }
1168
1169 let concentration_sum: f64 = weight_concentration.sum();
1171 let prior_concentration_sum = self.weight_concentration * self.n_components as f64;
1172
1173 lower_bound +=
1175 Self::log_gamma(concentration_sum) - Self::log_gamma(prior_concentration_sum);
1176 for k in 0..self.n_components {
1177 lower_bound += Self::log_gamma(self.weight_concentration)
1178 - Self::log_gamma(weight_concentration[k]);
1179 lower_bound += (weight_concentration[k] - self.weight_concentration)
1180 * (Self::digamma(weight_concentration[k]) - Self::digamma(concentration_sum));
1181 }
1182
1183 for k in 0..self.n_components {
1185 let structured_correction = structured_cov
1186 .slice(s![k, .., ..])
1187 .iter()
1188 .map(|&x| x.abs())
1189 .sum::<f64>()
1190 * 0.001;
1191 lower_bound -= structured_correction;
1192 }
1193
1194 for i in 0..n_samples {
1196 for k in 0..self.n_components {
1197 let responsibility = responsibilities[[i, k]];
1198 if responsibility > 1e-10 {
1199 lower_bound -= responsibility * responsibility.ln();
1200 }
1201 }
1202 }
1203
1204 Ok(lower_bound)
1205 }
1206
1207 fn count_parameters(&self, n_features: usize) -> usize {
1209 let mut n_params = self.n_components - 1; n_params += self.n_components * n_features; match self.covariance_type {
1214 CovarianceType::Full => {
1215 n_params += self.n_components * n_features * (n_features + 1) / 2
1216 }
1217 CovarianceType::Diagonal => n_params += self.n_components * n_features,
1218 CovarianceType::Tied => n_params += n_features * (n_features + 1) / 2,
1219 CovarianceType::Spherical => n_params += self.n_components,
1220 }
1221
1222 let structured_params = match self.structured_family {
1224 StructuredFamily::WeightAssignment => self.n_components + 1,
1225 StructuredFamily::MeanPrecision => n_features + n_features * n_features,
1226 StructuredFamily::ComponentWise => 1 + n_features + n_features * n_features,
1227 StructuredFamily::BlockDiagonal => 2 * n_features,
1228 };
1229
1230 n_params += self.n_components * structured_params * structured_params;
1231
1232 n_params
1233 }
1234
1235 fn digamma(x: f64) -> f64 {
1237 if x < 8.0 {
1238 Self::digamma(x + 1.0) - 1.0 / x
1239 } else {
1240 let inv_x = 1.0 / x;
1241 let inv_x2 = inv_x * inv_x;
1242 x.ln() - 0.5 * inv_x - inv_x2 / 12.0 + inv_x2 * inv_x2 / 120.0
1243 }
1244 }
1245
1246 fn log_gamma(x: f64) -> f64 {
1248 if x < 0.5 {
1249 (PI / (PI * x).sin()).ln() - Self::log_gamma(1.0 - x)
1250 } else {
1251 let g = 7.0;
1252 let c = [
1253 0.999_999_999_999_809_9,
1254 676.5203681218851,
1255 -1259.1392167224028,
1256 771.323_428_777_653_1,
1257 -176.615_029_162_140_6,
1258 12.507343278686905,
1259 -0.13857109526572012,
1260 9.984_369_578_019_572e-6,
1261 1.5056327351493116e-7,
1262 ];
1263
1264 let z = x - 1.0;
1265 let mut x_sum = c[0];
1266 for (i, &c_val) in c.iter().enumerate().skip(1) {
1267 x_sum += c_val / (z + i as f64);
1268 }
1269 let t = z + g + 0.5;
1270 (2.0 * PI).sqrt().ln() + (z + 0.5) * t.ln() - t + x_sum.ln()
1271 }
1272 }
1273
1274 fn log_sum_exp_array(&self, arr: &Array1<f64>) -> f64 {
1276 let max_val = arr.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1277 if max_val.is_finite() {
1278 max_val + arr.iter().map(|&x| (x - max_val).exp()).sum::<f64>().ln()
1279 } else {
1280 max_val
1281 }
1282 }
1283}
1284
1285impl Estimator<Trained> for StructuredVariationalGMMTrained {
1286 type Config = ();
1287 type Error = SklearsError;
1288 type Float = f64;
1289
1290 fn config(&self) -> &Self::Config {
1291 &()
1292 }
1293}
1294
1295impl Predict<ArrayView2<'_, f64>, Array1<usize>> for StructuredVariationalGMMTrained {
1296 fn predict(&self, X: &ArrayView2<f64>) -> SklResult<Array1<usize>> {
1297 let probabilities = self.predict_proba(X)?;
1298 let mut predictions = Array1::zeros(X.nrows());
1299
1300 for i in 0..X.nrows() {
1301 let mut max_prob = 0.0;
1302 let mut best_class = 0;
1303
1304 for k in 0..self.n_components {
1305 if probabilities[[i, k]] > max_prob {
1306 max_prob = probabilities[[i, k]];
1307 best_class = k;
1308 }
1309 }
1310
1311 predictions[i] = best_class;
1312 }
1313
1314 Ok(predictions)
1315 }
1316}
1317
1318impl StructuredVariationalGMMTrained {
1319 pub fn predict_proba(&self, X: &ArrayView2<f64>) -> SklResult<Array2<f64>> {
1321 let (n_samples, n_features) = X.dim();
1322
1323 if n_features != self.n_features {
1324 return Err(SklearsError::InvalidInput(format!(
1325 "Expected {} features, got {}",
1326 self.n_features, n_features
1327 )));
1328 }
1329
1330 let mut probabilities = Array2::zeros((n_samples, self.n_components));
1331
1332 let expected_log_weights = self.compute_expected_log_weights()?;
1334
1335 for i in 0..n_samples {
1336 let mut log_probs = Array1::zeros(self.n_components);
1337
1338 for k in 0..self.n_components {
1339 let expected_log_likelihood = self.compute_expected_log_likelihood(
1340 &X.slice(s![i, ..]),
1341 &self.mean_values.slice(s![k, ..]),
1342 &self.precision_values.slice(s![k, .., ..]),
1343 &self.degrees_of_freedom[k],
1344 &self.scale_matrices.slice(s![k, .., ..]),
1345 &self.structured_cov.slice(s![k, .., ..]),
1346 )?;
1347
1348 log_probs[k] = expected_log_weights[k] + expected_log_likelihood;
1349 }
1350
1351 let log_prob_norm = self.log_sum_exp_array(&log_probs);
1353 for k in 0..self.n_components {
1354 probabilities[[i, k]] = (log_probs[k] - log_prob_norm).exp();
1355 }
1356 }
1357
1358 Ok(probabilities)
1359 }
1360
1361 pub fn score(&self, X: &ArrayView2<f64>) -> SklResult<f64> {
1363 let (n_samples, n_features) = X.dim();
1364
1365 if n_features != self.n_features {
1366 return Err(SklearsError::InvalidInput(format!(
1367 "Expected {} features, got {}",
1368 self.n_features, n_features
1369 )));
1370 }
1371
1372 let expected_log_weights = self.compute_expected_log_weights()?;
1373 let mut total_log_likelihood = 0.0;
1374
1375 for i in 0..n_samples {
1376 let mut log_probs = Array1::zeros(self.n_components);
1377
1378 for k in 0..self.n_components {
1379 let expected_log_likelihood = self.compute_expected_log_likelihood(
1380 &X.slice(s![i, ..]),
1381 &self.mean_values.slice(s![k, ..]),
1382 &self.precision_values.slice(s![k, .., ..]),
1383 &self.degrees_of_freedom[k],
1384 &self.scale_matrices.slice(s![k, .., ..]),
1385 &self.structured_cov.slice(s![k, .., ..]),
1386 )?;
1387
1388 log_probs[k] = expected_log_weights[k] + expected_log_likelihood;
1389 }
1390
1391 total_log_likelihood += self.log_sum_exp_array(&log_probs);
1392 }
1393
1394 Ok(total_log_likelihood)
1395 }
1396
1397 pub fn model_selection(&self) -> &ModelSelection {
1399 &self.model_selection
1400 }
1401
1402 pub fn lower_bound(&self) -> f64 {
1404 self.lower_bound
1405 }
1406
1407 pub fn responsibilities(&self) -> &Array2<f64> {
1409 &self.responsibilities
1410 }
1411
1412 pub fn mean_values(&self) -> &Array2<f64> {
1414 &self.mean_values
1415 }
1416
1417 pub fn precision_values(&self) -> &Array3<f64> {
1419 &self.precision_values
1420 }
1421
1422 pub fn structured_cov(&self) -> &Array3<f64> {
1424 &self.structured_cov
1425 }
1426
1427 pub fn structured_family(&self) -> StructuredFamily {
1429 self.structured_family
1430 }
1431
1432 fn compute_expected_log_weights(&self) -> SklResult<Array1<f64>> {
1434 let concentration_sum: f64 = self.weight_concentration.sum();
1435 let mut expected_log_weights = Array1::zeros(self.n_components);
1436
1437 for k in 0..self.n_components {
1438 expected_log_weights[k] =
1439 Self::digamma(self.weight_concentration[k]) - Self::digamma(concentration_sum);
1440 }
1441
1442 Ok(expected_log_weights)
1443 }
1444
1445 fn compute_expected_log_likelihood(
1446 &self,
1447 x: &ArrayView1<f64>,
1448 mean: &ArrayView1<f64>,
1449 precision: &ArrayView2<f64>,
1450 degrees_of_freedom: &f64,
1451 _scale_matrix: &ArrayView2<f64>,
1452 structured_cov: &ArrayView2<f64>,
1453 ) -> SklResult<f64> {
1454 let n_features = x.len();
1455 let diff = x - mean;
1456
1457 let mut expected_log_det = 0.0;
1459 for i in 0..n_features {
1460 expected_log_det += Self::digamma((degrees_of_freedom + 1.0 - i as f64) / 2.0);
1461 }
1462 expected_log_det += n_features as f64 * (2.0_f64).ln();
1463
1464 let structured_correction = structured_cov[[0, 0]].abs() * 0.01;
1466 expected_log_det += structured_correction;
1467
1468 let mut expected_quad_form = 0.0;
1470 for i in 0..n_features {
1471 for j in 0..n_features {
1472 expected_quad_form += diff[i] * precision[[i, j]] * diff[j];
1473 }
1474 }
1475 expected_quad_form *= degrees_of_freedom / (degrees_of_freedom - 2.0);
1476
1477 expected_quad_form += structured_correction * expected_quad_form.abs() * 0.01;
1479
1480 let log_likelihood = 0.5 * expected_log_det
1481 - 0.5 * expected_quad_form
1482 - 0.5 * n_features as f64 * (2.0 * PI).ln();
1483
1484 Ok(log_likelihood)
1485 }
1486
1487 fn digamma(x: f64) -> f64 {
1488 if x < 8.0 {
1489 Self::digamma(x + 1.0) - 1.0 / x
1490 } else {
1491 let inv_x = 1.0 / x;
1492 let inv_x2 = inv_x * inv_x;
1493 x.ln() - 0.5 * inv_x - inv_x2 / 12.0 + inv_x2 * inv_x2 / 120.0
1494 }
1495 }
1496
1497 fn log_sum_exp_array(&self, arr: &Array1<f64>) -> f64 {
1498 let max_val = arr.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
1499 if max_val.is_finite() {
1500 max_val + arr.iter().map(|&x| (x - max_val).exp()).sum::<f64>().ln()
1501 } else {
1502 max_val
1503 }
1504 }
1505}
1506
1507#[allow(non_snake_case)]
1508#[cfg(test)]
1509mod tests {
1510 use super::*;
1511 use approx::assert_abs_diff_eq;
1512 use scirs2_core::ndarray::array;
1513 use sklears_core::traits::Predict;
1514
1515 #[test]
1516 fn test_structured_variational_gmm_creation() {
1517 let gmm = StructuredVariationalGMM::new()
1518 .n_components(3)
1519 .structured_family(StructuredFamily::MeanPrecision)
1520 .tol(1e-4)
1521 .max_iter(200);
1522
1523 assert_eq!(gmm.n_components, 3);
1524 assert_eq!(gmm.structured_family, StructuredFamily::MeanPrecision);
1525 assert_eq!(gmm.tol, 1e-4);
1526 assert_eq!(gmm.max_iter, 200);
1527 }
1528
1529 #[test]
1530 #[allow(non_snake_case)]
1531 fn test_structured_variational_gmm_fit_predict() {
1532 let X = array![
1533 [0.0, 0.0],
1534 [0.5, 0.5],
1535 [1.0, 1.0],
1536 [10.0, 10.0],
1537 [10.5, 10.5],
1538 [11.0, 11.0]
1539 ];
1540
1541 let gmm = StructuredVariationalGMM::new()
1542 .n_components(2)
1543 .structured_family(StructuredFamily::MeanPrecision)
1544 .random_state(42)
1545 .tol(1e-3)
1546 .max_iter(50);
1547
1548 let fitted = gmm.fit(&X.view(), &()).unwrap();
1549 let predictions = fitted.predict(&X.view()).unwrap();
1550
1551 assert_eq!(predictions.len(), 6);
1552 assert!(predictions.iter().all(|&label| label < 2));
1553
1554 let first_cluster = predictions[0];
1556 assert_eq!(predictions[1], first_cluster);
1557 assert_eq!(predictions[2], first_cluster);
1558
1559 let second_cluster = predictions[3];
1560 assert_eq!(predictions[4], second_cluster);
1561 assert_eq!(predictions[5], second_cluster);
1562
1563 assert_ne!(first_cluster, second_cluster);
1564 }
1565
1566 #[test]
1567 #[allow(non_snake_case)]
1568 fn test_structured_families() {
1569 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1570
1571 let families = vec![
1572 StructuredFamily::WeightAssignment,
1573 StructuredFamily::MeanPrecision,
1574 StructuredFamily::ComponentWise,
1575 StructuredFamily::BlockDiagonal,
1576 ];
1577
1578 for family in families {
1579 let gmm = StructuredVariationalGMM::new()
1580 .n_components(2)
1581 .structured_family(family)
1582 .random_state(42)
1583 .tol(1e-2)
1584 .max_iter(20);
1585
1586 let fitted = gmm.fit(&X.view(), &()).unwrap();
1587 let predictions = fitted.predict(&X.view()).unwrap();
1588
1589 assert_eq!(predictions.len(), 4);
1590 assert!(predictions.iter().all(|&label| label < 2));
1591 }
1592 }
1593
1594 #[test]
1595 #[allow(non_snake_case)]
1596 fn test_structured_variational_gmm_probabilities() {
1597 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1598
1599 let gmm = StructuredVariationalGMM::new()
1600 .n_components(2)
1601 .structured_family(StructuredFamily::MeanPrecision)
1602 .random_state(42)
1603 .tol(1e-3)
1604 .max_iter(30);
1605
1606 let fitted = gmm.fit(&X.view(), &()).unwrap();
1607 let probabilities = fitted.predict_proba(&X.view()).unwrap();
1608
1609 assert_eq!(probabilities.dim(), (4, 2));
1610
1611 for i in 0..4 {
1613 let sum: f64 = probabilities.slice(s![i, ..]).sum();
1614 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
1615 }
1616
1617 assert!(probabilities.iter().all(|&p| p >= 0.0));
1619 }
1620
1621 #[test]
1622 #[allow(non_snake_case)]
1623 fn test_structured_variational_gmm_score() {
1624 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1625
1626 let gmm = StructuredVariationalGMM::new()
1627 .n_components(2)
1628 .structured_family(StructuredFamily::MeanPrecision)
1629 .random_state(42)
1630 .tol(1e-3)
1631 .max_iter(30);
1632
1633 let fitted = gmm.fit(&X.view(), &()).unwrap();
1634 let score = fitted.score(&X.view()).unwrap();
1635
1636 assert!(score.is_finite());
1637 }
1638
1639 #[test]
1640 #[allow(non_snake_case)]
1641 fn test_structured_variational_gmm_model_selection() {
1642 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1643
1644 let gmm = StructuredVariationalGMM::new()
1645 .n_components(2)
1646 .structured_family(StructuredFamily::MeanPrecision)
1647 .random_state(42)
1648 .tol(1e-3)
1649 .max_iter(30);
1650
1651 let fitted = gmm.fit(&X.view(), &()).unwrap();
1652 let model_selection = fitted.model_selection();
1653
1654 assert!(model_selection.aic.is_finite());
1655 assert!(model_selection.bic.is_finite());
1656 assert!(model_selection.log_likelihood.is_finite());
1657 assert!(model_selection.n_parameters > 0);
1658 }
1659
1660 #[test]
1661 #[allow(non_snake_case)]
1662 fn test_structured_variational_gmm_covariance_types() {
1663 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1664
1665 let covariance_types = vec![
1666 CovarianceType::Full,
1667 CovarianceType::Diagonal,
1668 CovarianceType::Tied,
1669 CovarianceType::Spherical,
1670 ];
1671
1672 for covariance_type in covariance_types {
1673 let gmm = StructuredVariationalGMM::new()
1674 .n_components(2)
1675 .structured_family(StructuredFamily::MeanPrecision)
1676 .covariance_type(covariance_type)
1677 .random_state(42)
1678 .tol(1e-2)
1679 .max_iter(20);
1680
1681 let fitted = gmm.fit(&X.view(), &()).unwrap();
1682 let predictions = fitted.predict(&X.view()).unwrap();
1683
1684 assert_eq!(predictions.len(), 4);
1685 assert!(predictions.iter().all(|&label| label < 2));
1686 }
1687 }
1688
1689 #[test]
1690 #[allow(non_snake_case)]
1691 fn test_structured_variational_gmm_parameter_access() {
1692 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1693
1694 let gmm = StructuredVariationalGMM::new()
1695 .n_components(2)
1696 .structured_family(StructuredFamily::MeanPrecision)
1697 .random_state(42)
1698 .tol(1e-3)
1699 .max_iter(30);
1700
1701 let fitted = gmm.fit(&X.view(), &()).unwrap();
1702
1703 assert_eq!(fitted.mean_values().dim(), (2, 2));
1705 assert_eq!(fitted.precision_values().dim(), (2, 2, 2));
1706 assert_eq!(fitted.structured_cov().dim(), (2, 6, 6)); assert_eq!(fitted.responsibilities().dim(), (4, 2));
1708 assert_eq!(fitted.structured_family(), StructuredFamily::MeanPrecision);
1709
1710 assert!(fitted.lower_bound().is_finite());
1712 }
1713
1714 #[test]
1715 #[allow(non_snake_case)]
1716 fn test_structured_variational_gmm_reproducibility() {
1717 let X = array![[0.0, 0.0], [1.0, 1.0], [10.0, 10.0], [11.0, 11.0]];
1718
1719 let gmm1 = StructuredVariationalGMM::new()
1720 .n_components(2)
1721 .structured_family(StructuredFamily::MeanPrecision)
1722 .random_state(42)
1723 .tol(1e-3)
1724 .max_iter(30);
1725
1726 let gmm2 = StructuredVariationalGMM::new()
1727 .n_components(2)
1728 .structured_family(StructuredFamily::MeanPrecision)
1729 .random_state(42)
1730 .tol(1e-3)
1731 .max_iter(30);
1732
1733 let fitted1 = gmm1.fit(&X.view(), &()).unwrap();
1734 let fitted2 = gmm2.fit(&X.view(), &()).unwrap();
1735
1736 let predictions1 = fitted1.predict(&X.view()).unwrap();
1737 let predictions2 = fitted2.predict(&X.view()).unwrap();
1738
1739 assert_eq!(predictions1, predictions2);
1740 }
1741
1742 #[test]
1743 #[allow(non_snake_case)]
1744 fn test_structured_variational_gmm_single_component() {
1745 let X = array![[0.0, 0.0], [1.0, 1.0], [0.5, 0.5], [1.5, 1.5]];
1746
1747 let gmm = StructuredVariationalGMM::new()
1748 .n_components(1)
1749 .structured_family(StructuredFamily::MeanPrecision)
1750 .random_state(42)
1751 .tol(1e-3)
1752 .max_iter(30);
1753
1754 let fitted = gmm.fit(&X.view(), &()).unwrap();
1755 let predictions = fitted.predict(&X.view()).unwrap();
1756
1757 assert_eq!(predictions.len(), 4);
1758 assert!(predictions.iter().all(|&label| label == 0));
1759 }
1760
1761 #[test]
1762 #[allow(non_snake_case)]
1763 fn test_structured_variational_gmm_dimensional_consistency() {
1764 let X = array![
1765 [0.0, 0.0, 0.0],
1766 [1.0, 1.0, 1.0],
1767 [10.0, 10.0, 10.0],
1768 [11.0, 11.0, 11.0]
1769 ];
1770
1771 let gmm = StructuredVariationalGMM::new()
1772 .n_components(2)
1773 .structured_family(StructuredFamily::MeanPrecision)
1774 .random_state(42)
1775 .tol(1e-3)
1776 .max_iter(30);
1777
1778 let fitted = gmm.fit(&X.view(), &()).unwrap();
1779
1780 assert_eq!(fitted.mean_values().dim(), (2, 3));
1782 assert_eq!(fitted.precision_values().dim(), (2, 3, 3));
1783 assert_eq!(fitted.responsibilities().dim(), (4, 2));
1784
1785 let predictions = fitted.predict(&X.view()).unwrap();
1786 assert_eq!(predictions.len(), 4);
1787
1788 let probabilities = fitted.predict_proba(&X.view()).unwrap();
1789 assert_eq!(probabilities.dim(), (4, 2));
1790 }
1791
1792 #[test]
1793 #[allow(non_snake_case)]
1794 fn test_structured_variational_gmm_error_handling() {
1795 let X = array![[0.0, 0.0], [1.0, 1.0]];
1796
1797 let gmm = StructuredVariationalGMM::new()
1799 .n_components(5)
1800 .structured_family(StructuredFamily::MeanPrecision)
1801 .random_state(42);
1802
1803 let result = gmm.fit(&X.view(), &());
1804 assert!(result.is_err());
1805
1806 let gmm2 = StructuredVariationalGMM::new()
1808 .n_components(2)
1809 .structured_family(StructuredFamily::MeanPrecision)
1810 .max_iter(10)
1811 .tol(1e-2)
1812 .random_state(42);
1813
1814 let fitted = match gmm2.fit(&X.view(), &()) {
1815 Ok(fitted) => fitted,
1816 Err(_) => {
1817 return; }
1820 };
1821
1822 let X_wrong = array![[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]];
1823
1824 let result = fitted.predict(&X_wrong.view());
1825 assert!(result.is_err());
1826 }
1827}