train_station/optimizers/
mod.rs

1//! High-performance optimization algorithms for machine learning training
2//!
3//! This module provides a comprehensive suite of optimization algorithms designed for
4//! maximum performance and compatibility with modern machine learning workflows. All
5//! optimizers are implemented with zero external dependencies and feature SIMD-optimized
6//! parameter updates for optimal training performance.
7//!
8//! # Purpose
9//!
10//! The optimizer module serves as the core parameter optimization layer for the Train Station
11//! machine learning library, providing:
12//! - **High-performance implementations**: SIMD-optimized parameter updates with AVX2 support
13//! - **PyTorch compatibility**: Familiar interfaces and parameter semantics for easy migration
14//! - **GradTrack integration**: Seamless integration with the automatic differentiation system
15//! - **Memory efficiency**: Optimized state management with minimal memory overhead
16//! - **Thread safety**: All optimizers are thread-safe and support concurrent training
17//! - **Serialization support**: Complete state serialization for model checkpointing
18//!
19//! # Supported Optimizers
20//!
21//! ## Adam Optimizer
22//! - **Adaptive learning rates**: Per-parameter adaptive learning rate adjustment
23//! - **Momentum**: First and second moment estimation for stable convergence
24//! - **Bias correction**: Proper bias correction for early training stability
25//! - **AMSGrad variant**: Optional AMSGrad variant for improved convergence
26//! - **Weight decay**: L2 regularization support for model regularization
27//! - **SIMD optimization**: AVX2-optimized parameter updates for maximum performance
28//!
29//! # Design Philosophy
30//!
31//! ## Performance First
32//! - **SIMD optimization**: All parameter updates use vectorized operations when available
33//! - **Memory efficiency**: Minimal memory overhead with optimized state storage
34//! - **Zero allocations**: Hot paths avoid memory allocations for maximum performance
35//! - **Cache-friendly**: Memory access patterns optimized for CPU cache efficiency
36//!
37//! ## PyTorch Compatibility
38//! - **Familiar interfaces**: Method names and semantics match PyTorch conventions
39//! - **Parameter linking**: Explicit parameter registration for type safety
40//! - **Learning rate scheduling**: Support for dynamic learning rate adjustment
41//! - **State management**: Complete optimizer state serialization and restoration
42//!
43//! ## Thread Safety
44//! - **Concurrent training**: All optimizers support multi-threaded parameter updates
45//! - **Exclusive access**: Parameter updates require mutable references for safety
46//! - **State isolation**: Each optimizer instance maintains independent state
47//! - **Atomic operations**: Thread-safe operations where required
48//!
49//! # Usage Patterns
50//!
51//! ## Basic Training Loop
52//! ```
53//! use train_station::{Tensor, optimizers::{Adam, Optimizer}};
54//!
55//! // Create model parameters
56//! let mut weight = Tensor::randn(vec![10, 5], None).with_requires_grad();
57//! let mut bias = Tensor::zeros(vec![10]).with_requires_grad();
58//!
59//! // Create optimizer and link parameters
60//! let mut optimizer = Adam::new();
61//! optimizer.add_parameter(&weight);
62//! optimizer.add_parameter(&bias);
63//!
64//! // Training loop
65//! for epoch in 0..100 {
66//!     // Forward pass (compute loss)
67//!     let input = Tensor::randn(vec![5, 3], None);
68//!     let output = weight.matmul(&input);
69//!     let output_with_bias = output + &bias.unsqueeze(1); // Broadcast bias to [10, 3]
70//!     let target = Tensor::randn(vec![10, 3], None);
71//!     let mut loss = (output_with_bias - &target).pow_scalar(2.0).sum();
72//!     
73//!     // Backward pass
74//!     optimizer.zero_grad(&mut [&mut weight, &mut bias]);
75//!     loss.backward(None);
76//!     
77//!     // Parameter update
78//!     optimizer.step(&mut [&mut weight, &mut bias]);
79//! }
80//! ```
81//!
82//! ## Custom Configuration
83//! ```
84//! use train_station::optimizers::{Adam, AdamConfig, Optimizer};
85//!
86//! // Create custom configuration
87//! let config = AdamConfig {
88//!     learning_rate: 0.001,
89//!     beta1: 0.9,
90//!     beta2: 0.999,
91//!     eps: 1e-8,
92//!     weight_decay: 0.01,
93//!     amsgrad: false,
94//! };
95//!
96//! // Create optimizer with custom configuration
97//! let mut optimizer = Adam::with_config(config);
98//! ```
99//!
100//! ## State Serialization
101//! ```
102//! use train_station::optimizers::{Adam, Optimizer};
103//! use train_station::serialization::{Serializable, Format};
104//!
105//! let mut optimizer = Adam::new();
106//! // ... training ...
107//!
108//! // Save optimizer state
109//! optimizer.save("optimizer.json", Format::Json).unwrap();
110//!
111//! // Load optimizer state
112//! let mut loaded_optimizer = Adam::load("optimizer.json", Format::Json).unwrap();
113//! ```
114//!
115//! # Performance Characteristics
116//!
117//! ## SIMD Optimization
118//! - **AVX2 support**: Vectorized operations on x86_64 with AVX2 support
119//! - **Fallback paths**: Optimized scalar implementations for non-SIMD hardware
120//! - **Automatic detection**: Runtime CPU feature detection for optimal performance
121//! - **Memory alignment**: Proper memory alignment for vectorized operations
122//!
123//! ## Memory Efficiency
124//! - **Minimal overhead**: Optimized state storage with minimal memory footprint
125//! - **Lazy allocation**: State allocated only when parameters are linked
126//! - **Memory reuse**: Efficient memory reuse patterns to minimize allocations
127//! - **Cache optimization**: Memory access patterns optimized for CPU cache
128//!
129//! ## Scalability
130//! - **Large models**: Efficient handling of models with millions of parameters
131//! - **Batch processing**: Optimized for typical machine learning batch sizes
132//! - **Concurrent training**: Thread-safe operations for parallel training
133//! - **Memory scaling**: Linear memory scaling with parameter count
134//!
135//! # Thread Safety
136//!
137//! All optimizers in this module are designed to be thread-safe:
138//!
139//! - **Exclusive access**: Parameter updates require mutable references
140//! - **State isolation**: Each optimizer instance maintains independent state
141//! - **Concurrent safe**: Multiple optimizers can run concurrently on different parameters
142//! - **Atomic operations**: Thread-safe operations where required for correctness
143//!
144//! # Integration with GradTrack
145//!
146//! The optimizers integrate seamlessly with the GradTrack automatic differentiation system:
147//!
148//! - **Gradient access**: Automatic access to computed gradients from tensors
149//! - **Gradient clearing**: Efficient gradient clearing before backward passes
150//! - **Computation graph**: Proper integration with the computation graph system
151//! - **Memory management**: Efficient gradient memory management during optimization
152
153mod adam;
154
155pub use adam::{Adam, AdamConfig};
156
157/// Universal trait for parameter optimization algorithms
158///
159/// This trait provides a unified interface for all optimization algorithms in the Train Station
160/// library, ensuring consistent behavior and API compatibility across different optimizers.
161/// The trait follows PyTorch conventions for familiar usage patterns while providing
162/// high-performance implementations optimized for the Train Station ecosystem.
163///
164/// # Design Principles
165///
166/// The Optimizer trait is designed around several key principles:
167///
168/// ## Type Safety
169/// - **Parameter linking**: Explicit parameter registration prevents runtime errors
170/// - **Mutable references**: Parameter updates require exclusive access for thread safety
171/// - **Compile-time guarantees**: Type system ensures correct usage patterns
172/// - **Memory safety**: All operations are memory-safe with proper lifetime management
173///
174/// ## Performance
175/// - **Zero-cost abstractions**: Trait methods compile to direct function calls
176/// - **SIMD optimization**: Implementations use vectorized operations when available
177/// - **Memory efficiency**: Minimal overhead with optimized state management
178/// - **Cache-friendly**: Memory access patterns optimized for CPU cache performance
179///
180/// ## PyTorch Compatibility
181/// - **Familiar methods**: Method names and semantics match PyTorch conventions
182/// - **Parameter management**: Similar parameter linking and state management
183/// - **Learning rate control**: Dynamic learning rate adjustment support
184/// - **Training workflows**: Compatible with standard training loop patterns
185///
186/// # Required Methods
187///
188/// All optimizers must implement these core methods:
189///
190/// * `step()` - Perform parameter updates based on current gradients
191/// * `zero_grad()` - Clear accumulated gradients before backward pass
192/// * `learning_rate()` - Get current learning rate for monitoring
193/// * `set_learning_rate()` - Update learning rate for scheduling
194///
195/// # Usage Patterns
196///
197/// ## Basic Usage
198/// ```
199/// use train_station::{Tensor, optimizers::{Adam, Optimizer}};
200///
201/// // Create parameters and optimizer
202/// let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
203/// let mut optimizer = Adam::new();
204/// optimizer.add_parameter(&param);
205///
206/// // Training step
207/// optimizer.zero_grad(&mut [&mut param]);
208/// // ... forward pass and loss computation ...
209/// // loss.backward(None);
210/// optimizer.step(&mut [&mut param]);
211/// ```
212///
213/// ## Learning Rate Scheduling
214/// ```
215/// use train_station::optimizers::{Adam, Optimizer};
216///
217/// let mut optimizer = Adam::new();
218/// // ... parameter setup ...
219///
220/// for epoch in 0..100 {
221///     // Decay learning rate every 10 epochs
222///     if epoch % 10 == 0 {
223///         let current_lr = optimizer.learning_rate();
224///         optimizer.set_learning_rate(current_lr * 0.9);
225///     }
226///     
227///     // Training step
228///     // ... training logic ...
229/// }
230/// ```
231///
232/// # Thread Safety
233///
234/// All optimizer implementations are required to be thread-safe:
235///
236/// - **Send + Sync**: Optimizers can be moved between threads and shared safely
237/// - **Exclusive access**: Parameter updates require mutable references
238/// - **State isolation**: Each optimizer instance maintains independent state
239/// - **Concurrent training**: Multiple optimizers can run concurrently
240///
241/// # Performance Characteristics
242///
243/// Optimizer implementations are expected to provide:
244///
245/// - **O(n) complexity**: Linear time complexity with parameter count
246/// - **Minimal allocations**: Avoid memory allocations in hot paths
247/// - **SIMD optimization**: Use vectorized operations when available
248/// - **Cache efficiency**: Optimize memory access patterns for CPU cache
249///
250/// # Implementors
251///
252/// Current optimizer implementations:
253///
254/// * `Adam` - Adaptive Moment Estimation with momentum and bias correction
255///
256/// Future implementations may include:
257/// * SGD - Stochastic Gradient Descent with momentum
258/// * RMSprop - Root Mean Square Propagation
259/// * AdamW - Adam with decoupled weight decay
260pub trait Optimizer {
261    /// Perform a single optimization step to update parameters
262    ///
263    /// This method performs the core optimization algorithm, updating all provided parameters
264    /// based on their current gradients. The specific update rule depends on the optimizer
265    /// implementation (Adam, SGD, etc.). Parameters must be linked to the optimizer before
266    /// calling this method to ensure proper state management.
267    ///
268    /// # Arguments
269    ///
270    /// * `parameters` - Mutable slice of parameter tensor references to update
271    ///
272    /// # Behavior
273    ///
274    /// The method performs these operations:
275    /// 1. **Gradient validation**: Ensures all parameters have computed gradients
276    /// 2. **State update**: Updates internal optimizer state (momentum, velocity, etc.)
277    /// 3. **Parameter update**: Applies the optimization algorithm to update parameter values
278    /// 4. **Bias correction**: Applies bias correction if required by the algorithm
279    ///
280    /// # Requirements
281    ///
282    /// - **Parameter linking**: All parameters must be linked via `add_parameter()`
283    /// - **Gradient computation**: Parameters must have gradients from `backward()` call
284    /// - **Exclusive access**: Requires mutable references for thread-safe updates
285    /// - **Consistent state**: Optimizer state must be consistent with parameter count
286    ///
287    /// # Performance
288    ///
289    /// - **SIMD optimization**: Uses vectorized operations when available
290    /// - **Memory efficiency**: Minimizes memory allocations during updates
291    /// - **Cache-friendly**: Optimized memory access patterns for performance
292    /// - **Linear complexity**: O(n) time complexity with parameter count
293    ///
294    /// # Examples
295    ///
296    /// ```
297    /// use train_station::{Tensor, optimizers::{Adam, Optimizer}};
298    ///
299    /// let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
300    /// let mut optimizer = Adam::new();
301    /// optimizer.add_parameter(&param);
302    ///
303    /// // After forward pass and backward pass
304    /// optimizer.step(&mut [&mut param]);
305    /// ```
306    #[track_caller]
307    fn step(&mut self, parameters: &mut [&mut crate::tensor::core::Tensor]);
308
309    /// Clear accumulated gradients for all parameters
310    ///
311    /// This method resets all parameter gradients to zero, preparing for a new backward pass.
312    /// It should be called before each backward pass to prevent gradient accumulation across
313    /// multiple forward/backward cycles. This is essential for correct training behavior as
314    /// gradients accumulate by default in the GradTrack system.
315    ///
316    /// # Arguments
317    ///
318    /// * `parameters` - Mutable slice of parameter tensor references to clear gradients for
319    ///
320    /// # Behavior
321    ///
322    /// The method performs these operations:
323    /// 1. **Gradient clearing**: Sets all parameter gradients to zero
324    /// 2. **Memory management**: Efficiently manages gradient memory allocation
325    /// 3. **State consistency**: Maintains consistent gradient state across parameters
326    /// 4. **GradTrack integration**: Properly integrates with the automatic differentiation system
327    ///
328    /// # Usage Pattern
329    ///
330    /// This method should be called at the beginning of each training iteration:
331    /// 1. **Clear gradients**: Call `zero_grad()` to reset gradients
332    /// 2. **Forward pass**: Compute model output and loss
333    /// 3. **Backward pass**: Call `loss.backward()` to compute gradients
334    /// 4. **Parameter update**: Call `step()` to update parameters
335    ///
336    /// # Performance
337    ///
338    /// - **Efficient clearing**: Optimized gradient clearing with minimal overhead
339    /// - **Memory reuse**: Reuses existing gradient memory when possible
340    /// - **SIMD optimization**: Uses vectorized operations for large parameter tensors
341    /// - **Linear complexity**: O(n) time complexity with total parameter count
342    ///
343    /// # Examples
344    ///
345    /// ```
346    /// use train_station::{Tensor, optimizers::{Adam, Optimizer}};
347    ///
348    /// let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
349    /// let mut optimizer = Adam::new();
350    /// optimizer.add_parameter(&param);
351    ///
352    /// // Training iteration
353    /// optimizer.zero_grad(&mut [&mut param]);  // Clear gradients
354    /// // ... forward pass and loss computation ...
355    /// // loss.backward(None);                   // Compute gradients
356    /// optimizer.step(&mut [&mut param]);       // Update parameters
357    /// ```
358    ///
359    /// # Integration with GradTrack
360    ///
361    /// The method integrates seamlessly with the GradTrack automatic differentiation system:
362    /// - **Gradient storage**: Clears gradients stored in tensor gradient fields
363    /// - **Computation graph**: Maintains proper computation graph state
364    /// - **Memory efficiency**: Efficiently manages gradient memory allocation
365    #[track_caller]
366    fn zero_grad(&mut self, parameters: &mut [&mut crate::tensor::core::Tensor]);
367
368    /// Get the current learning rate for monitoring and scheduling
369    ///
370    /// This method returns the current learning rate used by the optimizer for parameter
371    /// updates. For optimizers with adaptive learning rates, this returns the base learning
372    /// rate that is modified by the adaptive algorithm. This method is essential for
373    /// learning rate monitoring and implementing learning rate scheduling strategies.
374    ///
375    /// # Returns
376    ///
377    /// The current learning rate as a 32-bit floating-point value
378    ///
379    /// # Behavior
380    ///
381    /// The returned value represents:
382    /// - **Base learning rate**: The configured learning rate for the optimizer
383    /// - **Current rate**: The learning rate currently being used for updates
384    /// - **Scheduling support**: The rate that can be modified by learning rate schedulers
385    /// - **Monitoring value**: The rate that should be logged for training monitoring
386    ///
387    /// # Usage Patterns
388    ///
389    /// ## Learning Rate Monitoring
390    /// ```
391    /// use train_station::optimizers::{Adam, Optimizer};
392    ///
393    /// let optimizer = Adam::new();
394    /// println!("Current learning rate: {}", optimizer.learning_rate());
395    /// ```
396    ///
397    /// ## Learning Rate Scheduling
398    /// ```
399    /// use train_station::optimizers::{Adam, Optimizer};
400    ///
401    /// let mut optimizer = Adam::new();
402    ///
403    /// for epoch in 0..100 {
404    ///     // Exponential decay every 10 epochs
405    ///     if epoch % 10 == 0 && epoch > 0 {
406    ///         let current_lr = optimizer.learning_rate();
407    ///         optimizer.set_learning_rate(current_lr * 0.9);
408    ///     }
409    /// }
410    /// ```
411    ///
412    /// ## Training Loop Integration
413    /// ```
414    /// use train_station::optimizers::{Adam, Optimizer};
415    ///
416    /// let mut optimizer = Adam::new();
417    ///
418    /// // Training loop with learning rate logging
419    /// for epoch in 0..100 {
420    ///     let lr = optimizer.learning_rate();
421    ///     println!("Epoch {}: Learning rate = {:.6}", epoch, lr);
422    ///     
423    ///     // ... training logic ...
424    /// }
425    /// ```
426    ///
427    /// # Performance
428    ///
429    /// - **Constant time**: O(1) time complexity for learning rate retrieval
430    /// - **No allocations**: No memory allocations during learning rate access
431    /// - **Minimal overhead**: Negligible performance impact for monitoring
432    ///
433    /// # Thread Safety
434    ///
435    /// This method is thread-safe and can be called concurrently with other read operations.
436    /// It does not modify optimizer state and can be safely used for monitoring in
437    /// multi-threaded training scenarios.
438    #[track_caller]
439    fn learning_rate(&self) -> f32;
440
441    /// Update the learning rate for dynamic scheduling and adjustment
442    ///
443    /// This method updates the learning rate used by the optimizer for parameter updates.
444    /// It enables dynamic learning rate adjustment during training, which is essential
445    /// for implementing learning rate scheduling strategies, adaptive training, and
446    /// fine-tuning workflows. The new learning rate takes effect immediately for
447    /// subsequent parameter updates.
448    ///
449    /// # Arguments
450    ///
451    /// * `lr` - The new learning rate value (must be positive for meaningful optimization)
452    ///
453    /// # Behavior
454    ///
455    /// The method performs these operations:
456    /// 1. **Rate validation**: Ensures the learning rate is a valid positive value
457    /// 2. **State update**: Updates internal optimizer configuration with new rate
458    /// 3. **Immediate effect**: New rate applies to subsequent `step()` calls
459    /// 4. **Consistency**: Maintains optimizer state consistency across all parameters
460    ///
461    /// # Learning Rate Scheduling
462    ///
463    /// Common scheduling patterns supported:
464    /// - **Exponential decay**: Multiply by decay factor periodically
465    /// - **Step decay**: Reduce by fixed amount at specific epochs
466    /// - **Cosine annealing**: Smooth cosine-based learning rate schedule
467    /// - **Adaptive adjustment**: Dynamic adjustment based on training metrics
468    ///
469    /// # Usage Patterns
470    ///
471    /// ## Exponential Decay Scheduling
472    /// ```
473    /// use train_station::optimizers::{Adam, Optimizer};
474    ///
475    /// let mut optimizer = Adam::new();
476    ///
477    /// for epoch in 0..100 {
478    ///     // Decay learning rate every 10 epochs
479    ///     if epoch % 10 == 0 && epoch > 0 {
480    ///         let current_lr = optimizer.learning_rate();
481    ///         optimizer.set_learning_rate(current_lr * 0.95);
482    ///     }
483    ///     
484    ///     // ... training logic ...
485    /// }
486    /// ```
487    ///
488    /// ## Step-based Scheduling
489    /// ```
490    /// use train_station::optimizers::{Adam, Optimizer};
491    ///
492    /// let mut optimizer = Adam::new();
493    /// let initial_lr = 0.001;
494    ///
495    /// for epoch in 0..100 {
496    ///     // Step decay at specific epochs
497    ///     let lr = match epoch {
498    ///         0..=29 => initial_lr,
499    ///         30..=59 => initial_lr * 0.1,
500    ///         60..=89 => initial_lr * 0.01,
501    ///         _ => initial_lr * 0.001,
502    ///     };
503    ///     optimizer.set_learning_rate(lr);
504    ///     
505    ///     // ... training logic ...
506    /// }
507    /// ```
508    ///
509    /// ## Adaptive Adjustment
510    /// ```
511    /// use train_station::optimizers::{Adam, Optimizer};
512    ///
513    /// let mut optimizer = Adam::new();
514    /// let mut best_loss = f32::INFINITY;
515    /// let mut patience = 0;
516    ///
517    /// for epoch in 0..100 {
518    ///     // ... training and validation ...
519    ///     let current_loss = 0.5; // Example validation loss
520    ///     
521    ///     if current_loss < best_loss {
522    ///         best_loss = current_loss;
523    ///         patience = 0;
524    ///     } else {
525    ///         patience += 1;
526    ///         if patience >= 5 {
527    ///             // Reduce learning rate when loss plateaus
528    ///             let current_lr = optimizer.learning_rate();
529    ///             optimizer.set_learning_rate(current_lr * 0.5);
530    ///             patience = 0;
531    ///         }
532    ///     }
533    /// }
534    /// ```
535    ///
536    /// # Performance
537    ///
538    /// - **Constant time**: O(1) time complexity for learning rate updates
539    /// - **No allocations**: No memory allocations during rate updates
540    /// - **Immediate effect**: Changes take effect for next parameter update
541    /// - **Minimal overhead**: Negligible performance impact on training
542    ///
543    /// # Thread Safety
544    ///
545    /// This method requires exclusive access (`&mut self`) and is thread-safe when
546    /// used with proper synchronization. Multiple threads should not modify the
547    /// learning rate concurrently without external synchronization.
548    ///
549    /// # Validation
550    ///
551    /// While the trait does not enforce validation, implementations should:
552    /// - Accept positive learning rates for normal optimization
553    /// - Handle zero learning rate (effectively disables updates)
554    /// - Consider very large rates that may cause numerical instability
555    #[track_caller]
556    fn set_learning_rate(&mut self, lr: f32);
557}