optimizer_basics/
optimizer_basics.rs

1//! Optimizer Basics Example
2//!
3//! This example demonstrates how to use optimizers in Train Station:
4//! - Setting up Adam optimizer with parameters
5//! - Training a simple linear regression model
6//! - Learning rate scheduling and monitoring
7//! - Advanced training patterns and analysis
8//!
9//! # Learning Objectives
10//!
11//! - Understand optimizer setup and parameter management
12//! - Learn to implement basic training loops
13//! - Explore learning rate scheduling techniques
14//! - Monitor training progress and convergence
15//!
16//! # Prerequisites
17//!
18//! - Basic Rust knowledge
19//! - Understanding of tensor basics (see tensor_basics.rs)
20//! - Familiarity with gradient descent concepts
21//!
22//! # Usage
23//!
24//! ```bash
25//! cargo run --example optimizer_basics
26//! ```
27
28use train_station::{
29    optimizers::{Adam, AdamConfig, Optimizer},
30    Tensor,
31};
32
33fn main() -> Result<(), Box<dyn std::error::Error>> {
34    println!("=== Optimizer Basics Example ===\n");
35
36    demonstrate_basic_optimizer_setup();
37    demonstrate_linear_regression()?;
38    demonstrate_advanced_training()?;
39    demonstrate_learning_rate_scheduling()?;
40    demonstrate_training_monitoring()?;
41
42    println!("\n=== Example completed successfully! ===");
43    Ok(())
44}
45
46/// Demonstrate basic optimizer setup and parameter management
47fn demonstrate_basic_optimizer_setup() {
48    println!("--- Basic Optimizer Setup ---");
49
50    // Create parameters that require gradients
51    let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52    let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54    println!("Created parameters:");
55    println!(
56        "  Weight: shape {:?}, requires_grad: {}",
57        weight.shape().dims,
58        weight.requires_grad()
59    );
60    println!(
61        "  Bias: shape {:?}, requires_grad: {}",
62        bias.shape().dims,
63        bias.requires_grad()
64    );
65
66    // Create Adam optimizer with default configuration
67    let mut optimizer = Adam::new();
68    println!(
69        "Created Adam optimizer with learning rate: {}",
70        optimizer.learning_rate()
71    );
72
73    // Add parameters to optimizer
74    optimizer.add_parameter(&weight);
75    optimizer.add_parameter(&bias);
76    println!(
77        "Added {} parameters to optimizer",
78        optimizer.parameter_count()
79    );
80
81    // Create optimizer with custom configuration
82    let config = AdamConfig {
83        learning_rate: 0.01,
84        beta1: 0.9,
85        beta2: 0.999,
86        eps: 1e-8,
87        weight_decay: 0.0,
88        amsgrad: false,
89    };
90
91    let mut custom_optimizer = Adam::with_config(config);
92    custom_optimizer.add_parameter(&weight);
93    custom_optimizer.add_parameter(&bias);
94
95    println!(
96        "Created custom optimizer with learning rate: {}",
97        custom_optimizer.learning_rate()
98    );
99
100    // Demonstrate parameter linking
101    println!("Parameter linking completed successfully");
102}
103
104/// Demonstrate simple linear regression training
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106    println!("\n--- Linear Regression Training ---");
107
108    // Create model parameters
109    let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112    // Create optimizer
113    let mut optimizer = Adam::with_learning_rate(0.01);
114    optimizer.add_parameter(&weight);
115    optimizer.add_parameter(&bias);
116
117    // Create simple training data: y = 2*x + 1
118    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121    println!("Training data:");
122    println!("  X: {:?}", x_data.data());
123    println!("  Y: {:?}", y_true.data());
124    println!("  Target: y = 2*x + 1");
125
126    // Training loop
127    let num_epochs = 100;
128    let mut losses = Vec::new();
129
130    for epoch in 0..num_epochs {
131        // Forward pass: y_pred = x * weight + bias
132        let y_pred = x_data.matmul(&weight) + &bias;
133
134        // Compute loss: MSE
135        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137        // Backward pass
138        loss.backward(None);
139
140        // Optimizer step
141        optimizer.step(&mut [&mut weight, &mut bias]);
142        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144        losses.push(loss.value());
145
146        // Print progress every 20 epochs
147        if epoch % 20 == 0 || epoch == num_epochs - 1 {
148            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149        }
150    }
151
152    // Evaluate final model
153    let final_predictions = x_data.matmul(&weight) + &bias;
154    println!("\nFinal model evaluation:");
155    println!("  Learned weight: {:.6}", weight.value());
156    println!("  Learned bias: {:.6}", bias.value());
157    println!("  Predictions vs True:");
158
159    for i in 0..5 {
160        let x1 = x_data.data()[i];
161        let pred = final_predictions.data()[i];
162        let true_val = y_true.data()[i];
163        println!(
164            "    x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165            x1,
166            pred,
167            true_val,
168            (pred - true_val).abs()
169        );
170    }
171
172    Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177    println!("\n--- Advanced Training Patterns ---");
178
179    // Create a more complex model
180    let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181    let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183    // Create optimizer with different learning rate
184    let mut optimizer = Adam::with_learning_rate(0.005);
185    optimizer.add_parameter(&weight);
186    optimizer.add_parameter(&bias);
187
188    // Create training data: y = 2*x + [1, 3]
189    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190    let y_true = Tensor::from_slice(
191        &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192        vec![5, 2],
193    )
194    .unwrap();
195
196    println!("Advanced training with monitoring:");
197    println!("  Initial learning rate: {}", optimizer.learning_rate());
198
199    // Training loop with monitoring
200    let num_epochs = 50;
201    let mut losses = Vec::new();
202    let mut weight_norms = Vec::new();
203    let mut gradient_norms = Vec::new();
204
205    for epoch in 0..num_epochs {
206        // Forward pass
207        let y_pred = x_data.matmul(&weight) + &bias;
208        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210        // Backward pass
211        loss.backward(None);
212
213        // Compute gradient norm before optimizer step
214        let gradient_norm = weight.grad_by_value().unwrap().norm();
215
216        // Optimizer step
217        optimizer.step(&mut [&mut weight, &mut bias]);
218        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220        // Learning rate scheduling: reduce every 10 epochs
221        if epoch > 0 && epoch % 10 == 0 {
222            let current_lr = optimizer.learning_rate();
223            let new_lr = current_lr * 0.5;
224            optimizer.set_learning_rate(new_lr);
225            println!(
226                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227                epoch, current_lr, new_lr
228            );
229        }
230
231        // Record metrics
232        losses.push(loss.value());
233        weight_norms.push(weight.norm().value());
234        gradient_norms.push(gradient_norm.value());
235
236        // Print detailed progress
237        if epoch % 10 == 0 || epoch == num_epochs - 1 {
238            println!(
239                "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240                epoch,
241                loss.value(),
242                weight.norm().value(),
243                gradient_norm.value()
244            );
245        }
246    }
247
248    println!("Final learning rate: {}", optimizer.learning_rate());
249
250    // Analyze training progression
251    let initial_loss = losses[0];
252    let final_loss = losses[losses.len() - 1];
253    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255    println!("\nTraining Analysis:");
256    println!("  Initial loss: {:.6}", initial_loss);
257    println!("  Final loss: {:.6}", final_loss);
258    println!("  Loss reduction: {:.1}%", loss_reduction);
259    println!("  Final weight norm: {:.6}", weight.norm().value());
260    println!("  Final bias: {:?}", bias.data());
261
262    Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267    println!("\n--- Learning Rate Scheduling ---");
268
269    // Create simple model
270    let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273    // Create optimizer with high initial learning rate
274    let mut optimizer = Adam::with_learning_rate(0.1);
275    optimizer.add_parameter(&weight);
276    optimizer.add_parameter(&bias);
277
278    // Simple data
279    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280    let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282    println!("Initial learning rate: {}", optimizer.learning_rate());
283
284    // Training loop with learning rate scheduling
285    let num_epochs = 50;
286    let mut losses = Vec::new();
287
288    for epoch in 0..num_epochs {
289        // Forward pass
290        let y_pred = x_data.matmul(&weight) + &bias;
291        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293        // Backward pass
294        loss.backward(None);
295
296        // Optimizer step
297        optimizer.step(&mut [&mut weight, &mut bias]);
298        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300        // Learning rate scheduling: reduce every 10 epochs
301        if epoch > 0 && epoch % 10 == 0 {
302            let current_lr = optimizer.learning_rate();
303            let new_lr = current_lr * 0.5;
304            optimizer.set_learning_rate(new_lr);
305            println!(
306                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307                epoch, current_lr, new_lr
308            );
309        }
310
311        losses.push(loss.value());
312
313        // Print progress
314        if epoch % 10 == 0 || epoch == num_epochs - 1 {
315            println!(
316                "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317                epoch,
318                loss.value(),
319                optimizer.learning_rate()
320            );
321        }
322    }
323
324    println!("Final learning rate: {}", optimizer.learning_rate());
325
326    Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331    println!("\n--- Training Monitoring ---");
332
333    // Create model
334    let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337    // Create optimizer
338    let mut optimizer = Adam::with_learning_rate(0.01);
339    optimizer.add_parameter(&weight);
340    optimizer.add_parameter(&bias);
341
342    // Training data
343    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346    // Training loop with comprehensive monitoring
347    let num_epochs = 30;
348    let mut losses = Vec::new();
349    let mut weight_history = Vec::new();
350    let mut bias_history = Vec::new();
351
352    for epoch in 0..num_epochs {
353        // Forward pass
354        let y_pred = x_data.matmul(&weight) + &bias;
355        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357        // Backward pass
358        loss.backward(None);
359
360        // Optimizer step
361        optimizer.step(&mut [&mut weight, &mut bias]);
362        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364        // Record history
365        losses.push(loss.value());
366        weight_history.push(weight.value());
367        bias_history.push(bias.value());
368
369        // Print detailed monitoring
370        if epoch % 5 == 0 || epoch == num_epochs - 1 {
371            println!(
372                "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373                epoch,
374                loss.value(),
375                weight.value(),
376                bias.value()
377            );
378        }
379    }
380
381    // Analyze training progression
382    println!("\nTraining Analysis:");
383    println!("  Initial loss: {:.6}", losses[0]);
384    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
385    println!(
386        "  Loss reduction: {:.1}%",
387        (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388    );
389
390    // Compute statistics
391    let loss_mean = compute_mean(&losses);
392    let loss_std = compute_std(&losses);
393    let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394    let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396    println!("  Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397    println!("  Weight change: {:.6}", weight_change);
398    println!("  Bias change: {:.6}", bias_change);
399    println!("  Final weight norm: {:.6}", weight.norm().value());
400    println!("  Final bias: {:.6}", bias.value());
401
402    Ok(())
403}
404
405/// Compute mean of a vector of f32 values
406fn compute_mean(values: &[f32]) -> f32 {
407    values.iter().sum::<f32>() / values.len() as f32
408}
409
410/// Compute standard deviation of a vector of f32 values
411fn compute_std(values: &[f32]) -> f32 {
412    let mean = compute_mean(values);
413    let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / values.len() as f32;
414    variance.sqrt()
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    /// Test basic optimizer functionality
422    #[test]
423    fn test_basic_optimizer() {
424        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
425        let mut optimizer = Adam::new();
426        optimizer.add_parameter(&weight);
427
428        // Simulate a training step
429        let mut loss = weight.sum();
430        loss.backward(None);
431        optimizer.step(&mut [&mut weight]);
432        optimizer.zero_grad(&mut [&mut weight]);
433
434        assert_eq!(optimizer.parameter_count(), 1);
435        assert!(optimizer.learning_rate() > 0.0);
436    }
437
438    /// Test linear regression training
439    #[test]
440    fn test_linear_regression() {
441        let mut weight = Tensor::randn(vec![1, 1], Some(47)).with_requires_grad();
442        let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
443
444        let mut optimizer = Adam::with_learning_rate(0.01);
445        optimizer.add_parameter(&weight);
446        optimizer.add_parameter(&bias);
447
448        let x = Tensor::ones(vec![1, 1]);
449        let y_true = Tensor::ones(vec![1, 1]);
450
451        // Single training step
452        let y_pred = x.matmul(&weight) + &bias;
453        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
454
455        loss.backward(None);
456        optimizer.step(&mut [&mut weight, &mut bias]);
457        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
458
459        // Loss should be finite
460        assert!(loss.value().is_finite());
461    }
462
463    /// Test learning rate scheduling
464    #[test]
465    fn test_learning_rate_scheduling() {
466        let mut optimizer = Adam::with_learning_rate(0.1);
467        assert_eq!(optimizer.learning_rate(), 0.1);
468
469        optimizer.set_learning_rate(0.05);
470        assert_eq!(optimizer.learning_rate(), 0.05);
471    }
472}