1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
11 types::Float,
12};
13
14#[derive(Debug, Clone)]
50pub struct MixtureDiscriminantAnalysis<S = Untrained> {
51 state: S,
52 n_components: usize,
53 covariance_type: String,
54 reg_covar: f64,
55 max_iter: usize,
56 tol: f64,
57 n_init: usize,
58 random_state: Option<u64>,
59}
60
61impl MixtureDiscriminantAnalysis<Untrained> {
62 pub fn new() -> Self {
64 Self {
65 state: Untrained,
66 n_components: 1,
67 covariance_type: "full".to_string(),
68 reg_covar: 1e-6,
69 max_iter: 100,
70 tol: 1e-3,
71 n_init: 1,
72 random_state: None,
73 }
74 }
75
76 pub fn n_components(mut self, n_components: usize) -> Self {
78 self.n_components = n_components;
79 self
80 }
81
82 pub fn covariance_type(mut self, covariance_type: String) -> Self {
84 self.covariance_type = covariance_type;
85 self
86 }
87
88 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
90 self.reg_covar = reg_covar;
91 self
92 }
93
94 pub fn max_iter(mut self, max_iter: usize) -> Self {
96 self.max_iter = max_iter;
97 self
98 }
99
100 pub fn tol(mut self, tol: f64) -> Self {
102 self.tol = tol;
103 self
104 }
105
106 pub fn n_init(mut self, n_init: usize) -> Self {
108 self.n_init = n_init;
109 self
110 }
111
112 pub fn random_state(mut self, random_state: u64) -> Self {
114 self.random_state = Some(random_state);
115 self
116 }
117
118 #[allow(clippy::type_complexity)]
120 fn initialize_parameters(
121 &self,
122 X: &Array2<f64>,
123 labeled_indices: &[usize],
124 y: &Array1<i32>,
125 classes: &[i32],
126 ) -> SklResult<(
127 Vec<Vec<Array1<f64>>>,
128 Vec<Vec<Array2<f64>>>,
129 Vec<Array1<f64>>,
130 Array1<f64>,
131 )> {
132 let n_features = X.ncols();
133 let n_classes = classes.len();
134
135 let mut means = Vec::new();
137 let mut covariances = Vec::new();
138 let mut component_weights = Vec::new();
139 let mut class_priors = Array1::zeros(n_classes);
140
141 for (class_idx, &class_label) in classes.iter().enumerate() {
142 let class_samples: Vec<usize> = labeled_indices
144 .iter()
145 .filter(|&&i| y[i] == class_label)
146 .copied()
147 .collect();
148
149 if class_samples.is_empty() {
150 return Err(SklearsError::InvalidInput(format!(
151 "No labeled samples for class {}",
152 class_label
153 )));
154 }
155
156 let mut class_means = Vec::new();
158 let mut class_covariances = Vec::new();
159
160 for comp_idx in 0..self.n_components {
162 let sample_idx = class_samples[comp_idx % class_samples.len()];
163 let mean = X.row(sample_idx).to_owned();
164 class_means.push(mean);
165
166 let cov = match self.covariance_type.as_str() {
168 "full" => {
169 let mut cov = Array2::eye(n_features) * self.reg_covar;
170 for i in 0..n_features {
172 for j in 0..n_features {
173 if i == j {
174 cov[[i, j]] += 1.0;
175 }
176 }
177 }
178 cov
179 }
180 "diag" => Array2::eye(n_features),
181 "spherical" => Array2::eye(n_features),
182 "tied" => Array2::eye(n_features),
183 _ => {
184 return Err(SklearsError::InvalidInput(format!(
185 "Unknown covariance type: {}",
186 self.covariance_type
187 )));
188 }
189 };
190 class_covariances.push(cov);
191 }
192
193 means.push(class_means);
194 covariances.push(class_covariances);
195
196 let weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
198 component_weights.push(weights);
199
200 class_priors[class_idx] = class_samples.len() as f64 / labeled_indices.len() as f64;
202 }
203
204 Ok((means, covariances, component_weights, class_priors))
205 }
206
207 fn multivariate_gaussian_pdf(
209 &self,
210 x: &ArrayView1<f64>,
211 mean: &Array1<f64>,
212 cov: &Array2<f64>,
213 ) -> f64 {
214 let n_features = x.len();
215 let diff = x - mean;
216
217 let det = match self.covariance_type.as_str() {
219 "spherical" => cov[[0, 0]].powf(n_features as f64),
220 "diag" => cov.diag().iter().product(),
221 _ => {
222 let mut det = 1.0;
224 for i in 0..n_features {
225 det *= cov[[i, i]];
226 }
227 det
228 }
229 };
230
231 if det <= 0.0 {
232 return 1e-10; }
234
235 let mut mahal_dist = 0.0;
237 match self.covariance_type.as_str() {
238 "spherical" => {
239 let var = cov[[0, 0]];
240 mahal_dist = diff.mapv(|x| x * x).sum() / var;
241 }
242 "diag" => {
243 for i in 0..n_features {
244 mahal_dist += diff[i] * diff[i] / cov[[i, i]];
245 }
246 }
247 _ => {
248 for i in 0..n_features {
250 mahal_dist += diff[i] * diff[i] / cov[[i, i]];
251 }
252 }
253 }
254
255 let normalization =
256 1.0 / ((2.0 * std::f64::consts::PI).powf(n_features as f64 / 2.0) * det.sqrt());
257 normalization * (-0.5 * mahal_dist).exp()
258 }
259
260 #[allow(clippy::too_many_arguments, clippy::type_complexity)]
262 fn e_step(
263 &self,
264 X: &Array2<f64>,
265 means: &[Vec<Array1<f64>>],
266 covariances: &[Vec<Array2<f64>>],
267 component_weights: &[Array1<f64>],
268 class_priors: &Array1<f64>,
269 labeled_indices: &[usize],
270 y: &Array1<i32>,
271 classes: &[i32],
272 ) -> (Array2<f64>, f64) {
273 let n_samples = X.nrows();
274 let n_classes = classes.len();
275 let total_components = n_classes * self.n_components;
276
277 let mut responsibilities = Array2::zeros((n_samples, total_components));
278 let mut log_likelihood = 0.0;
279
280 for i in 0..n_samples {
281 let x = X.row(i);
282 let mut total_prob = 0.0;
283 let mut probs = Vec::new();
284
285 for (class_idx, &class_label) in classes.iter().enumerate() {
287 for comp_idx in 0..self.n_components {
288 let comp_global_idx = class_idx * self.n_components + comp_idx;
289 let prob = class_priors[class_idx]
290 * component_weights[class_idx][comp_idx]
291 * self.multivariate_gaussian_pdf(
292 &x,
293 &means[class_idx][comp_idx],
294 &covariances[class_idx][comp_idx],
295 );
296 probs.push(prob);
297 total_prob += prob;
298 }
299 }
300
301 if total_prob > 0.0 {
303 for (comp_idx, &prob) in probs.iter().enumerate() {
304 responsibilities[[i, comp_idx]] = prob / total_prob;
305 }
306 log_likelihood += total_prob.ln();
307 } else {
308 for comp_idx in 0..total_components {
310 responsibilities[[i, comp_idx]] = 1.0 / total_components as f64;
311 }
312 }
313
314 if labeled_indices.contains(&i) {
316 if let Some(class_idx) = classes.iter().position(|&c| c == y[i]) {
317 for comp_idx in 0..total_components {
319 responsibilities[[i, comp_idx]] = 0.0;
320 }
321 for comp_idx in 0..self.n_components {
323 let global_comp_idx = class_idx * self.n_components + comp_idx;
324 responsibilities[[i, global_comp_idx]] = 1.0 / self.n_components as f64;
325 }
326 }
327 }
328 }
329
330 (responsibilities, log_likelihood)
331 }
332
333 #[allow(clippy::type_complexity)]
335 fn m_step(
336 &self,
337 X: &Array2<f64>,
338 responsibilities: &Array2<f64>,
339 classes: &[i32],
340 ) -> (
341 Vec<Vec<Array1<f64>>>,
342 Vec<Vec<Array2<f64>>>,
343 Vec<Array1<f64>>,
344 Array1<f64>,
345 ) {
346 let n_samples = X.nrows();
347 let n_features = X.ncols();
348 let n_classes = classes.len();
349
350 let mut means = Vec::new();
351 let mut covariances = Vec::new();
352 let mut component_weights = Vec::new();
353 let mut class_priors = Array1::zeros(n_classes);
354
355 for class_idx in 0..n_classes {
356 let mut class_means = Vec::new();
357 let mut class_covariances = Vec::new();
358 let mut class_component_weights = Array1::zeros(self.n_components);
359
360 let mut class_total_responsibility = 0.0;
361
362 for comp_idx in 0..self.n_components {
363 let global_comp_idx = class_idx * self.n_components + comp_idx;
364 let comp_responsibilities = responsibilities.column(global_comp_idx);
365 let comp_total_resp: f64 = comp_responsibilities.sum();
366
367 class_total_responsibility += comp_total_resp;
368
369 if comp_total_resp > 1e-10 {
370 let mut new_mean = Array1::zeros(n_features);
372 for i in 0..n_samples {
373 for j in 0..n_features {
374 new_mean[j] += comp_responsibilities[i] * X[[i, j]];
375 }
376 }
377 new_mean /= comp_total_resp;
378
379 let mut new_cov = Array2::zeros((n_features, n_features));
381 for i in 0..n_samples {
382 let diff = &X.row(i) - &new_mean;
383 let weight = comp_responsibilities[i];
384 for j in 0..n_features {
385 for k in 0..n_features {
386 new_cov[[j, k]] += weight * diff[j] * diff[k];
387 }
388 }
389 }
390 new_cov /= comp_total_resp;
391
392 for i in 0..n_features {
394 new_cov[[i, i]] += self.reg_covar;
395 }
396
397 class_means.push(new_mean);
398 class_covariances.push(new_cov);
399 class_component_weights[comp_idx] = comp_total_resp;
400 } else {
401 class_means.push(Array1::zeros(n_features));
403 class_covariances.push(Array2::eye(n_features) * self.reg_covar);
404 class_component_weights[comp_idx] = 1e-10;
405 }
406 }
407
408 let total_weight = class_component_weights.sum();
410 if total_weight > 0.0 {
411 class_component_weights /= total_weight;
412 } else {
413 class_component_weights.fill(1.0 / self.n_components as f64);
414 }
415
416 means.push(class_means);
417 covariances.push(class_covariances);
418 component_weights.push(class_component_weights);
419
420 class_priors[class_idx] = class_total_responsibility / n_samples as f64;
422 }
423
424 let total_prior = class_priors.sum();
426 if total_prior > 0.0 {
427 class_priors /= total_prior;
428 } else {
429 class_priors.fill(1.0 / n_classes as f64);
430 }
431
432 (means, covariances, component_weights, class_priors)
433 }
434}
435
436impl Default for MixtureDiscriminantAnalysis<Untrained> {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442impl Estimator for MixtureDiscriminantAnalysis<Untrained> {
443 type Config = ();
444 type Error = SklearsError;
445 type Float = Float;
446
447 fn config(&self) -> &Self::Config {
448 &()
449 }
450}
451
452#[derive(Debug, Clone)]
454pub struct MixtureDiscriminantAnalysisTrained {
455 pub means: Vec<Vec<Array1<f64>>>,
457 pub covariances: Vec<Vec<Array2<f64>>>,
459 pub component_weights: Vec<Array1<f64>>,
461 pub class_priors: Array1<f64>,
463 pub classes: Array1<i32>,
465 pub n_components: usize,
467 pub covariance_type: String,
469}
470
471impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MixtureDiscriminantAnalysis<Untrained> {
472 type Fitted = MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained>;
473
474 #[allow(non_snake_case)]
475 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
476 let X = X.to_owned();
477 let y = y.to_owned();
478
479 let mut labeled_indices = Vec::new();
481 let mut classes = std::collections::HashSet::new();
482
483 for (i, &label) in y.iter().enumerate() {
484 if label != -1 {
485 labeled_indices.push(i);
486 classes.insert(label);
487 }
488 }
489
490 if labeled_indices.is_empty() {
491 return Err(SklearsError::InvalidInput(
492 "No labeled samples provided".to_string(),
493 ));
494 }
495
496 let classes: Vec<i32> = classes.into_iter().collect();
497
498 let (mut means, mut covariances, mut component_weights, mut class_priors) =
500 self.initialize_parameters(&X, &labeled_indices, &y, &classes)?;
501
502 let mut prev_log_likelihood = f64::NEG_INFINITY;
503
504 for iteration in 0..self.max_iter {
506 let (responsibilities, log_likelihood) = self.e_step(
508 &X,
509 &means,
510 &covariances,
511 &component_weights,
512 &class_priors,
513 &labeled_indices,
514 &y,
515 &classes,
516 );
517
518 let (new_means, new_covariances, new_component_weights, new_class_priors) =
520 self.m_step(&X, &responsibilities, &classes);
521
522 means = new_means;
523 covariances = new_covariances;
524 component_weights = new_component_weights;
525 class_priors = new_class_priors;
526
527 if iteration > 0 && (log_likelihood - prev_log_likelihood).abs() < self.tol {
529 break;
530 }
531
532 prev_log_likelihood = log_likelihood;
533 }
534
535 Ok(MixtureDiscriminantAnalysis {
536 state: MixtureDiscriminantAnalysisTrained {
537 means,
538 covariances,
539 component_weights,
540 class_priors,
541 classes: Array1::from(classes),
542 n_components: self.n_components,
543 covariance_type: self.covariance_type.clone(),
544 },
545 n_components: self.n_components,
546 covariance_type: self.covariance_type,
547 reg_covar: self.reg_covar,
548 max_iter: self.max_iter,
549 tol: self.tol,
550 n_init: self.n_init,
551 random_state: self.random_state,
552 })
553 }
554}
555
556impl Predict<ArrayView2<'_, Float>, Array1<i32>>
557 for MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained>
558{
559 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
560 let probas = self.predict_proba(X)?;
561 let n_test = probas.nrows();
562 let mut predictions = Array1::zeros(n_test);
563
564 for i in 0..n_test {
565 let max_idx = probas
566 .row(i)
567 .iter()
568 .enumerate()
569 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
570 .unwrap()
571 .0;
572
573 predictions[i] = self.state.classes[max_idx];
574 }
575
576 Ok(predictions)
577 }
578}
579
580impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
581 for MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained>
582{
583 #[allow(non_snake_case)]
584 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
585 let X = X.to_owned();
586 let n_test = X.nrows();
587 let n_classes = self.state.classes.len();
588 let mut probas = Array2::zeros((n_test, n_classes));
589
590 for i in 0..n_test {
591 let x = X.row(i);
592 let mut class_probs = Array1::zeros(n_classes);
593
594 for class_idx in 0..n_classes {
596 let mut class_prob = 0.0;
597
598 for comp_idx in 0..self.state.n_components {
600 let component_prob = self.state.component_weights[class_idx][comp_idx]
601 * self.multivariate_gaussian_pdf(
602 &x,
603 &self.state.means[class_idx][comp_idx],
604 &self.state.covariances[class_idx][comp_idx],
605 );
606 class_prob += component_prob;
607 }
608
609 class_probs[class_idx] = self.state.class_priors[class_idx] * class_prob;
610 }
611
612 let total_prob = class_probs.sum();
614 if total_prob > 0.0 {
615 class_probs /= total_prob;
616 } else {
617 class_probs.fill(1.0 / n_classes as f64);
618 }
619
620 for j in 0..n_classes {
621 probas[[i, j]] = class_probs[j];
622 }
623 }
624
625 Ok(probas)
626 }
627}
628
629impl MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained> {
630 fn multivariate_gaussian_pdf(
632 &self,
633 x: &ArrayView1<f64>,
634 mean: &Array1<f64>,
635 cov: &Array2<f64>,
636 ) -> f64 {
637 let n_features = x.len();
638 let diff = x - mean;
639
640 let det = match self.state.covariance_type.as_str() {
642 "spherical" => cov[[0, 0]].powf(n_features as f64),
643 "diag" => cov.diag().iter().product(),
644 _ => {
645 let mut det = 1.0;
646 for i in 0..n_features {
647 det *= cov[[i, i]];
648 }
649 det
650 }
651 };
652
653 if det <= 0.0 {
654 return 1e-10;
655 }
656
657 let mut mahal_dist = 0.0;
658 match self.state.covariance_type.as_str() {
659 "spherical" => {
660 let var = cov[[0, 0]];
661 mahal_dist = diff.mapv(|x| x * x).sum() / var;
662 }
663 "diag" => {
664 for i in 0..n_features {
665 mahal_dist += diff[i] * diff[i] / cov[[i, i]];
666 }
667 }
668 _ => {
669 for i in 0..n_features {
670 mahal_dist += diff[i] * diff[i] / cov[[i, i]];
671 }
672 }
673 }
674
675 let normalization =
676 1.0 / ((2.0 * std::f64::consts::PI).powf(n_features as f64 / 2.0) * det.sqrt());
677 normalization * (-0.5 * mahal_dist).exp()
678 }
679}
680
681#[allow(non_snake_case)]
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use scirs2_core::array;
686
687 #[test]
688 #[allow(non_snake_case)]
689 fn test_mixture_discriminant_analysis() {
690 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
691 let y = array![0, 1, -1, -1]; let mda = MixtureDiscriminantAnalysis::new()
694 .n_components(1)
695 .max_iter(10); let fitted = mda.fit(&X.view(), &y.view()).unwrap();
697
698 let predictions = fitted.predict(&X.view()).unwrap();
699 assert_eq!(predictions.len(), 4);
700
701 let probas = fitted.predict_proba(&X.view()).unwrap();
702 assert_eq!(probas.dim(), (4, 2));
703
704 for i in 0..4 {
706 let sum: f64 = probas.row(i).sum();
707 assert!((sum - 1.0).abs() < 1e-8);
708 }
709 }
710
711 #[test]
712 fn test_mda_parameters() {
713 let mda = MixtureDiscriminantAnalysis::new()
714 .n_components(3)
715 .covariance_type("diag".to_string())
716 .reg_covar(1e-5)
717 .max_iter(200)
718 .tol(1e-6)
719 .n_init(5)
720 .random_state(42);
721
722 assert_eq!(mda.n_components, 3);
723 assert_eq!(mda.covariance_type, "diag");
724 assert_eq!(mda.reg_covar, 1e-5);
725 assert_eq!(mda.max_iter, 200);
726 assert_eq!(mda.tol, 1e-6);
727 assert_eq!(mda.n_init, 5);
728 assert_eq!(mda.random_state, Some(42));
729 }
730
731 #[test]
732 #[allow(non_snake_case)]
733 fn test_mda_covariance_types() {
734 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
735 let y = array![0, 1, -1, -1];
736
737 for cov_type in &["full", "diag", "spherical", "tied"] {
738 let mda = MixtureDiscriminantAnalysis::new()
739 .covariance_type(cov_type.to_string())
740 .max_iter(5);
741 let fitted = mda.fit(&X.view(), &y.view()).unwrap();
742
743 let predictions = fitted.predict(&X.view()).unwrap();
744 assert_eq!(predictions.len(), 4);
745 }
746 }
747
748 #[test]
749 #[allow(non_snake_case)]
750 fn test_mda_multiple_components() {
751 let X = array![
752 [1.0, 1.0],
753 [1.1, 1.1],
754 [1.2, 1.2],
755 [5.0, 5.0],
756 [5.1, 5.1],
757 [5.2, 5.2],
758 [3.0, 3.0],
759 [3.1, 3.1]
760 ];
761 let y = array![0, 0, -1, 1, 1, -1, -1, -1];
762
763 let mda = MixtureDiscriminantAnalysis::new()
764 .n_components(2)
765 .max_iter(20);
766 let fitted = mda.fit(&X.view(), &y.view()).unwrap();
767
768 let predictions = fitted.predict(&X.view()).unwrap();
769 assert_eq!(predictions.len(), 8);
770
771 let probas = fitted.predict_proba(&X.view()).unwrap();
772 assert_eq!(probas.dim(), (8, 2));
773 }
774
775 #[test]
776 #[allow(non_snake_case)]
777 fn test_mda_error_cases() {
778 let X = array![[1.0, 2.0], [2.0, 3.0]];
779 let y = array![-1, -1]; let mda = MixtureDiscriminantAnalysis::new();
782 let result = mda.fit(&X.view(), &y.view());
783 assert!(result.is_err());
784 }
785
786 #[test]
787 fn test_mda_gaussian_pdf() {
788 let mda = MixtureDiscriminantAnalysis::new().covariance_type("diag".to_string());
789
790 let x = array![1.0, 2.0];
791 let mean = array![1.0, 2.0];
792 let cov = Array2::eye(2);
793
794 let pdf = mda.multivariate_gaussian_pdf(&x.view(), &mean, &cov);
795 assert!(pdf > 0.0);
796 assert!(pdf <= 1.0);
797 }
798}