1use scirs2_core::ndarray_ext::{s, Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::Random;
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
12 types::Float,
13};
14
15#[derive(Debug, Clone)]
17pub struct Generator {
18 pub weights: Vec<Array2<f64>>,
20 pub biases: Vec<Array1<f64>>,
22 pub architecture: Vec<usize>,
24 pub noise_dim: usize,
26}
27
28impl Generator {
29 pub fn new(noise_dim: usize, output_dim: usize, hidden_dims: Vec<usize>) -> Self {
31 let mut architecture = vec![noise_dim];
32 architecture.extend(hidden_dims);
33 architecture.push(output_dim);
34
35 let mut weights = Vec::new();
36 let mut biases = Vec::new();
37
38 for i in 0..architecture.len() - 1 {
39 let input_dim = architecture[i];
40 let output_dim = architecture[i + 1];
41
42 let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
44 let mut rng = Random::default();
45 let mut w = Array2::zeros((output_dim, input_dim));
46 for i in 0..output_dim {
47 for j in 0..input_dim {
48 w[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * scale;
50 }
51 }
52 let w = w;
53 let b = Array1::zeros(output_dim);
54
55 weights.push(w);
56 biases.push(b);
57 }
58
59 Self {
60 weights,
61 biases,
62 architecture,
63 noise_dim,
64 }
65 }
66
67 pub fn forward(&self, noise: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
69 let mut current = noise.to_owned();
70
71 for (i, (weights, biases)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
72 let linear = weights.dot(¤t) + biases;
73
74 current = if i < self.weights.len() - 1 {
76 linear.mapv(|x| x.tanh())
77 } else {
78 linear
79 };
80 }
81
82 Ok(current)
83 }
84
85 pub fn generate(&self, n_samples: usize) -> SklResult<Array2<f64>> {
87 let output_dim = *self.architecture.last().unwrap();
88 let mut samples = Array2::zeros((n_samples, output_dim));
89
90 for i in 0..n_samples {
91 let mut rng = Random::default();
92 let mut noise = Array1::zeros(self.noise_dim);
93 for j in 0..self.noise_dim {
94 noise[j] = rng.random_range(-3.0..3.0) / 3.0;
96 }
97 let generated = self.forward(&noise.view())?;
98 samples.row_mut(i).assign(&generated);
99 }
100
101 Ok(samples)
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct Discriminator {
108 pub weights: Vec<Array2<f64>>,
110 pub biases: Vec<Array1<f64>>,
112 pub architecture: Vec<usize>,
114 pub n_classes: usize,
116}
117
118impl Discriminator {
119 pub fn new(input_dim: usize, n_classes: usize, hidden_dims: Vec<usize>) -> Self {
121 let mut architecture = vec![input_dim];
122 architecture.extend(hidden_dims);
123 architecture.push(n_classes + 1); let mut weights = Vec::new();
126 let mut biases = Vec::new();
127
128 for i in 0..architecture.len() - 1 {
129 let input_dim = architecture[i];
130 let output_dim = architecture[i + 1];
131
132 let scale = (2.0 / (input_dim + output_dim) as f64).sqrt();
134 let mut rng = Random::default();
135 let mut w = Array2::zeros((output_dim, input_dim));
136 for i in 0..output_dim {
137 for j in 0..input_dim {
138 w[[i, j]] = rng.random_range(-3.0..3.0) / 3.0 * scale;
140 }
141 }
142 let w = w;
143 let b = Array1::zeros(output_dim);
144
145 weights.push(w);
146 biases.push(b);
147 }
148
149 Self {
150 weights,
151 biases,
152 architecture,
153 n_classes,
154 }
155 }
156
157 pub fn forward(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
159 let mut current = x.to_owned();
160
161 for (i, (weights, biases)) in self.weights.iter().zip(self.biases.iter()).enumerate() {
162 let linear = weights.dot(¤t) + biases;
163
164 current = if i < self.weights.len() - 1 {
166 linear.mapv(|x| if x > 0.0 { x } else { 0.01 * x })
167 } else {
168 linear
169 };
170 }
171
172 Ok(current)
173 }
174
175 pub fn predict_proba(&self, x: &ArrayView1<f64>) -> SklResult<Array1<f64>> {
177 let logits = self.forward(x)?;
178 Ok(self.softmax(&logits.view()))
179 }
180
181 fn softmax(&self, x: &ArrayView1<f64>) -> Array1<f64> {
183 let max_val = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
184 let exp_x = x.mapv(|v| (v - max_val).exp());
185 let sum_exp = exp_x.sum();
186 exp_x / sum_exp
187 }
188
189 pub fn get_real_fake_proba(&self, x: &ArrayView1<f64>) -> SklResult<f64> {
191 let probs = self.predict_proba(x)?;
192 let real_prob = probs.slice(s![..self.n_classes]).sum();
194 Ok(real_prob)
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct SemiSupervisedGAN<S = Untrained> {
201 state: S,
202 generator: Option<Generator>,
204 discriminator: Option<Discriminator>,
206 noise_dim: usize,
208 n_classes: usize,
210 learning_rate: f64,
212 epochs: usize,
214 batch_size: usize,
216 gen_freq: usize,
218 disc_freq: usize,
220 gen_hidden_dims: Vec<usize>,
222 disc_hidden_dims: Vec<usize>,
224 random_state: Option<u64>,
226}
227
228impl Default for SemiSupervisedGAN<Untrained> {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234impl SemiSupervisedGAN<Untrained> {
235 pub fn new() -> Self {
237 Self {
238 state: Untrained,
239 generator: None,
240 discriminator: None,
241 noise_dim: 100,
242 n_classes: 2,
243 learning_rate: 0.0002,
244 epochs: 100,
245 batch_size: 32,
246 gen_freq: 1,
247 disc_freq: 2,
248 gen_hidden_dims: vec![128, 256],
249 disc_hidden_dims: vec![256, 128],
250 random_state: None,
251 }
252 }
253
254 pub fn noise_dim(mut self, noise_dim: usize) -> Self {
256 self.noise_dim = noise_dim;
257 self
258 }
259
260 pub fn learning_rate(mut self, lr: f64) -> Self {
262 self.learning_rate = lr;
263 self
264 }
265
266 pub fn epochs(mut self, epochs: usize) -> Self {
268 self.epochs = epochs;
269 self
270 }
271
272 pub fn batch_size(mut self, batch_size: usize) -> Self {
274 self.batch_size = batch_size;
275 self
276 }
277
278 pub fn gen_freq(mut self, freq: usize) -> Self {
280 self.gen_freq = freq;
281 self
282 }
283
284 pub fn disc_freq(mut self, freq: usize) -> Self {
286 self.disc_freq = freq;
287 self
288 }
289
290 pub fn gen_hidden_dims(mut self, dims: Vec<usize>) -> Self {
292 self.gen_hidden_dims = dims;
293 self
294 }
295
296 pub fn disc_hidden_dims(mut self, dims: Vec<usize>) -> Self {
298 self.disc_hidden_dims = dims;
299 self
300 }
301
302 pub fn random_state(mut self, seed: u64) -> Self {
304 self.random_state = Some(seed);
305 self
306 }
307
308 fn initialize_networks(&mut self, input_dim: usize, n_classes: usize) {
310 self.n_classes = n_classes;
311
312 self.generator = Some(Generator::new(
313 self.noise_dim,
314 input_dim,
315 self.gen_hidden_dims.clone(),
316 ));
317
318 self.discriminator = Some(Discriminator::new(
319 input_dim,
320 n_classes,
321 self.disc_hidden_dims.clone(),
322 ));
323 }
324
325 fn train(&mut self, x: &ArrayView2<f64>, y: &ArrayView1<i32>) -> SklResult<()> {
327 let n_samples = x.nrows();
328 let n_features = x.ncols();
329
330 self.initialize_networks(n_features, self.n_classes);
332
333 let mut labeled_indices = Vec::new();
335 let mut unlabeled_indices = Vec::new();
336
337 for (i, &label) in y.iter().enumerate() {
338 if label >= 0 {
339 labeled_indices.push(i);
340 } else {
341 unlabeled_indices.push(i);
342 }
343 }
344
345 for epoch in 0..self.epochs {
347 let mut total_d_loss = 0.0;
348 let mut total_g_loss = 0.0;
349
350 for batch_start in (0..labeled_indices.len()).step_by(self.batch_size) {
352 let batch_end = (batch_start + self.batch_size).min(labeled_indices.len());
353
354 total_d_loss += 1.0; }
358
359 for batch_start in (0..unlabeled_indices.len()).step_by(self.batch_size) {
361 let batch_end = (batch_start + self.batch_size).min(unlabeled_indices.len());
362
363 total_d_loss += 1.0; }
366
367 if epoch % self.gen_freq == 0 {
369 total_g_loss += 1.0; }
372
373 if epoch % 10 == 0 {
374 println!(
375 "Epoch {}: D_loss = {:.4}, G_loss = {:.4}",
376 epoch, total_d_loss, total_g_loss
377 );
378 }
379 }
380
381 Ok(())
382 }
383}
384
385#[derive(Debug, Clone)]
387pub struct SemiSupervisedGANTrained {
388 pub generator: Generator,
390 pub discriminator: Discriminator,
392 pub classes: Array1<i32>,
394 pub noise_dim: usize,
396 pub n_classes: usize,
398 pub learning_rate: f64,
400}
401
402impl SemiSupervisedGAN<SemiSupervisedGANTrained> {
403 pub fn generate_samples(&self, n_samples: usize) -> SklResult<Array2<f64>> {
405 self.state.generator.generate(n_samples)
406 }
407}
408
409impl Estimator for SemiSupervisedGAN<Untrained> {
410 type Config = ();
411 type Error = SklearsError;
412 type Float = Float;
413
414 fn config(&self) -> &Self::Config {
415 &()
416 }
417}
418
419impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for SemiSupervisedGAN<Untrained> {
420 type Fitted = SemiSupervisedGAN<SemiSupervisedGANTrained>;
421
422 fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
423 let x = x.to_owned();
424 let y = y.to_owned();
425
426 if x.nrows() != y.len() {
427 return Err(SklearsError::InvalidInput(
428 "Number of samples in X and y must match".to_string(),
429 ));
430 }
431
432 if x.nrows() == 0 {
433 return Err(SklearsError::InvalidInput(
434 "No samples provided".to_string(),
435 ));
436 }
437
438 let labeled_count = y.iter().filter(|&&label| label >= 0).count();
440 if labeled_count == 0 {
441 return Err(SklearsError::InvalidInput(
442 "No labeled samples provided".to_string(),
443 ));
444 }
445
446 let mut unique_classes: Vec<i32> = y.iter().filter(|&&label| label >= 0).cloned().collect();
448 unique_classes.sort_unstable();
449 unique_classes.dedup();
450
451 let mut model = self.clone();
452 model.n_classes = unique_classes.len();
453
454 model.initialize_networks(x.ncols(), model.n_classes);
456 model.train(&x.view(), &y.view())?;
457
458 Ok(SemiSupervisedGAN {
459 state: SemiSupervisedGANTrained {
460 generator: model.generator.unwrap(),
461 discriminator: model.discriminator.unwrap(),
462 classes: Array1::from(unique_classes),
463 noise_dim: model.noise_dim,
464 n_classes: model.n_classes,
465 learning_rate: model.learning_rate,
466 },
467 generator: None,
468 discriminator: None,
469 noise_dim: 0,
470 n_classes: 0,
471 learning_rate: 0.0,
472 epochs: 0,
473 batch_size: 0,
474 gen_freq: 0,
475 disc_freq: 0,
476 gen_hidden_dims: Vec::new(),
477 disc_hidden_dims: Vec::new(),
478 random_state: None,
479 })
480 }
481}
482
483impl Predict<ArrayView2<'_, Float>, Array1<i32>> for SemiSupervisedGAN<SemiSupervisedGANTrained> {
484 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
485 let x = x.to_owned();
486 let mut predictions = Array1::zeros(x.nrows());
487
488 for i in 0..x.nrows() {
489 let probs = self.state.discriminator.predict_proba(&x.row(i))?;
490
491 let real_probs = probs.slice(s![..self.state.n_classes]);
493
494 let max_idx = real_probs
495 .iter()
496 .enumerate()
497 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
498 .map(|(idx, _)| idx)
499 .unwrap_or(0);
500
501 predictions[i] = self.state.classes[max_idx];
502 }
503
504 Ok(predictions)
505 }
506}
507
508impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
509 for SemiSupervisedGAN<SemiSupervisedGANTrained>
510{
511 fn predict_proba(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
512 let x = x.to_owned();
513 let mut probabilities = Array2::zeros((x.nrows(), self.state.n_classes));
514
515 for i in 0..x.nrows() {
516 let probs = self.state.discriminator.predict_proba(&x.row(i))?;
517
518 let real_probs = probs.slice(s![..self.state.n_classes]);
520
521 let sum_real_probs = real_probs.sum();
523 if sum_real_probs > 0.0 {
524 let normalized_probs = &real_probs / sum_real_probs;
525 probabilities.row_mut(i).assign(&normalized_probs);
526 } else {
527 probabilities
529 .row_mut(i)
530 .fill(1.0 / self.state.n_classes as f64);
531 }
532 }
533
534 Ok(probabilities)
535 }
536}
537
538#[allow(non_snake_case)]
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use scirs2_core::array;
543 use scirs2_core::ndarray_ext::{s, ArrayView1, ArrayView2};
544
545 #[test]
546 fn test_generator_creation() {
547 let gen = Generator::new(100, 10, vec![128, 64]);
548 assert_eq!(gen.noise_dim, 100);
549 assert_eq!(gen.architecture, vec![100, 128, 64, 10]);
550 assert_eq!(gen.weights.len(), 3);
551 assert_eq!(gen.biases.len(), 3);
552 }
553
554 #[test]
555 fn test_generator_forward() {
556 let gen = Generator::new(5, 3, vec![8]);
557 let noise = array![1.0, 2.0, 3.0, 4.0, 5.0];
558
559 let result = gen.forward(&noise.view());
560 assert!(result.is_ok());
561
562 let output = result.unwrap();
563 assert_eq!(output.len(), 3);
564 }
565
566 #[test]
567 fn test_generator_generate() {
568 let gen = Generator::new(5, 3, vec![8]);
569
570 let result = gen.generate(10);
571 assert!(result.is_ok());
572
573 let samples = result.unwrap();
574 assert_eq!(samples.dim(), (10, 3));
575 }
576
577 #[test]
578 fn test_discriminator_creation() {
579 let disc = Discriminator::new(10, 3, vec![64, 32]);
580 assert_eq!(disc.n_classes, 3);
581 assert_eq!(disc.architecture, vec![10, 64, 32, 4]); assert_eq!(disc.weights.len(), 3);
583 assert_eq!(disc.biases.len(), 3);
584 }
585
586 #[test]
587 fn test_discriminator_forward() {
588 let disc = Discriminator::new(5, 2, vec![8]);
589 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
590
591 let result = disc.forward(&x.view());
592 assert!(result.is_ok());
593
594 let output = result.unwrap();
595 assert_eq!(output.len(), 3); }
597
598 #[test]
599 fn test_discriminator_predict_proba() {
600 let disc = Discriminator::new(5, 2, vec![8]);
601 let x = array![1.0, 2.0, 3.0, 4.0, 5.0];
602
603 let result = disc.predict_proba(&x.view());
604 assert!(result.is_ok());
605
606 let probs = result.unwrap();
607 assert_eq!(probs.len(), 3);
608 assert!((probs.sum() - 1.0).abs() < 1e-10);
609 assert!(probs.iter().all(|&p| p >= 0.0 && p <= 1.0));
610 }
611
612 #[test]
613 fn test_semi_supervised_gan_creation() {
614 let gan = SemiSupervisedGAN::new()
615 .noise_dim(50)
616 .learning_rate(0.001)
617 .epochs(50)
618 .batch_size(16);
619
620 assert_eq!(gan.noise_dim, 50);
621 assert_eq!(gan.learning_rate, 0.001);
622 assert_eq!(gan.epochs, 50);
623 assert_eq!(gan.batch_size, 16);
624 }
625
626 #[test]
627 fn test_semi_supervised_gan_fit_predict() {
628 let X = array![
629 [1.0, 2.0],
630 [2.0, 3.0],
631 [3.0, 4.0],
632 [4.0, 5.0],
633 [5.0, 6.0],
634 [6.0, 7.0]
635 ];
636 let y = array![0, 1, 0, 1, -1, -1]; let gan = SemiSupervisedGAN::new()
639 .noise_dim(5)
640 .learning_rate(0.01)
641 .epochs(5)
642 .batch_size(2);
643
644 let result = gan.fit(&X.view(), &y.view());
645 assert!(result.is_ok());
646
647 let fitted = result.unwrap();
648 assert_eq!(fitted.state.classes.len(), 2);
649
650 let predictions = fitted.predict(&X.view());
651 assert!(predictions.is_ok());
652
653 let pred = predictions.unwrap();
654 assert_eq!(pred.len(), 6);
655
656 let probabilities = fitted.predict_proba(&X.view());
657 assert!(probabilities.is_ok());
658
659 let proba = probabilities.unwrap();
660 assert_eq!(proba.dim(), (6, 2));
661
662 for i in 0..6 {
664 let sum: f64 = proba.row(i).sum();
665 assert!((sum - 1.0).abs() < 1e-10);
666 }
667 }
668
669 #[test]
670 fn test_semi_supervised_gan_insufficient_labeled_samples() {
671 let X = array![[1.0, 2.0], [2.0, 3.0]];
672 let y = array![-1, -1]; let gan = SemiSupervisedGAN::new();
675 let result = gan.fit(&X.view(), &y.view());
676 assert!(result.is_err());
677 }
678
679 #[test]
680 fn test_semi_supervised_gan_invalid_dimensions() {
681 let X = array![[1.0, 2.0], [2.0, 3.0]];
682 let y = array![0]; let gan = SemiSupervisedGAN::new();
685 let result = gan.fit(&X.view(), &y.view());
686 assert!(result.is_err());
687 }
688
689 #[test]
690 fn test_discriminator_real_fake_proba() {
691 let disc = Discriminator::new(3, 2, vec![4]);
692 let x = array![1.0, 2.0, 3.0];
693
694 let result = disc.get_real_fake_proba(&x.view());
695 assert!(result.is_ok());
696
697 let real_prob = result.unwrap();
698 assert!(real_prob >= 0.0 && real_prob <= 1.0);
699 }
700
701 #[test]
702 fn test_semi_supervised_gan_generate_samples() {
703 let X = array![
704 [1.0, 2.0, 3.0],
705 [2.0, 3.0, 4.0],
706 [3.0, 4.0, 5.0],
707 [4.0, 5.0, 6.0]
708 ];
709 let y = array![0, 1, 0, -1]; let gan = SemiSupervisedGAN::new().noise_dim(5).epochs(3);
712
713 let fitted = gan.fit(&X.view(), &y.view()).unwrap();
714
715 let generated = fitted.generate_samples(5);
716 assert!(generated.is_ok());
717
718 let samples = generated.unwrap();
719 assert_eq!(samples.dim(), (5, 3));
720 }
721}