adam_configurations/
adam_configurations.rs

1//! Adam Configurations Example
2//!
3//! This example demonstrates different Adam optimizer configurations and their
4//! impact on neural network training convergence and performance:
5//! - Default Adam configuration for baseline performance
6//! - Custom learning rates and their effects
7//! - Weight decay regularization techniques
8//! - Beta parameter tuning for momentum control
9//! - Performance comparison across configurations
10//!
11//! # Learning Objectives
12//!
13//! - Understand Adam hyperparameter configuration
14//! - Learn how learning rate affects convergence
15//! - Explore weight decay for regularization
16//! - Compare beta parameters for momentum control
17//! - Implement configuration benchmarking workflows
18//!
19//! # Prerequisites
20//!
21//! - Basic Rust knowledge
22//! - Understanding of tensor operations
23//! - Familiarity with neural network training loops
24//! - Knowledge of optimization concepts
25//!
26//! # Usage
27//!
28//! ```bash
29//! cargo run --example adam_configurations
30//! ```
31
32use train_station::{
33    optimizers::{Adam, AdamConfig, Optimizer},
34    Tensor,
35};
36
37/// Configuration for training experiments
38#[derive(Debug, Clone, PartialEq)]
39struct TrainingConfig {
40    pub epochs: usize,
41    pub learning_rate: f32,
42    pub weight_decay: f32,
43    pub beta1: f32,
44    pub beta2: f32,
45}
46
47impl Default for TrainingConfig {
48    fn default() -> Self {
49        Self {
50            epochs: 100,
51            learning_rate: 0.01,
52            weight_decay: 0.0,
53            beta1: 0.9,
54            beta2: 0.999,
55        }
56    }
57}
58
59/// Training statistics for performance analysis
60#[derive(Debug, Clone)]
61#[allow(dead_code)]
62struct TrainingStats {
63    pub config: TrainingConfig,
64    pub final_loss: f32,
65    pub loss_history: Vec<f32>,
66    pub convergence_epoch: usize,
67    pub weight_norm: f32,
68}
69
70fn main() -> Result<(), Box<dyn std::error::Error>> {
71    println!("=== Adam Configurations Example ===\n");
72
73    demonstrate_default_adam()?;
74    demonstrate_learning_rate_comparison()?;
75    demonstrate_weight_decay_comparison()?;
76    demonstrate_beta_parameter_tuning()?;
77    demonstrate_configuration_benchmarking()?;
78
79    println!("\n=== Example completed successfully! ===");
80    Ok(())
81}
82
83/// Demonstrate default Adam configuration
84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85    println!("--- Default Adam Configuration ---");
86
87    // Create a simple regression problem: y = 2*x + 1
88    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91    // Create model parameters
92    let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95    // Create Adam optimizer with default configuration
96    let mut optimizer = Adam::new();
97    optimizer.add_parameter(&weight);
98    optimizer.add_parameter(&bias);
99
100    println!("Default Adam configuration:");
101    println!("  Learning rate: {}", optimizer.learning_rate());
102    println!("  Initial weight: {:.6}", weight.value());
103    println!("  Initial bias: {:.6}", bias.value());
104
105    // Training loop
106    let num_epochs = 50;
107    let mut losses = Vec::new();
108
109    for epoch in 0..num_epochs {
110        // Forward pass
111        let y_pred = x_data.matmul(&weight) + &bias;
112        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114        // Backward pass
115        loss.backward(None);
116
117        // Optimizer step
118        optimizer.step(&mut [&mut weight, &mut bias]);
119        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121        losses.push(loss.value());
122
123        if epoch % 10 == 0 || epoch == num_epochs - 1 {
124            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125        }
126    }
127
128    // Evaluate final model
129    let _final_predictions = x_data.matmul(&weight) + &bias;
130    println!("\nFinal model:");
131    println!("  Learned weight: {:.6} (target: 2.0)", weight.value());
132    println!("  Learned bias: {:.6} (target: 1.0)", bias.value());
133    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
134
135    Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140    println!("\n--- Learning Rate Comparison ---");
141
142    let learning_rates = [0.001, 0.01, 0.1];
143    let mut results = Vec::new();
144
145    for &lr in &learning_rates {
146        println!("\nTesting learning rate: {}", lr);
147
148        let stats = train_with_config(TrainingConfig {
149            learning_rate: lr,
150            ..Default::default()
151        })?;
152
153        results.push((lr, stats.clone()));
154
155        println!("  Final loss: {:.6}", stats.final_loss);
156        println!("  Convergence epoch: {}", stats.convergence_epoch);
157    }
158
159    // Compare results
160    println!("\nLearning Rate Comparison Summary:");
161    for (lr, stats) in &results {
162        println!(
163            "  LR={:6}: Loss={:.6}, Converged@{}",
164            lr, stats.final_loss, stats.convergence_epoch
165        );
166    }
167
168    Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173    println!("\n--- Weight Decay Comparison ---");
174
175    let weight_decays = [0.0, 0.001, 0.01];
176    let mut results = Vec::new();
177
178    for &wd in &weight_decays {
179        println!("\nTesting weight decay: {}", wd);
180
181        let stats = train_with_config(TrainingConfig {
182            weight_decay: wd,
183            ..Default::default()
184        })?;
185
186        results.push((wd, stats.clone()));
187
188        println!("  Final loss: {:.6}", stats.final_loss);
189        println!("  Final weight norm: {:.6}", stats.weight_norm);
190    }
191
192    // Compare results
193    println!("\nWeight Decay Comparison Summary:");
194    for (wd, stats) in &results {
195        println!(
196            "  WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197            wd, stats.final_loss, stats.weight_norm
198        );
199    }
200
201    Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206    println!("\n--- Beta Parameter Tuning ---");
207
208    let beta_configs = [
209        (0.9, 0.999),  // Default
210        (0.8, 0.999),  // More aggressive momentum
211        (0.95, 0.999), // Less aggressive momentum
212        (0.9, 0.99),   // Faster second moment decay
213    ];
214
215    let mut results = Vec::new();
216
217    for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218        println!(
219            "\nTesting beta configuration {}: beta1={}, beta2={}",
220            i + 1,
221            beta1,
222            beta2
223        );
224
225        let config = TrainingConfig {
226            beta1: *beta1,
227            beta2: *beta2,
228            ..Default::default()
229        };
230
231        let stats = train_with_config(config)?;
232        results.push(((*beta1, *beta2), stats.clone()));
233
234        println!("  Final loss: {:.6}", stats.final_loss);
235        println!("  Convergence epoch: {}", stats.convergence_epoch);
236    }
237
238    // Compare results
239    println!("\nBeta Parameter Comparison Summary:");
240    for ((beta1, beta2), stats) in &results {
241        println!(
242            "  B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243            beta1, beta2, stats.final_loss, stats.convergence_epoch
244        );
245    }
246
247    Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252    println!("\n--- Configuration Benchmarking ---");
253
254    // Define configurations to benchmark
255    let configs = vec![
256        (
257            "Conservative",
258            TrainingConfig {
259                learning_rate: 0.001,
260                weight_decay: 0.001,
261                beta1: 0.95,
262                ..Default::default()
263            },
264        ),
265        (
266            "Balanced",
267            TrainingConfig {
268                learning_rate: 0.01,
269                weight_decay: 0.0,
270                beta1: 0.9,
271                ..Default::default()
272            },
273        ),
274        (
275            "Aggressive",
276            TrainingConfig {
277                learning_rate: 0.1,
278                weight_decay: 0.0,
279                beta1: 0.8,
280                ..Default::default()
281            },
282        ),
283    ];
284
285    let mut benchmark_results = Vec::new();
286
287    for (name, config) in configs {
288        println!("\nBenchmarking {} configuration:", name);
289
290        let start_time = std::time::Instant::now();
291        let stats = train_with_config(config.clone())?;
292        let elapsed = start_time.elapsed();
293
294        println!("  Training time: {:.2}ms", elapsed.as_millis());
295        println!("  Final loss: {:.6}", stats.final_loss);
296        println!("  Convergence: {} epochs", stats.convergence_epoch);
297
298        benchmark_results.push((name.to_string(), stats, elapsed));
299    }
300
301    // Summary
302    println!("\nBenchmarking Summary:");
303    for (name, stats, elapsed) in &benchmark_results {
304        println!(
305            "  {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306            name,
307            stats.final_loss,
308            elapsed.as_millis(),
309            stats.convergence_epoch
310        );
311    }
312
313    Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318    // Create training data
319    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322    // Create model parameters
323    let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326    // Create optimizer with custom configuration
327    let adam_config = AdamConfig {
328        learning_rate: config.learning_rate,
329        beta1: config.beta1,
330        beta2: config.beta2,
331        eps: 1e-8,
332        weight_decay: config.weight_decay,
333        amsgrad: false,
334    };
335
336    let mut optimizer = Adam::with_config(adam_config);
337    optimizer.add_parameter(&weight);
338    optimizer.add_parameter(&bias);
339
340    // Training loop
341    let mut losses = Vec::new();
342    let mut convergence_epoch = config.epochs;
343
344    for epoch in 0..config.epochs {
345        // Forward pass
346        let y_pred = x_data.matmul(&weight) + &bias;
347        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349        // Backward pass
350        loss.backward(None);
351
352        // Optimizer step
353        optimizer.step(&mut [&mut weight, &mut bias]);
354        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356        let loss_value = loss.value();
357        losses.push(loss_value);
358
359        // Check for convergence (loss < 0.01)
360        if loss_value < 0.01 && convergence_epoch == config.epochs {
361            convergence_epoch = epoch;
362        }
363    }
364
365    Ok(TrainingStats {
366        config,
367        final_loss: losses[losses.len() - 1],
368        loss_history: losses,
369        convergence_epoch,
370        weight_norm: weight.norm().value(),
371    })
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_default_adam_convergence() {
380        let config = TrainingConfig::default();
381        let stats = train_with_config(config).unwrap();
382
383        assert!(stats.final_loss < 1.0);
384        assert!(stats.convergence_epoch < config.epochs);
385    }
386
387    #[test]
388    fn test_learning_rate_effect() {
389        let config_slow = TrainingConfig {
390            learning_rate: 0.001,
391            ..Default::default()
392        };
393        let config_fast = TrainingConfig {
394            learning_rate: 0.1,
395            ..Default::default()
396        };
397
398        let stats_slow = train_with_config(config_slow).unwrap();
399        let stats_fast = train_with_config(config_fast).unwrap();
400
401        // Faster learning rate should converge faster (lower epoch count)
402        assert!(stats_fast.convergence_epoch <= stats_slow.convergence_epoch);
403    }
404
405    #[test]
406    fn test_weight_decay_effect() {
407        let config_no_decay = TrainingConfig {
408            weight_decay: 0.0,
409            ..Default::default()
410        };
411        let config_with_decay = TrainingConfig {
412            weight_decay: 0.01,
413            ..Default::default()
414        };
415
416        let stats_no_decay = train_with_config(config_no_decay).unwrap();
417        let stats_with_decay = train_with_config(config_with_decay).unwrap();
418
419        // Weight decay should result in smaller weight norms
420        assert!(stats_with_decay.weight_norm <= stats_no_decay.weight_norm);
421    }
422}