sklears_semi_supervised/
semi_supervised_gmm.rs

1//! Semi-supervised Gaussian Mixture Model
2//!
3//! This module implements a semi-supervised Gaussian Mixture Model that can utilize
4//! both labeled and unlabeled data for training. It uses the Expectation-Maximization
5//! algorithm with partial supervision from labeled data.
6
7use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use sklears_core::{
9    error::{Result as SklResult, SklearsError},
10    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
11    types::Float,
12};
13use std::collections::HashSet;
14
15/// Semi-supervised Gaussian Mixture Model
16///
17/// This classifier extends the standard Gaussian Mixture Model to handle
18/// semi-supervised learning scenarios where only a subset of training data
19/// is labeled. It uses the EM algorithm with constraints from labeled data.
20///
21/// # Parameters
22///
23/// * `n_components` - Number of mixture components
24/// * `max_iter` - Maximum number of EM iterations
25/// * `tol` - Convergence tolerance
26/// * `covariance_type` - Type of covariance matrix ('full', 'diag', 'spherical')
27/// * `reg_covar` - Regularization parameter for covariance matrices
28/// * `labeled_weight` - Weight for labeled data constraints
29/// * `random_seed` - Random seed for initialization
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_core::array;
35/// use sklears_semi_supervised::SemiSupervisedGMM;
36/// use sklears_core::traits::{Predict, Fit};
37///
38///
39/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
40/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
41///
42/// let gmm = SemiSupervisedGMM::new()
43///     .n_components(2)
44///     .max_iter(100)
45///     .labeled_weight(10.0);
46/// let fitted = gmm.fit(&X.view(), &y.view()).unwrap();
47/// let predictions = fitted.predict(&X.view()).unwrap();
48/// ```
49#[derive(Debug, Clone)]
50pub struct SemiSupervisedGMM<S = Untrained> {
51    state: S,
52    n_components: usize,
53    max_iter: usize,
54    tol: f64,
55    covariance_type: String,
56    reg_covar: f64,
57    labeled_weight: f64,
58    random_seed: Option<u64>,
59}
60
61impl SemiSupervisedGMM<Untrained> {
62    /// Create a new SemiSupervisedGMM instance
63    pub fn new() -> Self {
64        Self {
65            state: Untrained,
66            n_components: 2,
67            max_iter: 100,
68            tol: 1e-6,
69            covariance_type: "full".to_string(),
70            reg_covar: 1e-6,
71            labeled_weight: 1.0,
72            random_seed: None,
73        }
74    }
75
76    /// Set the number of mixture components
77    pub fn n_components(mut self, n_components: usize) -> Self {
78        self.n_components = n_components;
79        self
80    }
81
82    /// Set the maximum number of iterations
83    pub fn max_iter(mut self, max_iter: usize) -> Self {
84        self.max_iter = max_iter;
85        self
86    }
87
88    /// Set the convergence tolerance
89    pub fn tol(mut self, tol: f64) -> Self {
90        self.tol = tol;
91        self
92    }
93
94    /// Set the covariance type
95    pub fn covariance_type(mut self, covariance_type: String) -> Self {
96        self.covariance_type = covariance_type;
97        self
98    }
99
100    /// Set the covariance regularization parameter
101    pub fn reg_covar(mut self, reg_covar: f64) -> Self {
102        self.reg_covar = reg_covar;
103        self
104    }
105
106    /// Set the weight for labeled data constraints
107    pub fn labeled_weight(mut self, labeled_weight: f64) -> Self {
108        self.labeled_weight = labeled_weight;
109        self
110    }
111
112    /// Set the random seed
113    pub fn random_seed(mut self, seed: u64) -> Self {
114        self.random_seed = Some(seed);
115        self
116    }
117
118    fn initialize_parameters(
119        &self,
120        X: &Array2<f64>,
121        n_classes: usize,
122    ) -> (Array1<f64>, Array2<f64>, Vec<Array2<f64>>) {
123        let (n_samples, n_features) = X.dim();
124
125        // Initialize mixture weights
126        let weights = Array1::from_elem(n_classes, 1.0 / n_classes as f64);
127
128        // Initialize means using simple clustering
129        let mut means = Array2::zeros((n_classes, n_features));
130        for k in 0..n_classes {
131            let start_idx = (k * n_samples) / n_classes;
132            let end_idx = ((k + 1) * n_samples) / n_classes;
133
134            for j in 0..n_features {
135                let mut sum = 0.0;
136                let mut count = 0;
137                for i in start_idx..end_idx.min(n_samples) {
138                    sum += X[[i, j]];
139                    count += 1;
140                }
141                means[[k, j]] = if count > 0 { sum / count as f64 } else { 0.0 };
142            }
143        }
144
145        // Initialize covariances
146        let mut covariances = Vec::new();
147        for _k in 0..n_classes {
148            let mut cov = Array2::eye(n_features);
149            // Add regularization
150            for i in 0..n_features {
151                cov[[i, i]] += self.reg_covar;
152            }
153            covariances.push(cov);
154        }
155
156        (weights, means, covariances)
157    }
158
159    fn multivariate_normal_pdf(
160        &self,
161        x: &Array1<f64>,
162        mean: &Array1<f64>,
163        cov: &Array2<f64>,
164    ) -> f64 {
165        let n = x.len();
166        let diff = x - mean;
167
168        // Compute determinant (simplified for diagonal/spherical case)
169        let det = match self.covariance_type.as_str() {
170            "diag" | "spherical" => cov.diag().iter().product::<f64>(),
171            _ => {
172                // For full covariance, use a simple approximation
173                cov.diag().iter().product::<f64>()
174            }
175        };
176
177        if det <= 0.0 {
178            return 1e-10; // Small non-zero value to avoid numerical issues
179        }
180
181        // Compute inverse (simplified)
182        let inv_cov = match self.covariance_type.as_str() {
183            "diag" | "spherical" => {
184                let mut inv = Array2::zeros(cov.dim());
185                for i in 0..n {
186                    inv[[i, i]] = 1.0 / cov[[i, i]];
187                }
188                inv
189            }
190            _ => {
191                // Simplified inverse for numerical stability
192                let mut inv = Array2::zeros(cov.dim());
193                for i in 0..n {
194                    inv[[i, i]] = 1.0 / (cov[[i, i]] + self.reg_covar);
195                }
196                inv
197            }
198        };
199
200        let mahalanobis = diff.dot(&inv_cov.dot(&diff));
201        let norm_factor = 1.0 / ((2.0 * std::f64::consts::PI).powi(n as i32) * det).sqrt();
202
203        norm_factor * (-0.5 * mahalanobis).exp()
204    }
205
206    #[allow(clippy::too_many_arguments)]
207    fn expectation_step(
208        &self,
209        X: &Array2<f64>,
210        weights: &Array1<f64>,
211        means: &Array2<f64>,
212        covariances: &[Array2<f64>],
213        labeled_indices: &[usize],
214        y_labeled: &Array1<i32>,
215        classes: &[i32],
216    ) -> Array2<f64> {
217        let n_samples = X.nrows();
218        let n_classes = classes.len();
219        let mut responsibilities = Array2::zeros((n_samples, n_classes));
220
221        for i in 0..n_samples {
222            let x = X.row(i).to_owned();
223            let mut total_likelihood = 0.0;
224            let mut likelihoods = vec![0.0; n_classes];
225
226            // Compute likelihoods for each component
227            for k in 0..n_classes {
228                let mean = means.row(k).to_owned();
229                let likelihood = self.multivariate_normal_pdf(&x, &mean, &covariances[k]);
230                likelihoods[k] = weights[k] * likelihood;
231                total_likelihood += likelihoods[k];
232            }
233
234            // Check if this sample is labeled
235            if let Some(labeled_pos) = labeled_indices.iter().position(|&idx| idx == i) {
236                let true_label = y_labeled[labeled_pos];
237                if let Some(class_idx) = classes.iter().position(|&c| c == true_label) {
238                    // For labeled samples, set high probability for true class
239                    for k in 0..n_classes {
240                        responsibilities[[i, k]] = if k == class_idx {
241                            self.labeled_weight
242                        } else {
243                            (1.0 - self.labeled_weight) / (n_classes - 1) as f64
244                        };
245                    }
246                } else {
247                    // Fallback: use standard posterior probabilities
248                    for k in 0..n_classes {
249                        responsibilities[[i, k]] = if total_likelihood > 0.0 {
250                            likelihoods[k] / total_likelihood
251                        } else {
252                            1.0 / n_classes as f64
253                        };
254                    }
255                }
256            } else {
257                // For unlabeled samples, use standard posterior probabilities
258                for k in 0..n_classes {
259                    responsibilities[[i, k]] = if total_likelihood > 0.0 {
260                        likelihoods[k] / total_likelihood
261                    } else {
262                        1.0 / n_classes as f64
263                    };
264                }
265            }
266        }
267
268        responsibilities
269    }
270
271    fn maximization_step(
272        &self,
273        X: &Array2<f64>,
274        responsibilities: &Array2<f64>,
275    ) -> (Array1<f64>, Array2<f64>, Vec<Array2<f64>>) {
276        let (n_samples, n_features) = X.dim();
277        let n_classes = responsibilities.ncols();
278
279        // Update weights
280        let n_k = responsibilities.sum_axis(Axis(0));
281        let weights = &n_k / n_samples as f64;
282
283        // Update means
284        let mut means = Array2::zeros((n_classes, n_features));
285        for k in 0..n_classes {
286            if n_k[k] > 0.0 {
287                for j in 0..n_features {
288                    let mut weighted_sum = 0.0;
289                    for i in 0..n_samples {
290                        weighted_sum += responsibilities[[i, k]] * X[[i, j]];
291                    }
292                    means[[k, j]] = weighted_sum / n_k[k];
293                }
294            }
295        }
296
297        // Update covariances (simplified for diagonal case)
298        let mut covariances = Vec::new();
299        for k in 0..n_classes {
300            let mut cov = Array2::zeros((n_features, n_features));
301
302            if n_k[k] > 0.0 {
303                let mean_k = means.row(k).to_owned();
304
305                for i in 0..n_samples {
306                    let diff = &X.row(i).to_owned() - &mean_k;
307                    let weight = responsibilities[[i, k]];
308
309                    match self.covariance_type.as_str() {
310                        "diag" | "spherical" => {
311                            // Diagonal covariance
312                            for j in 0..n_features {
313                                cov[[j, j]] += weight * diff[j] * diff[j];
314                            }
315                        }
316                        _ => {
317                            // Full covariance (simplified to diagonal for numerical stability)
318                            for j in 0..n_features {
319                                cov[[j, j]] += weight * diff[j] * diff[j];
320                            }
321                        }
322                    }
323                }
324
325                // Normalize and add regularization
326                for j in 0..n_features {
327                    cov[[j, j]] = (cov[[j, j]] / n_k[k]) + self.reg_covar;
328                }
329            } else {
330                // Fallback for empty components
331                for j in 0..n_features {
332                    cov[[j, j]] = 1.0 + self.reg_covar;
333                }
334            }
335
336            covariances.push(cov);
337        }
338
339        (weights, means, covariances)
340    }
341
342    fn compute_log_likelihood(
343        &self,
344        X: &Array2<f64>,
345        weights: &Array1<f64>,
346        means: &Array2<f64>,
347        covariances: &[Array2<f64>],
348    ) -> f64 {
349        let n_samples = X.nrows();
350        let n_classes = weights.len();
351        let mut log_likelihood = 0.0;
352
353        for i in 0..n_samples {
354            let x = X.row(i).to_owned();
355            let mut sample_likelihood = 0.0;
356
357            for k in 0..n_classes {
358                let mean = means.row(k).to_owned();
359                let likelihood = self.multivariate_normal_pdf(&x, &mean, &covariances[k]);
360                sample_likelihood += weights[k] * likelihood;
361            }
362
363            if sample_likelihood > 0.0 {
364                log_likelihood += sample_likelihood.ln();
365            }
366        }
367
368        log_likelihood
369    }
370}
371
372impl Default for SemiSupervisedGMM<Untrained> {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378impl Estimator for SemiSupervisedGMM<Untrained> {
379    type Config = ();
380    type Error = SklearsError;
381    type Float = Float;
382
383    fn config(&self) -> &Self::Config {
384        &()
385    }
386}
387
388impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for SemiSupervisedGMM<Untrained> {
389    type Fitted = SemiSupervisedGMM<SemiSupervisedGMMTrained>;
390
391    #[allow(non_snake_case)]
392    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
393        let X = X.to_owned();
394        let y = y.to_owned();
395
396        // Identify labeled samples
397        let mut labeled_indices = Vec::new();
398        let mut y_labeled = Vec::new();
399        let mut classes = HashSet::new();
400
401        for (i, &label) in y.iter().enumerate() {
402            if label != -1 {
403                labeled_indices.push(i);
404                y_labeled.push(label);
405                classes.insert(label);
406            }
407        }
408
409        if labeled_indices.is_empty() {
410            return Err(SklearsError::InvalidInput(
411                "No labeled samples provided".to_string(),
412            ));
413        }
414
415        let classes: Vec<i32> = classes.into_iter().collect();
416        let y_labeled = Array1::from(y_labeled);
417        let n_classes = classes.len();
418
419        if n_classes != self.n_components {
420            return Err(SklearsError::InvalidInput(
421                "Number of components must equal number of classes for supervised learning"
422                    .to_string(),
423            ));
424        }
425
426        // Initialize parameters
427        let (mut weights, mut means, mut covariances) = self.initialize_parameters(&X, n_classes);
428        let mut prev_log_likelihood = f64::NEG_INFINITY;
429
430        // EM algorithm
431        for _iter in 0..self.max_iter {
432            // E-step
433            let responsibilities = self.expectation_step(
434                &X,
435                &weights,
436                &means,
437                &covariances,
438                &labeled_indices,
439                &y_labeled,
440                &classes,
441            );
442
443            // M-step
444            let (new_weights, new_means, new_covariances) =
445                self.maximization_step(&X, &responsibilities);
446            weights = new_weights;
447            means = new_means;
448            covariances = new_covariances;
449
450            // Check convergence
451            let log_likelihood = self.compute_log_likelihood(&X, &weights, &means, &covariances);
452            if (log_likelihood - prev_log_likelihood).abs() < self.tol {
453                break;
454            }
455            prev_log_likelihood = log_likelihood;
456        }
457
458        Ok(SemiSupervisedGMM {
459            state: SemiSupervisedGMMTrained {
460                X_train: X.clone(),
461                y_train: y,
462                classes: Array1::from(classes),
463                weights,
464                means,
465                covariances,
466            },
467            n_components: self.n_components,
468            max_iter: self.max_iter,
469            tol: self.tol,
470            covariance_type: self.covariance_type,
471            reg_covar: self.reg_covar,
472            labeled_weight: self.labeled_weight,
473            random_seed: self.random_seed,
474        })
475    }
476}
477
478impl Predict<ArrayView2<'_, Float>, Array1<i32>> for SemiSupervisedGMM<SemiSupervisedGMMTrained> {
479    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
480        let probas = self.predict_proba(X)?;
481        let n_test = probas.nrows();
482        let mut predictions = Array1::zeros(n_test);
483
484        for i in 0..n_test {
485            let max_idx = probas
486                .row(i)
487                .iter()
488                .enumerate()
489                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
490                .unwrap()
491                .0;
492            predictions[i] = self.state.classes[max_idx];
493        }
494
495        Ok(predictions)
496    }
497}
498
499impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
500    for SemiSupervisedGMM<SemiSupervisedGMMTrained>
501{
502    #[allow(non_snake_case)]
503    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
504        let X = X.to_owned();
505        let n_test = X.nrows();
506        let n_classes = self.state.classes.len();
507        let mut probas = Array2::zeros((n_test, n_classes));
508
509        for i in 0..n_test {
510            let x = X.row(i).to_owned();
511            let mut total_likelihood = 0.0;
512            let mut likelihoods = vec![0.0; n_classes];
513
514            // Compute likelihoods for each component
515            #[allow(clippy::needless_range_loop)]
516            for k in 0..n_classes {
517                let mean = self.state.means.row(k).to_owned();
518                let likelihood =
519                    self.multivariate_normal_pdf(&x, &mean, &self.state.covariances[k]);
520                likelihoods[k] = self.state.weights[k] * likelihood;
521                total_likelihood += likelihoods[k];
522            }
523
524            // Normalize to get probabilities
525            for k in 0..n_classes {
526                probas[[i, k]] = if total_likelihood > 0.0 {
527                    likelihoods[k] / total_likelihood
528                } else {
529                    1.0 / n_classes as f64
530                };
531            }
532        }
533
534        Ok(probas)
535    }
536}
537
538impl SemiSupervisedGMM<SemiSupervisedGMMTrained> {
539    fn multivariate_normal_pdf(
540        &self,
541        x: &Array1<f64>,
542        mean: &Array1<f64>,
543        cov: &Array2<f64>,
544    ) -> f64 {
545        let n = x.len();
546        let diff = x - mean;
547
548        // Compute determinant (simplified for diagonal/spherical case)
549        let det = match self.covariance_type.as_str() {
550            "diag" | "spherical" => cov.diag().iter().product::<f64>(),
551            _ => {
552                // For full covariance, use a simple approximation
553                cov.diag().iter().product::<f64>()
554            }
555        };
556
557        if det <= 0.0 {
558            return 1e-10; // Small non-zero value to avoid numerical issues
559        }
560
561        // Compute inverse (simplified)
562        let inv_cov = match self.covariance_type.as_str() {
563            "diag" | "spherical" => {
564                let mut inv = Array2::zeros(cov.dim());
565                for i in 0..n {
566                    inv[[i, i]] = 1.0 / cov[[i, i]];
567                }
568                inv
569            }
570            _ => {
571                // Simplified inverse for numerical stability
572                let mut inv = Array2::zeros(cov.dim());
573                for i in 0..n {
574                    inv[[i, i]] = 1.0 / (cov[[i, i]] + self.reg_covar);
575                }
576                inv
577            }
578        };
579
580        let mahalanobis = diff.dot(&inv_cov.dot(&diff));
581        let norm_factor = 1.0 / ((2.0 * std::f64::consts::PI).powi(n as i32) * det).sqrt();
582
583        norm_factor * (-0.5 * mahalanobis).exp()
584    }
585}
586
587/// Trained state for SemiSupervisedGMM
588#[derive(Debug, Clone)]
589pub struct SemiSupervisedGMMTrained {
590    /// X_train
591    pub X_train: Array2<f64>,
592    /// y_train
593    pub y_train: Array1<i32>,
594    /// classes
595    pub classes: Array1<i32>,
596    /// weights
597    pub weights: Array1<f64>,
598    /// means
599    pub means: Array2<f64>,
600    /// covariances
601    pub covariances: Vec<Array2<f64>>,
602}