1use ndarray::{s, Array, Array2, Array3, Array4, ArrayD, IxDyn};
14use scirs2_neural::layers::{
15 AdaptiveMaxPool2D, BatchNorm, Conv2D, Dense, Dropout, PaddingMode, Sequential,
16};
17use scirs2_neural::losses::{CrossEntropyLoss, MeanSquaredError};
18use scirs2_neural::prelude::*;
19
20type StdResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
22use rand::prelude::*;
23use rand::rngs::SmallRng;
24#[derive(Debug, Clone)]
29pub struct GenerativeConfig {
30 pub input_size: (usize, usize),
31 pub latent_dim: usize,
32 pub hidden_dims: Vec<usize>,
33 pub beta: f32, }
35
36impl Default for GenerativeConfig {
37 fn default() -> Self {
38 Self {
39 input_size: (32, 32),
40 latent_dim: 16,
41 hidden_dims: vec![128, 64, 32],
42 beta: 1.0,
43 }
44 }
45}
46
47pub struct GenerativeDataset {
49 config: GenerativeConfig,
50 rng: SmallRng,
51}
52
53impl GenerativeDataset {
54 pub fn new(config: GenerativeConfig, seed: u64) -> Self {
55 Self {
56 config,
57 rng: SmallRng::seed_from_u64(seed),
58 }
59 }
60
61 pub fn generate_sample(&mut self) -> Array3<f32> {
63 let (height, width) = self.config.input_size;
64 let mut image = Array3::<f32>::zeros((1, height, width)); let pattern_type = self.rng.random_range(0..4);
67
68 match pattern_type {
69 0 => {
70 let num_circles = self.rng.random_range(1..4);
72 for _ in 0..num_circles {
73 let center_x = self.rng.random_range(5..(width - 5)) as f32;
74 let center_y = self.rng.random_range(5..(height - 5)) as f32;
75 let radius = self.rng.random_range(3..8) as f32;
76 let intensity = self.rng.random_range(0.5..1.0);
77
78 for i in 0..height {
79 for j in 0..width {
80 let dx = j as f32 - center_x;
81 let dy = i as f32 - center_y;
82 if dx * dx + dy * dy <= radius * radius {
83 image[[0, i, j]] = intensity;
84 }
85 }
86 }
87 }
88 }
89 1 => {
90 let stripe_width = self.rng.random_range(2..6);
92 let intensity = self.rng.random_range(0.5..1.0);
93 for i in 0..height {
94 if (i / stripe_width) % 2 == 0 {
95 for j in 0..width {
96 image[[0, i, j]] = intensity;
97 }
98 }
99 }
100 }
101 2 => {
102 let square_size = self.rng.random_range(3..8);
104 let intensity = self.rng.random_range(0.5..1.0);
105 for i in 0..height {
106 for j in 0..width {
107 if ((i / square_size) + (j / square_size)) % 2 == 0 {
108 image[[0, i, j]] = intensity;
109 }
110 }
111 }
112 }
113 _ => {
114 let direction = self.rng.random_range(0..2);
116 let intensity = self.rng.random_range(0.5..1.0);
117 for i in 0..height {
118 for j in 0..width {
119 let gradient_val = if direction == 0 {
120 i as f32 / height as f32
121 } else {
122 j as f32 / width as f32
123 };
124 image[[0, i, j]] = intensity * gradient_val;
125 }
126 }
127 }
128 }
129
130 for elem in image.iter_mut() {
132 *elem += self.rng.random_range(-0.1..0.1);
133 *elem = elem.max(0.0).min(1.0);
134 }
135
136 image
137 }
138
139 pub fn generate_batch(&mut self, batch_size: usize) -> Array4<f32> {
141 let (height, width) = self.config.input_size;
142 let mut images = Array4::<f32>::zeros((batch_size, 1, height, width));
143
144 for i in 0..batch_size {
145 let image = self.generate_sample();
146 images.slice_mut(s![i, .., .., ..]).assign(&image);
147 }
148
149 images
150 }
151}
152
153pub struct VAEEncoder {
155 feature_extractor: Sequential<f32>,
156 mean_head: Sequential<f32>,
157 logvar_head: Sequential<f32>,
158 #[allow(dead_code)]
159 config: GenerativeConfig,
160}
161
162impl VAEEncoder {
163 pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
164 let (_height, _width) = config.input_size;
165
166 let mut feature_extractor = Sequential::new();
168 feature_extractor.add(Conv2D::new(1, 32, (3, 3), (2, 2), PaddingMode::Same, rng)?);
169 feature_extractor.add(BatchNorm::new(32, 1e-5, 0.1, rng)?);
170
171 feature_extractor.add(Conv2D::new(32, 64, (3, 3), (2, 2), PaddingMode::Same, rng)?);
172 feature_extractor.add(BatchNorm::new(64, 1e-5, 0.1, rng)?);
173
174 feature_extractor.add(Conv2D::new(
175 64,
176 128,
177 (3, 3),
178 (2, 2),
179 PaddingMode::Same,
180 rng,
181 )?);
182 feature_extractor.add(BatchNorm::new(128, 1e-5, 0.1, rng)?);
183
184 feature_extractor.add(AdaptiveMaxPool2D::new((4, 4), None)?);
185
186 let feature_size = 128 * 4 * 4;
188
189 let mut mean_head = Sequential::new();
191 mean_head.add(Dense::new(
192 feature_size,
193 config.hidden_dims[0],
194 Some("relu"),
195 rng,
196 )?);
197 mean_head.add(Dropout::new(0.2, rng)?);
198 mean_head.add(Dense::new(
199 config.hidden_dims[0],
200 config.latent_dim,
201 None,
202 rng,
203 )?);
204
205 let mut logvar_head = Sequential::new();
207 logvar_head.add(Dense::new(
208 feature_size,
209 config.hidden_dims[0],
210 Some("relu"),
211 rng,
212 )?);
213 logvar_head.add(Dropout::new(0.2, rng)?);
214 logvar_head.add(Dense::new(
215 config.hidden_dims[0],
216 config.latent_dim,
217 None,
218 rng,
219 )?);
220
221 Ok(Self {
222 feature_extractor,
223 mean_head,
224 logvar_head,
225 config,
226 })
227 }
228
229 pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<(ArrayD<f32>, ArrayD<f32>)> {
230 let features = self.feature_extractor.forward(input)?;
232
233 let batch_size = features.shape()[0];
235 let feature_dim = features.len() / batch_size;
236 let flattened = features
237 .to_shape(IxDyn(&[batch_size, feature_dim]))?
238 .to_owned();
239
240 let mean = self.mean_head.forward(&flattened)?;
242 let logvar = self.logvar_head.forward(&flattened)?;
243
244 Ok((mean, logvar))
245 }
246}
247
248pub struct VAEDecoder {
250 latent_projection: Sequential<f32>,
251 feature_layers: Sequential<f32>,
252 output_conv: Conv2D<f32>,
253 config: GenerativeConfig,
254}
255
256impl VAEDecoder {
257 pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
258 let mut latent_projection = Sequential::new();
260 latent_projection.add(Dense::new(
261 config.latent_dim,
262 config.hidden_dims[0],
263 Some("relu"),
264 rng,
265 )?);
266 latent_projection.add(Dense::new(
267 config.hidden_dims[0],
268 128 * 4 * 4,
269 Some("relu"),
270 rng,
271 )?);
272
273 let mut feature_layers = Sequential::new();
275 feature_layers.add(Conv2D::new(
276 128,
277 64,
278 (3, 3),
279 (1, 1),
280 PaddingMode::Same,
281 rng,
282 )?);
283 feature_layers.add(BatchNorm::new(64, 1e-5, 0.1, rng)?);
284
285 feature_layers.add(Conv2D::new(64, 32, (3, 3), (1, 1), PaddingMode::Same, rng)?);
286 feature_layers.add(BatchNorm::new(32, 1e-5, 0.1, rng)?);
287
288 let output_conv = Conv2D::new(32, 1, (3, 3), (1, 1), PaddingMode::Same, rng)?;
290
291 Ok(Self {
292 latent_projection,
293 feature_layers,
294 output_conv,
295 config,
296 })
297 }
298
299 pub fn forward(&self, latent: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
300 let projected = self.latent_projection.forward(latent)?;
302
303 let batch_size = projected.shape()[0];
305 let reshaped = projected.into_shape_with_order(IxDyn(&[batch_size, 128, 4, 4]))?;
306
307 let upsampled = self.upsample(&reshaped)?;
309
310 let features = self.feature_layers.forward(&upsampled)?;
312
313 let output = self.output_conv.forward(&features)?;
315
316 Ok(output)
317 }
318
319 fn upsample(&self, input: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
320 let shape = input.shape();
321 let batch_size = shape[0];
322 let channels = shape[1];
323 let height = shape[2];
324 let width = shape[3];
325
326 let (target_height, target_width) = self.config.input_size;
327 let scale_h = target_height / height;
328 let scale_w = target_width / width;
329
330 let mut upsampled =
331 Array4::<f32>::zeros((batch_size, channels, target_height, target_width));
332
333 for b in 0..batch_size {
334 for c in 0..channels {
335 for i in 0..height {
336 for j in 0..width {
337 let value = input[[b, c, i, j]];
338 for di in 0..scale_h {
339 for dj in 0..scale_w {
340 let new_i = i * scale_h + di;
341 let new_j = j * scale_w + dj;
342 if new_i < target_height && new_j < target_width {
343 upsampled[[b, c, new_i, new_j]] = value;
344 }
345 }
346 }
347 }
348 }
349 }
350 }
351
352 Ok(upsampled.into_dyn())
353 }
354}
355
356pub struct VAEModel {
358 encoder: VAEEncoder,
359 decoder: VAEDecoder,
360 config: GenerativeConfig,
361}
362
363impl VAEModel {
364 pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
365 let encoder = VAEEncoder::new(config.clone(), rng)?;
366 let decoder = VAEDecoder::new(config.clone(), rng)?;
367
368 Ok(Self {
369 encoder,
370 decoder,
371 config,
372 })
373 }
374
375 pub fn forward(
376 &self,
377 input: &ArrayD<f32>,
378 ) -> StdResult<(ArrayD<f32>, ArrayD<f32>, ArrayD<f32>)> {
379 let (mean, logvar) = self.encoder.forward(input)?;
381
382 let latent = self.reparameterize(&mean, &logvar)?;
384
385 let reconstruction = self.decoder.forward(&latent)?;
387
388 Ok((reconstruction, mean, logvar))
389 }
390
391 fn reparameterize(&self, mean: &ArrayD<f32>, logvar: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
392 let mut epsilon = Array::zeros(mean.raw_dim());
394 let mut rng = SmallRng::seed_from_u64(42); for elem in epsilon.iter_mut() {
397 *elem = rng.random_range(-1.0..1.0); }
399
400 let mut result = Array::zeros(mean.raw_dim());
402 for (((&m, &lv), &eps), res) in mean
403 .iter()
404 .zip(logvar.iter())
405 .zip(epsilon.iter())
406 .zip(result.iter_mut())
407 {
408 let std = (0.5 * lv).exp();
409 *res = m + std * eps;
410 }
411
412 Ok(result)
413 }
414
415 pub fn generate(&self, batch_size: usize) -> StdResult<ArrayD<f32>> {
417 let mut latent = Array2::<f32>::zeros((batch_size, self.config.latent_dim));
419 let mut rng = SmallRng::seed_from_u64(123);
420
421 for elem in latent.iter_mut() {
422 *elem = rng.random_range(-1.0..1.0);
423 }
424
425 let latent_dyn = latent.into_dyn();
426
427 self.decoder.forward(&latent_dyn)
429 }
430
431 pub fn interpolate(
433 &self,
434 latent1: &ArrayD<f32>,
435 latent2: &ArrayD<f32>,
436 steps: usize,
437 ) -> StdResult<Vec<ArrayD<f32>>> {
438 let mut results = Vec::new();
439
440 for i in 0..steps {
441 let alpha = i as f32 / (steps - 1) as f32;
442
443 let mut interpolated = Array::zeros(latent1.raw_dim());
445 for ((&l1, &l2), interp) in latent1
446 .iter()
447 .zip(latent2.iter())
448 .zip(interpolated.iter_mut())
449 {
450 *interp = (1.0 - alpha) * l1 + alpha * l2;
451 }
452
453 let generated = self.decoder.forward(&interpolated)?;
454 results.push(generated);
455 }
456
457 Ok(results)
458 }
459}
460
461pub struct VAELoss {
463 reconstruction_loss: MeanSquaredError,
464 beta: f32,
465}
466
467impl VAELoss {
468 pub fn new(beta: f32) -> Self {
469 Self {
470 reconstruction_loss: MeanSquaredError::new(),
471 beta,
472 }
473 }
474
475 pub fn compute_loss(
476 &self,
477 reconstruction: &ArrayD<f32>,
478 target: &ArrayD<f32>,
479 mean: &ArrayD<f32>,
480 logvar: &ArrayD<f32>,
481 ) -> StdResult<(f32, f32, f32)> {
482 let recon_loss = self.reconstruction_loss.forward(reconstruction, target)?;
484
485 let mut kl_loss = 0.0f32;
487 for (&m, &lv) in mean.iter().zip(logvar.iter()) {
488 kl_loss += -0.5 * (1.0 + lv - m * m - lv.exp());
489 }
490 kl_loss /= mean.len() as f32; let total_loss = recon_loss + self.beta * kl_loss;
493
494 Ok((total_loss, recon_loss, kl_loss))
495 }
496}
497
498pub struct GANGenerator {
500 layers: Sequential<f32>,
501 config: GenerativeConfig,
502}
503
504impl GANGenerator {
505 pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
506 let mut layers = Sequential::new();
507
508 layers.add(Dense::new(
510 config.latent_dim,
511 config.hidden_dims[0],
512 Some("relu"),
513 rng,
514 )?);
515 layers.add(BatchNorm::new(config.hidden_dims[0], 1e-5, 0.1, rng)?);
516
517 layers.add(Dense::new(
518 config.hidden_dims[0],
519 config.hidden_dims[1] * 2,
520 Some("relu"),
521 rng,
522 )?);
523 layers.add(BatchNorm::new(config.hidden_dims[1] * 2, 1e-5, 0.1, rng)?);
524
525 let output_size = config.input_size.0 * config.input_size.1;
527 layers.add(Dense::new(
528 config.hidden_dims[1] * 2,
529 output_size,
530 Some("tanh"),
531 rng,
532 )?);
533
534 Ok(Self { layers, config })
535 }
536
537 pub fn forward(&self, noise: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
538 let output = self.layers.forward(noise)?;
539
540 let batch_size = output.shape()[0];
542 let (height, width) = self.config.input_size;
543 let reshaped = output
544 .to_shape(IxDyn(&[batch_size, 1, height, width]))?
545 .to_owned();
546
547 Ok(reshaped)
548 }
549}
550
551pub struct GANDiscriminator {
553 layers: Sequential<f32>,
554 config: GenerativeConfig,
555}
556
557impl GANDiscriminator {
558 pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
559 let mut layers = Sequential::new();
560
561 let input_size = config.input_size.0 * config.input_size.1;
562
563 layers.add(Dense::new(
564 input_size,
565 config.hidden_dims[0],
566 Some("relu"),
567 rng,
568 )?);
569 layers.add(Dropout::new(0.3, rng)?);
570
571 layers.add(Dense::new(
572 config.hidden_dims[0],
573 config.hidden_dims[1],
574 Some("relu"),
575 rng,
576 )?);
577 layers.add(Dropout::new(0.3, rng)?);
578
579 layers.add(Dense::new(config.hidden_dims[1], 1, Some("sigmoid"), rng)?);
581
582 Ok(Self { layers, config })
583 }
584
585 pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
586 let batch_size = input.shape()[0];
588 let input_size = self.config.input_size.0 * self.config.input_size.1;
589 let flattened = input.to_shape(IxDyn(&[batch_size, input_size]))?.to_owned();
590
591 Ok(self.layers.forward(&flattened)?)
592 }
593}
594
595pub struct GenerativeMetrics {
597 #[allow(dead_code)]
598 config: GenerativeConfig,
599}
600
601impl GenerativeMetrics {
602 pub fn new(config: GenerativeConfig) -> Self {
603 Self { config }
604 }
605
606 pub fn reconstruction_error(&self, original: &ArrayD<f32>, reconstructed: &ArrayD<f32>) -> f32 {
608 let mut mse = 0.0f32;
609 let mut count = 0;
610
611 for (&orig, &recon) in original.iter().zip(reconstructed.iter()) {
612 let diff = orig - recon;
613 mse += diff * diff;
614 count += 1;
615 }
616
617 if count > 0 {
618 mse / count as f32
619 } else {
620 0.0
621 }
622 }
623
624 pub fn sample_diversity(&self, samples: &ArrayD<f32>) -> f32 {
626 let batch_size = samples.shape()[0];
627 if batch_size < 2 {
628 return 0.0;
629 }
630
631 let mut total_variance = 0.0f32;
632 let sample_size = samples.len() / batch_size;
633
634 for i in 0..sample_size {
635 let mut values = Vec::new();
636 for b in 0..batch_size {
637 let flat_idx = b * sample_size + i;
638 if let Some(&val) = samples.iter().nth(flat_idx) {
639 values.push(val);
640 }
641 }
642
643 if values.len() > 1 {
644 let mean = values.iter().sum::<f32>() / values.len() as f32;
645 let variance = values.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>()
646 / values.len() as f32;
647 total_variance += variance;
648 }
649 }
650
651 total_variance / sample_size as f32
652 }
653}
654
655fn train_vae_model() -> StdResult<()> {
657 println!("šØ Starting VAE Training");
658
659 let mut rng = SmallRng::seed_from_u64(42);
660 let config = GenerativeConfig::default();
661
662 println!("šļø Building VAE model...");
668 let vae = VAEModel::new(config.clone(), &mut rng)?;
669 println!("ā
VAE created with latent dimension {}", config.latent_dim);
670
671 let mut dataset = GenerativeDataset::new(config.clone(), 123);
673
674 let loss_fn = VAELoss::new(config.beta);
676
677 let metrics = GenerativeMetrics::new(config.clone());
679
680 println!("š Training configuration:");
681 println!(" - Input size: {:?}", config.input_size);
682 println!(" - Latent dimension: {}", config.latent_dim);
683 println!(" - Beta (KL weight): {}", config.beta);
684 println!(" - Hidden dimensions: {:?}", config.hidden_dims);
685
686 let num_epochs = 20;
688 let batch_size = 4;
689 let _learning_rate = 0.001;
690
691 for epoch in 0..num_epochs {
692 println!("\nš Epoch {}/{}", epoch + 1, num_epochs);
693
694 let mut epoch_total_loss = 0.0;
695 let mut epoch_recon_loss = 0.0;
696 let mut epoch_kl_loss = 0.0;
697 let num_batches = 12;
698
699 for batch_idx in 0..num_batches {
700 let images = dataset.generate_batch(batch_size);
702 let images_dyn = images.into_dyn();
703
704 let (reconstruction, mean, logvar) = vae.forward(&images_dyn)?;
706
707 let (total_loss, recon_loss, kl_loss) =
709 loss_fn.compute_loss(&reconstruction, &images_dyn, &mean, &logvar)?;
710
711 epoch_total_loss += total_loss;
712 epoch_recon_loss += recon_loss;
713 epoch_kl_loss += kl_loss;
714
715 if batch_idx % 6 == 0 {
716 print!(
717 "š Batch {}/{} - Total: {:.4}, Recon: {:.4}, KL: {:.4} \r",
718 batch_idx + 1,
719 num_batches,
720 total_loss,
721 recon_loss,
722 kl_loss
723 );
724 }
725 }
726
727 let avg_total = epoch_total_loss / num_batches as f32;
728 let avg_recon = epoch_recon_loss / num_batches as f32;
729 let avg_kl = epoch_kl_loss / num_batches as f32;
730
731 println!(
732 "ā
Epoch {} - Total: {:.4}, Recon: {:.4}, KL: {:.4}",
733 epoch + 1,
734 avg_total,
735 avg_recon,
736 avg_kl
737 );
738
739 if (epoch + 1) % 5 == 0 {
741 println!("š Running evaluation and generation...");
742
743 let test_images = dataset.generate_batch(batch_size);
745 let test_images_dyn = test_images.into_dyn();
746 let (test_reconstruction, _, _) = vae.forward(&test_images_dyn)?;
747
748 let recon_error = metrics.reconstruction_error(&test_images_dyn, &test_reconstruction);
749 println!("š Reconstruction MSE: {:.6}", recon_error);
750
751 let generated_samples = vae.generate(8)?;
753 let diversity = metrics.sample_diversity(&generated_samples);
754 println!("š² Sample diversity: {:.6}", diversity);
755
756 let latent1 = Array2::<f32>::from_elem((1, config.latent_dim), -1.0).into_dyn();
758 let latent2 = Array2::<f32>::from_elem((1, config.latent_dim), 1.0).into_dyn();
759 let interpolated = vae.interpolate(&latent1, &latent2, 5)?;
760
761 println!("š Generated {} interpolated samples", interpolated.len());
762 }
763 }
764
765 println!("\nš VAE training completed!");
766 Ok(())
767}
768
769fn train_gan_model() -> StdResult<()> {
771 println!("āļø Starting GAN Training");
772
773 let mut rng = SmallRng::seed_from_u64(42);
774 let config = GenerativeConfig::default();
775
776 println!("šļø Building GAN models...");
778 let generator = GANGenerator::new(config.clone(), &mut rng)?;
779 let discriminator = GANDiscriminator::new(config.clone(), &mut rng)?;
780 println!("ā
GAN models created");
781
782 let mut dataset = GenerativeDataset::new(config.clone(), 456);
784
785 let _adversarial_loss = CrossEntropyLoss::new(1e-7);
787
788 println!("š GAN training configuration:");
789 println!(" - Generator latent dim: {}", config.latent_dim);
790 println!(" - Discriminator architecture: {:?}", config.hidden_dims);
791
792 let num_epochs = 15;
794 let batch_size = 4;
795
796 for epoch in 0..num_epochs {
797 println!("\nš Epoch {}/{}", epoch + 1, num_epochs);
798
799 let mut d_loss_total = 0.0;
800 let mut g_loss_total = 0.0;
801 let num_batches = 8;
802
803 for batch_idx in 0..num_batches {
804 let real_images = dataset.generate_batch(batch_size);
806 let real_images_dyn = real_images.into_dyn();
807
808 let mut noise = Array2::<f32>::zeros((batch_size, config.latent_dim));
810 for elem in noise.iter_mut() {
811 *elem = rng.random_range(-1.0..1.0);
812 }
813 let noise_dyn = noise.into_dyn();
814 let fake_images = generator.forward(&noise_dyn)?;
815
816 let real_pred = discriminator.forward(&real_images_dyn)?;
818 let fake_pred = discriminator.forward(&fake_images)?;
819
820 let mut d_loss_real = 0.0f32;
822 let mut d_loss_fake = 0.0f32;
823
824 for &pred in real_pred.iter() {
825 d_loss_real += -(1.0f32).ln() - pred; }
827
828 for &pred in fake_pred.iter() {
829 d_loss_fake += -(1.0 - pred).ln(); }
831
832 let d_loss = (d_loss_real + d_loss_fake) / (batch_size * 2) as f32;
833 d_loss_total += d_loss;
834
835 let fake_pred_for_g = discriminator.forward(&fake_images)?;
837 let mut g_loss = 0.0f32;
838
839 for &pred in fake_pred_for_g.iter() {
840 g_loss += -(1.0f32).ln() - pred; }
842 g_loss /= batch_size as f32;
843 g_loss_total += g_loss;
844
845 if batch_idx % 4 == 0 {
846 print!(
847 "š Batch {}/{} - D Loss: {:.4}, G Loss: {:.4} \r",
848 batch_idx + 1,
849 num_batches,
850 d_loss,
851 g_loss
852 );
853 }
854 }
855
856 let avg_d_loss = d_loss_total / num_batches as f32;
857 let avg_g_loss = g_loss_total / num_batches as f32;
858
859 println!(
860 "ā
Epoch {} - D Loss: {:.4}, G Loss: {:.4}",
861 epoch + 1,
862 avg_d_loss,
863 avg_g_loss
864 );
865
866 if (epoch + 1) % 5 == 0 {
868 println!("š² Generating samples...");
869
870 let mut sample_noise = Array2::<f32>::zeros((4, config.latent_dim));
871 for elem in sample_noise.iter_mut() {
872 *elem = rng.random_range(-1.0..1.0);
873 }
874 let sample_noise_dyn = sample_noise.into_dyn();
875 let generated = generator.forward(&sample_noise_dyn)?;
876
877 println!("š Generated {} samples", generated.shape()[0]);
878 }
879 }
880
881 println!("\nš GAN training completed!");
882 Ok(())
883}
884
885fn main() -> StdResult<()> {
886 println!("šØ Generative Models Complete Example");
887 println!("=====================================");
888 println!();
889 println!("This example demonstrates:");
890 println!("⢠Variational Autoencoder (VAE) implementation");
891 println!("⢠Generative Adversarial Network (GAN) basics");
892 println!("⢠Synthetic pattern dataset generation");
893 println!("⢠VAE loss (reconstruction + KL divergence)");
894 println!("⢠Latent space interpolation");
895 println!("⢠Sample generation and evaluation");
896 println!();
897
898 train_vae_model()?;
900
901 println!("\n{}", "=".repeat(50));
902
903 train_gan_model()?;
905
906 println!("\nš” Key Concepts Demonstrated:");
907 println!(" š¹ Variational inference and reparameterization trick");
908 println!(" š¹ KL divergence regularization");
909 println!(" š¹ Adversarial training dynamics");
910 println!(" š¹ Latent space manipulation");
911 println!(" š¹ Reconstruction vs generation quality");
912 println!(" š¹ Sample diversity metrics");
913 println!();
914 println!("š For production use:");
915 println!(" ⢠Implement β-VAE, WAE, or other VAE variants");
916 println!(" ⢠Add convolutional layers for better image modeling");
917 println!(" ⢠Implement DCGAN, StyleGAN, or other advanced GANs");
918 println!(" ⢠Add progressive training and spectral normalization");
919 println!(" ⢠Use FID, IS, or other advanced evaluation metrics");
920 println!(" ⢠Implement conditional generation (cVAE, cGAN)");
921 println!(" ⢠Add attention mechanisms and self-attention");
922
923 Ok(())
924}
925
926#[cfg(test)]
927mod tests {
928 use super::*;
929
930 #[test]
931 fn test_generative_config() {
932 let config = GenerativeConfig::default();
933 assert_eq!(config.input_size, (32, 32));
934 assert_eq!(config.latent_dim, 16);
935 assert_eq!(config.beta, 1.0);
936 assert!(!config.hidden_dims.is_empty());
937 }
938
939 #[test]
940 fn test_dataset_generation() {
941 let config = GenerativeConfig::default();
942 let mut dataset = GenerativeDataset::new(config.clone(), 42);
943
944 let image = dataset.generate_sample();
945 assert_eq!(
946 image.shape(),
947 &[1, config.input_size.0, config.input_size.1]
948 );
949
950 for &val in image.iter() {
952 assert!(val >= 0.0 && val <= 1.0);
953 }
954 }
955
956 #[test]
957 fn test_vae_creation() -> StdResult<()> {
958 let mut rng = SmallRng::seed_from_u64(42);
959 let config = GenerativeConfig::default();
960
961 let vae = VAEModel::new(config.clone(), &mut rng)?;
962
963 let batch_size = 2;
965 let input = Array4::<f32>::ones((batch_size, 1, config.input_size.0, config.input_size.1))
966 .into_dyn();
967 let (reconstruction, mean, logvar) = vae.forward(&input)?;
968
969 assert_eq!(reconstruction.shape()[0], batch_size);
970 assert_eq!(mean.shape()[1], config.latent_dim);
971 assert_eq!(logvar.shape()[1], config.latent_dim);
972
973 Ok(())
974 }
975
976 #[test]
977 fn test_gan_creation() -> StdResult<()> {
978 let mut rng = SmallRng::seed_from_u64(42);
979 let config = GenerativeConfig::default();
980
981 let generator = GANGenerator::new(config.clone(), &mut rng)?;
982 let discriminator = GANDiscriminator::new(config.clone(), &mut rng)?;
983
984 let batch_size = 2;
986 let noise = Array2::<f32>::ones((batch_size, config.latent_dim)).into_dyn();
987 let generated = generator.forward(&noise)?;
988
989 assert_eq!(generated.shape()[0], batch_size);
990 assert_eq!(generated.shape()[1], 1);
991
992 let pred = discriminator.forward(&generated)?;
994 assert_eq!(pred.shape()[0], batch_size);
995 assert_eq!(pred.shape()[1], 1);
996
997 Ok(())
998 }
999
1000 #[test]
1001 fn test_vae_loss() -> StdResult<()> {
1002 let loss_fn = VAELoss::new(1.0);
1003
1004 let reconstruction = Array2::<f32>::ones((2, 10)).into_dyn();
1005 let target = Array2::<f32>::zeros((2, 10)).into_dyn();
1006 let mean = Array2::<f32>::zeros((2, 5)).into_dyn();
1007 let logvar = Array2::<f32>::zeros((2, 5)).into_dyn();
1008
1009 let (total_loss, recon_loss, kl_loss) =
1010 loss_fn.compute_loss(&reconstruction, &target, &mean, &logvar)?;
1011
1012 assert!(total_loss > 0.0);
1013 assert!(recon_loss > 0.0);
1014 assert!(kl_loss.abs() < 1e-6);
1016
1017 Ok(())
1018 }
1019
1020 #[test]
1021 fn test_generative_metrics() {
1022 let config = GenerativeConfig::default();
1023 let metrics = GenerativeMetrics::new(config);
1024
1025 let original = Array2::<f32>::ones((2, 10)).into_dyn();
1027 let reconstructed = Array2::<f32>::zeros((2, 10)).into_dyn();
1028
1029 let error = metrics.reconstruction_error(&original, &reconstructed);
1030 assert_eq!(error, 1.0); let samples = Array2::<f32>::from_shape_fn((3, 4), |(i, j)| i as f32 + j as f32).into_dyn();
1034 let diversity = metrics.sample_diversity(&samples);
1035 assert!(diversity > 0.0);
1036 }
1037}