advanced_optimizers_example/
advanced_optimizers_example.rs1use ndarray::Array2;
2use rand::prelude::*;
3use rand::rngs::SmallRng;
4use rand::SeedableRng;
5use scirs2_neural::layers::{Dense, Dropout};
6use scirs2_neural::losses::CrossEntropyLoss;
7use scirs2_neural::models::{Model, Sequential};
8use scirs2_neural::optimizers::{Adam, AdamW, Optimizer, RAdam, RMSprop, SGD};
9use std::time::Instant;
10
11fn main() -> Result<(), Box<dyn std::error::Error>> {
12 println!("Advanced Optimizers Example");
13
14 let mut rng = SmallRng::seed_from_u64(42);
16
17 let num_samples = 1000;
19 let num_features = 20;
20 let num_classes = 2;
21
22 println!(
23 "Generating synthetic dataset with {} samples, {} features...",
24 num_samples, num_features
25 );
26
27 let mut x_data = Array2::<f32>::zeros((num_samples, num_features));
29 for i in 0..num_samples {
30 for j in 0..num_features {
31 x_data[[i, j]] = rng.random_range(-1.0..1.0);
32 }
33 }
34
35 let mut true_weights = Array2::<f32>::zeros((num_features, 1));
37 for i in 0..num_features {
38 true_weights[[i, 0]] = rng.random_range(-1.0..1.0);
39 }
40 let true_bias = rng.random_range(-1.0..1.0);
41
42 let mut y_data = Array2::<f32>::zeros((num_samples, num_classes));
44 for i in 0..num_samples {
45 let mut logit = true_bias;
46 for j in 0..num_features {
47 logit += x_data[[i, j]] * true_weights[[j, 0]];
48 }
49
50 let prob = 1.0 / (1.0 + (-logit).exp());
52
53 if prob > 0.5 {
55 y_data[[i, 1]] = 1.0; } else {
57 y_data[[i, 0]] = 1.0; }
59 }
60
61 let train_size = (num_samples as f32 * 0.8) as usize;
63 let test_size = num_samples - train_size;
64
65 let x_train = x_data.slice(ndarray::s![0..train_size, ..]).to_owned();
66 let y_train = y_data.slice(ndarray::s![0..train_size, ..]).to_owned();
67 let x_test = x_data.slice(ndarray::s![train_size.., ..]).to_owned();
68 let y_test = y_data.slice(ndarray::s![train_size.., ..]).to_owned();
69
70 println!("Training set: {} samples", train_size);
71 println!("Test set: {} samples", test_size);
72
73 let hidden_size = 64;
75 let dropout_rate = 0.2;
76 let seed_rng = SmallRng::seed_from_u64(42);
77
78 let create_model = || -> Result<Sequential<f32>, Box<dyn std::error::Error>> {
80 let mut model = Sequential::new();
81
82 let dense1 = Dense::new(
84 num_features,
85 hidden_size,
86 Some("relu"),
87 &mut seed_rng.clone(),
88 )?;
89 model.add_layer(dense1);
90
91 let dropout = Dropout::new(dropout_rate, &mut seed_rng.clone())?;
93 model.add_layer(dropout);
94
95 let dense2 = Dense::new(
97 hidden_size,
98 num_classes,
99 Some("softmax"),
100 &mut seed_rng.clone(),
101 )?;
102 model.add_layer(dense2);
103
104 Ok(model)
105 };
106
107 let mut sgd_model = create_model()?;
109 let mut adam_model = create_model()?;
110 let mut adamw_model = create_model()?;
111 let mut radam_model = create_model()?;
112 let mut rmsprop_model = create_model()?;
113
114 let loss_fn = CrossEntropyLoss::new(1e-10);
116
117 let learning_rate = 0.001;
119 let batch_size = 32;
120 let epochs = 20;
121
122 let mut sgd_optimizer = SGD::new_with_config(learning_rate, 0.9, 0.0);
123 let mut adam_optimizer = Adam::new(learning_rate, 0.9, 0.999, 1e-8);
124 let mut adamw_optimizer = AdamW::new(learning_rate, 0.9, 0.999, 1e-8, 0.01);
125 let mut radam_optimizer = RAdam::new(learning_rate, 0.9, 0.999, 1e-8, 0.0);
126 let mut rmsprop_optimizer = RMSprop::new_with_config(learning_rate, 0.9, 1e-8, 0.0);
127
128 let compute_accuracy = |model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>| -> f32 {
130 let predictions = model.forward(&x.clone().into_dyn()).unwrap();
131 let mut correct = 0;
132
133 for i in 0..x.shape()[0] {
134 let mut max_idx = 0;
135 let mut max_val = predictions[[i, 0]];
136
137 for j in 1..num_classes {
138 if predictions[[i, j]] > max_val {
139 max_val = predictions[[i, j]];
140 max_idx = j;
141 }
142 }
143
144 let true_idx =
145 y[[i, 0]] < y[[i, 1]] as usize as u8 as i8 as usize as isize as i32 as f32;
146 if max_idx as i32 == true_idx as i32 {
147 correct += 1;
148 }
149 }
150
151 correct as f32 / x.shape()[0] as f32
152 };
153
154 let mut train_model =
156 |model: &mut Sequential<f32>, optimizer: &mut dyn Optimizer<f32>, name: &str| -> Vec<f32> {
157 println!("\nTraining with {} optimizer...", name);
158 let start_time = Instant::now();
159
160 let mut train_losses = Vec::new();
161 let num_batches = train_size.div_ceil(batch_size);
162
163 for epoch in 0..epochs {
164 let mut epoch_loss = 0.0;
165
166 let mut indices: Vec<usize> = (0..train_size).collect();
168 indices.shuffle(&mut rng);
169
170 for batch_idx in 0..num_batches {
171 let start = batch_idx * batch_size;
172 let end = (start + batch_size).min(train_size);
173 let batch_indices = &indices[start..end];
174
175 let mut x_batch = Array2::<f32>::zeros((batch_indices.len(), num_features));
177 let mut y_batch = Array2::<f32>::zeros((batch_indices.len(), num_classes));
178
179 for (i, &idx) in batch_indices.iter().enumerate() {
180 for j in 0..num_features {
181 x_batch[[i, j]] = x_train[[idx, j]];
182 }
183 for j in 0..num_classes {
184 y_batch[[i, j]] = y_train[[idx, j]];
185 }
186 }
187
188 let x_batch_dyn = x_batch.into_dyn();
190 let y_batch_dyn = y_batch.into_dyn();
191
192 let batch_loss = model
194 .train_batch(&x_batch_dyn, &y_batch_dyn, &loss_fn, optimizer)
195 .unwrap();
196 epoch_loss += batch_loss;
197 }
198
199 epoch_loss /= num_batches as f32;
200 train_losses.push(epoch_loss);
201
202 if epoch % 5 == 0 || epoch == epochs - 1 {
204 let train_accuracy = compute_accuracy(model, &x_train, &y_train);
205 let test_accuracy = compute_accuracy(model, &x_test, &y_test);
206
207 println!(
208 "Epoch {}/{}: loss = {:.6}, train_acc = {:.2}%, test_acc = {:.2}%",
209 epoch + 1,
210 epochs,
211 epoch_loss,
212 train_accuracy * 100.0,
213 test_accuracy * 100.0
214 );
215 }
216 }
217
218 let elapsed = start_time.elapsed();
219 println!("{} training completed in {:.2?}", name, elapsed);
220
221 let train_accuracy = compute_accuracy(model, &x_train, &y_train);
223 let test_accuracy = compute_accuracy(model, &x_test, &y_test);
224 println!("Final metrics for {}:", name);
225 println!(" Train accuracy: {:.2}%", train_accuracy * 100.0);
226 println!(" Test accuracy: {:.2}%", test_accuracy * 100.0);
227
228 train_losses
229 };
230
231 let sgd_losses = train_model(&mut sgd_model, &mut sgd_optimizer, "SGD");
233 let adam_losses = train_model(&mut adam_model, &mut adam_optimizer, "Adam");
234 let adamw_losses = train_model(&mut adamw_model, &mut adamw_optimizer, "AdamW");
235 let radam_losses = train_model(&mut radam_model, &mut radam_optimizer, "RAdam");
236 let rmsprop_losses = train_model(&mut rmsprop_model, &mut rmsprop_optimizer, "RMSprop");
237
238 println!("\nOptimizer Comparison Summary:");
240 println!("----------------------------");
241 println!("Initial learning rate: {}", learning_rate);
242 println!("Batch size: {}", batch_size);
243 println!("Epochs: {}", epochs);
244 println!();
245
246 println!("Final Loss Values:");
247 println!(" SGD: {:.6}", sgd_losses.last().unwrap());
248 println!(" Adam: {:.6}", adam_losses.last().unwrap());
249 println!(" AdamW: {:.6}", adamw_losses.last().unwrap());
250 println!(" RAdam: {:.6}", radam_losses.last().unwrap());
251 println!(" RMSprop: {:.6}", rmsprop_losses.last().unwrap());
252
253 println!("\nLoss progression (first value, middle value, last value):");
254 println!(
255 " SGD: {:.6}, {:.6}, {:.6}",
256 sgd_losses.first().unwrap(),
257 sgd_losses[epochs / 2],
258 sgd_losses.last().unwrap()
259 );
260 println!(
261 " Adam: {:.6}, {:.6}, {:.6}",
262 adam_losses.first().unwrap(),
263 adam_losses[epochs / 2],
264 adam_losses.last().unwrap()
265 );
266 println!(
267 " AdamW: {:.6}, {:.6}, {:.6}",
268 adamw_losses.first().unwrap(),
269 adamw_losses[epochs / 2],
270 adamw_losses.last().unwrap()
271 );
272 println!(
273 " RAdam: {:.6}, {:.6}, {:.6}",
274 radam_losses.first().unwrap(),
275 radam_losses[epochs / 2],
276 radam_losses.last().unwrap()
277 );
278 println!(
279 " RMSprop: {:.6}, {:.6}, {:.6}",
280 rmsprop_losses.first().unwrap(),
281 rmsprop_losses[epochs / 2],
282 rmsprop_losses.last().unwrap()
283 );
284
285 println!("\nLoss improvement ratio (first loss / last loss):");
286 println!(
287 " SGD: {:.2}x",
288 sgd_losses.first().unwrap() / sgd_losses.last().unwrap()
289 );
290 println!(
291 " Adam: {:.2}x",
292 adam_losses.first().unwrap() / adam_losses.last().unwrap()
293 );
294 println!(
295 " AdamW: {:.2}x",
296 adamw_losses.first().unwrap() / adamw_losses.last().unwrap()
297 );
298 println!(
299 " RAdam: {:.2}x",
300 radam_losses.first().unwrap() / radam_losses.last().unwrap()
301 );
302 println!(
303 " RMSprop: {:.2}x",
304 rmsprop_losses.first().unwrap() / rmsprop_losses.last().unwrap()
305 );
306
307 println!("\nAdvanced optimizers demo completed successfully!");
308
309 Ok(())
310}