generative_models_complete/
generative_models_complete.rs

1//! Complete Generative Models Example
2//!
3//! This example demonstrates building generative models using scirs2-neural.
4//! It includes:
5//! - Variational Autoencoder (VAE) implementation
6//! - Generator and Discriminator for simple GAN
7//! - Synthetic dataset generation
8//! - VAE loss (reconstruction + KL divergence)
9//! - GAN training with adversarial loss
10//! - Sample generation and evaluation metrics
11//! - Latent space interpolation
12
13use ndarray::{s, Array, Array2, Array3, Array4, ArrayD, IxDyn};
14use scirs2_neural::layers::{
15    AdaptiveMaxPool2D, BatchNorm, Conv2D, Dense, Dropout, PaddingMode, Sequential,
16};
17use scirs2_neural::losses::{CrossEntropyLoss, MeanSquaredError};
18use scirs2_neural::prelude::*;
19
20// Type alias to avoid conflicts with scirs2-neural's Result
21type StdResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
22use rand::prelude::*;
23use rand::rngs::SmallRng;
24// use std::collections::HashMap;
25// use std::f32::consts::PI;
26
27/// Configuration for generative models
28#[derive(Debug, Clone)]
29pub struct GenerativeConfig {
30    pub input_size: (usize, usize),
31    pub latent_dim: usize,
32    pub hidden_dims: Vec<usize>,
33    pub beta: f32, // Beta parameter for beta-VAE
34}
35
36impl Default for GenerativeConfig {
37    fn default() -> Self {
38        Self {
39            input_size: (32, 32),
40            latent_dim: 16,
41            hidden_dims: vec![128, 64, 32],
42            beta: 1.0,
43        }
44    }
45}
46
47/// Synthetic dataset generator for generative modeling
48pub struct GenerativeDataset {
49    config: GenerativeConfig,
50    rng: SmallRng,
51}
52
53impl GenerativeDataset {
54    pub fn new(config: GenerativeConfig, seed: u64) -> Self {
55        Self {
56            config,
57            rng: SmallRng::seed_from_u64(seed),
58        }
59    }
60
61    /// Generate a synthetic image (simple patterns)
62    pub fn generate_sample(&mut self) -> Array3<f32> {
63        let (height, width) = self.config.input_size;
64        let mut image = Array3::<f32>::zeros((1, height, width)); // Grayscale
65
66        let pattern_type = self.rng.random_range(0..4);
67
68        match pattern_type {
69            0 => {
70                // Circles
71                let num_circles = self.rng.random_range(1..4);
72                for _ in 0..num_circles {
73                    let center_x = self.rng.random_range(5..(width - 5)) as f32;
74                    let center_y = self.rng.random_range(5..(height - 5)) as f32;
75                    let radius = self.rng.random_range(3..8) as f32;
76                    let intensity = self.rng.random_range(0.5..1.0);
77
78                    for i in 0..height {
79                        for j in 0..width {
80                            let dx = j as f32 - center_x;
81                            let dy = i as f32 - center_y;
82                            if dx * dx + dy * dy <= radius * radius {
83                                image[[0, i, j]] = intensity;
84                            }
85                        }
86                    }
87                }
88            }
89            1 => {
90                // Stripes
91                let stripe_width = self.rng.random_range(2..6);
92                let intensity = self.rng.random_range(0.5..1.0);
93                for i in 0..height {
94                    if (i / stripe_width) % 2 == 0 {
95                        for j in 0..width {
96                            image[[0, i, j]] = intensity;
97                        }
98                    }
99                }
100            }
101            2 => {
102                // Checkerboard
103                let square_size = self.rng.random_range(3..8);
104                let intensity = self.rng.random_range(0.5..1.0);
105                for i in 0..height {
106                    for j in 0..width {
107                        if ((i / square_size) + (j / square_size)) % 2 == 0 {
108                            image[[0, i, j]] = intensity;
109                        }
110                    }
111                }
112            }
113            _ => {
114                // Gradient
115                let direction = self.rng.random_range(0..2);
116                let intensity = self.rng.random_range(0.5..1.0);
117                for i in 0..height {
118                    for j in 0..width {
119                        let gradient_val = if direction == 0 {
120                            i as f32 / height as f32
121                        } else {
122                            j as f32 / width as f32
123                        };
124                        image[[0, i, j]] = intensity * gradient_val;
125                    }
126                }
127            }
128        }
129
130        // Add noise
131        for elem in image.iter_mut() {
132            *elem += self.rng.random_range(-0.1..0.1);
133            *elem = elem.max(0.0).min(1.0);
134        }
135
136        image
137    }
138
139    /// Generate a batch of samples
140    pub fn generate_batch(&mut self, batch_size: usize) -> Array4<f32> {
141        let (height, width) = self.config.input_size;
142        let mut images = Array4::<f32>::zeros((batch_size, 1, height, width));
143
144        for i in 0..batch_size {
145            let image = self.generate_sample();
146            images.slice_mut(s![i, .., .., ..]).assign(&image);
147        }
148
149        images
150    }
151}
152
153/// VAE Encoder that outputs mean and log variance for latent distribution
154pub struct VAEEncoder {
155    feature_extractor: Sequential<f32>,
156    mean_head: Sequential<f32>,
157    logvar_head: Sequential<f32>,
158    #[allow(dead_code)]
159    config: GenerativeConfig,
160}
161
162impl VAEEncoder {
163    pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
164        let (_height, _width) = config.input_size;
165
166        // Feature extraction layers
167        let mut feature_extractor = Sequential::new();
168        feature_extractor.add(Conv2D::new(1, 32, (3, 3), (2, 2), PaddingMode::Same, rng)?);
169        feature_extractor.add(BatchNorm::new(32, 1e-5, 0.1, rng)?);
170
171        feature_extractor.add(Conv2D::new(32, 64, (3, 3), (2, 2), PaddingMode::Same, rng)?);
172        feature_extractor.add(BatchNorm::new(64, 1e-5, 0.1, rng)?);
173
174        feature_extractor.add(Conv2D::new(
175            64,
176            128,
177            (3, 3),
178            (2, 2),
179            PaddingMode::Same,
180            rng,
181        )?);
182        feature_extractor.add(BatchNorm::new(128, 1e-5, 0.1, rng)?);
183
184        feature_extractor.add(AdaptiveMaxPool2D::new((4, 4), None)?);
185
186        // Calculate flattened feature size
187        let feature_size = 128 * 4 * 4;
188
189        // Mean head
190        let mut mean_head = Sequential::new();
191        mean_head.add(Dense::new(
192            feature_size,
193            config.hidden_dims[0],
194            Some("relu"),
195            rng,
196        )?);
197        mean_head.add(Dropout::new(0.2, rng)?);
198        mean_head.add(Dense::new(
199            config.hidden_dims[0],
200            config.latent_dim,
201            None,
202            rng,
203        )?);
204
205        // Log variance head
206        let mut logvar_head = Sequential::new();
207        logvar_head.add(Dense::new(
208            feature_size,
209            config.hidden_dims[0],
210            Some("relu"),
211            rng,
212        )?);
213        logvar_head.add(Dropout::new(0.2, rng)?);
214        logvar_head.add(Dense::new(
215            config.hidden_dims[0],
216            config.latent_dim,
217            None,
218            rng,
219        )?);
220
221        Ok(Self {
222            feature_extractor,
223            mean_head,
224            logvar_head,
225            config,
226        })
227    }
228
229    pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<(ArrayD<f32>, ArrayD<f32>)> {
230        // Extract features
231        let features = self.feature_extractor.forward(input)?;
232
233        // Flatten features
234        let batch_size = features.shape()[0];
235        let feature_dim = features.len() / batch_size;
236        let flattened = features
237            .to_shape(IxDyn(&[batch_size, feature_dim]))?
238            .to_owned();
239
240        // Get mean and log variance
241        let mean = self.mean_head.forward(&flattened)?;
242        let logvar = self.logvar_head.forward(&flattened)?;
243
244        Ok((mean, logvar))
245    }
246}
247
248/// VAE Decoder that reconstructs images from latent codes
249pub struct VAEDecoder {
250    latent_projection: Sequential<f32>,
251    feature_layers: Sequential<f32>,
252    output_conv: Conv2D<f32>,
253    config: GenerativeConfig,
254}
255
256impl VAEDecoder {
257    pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
258        // Project latent to feature space
259        let mut latent_projection = Sequential::new();
260        latent_projection.add(Dense::new(
261            config.latent_dim,
262            config.hidden_dims[0],
263            Some("relu"),
264            rng,
265        )?);
266        latent_projection.add(Dense::new(
267            config.hidden_dims[0],
268            128 * 4 * 4,
269            Some("relu"),
270            rng,
271        )?);
272
273        // Feature reconstruction layers (simplified transpose convolutions)
274        let mut feature_layers = Sequential::new();
275        feature_layers.add(Conv2D::new(
276            128,
277            64,
278            (3, 3),
279            (1, 1),
280            PaddingMode::Same,
281            rng,
282        )?);
283        feature_layers.add(BatchNorm::new(64, 1e-5, 0.1, rng)?);
284
285        feature_layers.add(Conv2D::new(64, 32, (3, 3), (1, 1), PaddingMode::Same, rng)?);
286        feature_layers.add(BatchNorm::new(32, 1e-5, 0.1, rng)?);
287
288        // Output layer
289        let output_conv = Conv2D::new(32, 1, (3, 3), (1, 1), PaddingMode::Same, rng)?;
290
291        Ok(Self {
292            latent_projection,
293            feature_layers,
294            output_conv,
295            config,
296        })
297    }
298
299    pub fn forward(&self, latent: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
300        // Project latent to feature space
301        let projected = self.latent_projection.forward(latent)?;
302
303        // Reshape to spatial format
304        let batch_size = projected.shape()[0];
305        let reshaped = projected.into_shape_with_order(IxDyn(&[batch_size, 128, 4, 4]))?;
306
307        // Upsample to target size (simplified)
308        let upsampled = self.upsample(&reshaped)?;
309
310        // Apply feature layers
311        let features = self.feature_layers.forward(&upsampled)?;
312
313        // Generate output
314        let output = self.output_conv.forward(&features)?;
315
316        Ok(output)
317    }
318
319    fn upsample(&self, input: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
320        let shape = input.shape();
321        let batch_size = shape[0];
322        let channels = shape[1];
323        let height = shape[2];
324        let width = shape[3];
325
326        let (target_height, target_width) = self.config.input_size;
327        let scale_h = target_height / height;
328        let scale_w = target_width / width;
329
330        let mut upsampled =
331            Array4::<f32>::zeros((batch_size, channels, target_height, target_width));
332
333        for b in 0..batch_size {
334            for c in 0..channels {
335                for i in 0..height {
336                    for j in 0..width {
337                        let value = input[[b, c, i, j]];
338                        for di in 0..scale_h {
339                            for dj in 0..scale_w {
340                                let new_i = i * scale_h + di;
341                                let new_j = j * scale_w + dj;
342                                if new_i < target_height && new_j < target_width {
343                                    upsampled[[b, c, new_i, new_j]] = value;
344                                }
345                            }
346                        }
347                    }
348                }
349            }
350        }
351
352        Ok(upsampled.into_dyn())
353    }
354}
355
356/// Complete VAE model
357pub struct VAEModel {
358    encoder: VAEEncoder,
359    decoder: VAEDecoder,
360    config: GenerativeConfig,
361}
362
363impl VAEModel {
364    pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
365        let encoder = VAEEncoder::new(config.clone(), rng)?;
366        let decoder = VAEDecoder::new(config.clone(), rng)?;
367
368        Ok(Self {
369            encoder,
370            decoder,
371            config,
372        })
373    }
374
375    pub fn forward(
376        &self,
377        input: &ArrayD<f32>,
378    ) -> StdResult<(ArrayD<f32>, ArrayD<f32>, ArrayD<f32>)> {
379        // Encode
380        let (mean, logvar) = self.encoder.forward(input)?;
381
382        // Reparameterization trick (simplified)
383        let latent = self.reparameterize(&mean, &logvar)?;
384
385        // Decode
386        let reconstruction = self.decoder.forward(&latent)?;
387
388        Ok((reconstruction, mean, logvar))
389    }
390
391    fn reparameterize(&self, mean: &ArrayD<f32>, logvar: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
392        // Sample epsilon from standard normal
393        let mut epsilon = Array::zeros(mean.raw_dim());
394        let mut rng = SmallRng::seed_from_u64(42); // Fixed seed for reproducibility
395
396        for elem in epsilon.iter_mut() {
397            *elem = rng.random_range(-1.0..1.0); // Approximate normal
398        }
399
400        // z = mean + std * epsilon, where std = exp(0.5 * logvar)
401        let mut result = Array::zeros(mean.raw_dim());
402        for (((&m, &lv), &eps), res) in mean
403            .iter()
404            .zip(logvar.iter())
405            .zip(epsilon.iter())
406            .zip(result.iter_mut())
407        {
408            let std = (0.5 * lv).exp();
409            *res = m + std * eps;
410        }
411
412        Ok(result)
413    }
414
415    /// Generate new samples from random latent codes
416    pub fn generate(&self, batch_size: usize) -> StdResult<ArrayD<f32>> {
417        // Sample random latent codes
418        let mut latent = Array2::<f32>::zeros((batch_size, self.config.latent_dim));
419        let mut rng = SmallRng::seed_from_u64(123);
420
421        for elem in latent.iter_mut() {
422            *elem = rng.random_range(-1.0..1.0);
423        }
424
425        let latent_dyn = latent.into_dyn();
426
427        // Decode to generate images
428        self.decoder.forward(&latent_dyn)
429    }
430
431    /// Interpolate between two latent codes
432    pub fn interpolate(
433        &self,
434        latent1: &ArrayD<f32>,
435        latent2: &ArrayD<f32>,
436        steps: usize,
437    ) -> StdResult<Vec<ArrayD<f32>>> {
438        let mut results = Vec::new();
439
440        for i in 0..steps {
441            let alpha = i as f32 / (steps - 1) as f32;
442
443            // Linear interpolation
444            let mut interpolated = Array::zeros(latent1.raw_dim());
445            for ((&l1, &l2), interp) in latent1
446                .iter()
447                .zip(latent2.iter())
448                .zip(interpolated.iter_mut())
449            {
450                *interp = (1.0 - alpha) * l1 + alpha * l2;
451            }
452
453            let generated = self.decoder.forward(&interpolated)?;
454            results.push(generated);
455        }
456
457        Ok(results)
458    }
459}
460
461/// VAE Loss combining reconstruction and KL divergence
462pub struct VAELoss {
463    reconstruction_loss: MeanSquaredError,
464    beta: f32,
465}
466
467impl VAELoss {
468    pub fn new(beta: f32) -> Self {
469        Self {
470            reconstruction_loss: MeanSquaredError::new(),
471            beta,
472        }
473    }
474
475    pub fn compute_loss(
476        &self,
477        reconstruction: &ArrayD<f32>,
478        target: &ArrayD<f32>,
479        mean: &ArrayD<f32>,
480        logvar: &ArrayD<f32>,
481    ) -> StdResult<(f32, f32, f32)> {
482        // Reconstruction loss
483        let recon_loss = self.reconstruction_loss.forward(reconstruction, target)?;
484
485        // KL divergence loss: -0.5 * sum(1 + logvar - mean^2 - exp(logvar))
486        let mut kl_loss = 0.0f32;
487        for (&m, &lv) in mean.iter().zip(logvar.iter()) {
488            kl_loss += -0.5 * (1.0 + lv - m * m - lv.exp());
489        }
490        kl_loss /= mean.len() as f32; // Average over elements
491
492        let total_loss = recon_loss + self.beta * kl_loss;
493
494        Ok((total_loss, recon_loss, kl_loss))
495    }
496}
497
498/// Simple GAN Generator
499pub struct GANGenerator {
500    layers: Sequential<f32>,
501    config: GenerativeConfig,
502}
503
504impl GANGenerator {
505    pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
506        let mut layers = Sequential::new();
507
508        // Project noise to feature space
509        layers.add(Dense::new(
510            config.latent_dim,
511            config.hidden_dims[0],
512            Some("relu"),
513            rng,
514        )?);
515        layers.add(BatchNorm::new(config.hidden_dims[0], 1e-5, 0.1, rng)?);
516
517        layers.add(Dense::new(
518            config.hidden_dims[0],
519            config.hidden_dims[1] * 2,
520            Some("relu"),
521            rng,
522        )?);
523        layers.add(BatchNorm::new(config.hidden_dims[1] * 2, 1e-5, 0.1, rng)?);
524
525        // Output layer
526        let output_size = config.input_size.0 * config.input_size.1;
527        layers.add(Dense::new(
528            config.hidden_dims[1] * 2,
529            output_size,
530            Some("tanh"),
531            rng,
532        )?);
533
534        Ok(Self { layers, config })
535    }
536
537    pub fn forward(&self, noise: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
538        let output = self.layers.forward(noise)?;
539
540        // Reshape to image format
541        let batch_size = output.shape()[0];
542        let (height, width) = self.config.input_size;
543        let reshaped = output
544            .to_shape(IxDyn(&[batch_size, 1, height, width]))?
545            .to_owned();
546
547        Ok(reshaped)
548    }
549}
550
551/// Simple GAN Discriminator
552pub struct GANDiscriminator {
553    layers: Sequential<f32>,
554    config: GenerativeConfig,
555}
556
557impl GANDiscriminator {
558    pub fn new(config: GenerativeConfig, rng: &mut SmallRng) -> StdResult<Self> {
559        let mut layers = Sequential::new();
560
561        let input_size = config.input_size.0 * config.input_size.1;
562
563        layers.add(Dense::new(
564            input_size,
565            config.hidden_dims[0],
566            Some("relu"),
567            rng,
568        )?);
569        layers.add(Dropout::new(0.3, rng)?);
570
571        layers.add(Dense::new(
572            config.hidden_dims[0],
573            config.hidden_dims[1],
574            Some("relu"),
575            rng,
576        )?);
577        layers.add(Dropout::new(0.3, rng)?);
578
579        // Output probability of being real
580        layers.add(Dense::new(config.hidden_dims[1], 1, Some("sigmoid"), rng)?);
581
582        Ok(Self { layers, config })
583    }
584
585    pub fn forward(&self, input: &ArrayD<f32>) -> StdResult<ArrayD<f32>> {
586        // Flatten input
587        let batch_size = input.shape()[0];
588        let input_size = self.config.input_size.0 * self.config.input_size.1;
589        let flattened = input.to_shape(IxDyn(&[batch_size, input_size]))?.to_owned();
590
591        Ok(self.layers.forward(&flattened)?)
592    }
593}
594
595/// Generative model evaluation metrics
596pub struct GenerativeMetrics {
597    #[allow(dead_code)]
598    config: GenerativeConfig,
599}
600
601impl GenerativeMetrics {
602    pub fn new(config: GenerativeConfig) -> Self {
603        Self { config }
604    }
605
606    /// Calculate reconstruction error
607    pub fn reconstruction_error(&self, original: &ArrayD<f32>, reconstructed: &ArrayD<f32>) -> f32 {
608        let mut mse = 0.0f32;
609        let mut count = 0;
610
611        for (&orig, &recon) in original.iter().zip(reconstructed.iter()) {
612            let diff = orig - recon;
613            mse += diff * diff;
614            count += 1;
615        }
616
617        if count > 0 {
618            mse / count as f32
619        } else {
620            0.0
621        }
622    }
623
624    /// Calculate sample diversity (simplified variance measure)
625    pub fn sample_diversity(&self, samples: &ArrayD<f32>) -> f32 {
626        let batch_size = samples.shape()[0];
627        if batch_size < 2 {
628            return 0.0;
629        }
630
631        let mut total_variance = 0.0f32;
632        let sample_size = samples.len() / batch_size;
633
634        for i in 0..sample_size {
635            let mut values = Vec::new();
636            for b in 0..batch_size {
637                let flat_idx = b * sample_size + i;
638                if let Some(&val) = samples.iter().nth(flat_idx) {
639                    values.push(val);
640                }
641            }
642
643            if values.len() > 1 {
644                let mean = values.iter().sum::<f32>() / values.len() as f32;
645                let variance = values.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>()
646                    / values.len() as f32;
647                total_variance += variance;
648            }
649        }
650
651        total_variance / sample_size as f32
652    }
653}
654
655/// Training function for VAE
656fn train_vae_model() -> StdResult<()> {
657    println!("šŸŽØ Starting VAE Training");
658
659    let mut rng = SmallRng::seed_from_u64(42);
660    let config = GenerativeConfig::default();
661
662    // Initialize JIT context (currently not implemented)
663    // let _jit_context: JitContext<f32> = JitContext::new(JitStrategy::Aggressive);
664    // println!("šŸš€ JIT compilation enabled with aggressive strategy");
665
666    // Create model
667    println!("šŸ—ļø Building VAE model...");
668    let vae = VAEModel::new(config.clone(), &mut rng)?;
669    println!("āœ… VAE created with latent dimension {}", config.latent_dim);
670
671    // Create dataset
672    let mut dataset = GenerativeDataset::new(config.clone(), 123);
673
674    // Create loss function
675    let loss_fn = VAELoss::new(config.beta);
676
677    // Create metrics
678    let metrics = GenerativeMetrics::new(config.clone());
679
680    println!("šŸ“Š Training configuration:");
681    println!("   - Input size: {:?}", config.input_size);
682    println!("   - Latent dimension: {}", config.latent_dim);
683    println!("   - Beta (KL weight): {}", config.beta);
684    println!("   - Hidden dimensions: {:?}", config.hidden_dims);
685
686    // Training loop
687    let num_epochs = 20;
688    let batch_size = 4;
689    let _learning_rate = 0.001;
690
691    for epoch in 0..num_epochs {
692        println!("\nšŸ“ˆ Epoch {}/{}", epoch + 1, num_epochs);
693
694        let mut epoch_total_loss = 0.0;
695        let mut epoch_recon_loss = 0.0;
696        let mut epoch_kl_loss = 0.0;
697        let num_batches = 12;
698
699        for batch_idx in 0..num_batches {
700            // Generate training batch
701            let images = dataset.generate_batch(batch_size);
702            let images_dyn = images.into_dyn();
703
704            // Forward pass
705            let (reconstruction, mean, logvar) = vae.forward(&images_dyn)?;
706
707            // Compute loss
708            let (total_loss, recon_loss, kl_loss) =
709                loss_fn.compute_loss(&reconstruction, &images_dyn, &mean, &logvar)?;
710
711            epoch_total_loss += total_loss;
712            epoch_recon_loss += recon_loss;
713            epoch_kl_loss += kl_loss;
714
715            if batch_idx % 6 == 0 {
716                print!(
717                    "šŸ”„ Batch {}/{} - Total: {:.4}, Recon: {:.4}, KL: {:.4}           \r",
718                    batch_idx + 1,
719                    num_batches,
720                    total_loss,
721                    recon_loss,
722                    kl_loss
723                );
724            }
725        }
726
727        let avg_total = epoch_total_loss / num_batches as f32;
728        let avg_recon = epoch_recon_loss / num_batches as f32;
729        let avg_kl = epoch_kl_loss / num_batches as f32;
730
731        println!(
732            "āœ… Epoch {} - Total: {:.4}, Recon: {:.4}, KL: {:.4}",
733            epoch + 1,
734            avg_total,
735            avg_recon,
736            avg_kl
737        );
738
739        // Evaluation and generation every few epochs
740        if (epoch + 1) % 5 == 0 {
741            println!("šŸ” Running evaluation and generation...");
742
743            // Test reconstruction
744            let test_images = dataset.generate_batch(batch_size);
745            let test_images_dyn = test_images.into_dyn();
746            let (test_reconstruction, _, _) = vae.forward(&test_images_dyn)?;
747
748            let recon_error = metrics.reconstruction_error(&test_images_dyn, &test_reconstruction);
749            println!("šŸ“Š Reconstruction MSE: {:.6}", recon_error);
750
751            // Generate new samples
752            let generated_samples = vae.generate(8)?;
753            let diversity = metrics.sample_diversity(&generated_samples);
754            println!("šŸŽ² Sample diversity: {:.6}", diversity);
755
756            // Test interpolation
757            let latent1 = Array2::<f32>::from_elem((1, config.latent_dim), -1.0).into_dyn();
758            let latent2 = Array2::<f32>::from_elem((1, config.latent_dim), 1.0).into_dyn();
759            let interpolated = vae.interpolate(&latent1, &latent2, 5)?;
760
761            println!("šŸ”„ Generated {} interpolated samples", interpolated.len());
762        }
763    }
764
765    println!("\nšŸŽ‰ VAE training completed!");
766    Ok(())
767}
768
769/// Training function for simple GAN
770fn train_gan_model() -> StdResult<()> {
771    println!("āš”ļø Starting GAN Training");
772
773    let mut rng = SmallRng::seed_from_u64(42);
774    let config = GenerativeConfig::default();
775
776    // Create models
777    println!("šŸ—ļø Building GAN models...");
778    let generator = GANGenerator::new(config.clone(), &mut rng)?;
779    let discriminator = GANDiscriminator::new(config.clone(), &mut rng)?;
780    println!("āœ… GAN models created");
781
782    // Create dataset
783    let mut dataset = GenerativeDataset::new(config.clone(), 456);
784
785    // Create loss functions
786    let _adversarial_loss = CrossEntropyLoss::new(1e-7);
787
788    println!("šŸ“Š GAN training configuration:");
789    println!("   - Generator latent dim: {}", config.latent_dim);
790    println!("   - Discriminator architecture: {:?}", config.hidden_dims);
791
792    // Training loop (simplified)
793    let num_epochs = 15;
794    let batch_size = 4;
795
796    for epoch in 0..num_epochs {
797        println!("\nšŸ“ˆ Epoch {}/{}", epoch + 1, num_epochs);
798
799        let mut d_loss_total = 0.0;
800        let mut g_loss_total = 0.0;
801        let num_batches = 8;
802
803        for batch_idx in 0..num_batches {
804            // Train Discriminator
805            let real_images = dataset.generate_batch(batch_size);
806            let real_images_dyn = real_images.into_dyn();
807
808            // Generate fake images
809            let mut noise = Array2::<f32>::zeros((batch_size, config.latent_dim));
810            for elem in noise.iter_mut() {
811                *elem = rng.random_range(-1.0..1.0);
812            }
813            let noise_dyn = noise.into_dyn();
814            let fake_images = generator.forward(&noise_dyn)?;
815
816            // Discriminator predictions
817            let real_pred = discriminator.forward(&real_images_dyn)?;
818            let fake_pred = discriminator.forward(&fake_images)?;
819
820            // Simplified loss calculation (normally would use proper labels)
821            let mut d_loss_real = 0.0f32;
822            let mut d_loss_fake = 0.0f32;
823
824            for &pred in real_pred.iter() {
825                d_loss_real += -(1.0f32).ln() - pred; // Log loss for real=1
826            }
827
828            for &pred in fake_pred.iter() {
829                d_loss_fake += -(1.0 - pred).ln(); // Log loss for fake=0
830            }
831
832            let d_loss = (d_loss_real + d_loss_fake) / (batch_size * 2) as f32;
833            d_loss_total += d_loss;
834
835            // Train Generator (simplified)
836            let fake_pred_for_g = discriminator.forward(&fake_images)?;
837            let mut g_loss = 0.0f32;
838
839            for &pred in fake_pred_for_g.iter() {
840                g_loss += -(1.0f32).ln() - pred; // Want discriminator to output 1 for fake
841            }
842            g_loss /= batch_size as f32;
843            g_loss_total += g_loss;
844
845            if batch_idx % 4 == 0 {
846                print!(
847                    "šŸ”„ Batch {}/{} - D Loss: {:.4}, G Loss: {:.4}        \r",
848                    batch_idx + 1,
849                    num_batches,
850                    d_loss,
851                    g_loss
852                );
853            }
854        }
855
856        let avg_d_loss = d_loss_total / num_batches as f32;
857        let avg_g_loss = g_loss_total / num_batches as f32;
858
859        println!(
860            "āœ… Epoch {} - D Loss: {:.4}, G Loss: {:.4}",
861            epoch + 1,
862            avg_d_loss,
863            avg_g_loss
864        );
865
866        // Generate samples every few epochs
867        if (epoch + 1) % 5 == 0 {
868            println!("šŸŽ² Generating samples...");
869
870            let mut sample_noise = Array2::<f32>::zeros((4, config.latent_dim));
871            for elem in sample_noise.iter_mut() {
872                *elem = rng.random_range(-1.0..1.0);
873            }
874            let sample_noise_dyn = sample_noise.into_dyn();
875            let generated = generator.forward(&sample_noise_dyn)?;
876
877            println!("šŸ“Š Generated {} samples", generated.shape()[0]);
878        }
879    }
880
881    println!("\nšŸŽ‰ GAN training completed!");
882    Ok(())
883}
884
885fn main() -> StdResult<()> {
886    println!("šŸŽØ Generative Models Complete Example");
887    println!("=====================================");
888    println!();
889    println!("This example demonstrates:");
890    println!("• Variational Autoencoder (VAE) implementation");
891    println!("• Generative Adversarial Network (GAN) basics");
892    println!("• Synthetic pattern dataset generation");
893    println!("• VAE loss (reconstruction + KL divergence)");
894    println!("• Latent space interpolation");
895    println!("• Sample generation and evaluation");
896    println!();
897
898    // Train VAE
899    train_vae_model()?;
900
901    println!("\n{}", "=".repeat(50));
902
903    // Train GAN
904    train_gan_model()?;
905
906    println!("\nšŸ’” Key Concepts Demonstrated:");
907    println!("   šŸ”¹ Variational inference and reparameterization trick");
908    println!("   šŸ”¹ KL divergence regularization");
909    println!("   šŸ”¹ Adversarial training dynamics");
910    println!("   šŸ”¹ Latent space manipulation");
911    println!("   šŸ”¹ Reconstruction vs generation quality");
912    println!("   šŸ”¹ Sample diversity metrics");
913    println!();
914    println!("šŸš€ For production use:");
915    println!("   • Implement β-VAE, WAE, or other VAE variants");
916    println!("   • Add convolutional layers for better image modeling");
917    println!("   • Implement DCGAN, StyleGAN, or other advanced GANs");
918    println!("   • Add progressive training and spectral normalization");
919    println!("   • Use FID, IS, or other advanced evaluation metrics");
920    println!("   • Implement conditional generation (cVAE, cGAN)");
921    println!("   • Add attention mechanisms and self-attention");
922
923    Ok(())
924}
925
926#[cfg(test)]
927mod tests {
928    use super::*;
929
930    #[test]
931    fn test_generative_config() {
932        let config = GenerativeConfig::default();
933        assert_eq!(config.input_size, (32, 32));
934        assert_eq!(config.latent_dim, 16);
935        assert_eq!(config.beta, 1.0);
936        assert!(!config.hidden_dims.is_empty());
937    }
938
939    #[test]
940    fn test_dataset_generation() {
941        let config = GenerativeConfig::default();
942        let mut dataset = GenerativeDataset::new(config.clone(), 42);
943
944        let image = dataset.generate_sample();
945        assert_eq!(
946            image.shape(),
947            &[1, config.input_size.0, config.input_size.1]
948        );
949
950        // Check values are in valid range
951        for &val in image.iter() {
952            assert!(val >= 0.0 && val <= 1.0);
953        }
954    }
955
956    #[test]
957    fn test_vae_creation() -> StdResult<()> {
958        let mut rng = SmallRng::seed_from_u64(42);
959        let config = GenerativeConfig::default();
960
961        let vae = VAEModel::new(config.clone(), &mut rng)?;
962
963        // Test forward pass
964        let batch_size = 2;
965        let input = Array4::<f32>::ones((batch_size, 1, config.input_size.0, config.input_size.1))
966            .into_dyn();
967        let (reconstruction, mean, logvar) = vae.forward(&input)?;
968
969        assert_eq!(reconstruction.shape()[0], batch_size);
970        assert_eq!(mean.shape()[1], config.latent_dim);
971        assert_eq!(logvar.shape()[1], config.latent_dim);
972
973        Ok(())
974    }
975
976    #[test]
977    fn test_gan_creation() -> StdResult<()> {
978        let mut rng = SmallRng::seed_from_u64(42);
979        let config = GenerativeConfig::default();
980
981        let generator = GANGenerator::new(config.clone(), &mut rng)?;
982        let discriminator = GANDiscriminator::new(config.clone(), &mut rng)?;
983
984        // Test generator
985        let batch_size = 2;
986        let noise = Array2::<f32>::ones((batch_size, config.latent_dim)).into_dyn();
987        let generated = generator.forward(&noise)?;
988
989        assert_eq!(generated.shape()[0], batch_size);
990        assert_eq!(generated.shape()[1], 1);
991
992        // Test discriminator
993        let pred = discriminator.forward(&generated)?;
994        assert_eq!(pred.shape()[0], batch_size);
995        assert_eq!(pred.shape()[1], 1);
996
997        Ok(())
998    }
999
1000    #[test]
1001    fn test_vae_loss() -> StdResult<()> {
1002        let loss_fn = VAELoss::new(1.0);
1003
1004        let reconstruction = Array2::<f32>::ones((2, 10)).into_dyn();
1005        let target = Array2::<f32>::zeros((2, 10)).into_dyn();
1006        let mean = Array2::<f32>::zeros((2, 5)).into_dyn();
1007        let logvar = Array2::<f32>::zeros((2, 5)).into_dyn();
1008
1009        let (total_loss, recon_loss, kl_loss) =
1010            loss_fn.compute_loss(&reconstruction, &target, &mean, &logvar)?;
1011
1012        assert!(total_loss > 0.0);
1013        assert!(recon_loss > 0.0);
1014        // KL loss should be 0 for mean=0, logvar=0
1015        assert!(kl_loss.abs() < 1e-6);
1016
1017        Ok(())
1018    }
1019
1020    #[test]
1021    fn test_generative_metrics() {
1022        let config = GenerativeConfig::default();
1023        let metrics = GenerativeMetrics::new(config);
1024
1025        // Test reconstruction error
1026        let original = Array2::<f32>::ones((2, 10)).into_dyn();
1027        let reconstructed = Array2::<f32>::zeros((2, 10)).into_dyn();
1028
1029        let error = metrics.reconstruction_error(&original, &reconstructed);
1030        assert_eq!(error, 1.0); // MSE between all 1s and all 0s
1031
1032        // Test diversity
1033        let samples = Array2::<f32>::from_shape_fn((3, 4), |(i, j)| i as f32 + j as f32).into_dyn();
1034        let diversity = metrics.sample_diversity(&samples);
1035        assert!(diversity > 0.0);
1036    }
1037}