advanced_optimizers_example/
advanced_optimizers_example.rs

1use ndarray::Array2;
2use rand::prelude::*;
3use rand::rngs::SmallRng;
4use rand::SeedableRng;
5use scirs2_neural::layers::{Dense, Dropout};
6use scirs2_neural::losses::CrossEntropyLoss;
7use scirs2_neural::models::{Model, Sequential};
8use scirs2_neural::optimizers::{Adam, AdamW, Optimizer, RAdam, RMSprop, SGD};
9use std::time::Instant;
10
11fn main() -> Result<(), Box<dyn std::error::Error>> {
12    println!("Advanced Optimizers Example");
13
14    // Initialize random number generator
15    let mut rng = SmallRng::seed_from_u64(42);
16
17    // Create a synthetic binary classification dataset
18    let num_samples = 1000;
19    let num_features = 20;
20    let num_classes = 2;
21
22    println!(
23        "Generating synthetic dataset with {} samples, {} features...",
24        num_samples, num_features
25    );
26
27    // Generate random input features
28    let mut x_data = Array2::<f32>::zeros((num_samples, num_features));
29    for i in 0..num_samples {
30        for j in 0..num_features {
31            x_data[[i, j]] = rng.random_range(-1.0..1.0);
32        }
33    }
34
35    // Create true weights and bias for data generation
36    let mut true_weights = Array2::<f32>::zeros((num_features, 1));
37    for i in 0..num_features {
38        true_weights[[i, 0]] = rng.random_range(-1.0..1.0);
39    }
40    let true_bias = rng.random_range(-1.0..1.0);
41
42    // Generate binary labels (0 or 1) based on linear model with logistic function
43    let mut y_data = Array2::<f32>::zeros((num_samples, num_classes));
44    for i in 0..num_samples {
45        let mut logit = true_bias;
46        for j in 0..num_features {
47            logit += x_data[[i, j]] * true_weights[[j, 0]];
48        }
49
50        // Apply sigmoid to get probability
51        let prob = 1.0 / (1.0 + (-logit).exp());
52
53        // Convert to one-hot encoding
54        if prob > 0.5 {
55            y_data[[i, 1]] = 1.0; // Class 1
56        } else {
57            y_data[[i, 0]] = 1.0; // Class 0
58        }
59    }
60
61    // Split into train and test sets (80% train, 20% test)
62    let train_size = (num_samples as f32 * 0.8) as usize;
63    let test_size = num_samples - train_size;
64
65    let x_train = x_data.slice(ndarray::s![0..train_size, ..]).to_owned();
66    let y_train = y_data.slice(ndarray::s![0..train_size, ..]).to_owned();
67    let x_test = x_data.slice(ndarray::s![train_size.., ..]).to_owned();
68    let y_test = y_data.slice(ndarray::s![train_size.., ..]).to_owned();
69
70    println!("Training set: {} samples", train_size);
71    println!("Test set: {} samples", test_size);
72
73    // Create a simple neural network model
74    let hidden_size = 64;
75    let dropout_rate = 0.2;
76    let seed_rng = SmallRng::seed_from_u64(42);
77
78    // Shared function to create identical model architectures for fair comparison
79    let create_model = || -> Result<Sequential<f32>, Box<dyn std::error::Error>> {
80        let mut model = Sequential::new();
81
82        // Input to hidden layer
83        let dense1 = Dense::new(
84            num_features,
85            hidden_size,
86            Some("relu"),
87            &mut seed_rng.clone(),
88        )?;
89        model.add_layer(dense1);
90
91        // Dropout for regularization
92        let dropout = Dropout::new(dropout_rate, &mut seed_rng.clone())?;
93        model.add_layer(dropout);
94
95        // Hidden to output layer
96        let dense2 = Dense::new(
97            hidden_size,
98            num_classes,
99            Some("softmax"),
100            &mut seed_rng.clone(),
101        )?;
102        model.add_layer(dense2);
103
104        Ok(model)
105    };
106
107    // Create models for each optimizer
108    let mut sgd_model = create_model()?;
109    let mut adam_model = create_model()?;
110    let mut adamw_model = create_model()?;
111    let mut radam_model = create_model()?;
112    let mut rmsprop_model = create_model()?;
113
114    // Create the loss function
115    let loss_fn = CrossEntropyLoss::new(1e-10);
116
117    // Create optimizers
118    let learning_rate = 0.001;
119    let batch_size = 32;
120    let epochs = 20;
121
122    let mut sgd_optimizer = SGD::new_with_config(learning_rate, 0.9, 0.0);
123    let mut adam_optimizer = Adam::new(learning_rate, 0.9, 0.999, 1e-8);
124    let mut adamw_optimizer = AdamW::new(learning_rate, 0.9, 0.999, 1e-8, 0.01);
125    let mut radam_optimizer = RAdam::new(learning_rate, 0.9, 0.999, 1e-8, 0.0);
126    let mut rmsprop_optimizer = RMSprop::new_with_config(learning_rate, 0.9, 1e-8, 0.0);
127
128    // Helper function to compute accuracy
129    let compute_accuracy = |model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>| -> f32 {
130        let predictions = model.forward(&x.clone().into_dyn()).unwrap();
131        let mut correct = 0;
132
133        for i in 0..x.shape()[0] {
134            let mut max_idx = 0;
135            let mut max_val = predictions[[i, 0]];
136
137            for j in 1..num_classes {
138                if predictions[[i, j]] > max_val {
139                    max_val = predictions[[i, j]];
140                    max_idx = j;
141                }
142            }
143
144            let true_idx =
145                y[[i, 0]] < y[[i, 1]] as usize as u8 as i8 as usize as isize as i32 as f32;
146            if max_idx as i32 == true_idx as i32 {
147                correct += 1;
148            }
149        }
150
151        correct as f32 / x.shape()[0] as f32
152    };
153
154    // Helper function to train model
155    let mut train_model =
156        |model: &mut Sequential<f32>, optimizer: &mut dyn Optimizer<f32>, name: &str| -> Vec<f32> {
157            println!("\nTraining with {} optimizer...", name);
158            let start_time = Instant::now();
159
160            let mut train_losses = Vec::new();
161            let num_batches = train_size.div_ceil(batch_size);
162
163            for epoch in 0..epochs {
164                let mut epoch_loss = 0.0;
165
166                // Create a permutation for shuffling the data
167                let mut indices: Vec<usize> = (0..train_size).collect();
168                indices.shuffle(&mut rng);
169
170                for batch_idx in 0..num_batches {
171                    let start = batch_idx * batch_size;
172                    let end = (start + batch_size).min(train_size);
173                    let batch_indices = &indices[start..end];
174
175                    // Create batch data
176                    let mut x_batch = Array2::<f32>::zeros((batch_indices.len(), num_features));
177                    let mut y_batch = Array2::<f32>::zeros((batch_indices.len(), num_classes));
178
179                    for (i, &idx) in batch_indices.iter().enumerate() {
180                        for j in 0..num_features {
181                            x_batch[[i, j]] = x_train[[idx, j]];
182                        }
183                        for j in 0..num_classes {
184                            y_batch[[i, j]] = y_train[[idx, j]];
185                        }
186                    }
187
188                    // Convert to dynamic dimension arrays
189                    let x_batch_dyn = x_batch.into_dyn();
190                    let y_batch_dyn = y_batch.into_dyn();
191
192                    // Perform a training step
193                    let batch_loss = model
194                        .train_batch(&x_batch_dyn, &y_batch_dyn, &loss_fn, optimizer)
195                        .unwrap();
196                    epoch_loss += batch_loss;
197                }
198
199                epoch_loss /= num_batches as f32;
200                train_losses.push(epoch_loss);
201
202                // Calculate and print metrics every few epochs
203                if epoch % 5 == 0 || epoch == epochs - 1 {
204                    let train_accuracy = compute_accuracy(model, &x_train, &y_train);
205                    let test_accuracy = compute_accuracy(model, &x_test, &y_test);
206
207                    println!(
208                        "Epoch {}/{}: loss = {:.6}, train_acc = {:.2}%, test_acc = {:.2}%",
209                        epoch + 1,
210                        epochs,
211                        epoch_loss,
212                        train_accuracy * 100.0,
213                        test_accuracy * 100.0
214                    );
215                }
216            }
217
218            let elapsed = start_time.elapsed();
219            println!("{} training completed in {:.2?}", name, elapsed);
220
221            // Final evaluation
222            let train_accuracy = compute_accuracy(model, &x_train, &y_train);
223            let test_accuracy = compute_accuracy(model, &x_test, &y_test);
224            println!("Final metrics for {}:", name);
225            println!("  Train accuracy: {:.2}%", train_accuracy * 100.0);
226            println!("  Test accuracy:  {:.2}%", test_accuracy * 100.0);
227
228            train_losses
229        };
230
231    // Train models with different optimizers
232    let sgd_losses = train_model(&mut sgd_model, &mut sgd_optimizer, "SGD");
233    let adam_losses = train_model(&mut adam_model, &mut adam_optimizer, "Adam");
234    let adamw_losses = train_model(&mut adamw_model, &mut adamw_optimizer, "AdamW");
235    let radam_losses = train_model(&mut radam_model, &mut radam_optimizer, "RAdam");
236    let rmsprop_losses = train_model(&mut rmsprop_model, &mut rmsprop_optimizer, "RMSprop");
237
238    // Print comparison summary
239    println!("\nOptimizer Comparison Summary:");
240    println!("----------------------------");
241    println!("Initial learning rate: {}", learning_rate);
242    println!("Batch size: {}", batch_size);
243    println!("Epochs: {}", epochs);
244    println!();
245
246    println!("Final Loss Values:");
247    println!("  SGD:     {:.6}", sgd_losses.last().unwrap());
248    println!("  Adam:    {:.6}", adam_losses.last().unwrap());
249    println!("  AdamW:   {:.6}", adamw_losses.last().unwrap());
250    println!("  RAdam:   {:.6}", radam_losses.last().unwrap());
251    println!("  RMSprop: {:.6}", rmsprop_losses.last().unwrap());
252
253    println!("\nLoss progression (first value, middle value, last value):");
254    println!(
255        "  SGD:     {:.6}, {:.6}, {:.6}",
256        sgd_losses.first().unwrap(),
257        sgd_losses[epochs / 2],
258        sgd_losses.last().unwrap()
259    );
260    println!(
261        "  Adam:    {:.6}, {:.6}, {:.6}",
262        adam_losses.first().unwrap(),
263        adam_losses[epochs / 2],
264        adam_losses.last().unwrap()
265    );
266    println!(
267        "  AdamW:   {:.6}, {:.6}, {:.6}",
268        adamw_losses.first().unwrap(),
269        adamw_losses[epochs / 2],
270        adamw_losses.last().unwrap()
271    );
272    println!(
273        "  RAdam:   {:.6}, {:.6}, {:.6}",
274        radam_losses.first().unwrap(),
275        radam_losses[epochs / 2],
276        radam_losses.last().unwrap()
277    );
278    println!(
279        "  RMSprop: {:.6}, {:.6}, {:.6}",
280        rmsprop_losses.first().unwrap(),
281        rmsprop_losses[epochs / 2],
282        rmsprop_losses.last().unwrap()
283    );
284
285    println!("\nLoss improvement ratio (first loss / last loss):");
286    println!(
287        "  SGD:     {:.2}x",
288        sgd_losses.first().unwrap() / sgd_losses.last().unwrap()
289    );
290    println!(
291        "  Adam:    {:.2}x",
292        adam_losses.first().unwrap() / adam_losses.last().unwrap()
293    );
294    println!(
295        "  AdamW:   {:.2}x",
296        adamw_losses.first().unwrap() / adamw_losses.last().unwrap()
297    );
298    println!(
299        "  RAdam:   {:.2}x",
300        radam_losses.first().unwrap() / radam_losses.last().unwrap()
301    );
302    println!(
303        "  RMSprop: {:.2}x",
304        rmsprop_losses.first().unwrap() / rmsprop_losses.last().unwrap()
305    );
306
307    println!("\nAdvanced optimizers demo completed successfully!");
308
309    Ok(())
310}