1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
11 types::Float,
12};
13use std::collections::HashSet;
14
15#[derive(Debug, Clone)]
50pub struct SemiSupervisedGMM<S = Untrained> {
51 state: S,
52 n_components: usize,
53 max_iter: usize,
54 tol: f64,
55 covariance_type: String,
56 reg_covar: f64,
57 labeled_weight: f64,
58 random_seed: Option<u64>,
59}
60
61impl SemiSupervisedGMM<Untrained> {
62 pub fn new() -> Self {
64 Self {
65 state: Untrained,
66 n_components: 2,
67 max_iter: 100,
68 tol: 1e-6,
69 covariance_type: "full".to_string(),
70 reg_covar: 1e-6,
71 labeled_weight: 1.0,
72 random_seed: None,
73 }
74 }
75
76 pub fn n_components(mut self, n_components: usize) -> Self {
78 self.n_components = n_components;
79 self
80 }
81
82 pub fn max_iter(mut self, max_iter: usize) -> Self {
84 self.max_iter = max_iter;
85 self
86 }
87
88 pub fn tol(mut self, tol: f64) -> Self {
90 self.tol = tol;
91 self
92 }
93
94 pub fn covariance_type(mut self, covariance_type: String) -> Self {
96 self.covariance_type = covariance_type;
97 self
98 }
99
100 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
102 self.reg_covar = reg_covar;
103 self
104 }
105
106 pub fn labeled_weight(mut self, labeled_weight: f64) -> Self {
108 self.labeled_weight = labeled_weight;
109 self
110 }
111
112 pub fn random_seed(mut self, seed: u64) -> Self {
114 self.random_seed = Some(seed);
115 self
116 }
117
118 fn initialize_parameters(
119 &self,
120 X: &Array2<f64>,
121 n_classes: usize,
122 ) -> (Array1<f64>, Array2<f64>, Vec<Array2<f64>>) {
123 let (n_samples, n_features) = X.dim();
124
125 let weights = Array1::from_elem(n_classes, 1.0 / n_classes as f64);
127
128 let mut means = Array2::zeros((n_classes, n_features));
130 for k in 0..n_classes {
131 let start_idx = (k * n_samples) / n_classes;
132 let end_idx = ((k + 1) * n_samples) / n_classes;
133
134 for j in 0..n_features {
135 let mut sum = 0.0;
136 let mut count = 0;
137 for i in start_idx..end_idx.min(n_samples) {
138 sum += X[[i, j]];
139 count += 1;
140 }
141 means[[k, j]] = if count > 0 { sum / count as f64 } else { 0.0 };
142 }
143 }
144
145 let mut covariances = Vec::new();
147 for _k in 0..n_classes {
148 let mut cov = Array2::eye(n_features);
149 for i in 0..n_features {
151 cov[[i, i]] += self.reg_covar;
152 }
153 covariances.push(cov);
154 }
155
156 (weights, means, covariances)
157 }
158
159 fn multivariate_normal_pdf(
160 &self,
161 x: &Array1<f64>,
162 mean: &Array1<f64>,
163 cov: &Array2<f64>,
164 ) -> f64 {
165 let n = x.len();
166 let diff = x - mean;
167
168 let det = match self.covariance_type.as_str() {
170 "diag" | "spherical" => cov.diag().iter().product::<f64>(),
171 _ => {
172 cov.diag().iter().product::<f64>()
174 }
175 };
176
177 if det <= 0.0 {
178 return 1e-10; }
180
181 let inv_cov = match self.covariance_type.as_str() {
183 "diag" | "spherical" => {
184 let mut inv = Array2::zeros(cov.dim());
185 for i in 0..n {
186 inv[[i, i]] = 1.0 / cov[[i, i]];
187 }
188 inv
189 }
190 _ => {
191 let mut inv = Array2::zeros(cov.dim());
193 for i in 0..n {
194 inv[[i, i]] = 1.0 / (cov[[i, i]] + self.reg_covar);
195 }
196 inv
197 }
198 };
199
200 let mahalanobis = diff.dot(&inv_cov.dot(&diff));
201 let norm_factor = 1.0 / ((2.0 * std::f64::consts::PI).powi(n as i32) * det).sqrt();
202
203 norm_factor * (-0.5 * mahalanobis).exp()
204 }
205
206 #[allow(clippy::too_many_arguments)]
207 fn expectation_step(
208 &self,
209 X: &Array2<f64>,
210 weights: &Array1<f64>,
211 means: &Array2<f64>,
212 covariances: &[Array2<f64>],
213 labeled_indices: &[usize],
214 y_labeled: &Array1<i32>,
215 classes: &[i32],
216 ) -> Array2<f64> {
217 let n_samples = X.nrows();
218 let n_classes = classes.len();
219 let mut responsibilities = Array2::zeros((n_samples, n_classes));
220
221 for i in 0..n_samples {
222 let x = X.row(i).to_owned();
223 let mut total_likelihood = 0.0;
224 let mut likelihoods = vec![0.0; n_classes];
225
226 for k in 0..n_classes {
228 let mean = means.row(k).to_owned();
229 let likelihood = self.multivariate_normal_pdf(&x, &mean, &covariances[k]);
230 likelihoods[k] = weights[k] * likelihood;
231 total_likelihood += likelihoods[k];
232 }
233
234 if let Some(labeled_pos) = labeled_indices.iter().position(|&idx| idx == i) {
236 let true_label = y_labeled[labeled_pos];
237 if let Some(class_idx) = classes.iter().position(|&c| c == true_label) {
238 for k in 0..n_classes {
240 responsibilities[[i, k]] = if k == class_idx {
241 self.labeled_weight
242 } else {
243 (1.0 - self.labeled_weight) / (n_classes - 1) as f64
244 };
245 }
246 } else {
247 for k in 0..n_classes {
249 responsibilities[[i, k]] = if total_likelihood > 0.0 {
250 likelihoods[k] / total_likelihood
251 } else {
252 1.0 / n_classes as f64
253 };
254 }
255 }
256 } else {
257 for k in 0..n_classes {
259 responsibilities[[i, k]] = if total_likelihood > 0.0 {
260 likelihoods[k] / total_likelihood
261 } else {
262 1.0 / n_classes as f64
263 };
264 }
265 }
266 }
267
268 responsibilities
269 }
270
271 fn maximization_step(
272 &self,
273 X: &Array2<f64>,
274 responsibilities: &Array2<f64>,
275 ) -> (Array1<f64>, Array2<f64>, Vec<Array2<f64>>) {
276 let (n_samples, n_features) = X.dim();
277 let n_classes = responsibilities.ncols();
278
279 let n_k = responsibilities.sum_axis(Axis(0));
281 let weights = &n_k / n_samples as f64;
282
283 let mut means = Array2::zeros((n_classes, n_features));
285 for k in 0..n_classes {
286 if n_k[k] > 0.0 {
287 for j in 0..n_features {
288 let mut weighted_sum = 0.0;
289 for i in 0..n_samples {
290 weighted_sum += responsibilities[[i, k]] * X[[i, j]];
291 }
292 means[[k, j]] = weighted_sum / n_k[k];
293 }
294 }
295 }
296
297 let mut covariances = Vec::new();
299 for k in 0..n_classes {
300 let mut cov = Array2::zeros((n_features, n_features));
301
302 if n_k[k] > 0.0 {
303 let mean_k = means.row(k).to_owned();
304
305 for i in 0..n_samples {
306 let diff = &X.row(i).to_owned() - &mean_k;
307 let weight = responsibilities[[i, k]];
308
309 match self.covariance_type.as_str() {
310 "diag" | "spherical" => {
311 for j in 0..n_features {
313 cov[[j, j]] += weight * diff[j] * diff[j];
314 }
315 }
316 _ => {
317 for j in 0..n_features {
319 cov[[j, j]] += weight * diff[j] * diff[j];
320 }
321 }
322 }
323 }
324
325 for j in 0..n_features {
327 cov[[j, j]] = (cov[[j, j]] / n_k[k]) + self.reg_covar;
328 }
329 } else {
330 for j in 0..n_features {
332 cov[[j, j]] = 1.0 + self.reg_covar;
333 }
334 }
335
336 covariances.push(cov);
337 }
338
339 (weights, means, covariances)
340 }
341
342 fn compute_log_likelihood(
343 &self,
344 X: &Array2<f64>,
345 weights: &Array1<f64>,
346 means: &Array2<f64>,
347 covariances: &[Array2<f64>],
348 ) -> f64 {
349 let n_samples = X.nrows();
350 let n_classes = weights.len();
351 let mut log_likelihood = 0.0;
352
353 for i in 0..n_samples {
354 let x = X.row(i).to_owned();
355 let mut sample_likelihood = 0.0;
356
357 for k in 0..n_classes {
358 let mean = means.row(k).to_owned();
359 let likelihood = self.multivariate_normal_pdf(&x, &mean, &covariances[k]);
360 sample_likelihood += weights[k] * likelihood;
361 }
362
363 if sample_likelihood > 0.0 {
364 log_likelihood += sample_likelihood.ln();
365 }
366 }
367
368 log_likelihood
369 }
370}
371
372impl Default for SemiSupervisedGMM<Untrained> {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378impl Estimator for SemiSupervisedGMM<Untrained> {
379 type Config = ();
380 type Error = SklearsError;
381 type Float = Float;
382
383 fn config(&self) -> &Self::Config {
384 &()
385 }
386}
387
388impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for SemiSupervisedGMM<Untrained> {
389 type Fitted = SemiSupervisedGMM<SemiSupervisedGMMTrained>;
390
391 #[allow(non_snake_case)]
392 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
393 let X = X.to_owned();
394 let y = y.to_owned();
395
396 let mut labeled_indices = Vec::new();
398 let mut y_labeled = Vec::new();
399 let mut classes = HashSet::new();
400
401 for (i, &label) in y.iter().enumerate() {
402 if label != -1 {
403 labeled_indices.push(i);
404 y_labeled.push(label);
405 classes.insert(label);
406 }
407 }
408
409 if labeled_indices.is_empty() {
410 return Err(SklearsError::InvalidInput(
411 "No labeled samples provided".to_string(),
412 ));
413 }
414
415 let classes: Vec<i32> = classes.into_iter().collect();
416 let y_labeled = Array1::from(y_labeled);
417 let n_classes = classes.len();
418
419 if n_classes != self.n_components {
420 return Err(SklearsError::InvalidInput(
421 "Number of components must equal number of classes for supervised learning"
422 .to_string(),
423 ));
424 }
425
426 let (mut weights, mut means, mut covariances) = self.initialize_parameters(&X, n_classes);
428 let mut prev_log_likelihood = f64::NEG_INFINITY;
429
430 for _iter in 0..self.max_iter {
432 let responsibilities = self.expectation_step(
434 &X,
435 &weights,
436 &means,
437 &covariances,
438 &labeled_indices,
439 &y_labeled,
440 &classes,
441 );
442
443 let (new_weights, new_means, new_covariances) =
445 self.maximization_step(&X, &responsibilities);
446 weights = new_weights;
447 means = new_means;
448 covariances = new_covariances;
449
450 let log_likelihood = self.compute_log_likelihood(&X, &weights, &means, &covariances);
452 if (log_likelihood - prev_log_likelihood).abs() < self.tol {
453 break;
454 }
455 prev_log_likelihood = log_likelihood;
456 }
457
458 Ok(SemiSupervisedGMM {
459 state: SemiSupervisedGMMTrained {
460 X_train: X.clone(),
461 y_train: y,
462 classes: Array1::from(classes),
463 weights,
464 means,
465 covariances,
466 },
467 n_components: self.n_components,
468 max_iter: self.max_iter,
469 tol: self.tol,
470 covariance_type: self.covariance_type,
471 reg_covar: self.reg_covar,
472 labeled_weight: self.labeled_weight,
473 random_seed: self.random_seed,
474 })
475 }
476}
477
478impl Predict<ArrayView2<'_, Float>, Array1<i32>> for SemiSupervisedGMM<SemiSupervisedGMMTrained> {
479 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
480 let probas = self.predict_proba(X)?;
481 let n_test = probas.nrows();
482 let mut predictions = Array1::zeros(n_test);
483
484 for i in 0..n_test {
485 let max_idx = probas
486 .row(i)
487 .iter()
488 .enumerate()
489 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
490 .unwrap()
491 .0;
492 predictions[i] = self.state.classes[max_idx];
493 }
494
495 Ok(predictions)
496 }
497}
498
499impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
500 for SemiSupervisedGMM<SemiSupervisedGMMTrained>
501{
502 #[allow(non_snake_case)]
503 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
504 let X = X.to_owned();
505 let n_test = X.nrows();
506 let n_classes = self.state.classes.len();
507 let mut probas = Array2::zeros((n_test, n_classes));
508
509 for i in 0..n_test {
510 let x = X.row(i).to_owned();
511 let mut total_likelihood = 0.0;
512 let mut likelihoods = vec![0.0; n_classes];
513
514 #[allow(clippy::needless_range_loop)]
516 for k in 0..n_classes {
517 let mean = self.state.means.row(k).to_owned();
518 let likelihood =
519 self.multivariate_normal_pdf(&x, &mean, &self.state.covariances[k]);
520 likelihoods[k] = self.state.weights[k] * likelihood;
521 total_likelihood += likelihoods[k];
522 }
523
524 for k in 0..n_classes {
526 probas[[i, k]] = if total_likelihood > 0.0 {
527 likelihoods[k] / total_likelihood
528 } else {
529 1.0 / n_classes as f64
530 };
531 }
532 }
533
534 Ok(probas)
535 }
536}
537
538impl SemiSupervisedGMM<SemiSupervisedGMMTrained> {
539 fn multivariate_normal_pdf(
540 &self,
541 x: &Array1<f64>,
542 mean: &Array1<f64>,
543 cov: &Array2<f64>,
544 ) -> f64 {
545 let n = x.len();
546 let diff = x - mean;
547
548 let det = match self.covariance_type.as_str() {
550 "diag" | "spherical" => cov.diag().iter().product::<f64>(),
551 _ => {
552 cov.diag().iter().product::<f64>()
554 }
555 };
556
557 if det <= 0.0 {
558 return 1e-10; }
560
561 let inv_cov = match self.covariance_type.as_str() {
563 "diag" | "spherical" => {
564 let mut inv = Array2::zeros(cov.dim());
565 for i in 0..n {
566 inv[[i, i]] = 1.0 / cov[[i, i]];
567 }
568 inv
569 }
570 _ => {
571 let mut inv = Array2::zeros(cov.dim());
573 for i in 0..n {
574 inv[[i, i]] = 1.0 / (cov[[i, i]] + self.reg_covar);
575 }
576 inv
577 }
578 };
579
580 let mahalanobis = diff.dot(&inv_cov.dot(&diff));
581 let norm_factor = 1.0 / ((2.0 * std::f64::consts::PI).powi(n as i32) * det).sqrt();
582
583 norm_factor * (-0.5 * mahalanobis).exp()
584 }
585}
586
587#[derive(Debug, Clone)]
589pub struct SemiSupervisedGMMTrained {
590 pub X_train: Array2<f64>,
592 pub y_train: Array1<i32>,
594 pub classes: Array1<i32>,
596 pub weights: Array1<f64>,
598 pub means: Array2<f64>,
600 pub covariances: Vec<Array2<f64>>,
602}