sklears_semi_supervised/
mixture_discriminant_analysis.rs

1//! Mixture Discriminant Analysis for semi-supervised learning
2//!
3//! This module provides Mixture Discriminant Analysis (MDA), a probabilistic
4//! semi-supervised learning method that extends discriminant analysis to
5//! handle both labeled and unlabeled data through mixture modeling.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2};
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
11    types::Float,
12};
13
14/// Mixture Discriminant Analysis
15///
16/// MDA is a semi-supervised extension of Linear/Quadratic Discriminant Analysis
17/// that uses both labeled and unlabeled data to learn class-conditional densities.
18/// It models each class as a mixture of Gaussians and uses EM algorithm to
19/// estimate parameters from both labeled and unlabeled samples.
20///
21/// # Parameters
22///
23/// * `n_components` - Number of mixture components per class
24/// * `covariance_type` - Type of covariance matrix ('full', 'tied', 'diag', 'spherical')
25/// * `reg_covar` - Regularization added to diagonal of covariance
26/// * `max_iter` - Maximum number of EM iterations
27/// * `tol` - Convergence tolerance
28/// * `n_init` - Number of initializations
29/// * `random_state` - Random state for reproducible results
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_core::array;
35/// use sklears_semi_supervised::MixtureDiscriminantAnalysis;
36/// use sklears_core::traits::{Predict, Fit};
37///
38///
39/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
40/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
41///
42/// let mda = MixtureDiscriminantAnalysis::new()
43///     .n_components(2)
44///     .covariance_type("full".to_string())
45///     .max_iter(100);
46/// let fitted = mda.fit(&X.view(), &y.view()).unwrap();
47/// let predictions = fitted.predict(&X.view()).unwrap();
48/// ```
49#[derive(Debug, Clone)]
50pub struct MixtureDiscriminantAnalysis<S = Untrained> {
51    state: S,
52    n_components: usize,
53    covariance_type: String,
54    reg_covar: f64,
55    max_iter: usize,
56    tol: f64,
57    n_init: usize,
58    random_state: Option<u64>,
59}
60
61impl MixtureDiscriminantAnalysis<Untrained> {
62    /// Create a new MixtureDiscriminantAnalysis instance
63    pub fn new() -> Self {
64        Self {
65            state: Untrained,
66            n_components: 1,
67            covariance_type: "full".to_string(),
68            reg_covar: 1e-6,
69            max_iter: 100,
70            tol: 1e-3,
71            n_init: 1,
72            random_state: None,
73        }
74    }
75
76    /// Set the number of mixture components per class
77    pub fn n_components(mut self, n_components: usize) -> Self {
78        self.n_components = n_components;
79        self
80    }
81
82    /// Set the covariance type
83    pub fn covariance_type(mut self, covariance_type: String) -> Self {
84        self.covariance_type = covariance_type;
85        self
86    }
87
88    /// Set the covariance regularization
89    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
90        self.reg_covar = reg_covar;
91        self
92    }
93
94    /// Set the maximum number of iterations
95    pub fn max_iter(mut self, max_iter: usize) -> Self {
96        self.max_iter = max_iter;
97        self
98    }
99
100    /// Set the convergence tolerance
101    pub fn tol(mut self, tol: f64) -> Self {
102        self.tol = tol;
103        self
104    }
105
106    /// Set the number of initializations
107    pub fn n_init(mut self, n_init: usize) -> Self {
108        self.n_init = n_init;
109        self
110    }
111
112    /// Set the random state for reproducible results
113    pub fn random_state(mut self, random_state: u64) -> Self {
114        self.random_state = Some(random_state);
115        self
116    }
117
118    /// Initialize mixture parameters
119    #[allow(clippy::type_complexity)]
120    fn initialize_parameters(
121        &self,
122        X: &Array2<f64>,
123        labeled_indices: &[usize],
124        y: &Array1<i32>,
125        classes: &[i32],
126    ) -> SklResult<(
127        Vec<Vec<Array1<f64>>>,
128        Vec<Vec<Array2<f64>>>,
129        Vec<Array1<f64>>,
130        Array1<f64>,
131    )> {
132        let n_features = X.ncols();
133        let n_classes = classes.len();
134
135        // Initialize means, covariances, component weights, and class priors
136        let mut means = Vec::new();
137        let mut covariances = Vec::new();
138        let mut component_weights = Vec::new();
139        let mut class_priors = Array1::zeros(n_classes);
140
141        for (class_idx, &class_label) in classes.iter().enumerate() {
142            // Find labeled samples for this class
143            let class_samples: Vec<usize> = labeled_indices
144                .iter()
145                .filter(|&&i| y[i] == class_label)
146                .copied()
147                .collect();
148
149            if class_samples.is_empty() {
150                return Err(SklearsError::InvalidInput(format!(
151                    "No labeled samples for class {}",
152                    class_label
153                )));
154            }
155
156            // Initialize means for this class
157            let mut class_means = Vec::new();
158            let mut class_covariances = Vec::new();
159
160            // Simple initialization: use labeled samples as initial means
161            for comp_idx in 0..self.n_components {
162                let sample_idx = class_samples[comp_idx % class_samples.len()];
163                let mean = X.row(sample_idx).to_owned();
164                class_means.push(mean);
165
166                // Initialize covariance based on type
167                let cov = match self.covariance_type.as_str() {
168                    "full" => {
169                        let mut cov = Array2::eye(n_features) * self.reg_covar;
170                        // Add small random perturbation
171                        for i in 0..n_features {
172                            for j in 0..n_features {
173                                if i == j {
174                                    cov[[i, j]] += 1.0;
175                                }
176                            }
177                        }
178                        cov
179                    }
180                    "diag" => Array2::eye(n_features),
181                    "spherical" => Array2::eye(n_features),
182                    "tied" => Array2::eye(n_features),
183                    _ => {
184                        return Err(SklearsError::InvalidInput(format!(
185                            "Unknown covariance type: {}",
186                            self.covariance_type
187                        )));
188                    }
189                };
190                class_covariances.push(cov);
191            }
192
193            means.push(class_means);
194            covariances.push(class_covariances);
195
196            // Initialize component weights (uniform)
197            let weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
198            component_weights.push(weights);
199
200            // Set class prior based on labeled samples
201            class_priors[class_idx] = class_samples.len() as f64 / labeled_indices.len() as f64;
202        }
203
204        Ok((means, covariances, component_weights, class_priors))
205    }
206
207    /// Compute multivariate Gaussian PDF
208    fn multivariate_gaussian_pdf(
209        &self,
210        x: &ArrayView1<f64>,
211        mean: &Array1<f64>,
212        cov: &Array2<f64>,
213    ) -> f64 {
214        let n_features = x.len();
215        let diff = x - mean;
216
217        // Compute determinant and inverse (simplified)
218        let det = match self.covariance_type.as_str() {
219            "spherical" => cov[[0, 0]].powf(n_features as f64),
220            "diag" => cov.diag().iter().product(),
221            _ => {
222                // Simplified determinant calculation
223                let mut det = 1.0;
224                for i in 0..n_features {
225                    det *= cov[[i, i]];
226                }
227                det
228            }
229        };
230
231        if det <= 0.0 {
232            return 1e-10; // Avoid numerical issues
233        }
234
235        // Simplified Mahalanobis distance calculation
236        let mut mahal_dist = 0.0;
237        match self.covariance_type.as_str() {
238            "spherical" => {
239                let var = cov[[0, 0]];
240                mahal_dist = diff.mapv(|x| x * x).sum() / var;
241            }
242            "diag" => {
243                for i in 0..n_features {
244                    mahal_dist += diff[i] * diff[i] / cov[[i, i]];
245                }
246            }
247            _ => {
248                // Simplified full covariance
249                for i in 0..n_features {
250                    mahal_dist += diff[i] * diff[i] / cov[[i, i]];
251                }
252            }
253        }
254
255        let normalization =
256            1.0 / ((2.0 * std::f64::consts::PI).powf(n_features as f64 / 2.0) * det.sqrt());
257        normalization * (-0.5 * mahal_dist).exp()
258    }
259
260    /// E-step: Compute responsibilities
261    #[allow(clippy::too_many_arguments, clippy::type_complexity)]
262    fn e_step(
263        &self,
264        X: &Array2<f64>,
265        means: &[Vec<Array1<f64>>],
266        covariances: &[Vec<Array2<f64>>],
267        component_weights: &[Array1<f64>],
268        class_priors: &Array1<f64>,
269        labeled_indices: &[usize],
270        y: &Array1<i32>,
271        classes: &[i32],
272    ) -> (Array2<f64>, f64) {
273        let n_samples = X.nrows();
274        let n_classes = classes.len();
275        let total_components = n_classes * self.n_components;
276
277        let mut responsibilities = Array2::zeros((n_samples, total_components));
278        let mut log_likelihood = 0.0;
279
280        for i in 0..n_samples {
281            let x = X.row(i);
282            let mut total_prob = 0.0;
283            let mut probs = Vec::new();
284
285            // Compute probabilities for each class and component
286            for (class_idx, &class_label) in classes.iter().enumerate() {
287                for comp_idx in 0..self.n_components {
288                    let comp_global_idx = class_idx * self.n_components + comp_idx;
289                    let prob = class_priors[class_idx]
290                        * component_weights[class_idx][comp_idx]
291                        * self.multivariate_gaussian_pdf(
292                            &x,
293                            &means[class_idx][comp_idx],
294                            &covariances[class_idx][comp_idx],
295                        );
296                    probs.push(prob);
297                    total_prob += prob;
298                }
299            }
300
301            // Normalize and assign responsibilities
302            if total_prob > 0.0 {
303                for (comp_idx, &prob) in probs.iter().enumerate() {
304                    responsibilities[[i, comp_idx]] = prob / total_prob;
305                }
306                log_likelihood += total_prob.ln();
307            } else {
308                // Uniform assignment if no valid probability
309                for comp_idx in 0..total_components {
310                    responsibilities[[i, comp_idx]] = 1.0 / total_components as f64;
311                }
312            }
313
314            // Hard assignment for labeled samples
315            if labeled_indices.contains(&i) {
316                if let Some(class_idx) = classes.iter().position(|&c| c == y[i]) {
317                    // Set responsibility to 1 for true class, 0 for others
318                    for comp_idx in 0..total_components {
319                        responsibilities[[i, comp_idx]] = 0.0;
320                    }
321                    // Uniform distribution within the true class
322                    for comp_idx in 0..self.n_components {
323                        let global_comp_idx = class_idx * self.n_components + comp_idx;
324                        responsibilities[[i, global_comp_idx]] = 1.0 / self.n_components as f64;
325                    }
326                }
327            }
328        }
329
330        (responsibilities, log_likelihood)
331    }
332
333    /// M-step: Update parameters
334    #[allow(clippy::type_complexity)]
335    fn m_step(
336        &self,
337        X: &Array2<f64>,
338        responsibilities: &Array2<f64>,
339        classes: &[i32],
340    ) -> (
341        Vec<Vec<Array1<f64>>>,
342        Vec<Vec<Array2<f64>>>,
343        Vec<Array1<f64>>,
344        Array1<f64>,
345    ) {
346        let n_samples = X.nrows();
347        let n_features = X.ncols();
348        let n_classes = classes.len();
349
350        let mut means = Vec::new();
351        let mut covariances = Vec::new();
352        let mut component_weights = Vec::new();
353        let mut class_priors = Array1::zeros(n_classes);
354
355        for class_idx in 0..n_classes {
356            let mut class_means = Vec::new();
357            let mut class_covariances = Vec::new();
358            let mut class_component_weights = Array1::zeros(self.n_components);
359
360            let mut class_total_responsibility = 0.0;
361
362            for comp_idx in 0..self.n_components {
363                let global_comp_idx = class_idx * self.n_components + comp_idx;
364                let comp_responsibilities = responsibilities.column(global_comp_idx);
365                let comp_total_resp: f64 = comp_responsibilities.sum();
366
367                class_total_responsibility += comp_total_resp;
368
369                if comp_total_resp > 1e-10 {
370                    // Update mean
371                    let mut new_mean = Array1::zeros(n_features);
372                    for i in 0..n_samples {
373                        for j in 0..n_features {
374                            new_mean[j] += comp_responsibilities[i] * X[[i, j]];
375                        }
376                    }
377                    new_mean /= comp_total_resp;
378
379                    // Update covariance
380                    let mut new_cov = Array2::zeros((n_features, n_features));
381                    for i in 0..n_samples {
382                        let diff = &X.row(i) - &new_mean;
383                        let weight = comp_responsibilities[i];
384                        for j in 0..n_features {
385                            for k in 0..n_features {
386                                new_cov[[j, k]] += weight * diff[j] * diff[k];
387                            }
388                        }
389                    }
390                    new_cov /= comp_total_resp;
391
392                    // Add regularization
393                    for i in 0..n_features {
394                        new_cov[[i, i]] += self.reg_covar;
395                    }
396
397                    class_means.push(new_mean);
398                    class_covariances.push(new_cov);
399                    class_component_weights[comp_idx] = comp_total_resp;
400                } else {
401                    // Fallback for empty components
402                    class_means.push(Array1::zeros(n_features));
403                    class_covariances.push(Array2::eye(n_features) * self.reg_covar);
404                    class_component_weights[comp_idx] = 1e-10;
405                }
406            }
407
408            // Normalize component weights
409            let total_weight = class_component_weights.sum();
410            if total_weight > 0.0 {
411                class_component_weights /= total_weight;
412            } else {
413                class_component_weights.fill(1.0 / self.n_components as f64);
414            }
415
416            means.push(class_means);
417            covariances.push(class_covariances);
418            component_weights.push(class_component_weights);
419
420            // Update class prior
421            class_priors[class_idx] = class_total_responsibility / n_samples as f64;
422        }
423
424        // Normalize class priors
425        let total_prior = class_priors.sum();
426        if total_prior > 0.0 {
427            class_priors /= total_prior;
428        } else {
429            class_priors.fill(1.0 / n_classes as f64);
430        }
431
432        (means, covariances, component_weights, class_priors)
433    }
434}
435
436impl Default for MixtureDiscriminantAnalysis<Untrained> {
437    fn default() -> Self {
438        Self::new()
439    }
440}
441
442impl Estimator for MixtureDiscriminantAnalysis<Untrained> {
443    type Config = ();
444    type Error = SklearsError;
445    type Float = Float;
446
447    fn config(&self) -> &Self::Config {
448        &()
449    }
450}
451
452/// Trained state for MixtureDiscriminantAnalysis
453#[derive(Debug, Clone)]
454pub struct MixtureDiscriminantAnalysisTrained {
455    /// means
456    pub means: Vec<Vec<Array1<f64>>>,
457    /// covariances
458    pub covariances: Vec<Vec<Array2<f64>>>,
459    /// component_weights
460    pub component_weights: Vec<Array1<f64>>,
461    /// class_priors
462    pub class_priors: Array1<f64>,
463    /// classes
464    pub classes: Array1<i32>,
465    /// n_components
466    pub n_components: usize,
467    /// covariance_type
468    pub covariance_type: String,
469}
470
471impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for MixtureDiscriminantAnalysis<Untrained> {
472    type Fitted = MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained>;
473
474    #[allow(non_snake_case)]
475    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
476        let X = X.to_owned();
477        let y = y.to_owned();
478
479        // Identify labeled samples and classes
480        let mut labeled_indices = Vec::new();
481        let mut classes = std::collections::HashSet::new();
482
483        for (i, &label) in y.iter().enumerate() {
484            if label != -1 {
485                labeled_indices.push(i);
486                classes.insert(label);
487            }
488        }
489
490        if labeled_indices.is_empty() {
491            return Err(SklearsError::InvalidInput(
492                "No labeled samples provided".to_string(),
493            ));
494        }
495
496        let classes: Vec<i32> = classes.into_iter().collect();
497
498        // Initialize parameters
499        let (mut means, mut covariances, mut component_weights, mut class_priors) =
500            self.initialize_parameters(&X, &labeled_indices, &y, &classes)?;
501
502        let mut prev_log_likelihood = f64::NEG_INFINITY;
503
504        // EM algorithm
505        for iteration in 0..self.max_iter {
506            // E-step
507            let (responsibilities, log_likelihood) = self.e_step(
508                &X,
509                &means,
510                &covariances,
511                &component_weights,
512                &class_priors,
513                &labeled_indices,
514                &y,
515                &classes,
516            );
517
518            // M-step
519            let (new_means, new_covariances, new_component_weights, new_class_priors) =
520                self.m_step(&X, &responsibilities, &classes);
521
522            means = new_means;
523            covariances = new_covariances;
524            component_weights = new_component_weights;
525            class_priors = new_class_priors;
526
527            // Check convergence
528            if iteration > 0 && (log_likelihood - prev_log_likelihood).abs() < self.tol {
529                break;
530            }
531
532            prev_log_likelihood = log_likelihood;
533        }
534
535        Ok(MixtureDiscriminantAnalysis {
536            state: MixtureDiscriminantAnalysisTrained {
537                means,
538                covariances,
539                component_weights,
540                class_priors,
541                classes: Array1::from(classes),
542                n_components: self.n_components,
543                covariance_type: self.covariance_type.clone(),
544            },
545            n_components: self.n_components,
546            covariance_type: self.covariance_type,
547            reg_covar: self.reg_covar,
548            max_iter: self.max_iter,
549            tol: self.tol,
550            n_init: self.n_init,
551            random_state: self.random_state,
552        })
553    }
554}
555
556impl Predict<ArrayView2<'_, Float>, Array1<i32>>
557    for MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained>
558{
559    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
560        let probas = self.predict_proba(X)?;
561        let n_test = probas.nrows();
562        let mut predictions = Array1::zeros(n_test);
563
564        for i in 0..n_test {
565            let max_idx = probas
566                .row(i)
567                .iter()
568                .enumerate()
569                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
570                .unwrap()
571                .0;
572
573            predictions[i] = self.state.classes[max_idx];
574        }
575
576        Ok(predictions)
577    }
578}
579
580impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
581    for MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained>
582{
583    #[allow(non_snake_case)]
584    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
585        let X = X.to_owned();
586        let n_test = X.nrows();
587        let n_classes = self.state.classes.len();
588        let mut probas = Array2::zeros((n_test, n_classes));
589
590        for i in 0..n_test {
591            let x = X.row(i);
592            let mut class_probs = Array1::zeros(n_classes);
593
594            // Compute probability for each class
595            for class_idx in 0..n_classes {
596                let mut class_prob = 0.0;
597
598                // Sum over all components in the class
599                for comp_idx in 0..self.state.n_components {
600                    let component_prob = self.state.component_weights[class_idx][comp_idx]
601                        * self.multivariate_gaussian_pdf(
602                            &x,
603                            &self.state.means[class_idx][comp_idx],
604                            &self.state.covariances[class_idx][comp_idx],
605                        );
606                    class_prob += component_prob;
607                }
608
609                class_probs[class_idx] = self.state.class_priors[class_idx] * class_prob;
610            }
611
612            // Normalize probabilities
613            let total_prob = class_probs.sum();
614            if total_prob > 0.0 {
615                class_probs /= total_prob;
616            } else {
617                class_probs.fill(1.0 / n_classes as f64);
618            }
619
620            for j in 0..n_classes {
621                probas[[i, j]] = class_probs[j];
622            }
623        }
624
625        Ok(probas)
626    }
627}
628
629impl MixtureDiscriminantAnalysis<MixtureDiscriminantAnalysisTrained> {
630    /// Compute multivariate Gaussian PDF (same as training method)
631    fn multivariate_gaussian_pdf(
632        &self,
633        x: &ArrayView1<f64>,
634        mean: &Array1<f64>,
635        cov: &Array2<f64>,
636    ) -> f64 {
637        let n_features = x.len();
638        let diff = x - mean;
639
640        // Compute determinant and Mahalanobis distance (simplified)
641        let det = match self.state.covariance_type.as_str() {
642            "spherical" => cov[[0, 0]].powf(n_features as f64),
643            "diag" => cov.diag().iter().product(),
644            _ => {
645                let mut det = 1.0;
646                for i in 0..n_features {
647                    det *= cov[[i, i]];
648                }
649                det
650            }
651        };
652
653        if det <= 0.0 {
654            return 1e-10;
655        }
656
657        let mut mahal_dist = 0.0;
658        match self.state.covariance_type.as_str() {
659            "spherical" => {
660                let var = cov[[0, 0]];
661                mahal_dist = diff.mapv(|x| x * x).sum() / var;
662            }
663            "diag" => {
664                for i in 0..n_features {
665                    mahal_dist += diff[i] * diff[i] / cov[[i, i]];
666                }
667            }
668            _ => {
669                for i in 0..n_features {
670                    mahal_dist += diff[i] * diff[i] / cov[[i, i]];
671                }
672            }
673        }
674
675        let normalization =
676            1.0 / ((2.0 * std::f64::consts::PI).powf(n_features as f64 / 2.0) * det.sqrt());
677        normalization * (-0.5 * mahal_dist).exp()
678    }
679}
680
681#[allow(non_snake_case)]
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use scirs2_core::array;
686
687    #[test]
688    #[allow(non_snake_case)]
689    fn test_mixture_discriminant_analysis() {
690        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
691        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
692
693        let mda = MixtureDiscriminantAnalysis::new()
694            .n_components(1)
695            .max_iter(10); // Reduced for testing
696        let fitted = mda.fit(&X.view(), &y.view()).unwrap();
697
698        let predictions = fitted.predict(&X.view()).unwrap();
699        assert_eq!(predictions.len(), 4);
700
701        let probas = fitted.predict_proba(&X.view()).unwrap();
702        assert_eq!(probas.dim(), (4, 2));
703
704        // Check that probabilities sum to 1
705        for i in 0..4 {
706            let sum: f64 = probas.row(i).sum();
707            assert!((sum - 1.0).abs() < 1e-8);
708        }
709    }
710
711    #[test]
712    fn test_mda_parameters() {
713        let mda = MixtureDiscriminantAnalysis::new()
714            .n_components(3)
715            .covariance_type("diag".to_string())
716            .reg_covar(1e-5)
717            .max_iter(200)
718            .tol(1e-6)
719            .n_init(5)
720            .random_state(42);
721
722        assert_eq!(mda.n_components, 3);
723        assert_eq!(mda.covariance_type, "diag");
724        assert_eq!(mda.reg_covar, 1e-5);
725        assert_eq!(mda.max_iter, 200);
726        assert_eq!(mda.tol, 1e-6);
727        assert_eq!(mda.n_init, 5);
728        assert_eq!(mda.random_state, Some(42));
729    }
730
731    #[test]
732    #[allow(non_snake_case)]
733    fn test_mda_covariance_types() {
734        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
735        let y = array![0, 1, -1, -1];
736
737        for cov_type in &["full", "diag", "spherical", "tied"] {
738            let mda = MixtureDiscriminantAnalysis::new()
739                .covariance_type(cov_type.to_string())
740                .max_iter(5);
741            let fitted = mda.fit(&X.view(), &y.view()).unwrap();
742
743            let predictions = fitted.predict(&X.view()).unwrap();
744            assert_eq!(predictions.len(), 4);
745        }
746    }
747
748    #[test]
749    #[allow(non_snake_case)]
750    fn test_mda_multiple_components() {
751        let X = array![
752            [1.0, 1.0],
753            [1.1, 1.1],
754            [1.2, 1.2],
755            [5.0, 5.0],
756            [5.1, 5.1],
757            [5.2, 5.2],
758            [3.0, 3.0],
759            [3.1, 3.1]
760        ];
761        let y = array![0, 0, -1, 1, 1, -1, -1, -1];
762
763        let mda = MixtureDiscriminantAnalysis::new()
764            .n_components(2)
765            .max_iter(20);
766        let fitted = mda.fit(&X.view(), &y.view()).unwrap();
767
768        let predictions = fitted.predict(&X.view()).unwrap();
769        assert_eq!(predictions.len(), 8);
770
771        let probas = fitted.predict_proba(&X.view()).unwrap();
772        assert_eq!(probas.dim(), (8, 2));
773    }
774
775    #[test]
776    #[allow(non_snake_case)]
777    fn test_mda_error_cases() {
778        let X = array![[1.0, 2.0], [2.0, 3.0]];
779        let y = array![-1, -1]; // No labeled samples
780
781        let mda = MixtureDiscriminantAnalysis::new();
782        let result = mda.fit(&X.view(), &y.view());
783        assert!(result.is_err());
784    }
785
786    #[test]
787    fn test_mda_gaussian_pdf() {
788        let mda = MixtureDiscriminantAnalysis::new().covariance_type("diag".to_string());
789
790        let x = array![1.0, 2.0];
791        let mean = array![1.0, 2.0];
792        let cov = Array2::eye(2);
793
794        let pdf = mda.multivariate_gaussian_pdf(&x.view(), &mean, &cov);
795        assert!(pdf > 0.0);
796        assert!(pdf <= 1.0);
797    }
798}