1use train_station::{
33 optimizers::{Adam, AdamConfig, Optimizer},
34 Tensor,
35};
36
37#[derive(Debug, Clone, PartialEq)]
39struct TrainingConfig {
40 pub epochs: usize,
41 pub learning_rate: f32,
42 pub weight_decay: f32,
43 pub beta1: f32,
44 pub beta2: f32,
45}
46
47impl Default for TrainingConfig {
48 fn default() -> Self {
49 Self {
50 epochs: 100,
51 learning_rate: 0.01,
52 weight_decay: 0.0,
53 beta1: 0.9,
54 beta2: 0.999,
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61#[allow(dead_code)]
62struct TrainingStats {
63 pub config: TrainingConfig,
64 pub final_loss: f32,
65 pub loss_history: Vec<f32>,
66 pub convergence_epoch: usize,
67 pub weight_norm: f32,
68}
69
70fn main() -> Result<(), Box<dyn std::error::Error>> {
71 println!("=== Adam Configurations Example ===\n");
72
73 demonstrate_default_adam()?;
74 demonstrate_learning_rate_comparison()?;
75 demonstrate_weight_decay_comparison()?;
76 demonstrate_beta_parameter_tuning()?;
77 demonstrate_configuration_benchmarking()?;
78
79 println!("\n=== Example completed successfully! ===");
80 Ok(())
81}
82
83fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85 println!("--- Default Adam Configuration ---");
86
87 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91 let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95 let mut optimizer = Adam::new();
97 optimizer.add_parameter(&weight);
98 optimizer.add_parameter(&bias);
99
100 println!("Default Adam configuration:");
101 println!(" Learning rate: {}", optimizer.learning_rate());
102 println!(" Initial weight: {:.6}", weight.value());
103 println!(" Initial bias: {:.6}", bias.value());
104
105 let num_epochs = 50;
107 let mut losses = Vec::new();
108
109 for epoch in 0..num_epochs {
110 let y_pred = x_data.matmul(&weight) + &bias;
112 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114 loss.backward(None);
116
117 optimizer.step(&mut [&mut weight, &mut bias]);
119 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121 losses.push(loss.value());
122
123 if epoch % 10 == 0 || epoch == num_epochs - 1 {
124 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125 }
126 }
127
128 let _final_predictions = x_data.matmul(&weight) + &bias;
130 println!("\nFinal model:");
131 println!(" Learned weight: {:.6} (target: 2.0)", weight.value());
132 println!(" Learned bias: {:.6} (target: 1.0)", bias.value());
133 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
134
135 Ok(())
136}
137
138fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140 println!("\n--- Learning Rate Comparison ---");
141
142 let learning_rates = [0.001, 0.01, 0.1];
143 let mut results = Vec::new();
144
145 for &lr in &learning_rates {
146 println!("\nTesting learning rate: {}", lr);
147
148 let stats = train_with_config(TrainingConfig {
149 learning_rate: lr,
150 ..Default::default()
151 })?;
152
153 results.push((lr, stats.clone()));
154
155 println!(" Final loss: {:.6}", stats.final_loss);
156 println!(" Convergence epoch: {}", stats.convergence_epoch);
157 }
158
159 println!("\nLearning Rate Comparison Summary:");
161 for (lr, stats) in &results {
162 println!(
163 " LR={:6}: Loss={:.6}, Converged@{}",
164 lr, stats.final_loss, stats.convergence_epoch
165 );
166 }
167
168 Ok(())
169}
170
171fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173 println!("\n--- Weight Decay Comparison ---");
174
175 let weight_decays = [0.0, 0.001, 0.01];
176 let mut results = Vec::new();
177
178 for &wd in &weight_decays {
179 println!("\nTesting weight decay: {}", wd);
180
181 let stats = train_with_config(TrainingConfig {
182 weight_decay: wd,
183 ..Default::default()
184 })?;
185
186 results.push((wd, stats.clone()));
187
188 println!(" Final loss: {:.6}", stats.final_loss);
189 println!(" Final weight norm: {:.6}", stats.weight_norm);
190 }
191
192 println!("\nWeight Decay Comparison Summary:");
194 for (wd, stats) in &results {
195 println!(
196 " WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197 wd, stats.final_loss, stats.weight_norm
198 );
199 }
200
201 Ok(())
202}
203
204fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206 println!("\n--- Beta Parameter Tuning ---");
207
208 let beta_configs = [
209 (0.9, 0.999), (0.8, 0.999), (0.95, 0.999), (0.9, 0.99), ];
214
215 let mut results = Vec::new();
216
217 for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218 println!(
219 "\nTesting beta configuration {}: beta1={}, beta2={}",
220 i + 1,
221 beta1,
222 beta2
223 );
224
225 let config = TrainingConfig {
226 beta1: *beta1,
227 beta2: *beta2,
228 ..Default::default()
229 };
230
231 let stats = train_with_config(config)?;
232 results.push(((*beta1, *beta2), stats.clone()));
233
234 println!(" Final loss: {:.6}", stats.final_loss);
235 println!(" Convergence epoch: {}", stats.convergence_epoch);
236 }
237
238 println!("\nBeta Parameter Comparison Summary:");
240 for ((beta1, beta2), stats) in &results {
241 println!(
242 " B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243 beta1, beta2, stats.final_loss, stats.convergence_epoch
244 );
245 }
246
247 Ok(())
248}
249
250fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252 println!("\n--- Configuration Benchmarking ---");
253
254 let configs = vec![
256 (
257 "Conservative",
258 TrainingConfig {
259 learning_rate: 0.001,
260 weight_decay: 0.001,
261 beta1: 0.95,
262 ..Default::default()
263 },
264 ),
265 (
266 "Balanced",
267 TrainingConfig {
268 learning_rate: 0.01,
269 weight_decay: 0.0,
270 beta1: 0.9,
271 ..Default::default()
272 },
273 ),
274 (
275 "Aggressive",
276 TrainingConfig {
277 learning_rate: 0.1,
278 weight_decay: 0.0,
279 beta1: 0.8,
280 ..Default::default()
281 },
282 ),
283 ];
284
285 let mut benchmark_results = Vec::new();
286
287 for (name, config) in configs {
288 println!("\nBenchmarking {} configuration:", name);
289
290 let start_time = std::time::Instant::now();
291 let stats = train_with_config(config.clone())?;
292 let elapsed = start_time.elapsed();
293
294 println!(" Training time: {:.2}ms", elapsed.as_millis());
295 println!(" Final loss: {:.6}", stats.final_loss);
296 println!(" Convergence: {} epochs", stats.convergence_epoch);
297
298 benchmark_results.push((name.to_string(), stats, elapsed));
299 }
300
301 println!("\nBenchmarking Summary:");
303 for (name, stats, elapsed) in &benchmark_results {
304 println!(
305 " {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306 name,
307 stats.final_loss,
308 elapsed.as_millis(),
309 stats.convergence_epoch
310 );
311 }
312
313 Ok(())
314}
315
316fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322 let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326 let adam_config = AdamConfig {
328 learning_rate: config.learning_rate,
329 beta1: config.beta1,
330 beta2: config.beta2,
331 eps: 1e-8,
332 weight_decay: config.weight_decay,
333 amsgrad: false,
334 };
335
336 let mut optimizer = Adam::with_config(adam_config);
337 optimizer.add_parameter(&weight);
338 optimizer.add_parameter(&bias);
339
340 let mut losses = Vec::new();
342 let mut convergence_epoch = config.epochs;
343
344 for epoch in 0..config.epochs {
345 let y_pred = x_data.matmul(&weight) + &bias;
347 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349 loss.backward(None);
351
352 optimizer.step(&mut [&mut weight, &mut bias]);
354 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356 let loss_value = loss.value();
357 losses.push(loss_value);
358
359 if loss_value < 0.01 && convergence_epoch == config.epochs {
361 convergence_epoch = epoch;
362 }
363 }
364
365 Ok(TrainingStats {
366 config,
367 final_loss: losses[losses.len() - 1],
368 loss_history: losses,
369 convergence_epoch,
370 weight_norm: weight.norm().value(),
371 })
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_default_adam_convergence() {
380 let config = TrainingConfig::default();
381 let stats = train_with_config(config).unwrap();
382
383 assert!(stats.final_loss < 1.0);
384 assert!(stats.convergence_epoch < config.epochs);
385 }
386
387 #[test]
388 fn test_learning_rate_effect() {
389 let config_slow = TrainingConfig {
390 learning_rate: 0.001,
391 ..Default::default()
392 };
393 let config_fast = TrainingConfig {
394 learning_rate: 0.1,
395 ..Default::default()
396 };
397
398 let stats_slow = train_with_config(config_slow).unwrap();
399 let stats_fast = train_with_config(config_fast).unwrap();
400
401 assert!(stats_fast.convergence_epoch <= stats_slow.convergence_epoch);
403 }
404
405 #[test]
406 fn test_weight_decay_effect() {
407 let config_no_decay = TrainingConfig {
408 weight_decay: 0.0,
409 ..Default::default()
410 };
411 let config_with_decay = TrainingConfig {
412 weight_decay: 0.01,
413 ..Default::default()
414 };
415
416 let stats_no_decay = train_with_config(config_no_decay).unwrap();
417 let stats_with_decay = train_with_config(config_with_decay).unwrap();
418
419 assert!(stats_with_decay.weight_norm <= stats_no_decay.weight_norm);
421 }
422}