train_station/optimizers/mod.rs
1//! High-performance optimization algorithms for machine learning training
2//!
3//! This module provides a comprehensive suite of optimization algorithms designed for
4//! maximum performance and compatibility with modern machine learning workflows. All
5//! optimizers are implemented with zero external dependencies and feature SIMD-optimized
6//! parameter updates for optimal training performance.
7//!
8//! # Purpose
9//!
10//! The optimizer module serves as the core parameter optimization layer for the Train Station
11//! machine learning library, providing:
12//! - **High-performance implementations**: SIMD-optimized parameter updates with AVX2 support
13//! - **PyTorch compatibility**: Familiar interfaces and parameter semantics for easy migration
14//! - **GradTrack integration**: Seamless integration with the automatic differentiation system
15//! - **Memory efficiency**: Optimized state management with minimal memory overhead
16//! - **Thread safety**: All optimizers are thread-safe and support concurrent training
17//! - **Serialization support**: Complete state serialization for model checkpointing
18//!
19//! # Supported Optimizers
20//!
21//! ## Adam Optimizer
22//! - **Adaptive learning rates**: Per-parameter adaptive learning rate adjustment
23//! - **Momentum**: First and second moment estimation for stable convergence
24//! - **Bias correction**: Proper bias correction for early training stability
25//! - **AMSGrad variant**: Optional AMSGrad variant for improved convergence
26//! - **Weight decay**: L2 regularization support for model regularization
27//! - **SIMD optimization**: AVX2-optimized parameter updates for maximum performance
28//!
29//! # Design Philosophy
30//!
31//! ## Performance First
32//! - **SIMD optimization**: All parameter updates use vectorized operations when available
33//! - **Memory efficiency**: Minimal memory overhead with optimized state storage
34//! - **Zero allocations**: Hot paths avoid memory allocations for maximum performance
35//! - **Cache-friendly**: Memory access patterns optimized for CPU cache efficiency
36//!
37//! ## PyTorch Compatibility
38//! - **Familiar interfaces**: Method names and semantics match PyTorch conventions
39//! - **Parameter linking**: Explicit parameter registration for type safety
40//! - **Learning rate scheduling**: Support for dynamic learning rate adjustment
41//! - **State management**: Complete optimizer state serialization and restoration
42//!
43//! ## Thread Safety
44//! - **Concurrent training**: All optimizers support multi-threaded parameter updates
45//! - **Exclusive access**: Parameter updates require mutable references for safety
46//! - **State isolation**: Each optimizer instance maintains independent state
47//! - **Atomic operations**: Thread-safe operations where required
48//!
49//! # Usage Patterns
50//!
51//! ## Basic Training Loop
52//! ```
53//! use train_station::{Tensor, optimizers::{Adam, Optimizer}};
54//!
55//! // Create model parameters
56//! let mut weight = Tensor::randn(vec![10, 5], None).with_requires_grad();
57//! let mut bias = Tensor::zeros(vec![10]).with_requires_grad();
58//!
59//! // Create optimizer and link parameters
60//! let mut optimizer = Adam::new();
61//! optimizer.add_parameter(&weight);
62//! optimizer.add_parameter(&bias);
63//!
64//! // Training loop
65//! for epoch in 0..100 {
66//! // Forward pass (compute loss)
67//! let input = Tensor::randn(vec![5, 3], None);
68//! let output = weight.matmul(&input);
69//! let output_with_bias = output + &bias.unsqueeze(1); // Broadcast bias to [10, 3]
70//! let target = Tensor::randn(vec![10, 3], None);
71//! let mut loss = (output_with_bias - &target).pow_scalar(2.0).sum();
72//!
73//! // Backward pass
74//! optimizer.zero_grad(&mut [&mut weight, &mut bias]);
75//! loss.backward(None);
76//!
77//! // Parameter update
78//! optimizer.step(&mut [&mut weight, &mut bias]);
79//! }
80//! ```
81//!
82//! ## Custom Configuration
83//! ```
84//! use train_station::optimizers::{Adam, AdamConfig, Optimizer};
85//!
86//! // Create custom configuration
87//! let config = AdamConfig {
88//! learning_rate: 0.001,
89//! beta1: 0.9,
90//! beta2: 0.999,
91//! eps: 1e-8,
92//! weight_decay: 0.01,
93//! amsgrad: false,
94//! };
95//!
96//! // Create optimizer with custom configuration
97//! let mut optimizer = Adam::with_config(config);
98//! ```
99//!
100//! ## State Serialization
101//! ```
102//! use train_station::optimizers::{Adam, Optimizer};
103//! use train_station::serialization::{Serializable, Format};
104//!
105//! let mut optimizer = Adam::new();
106//! // ... training ...
107//!
108//! // Save optimizer state
109//! optimizer.save("optimizer.json", Format::Json).unwrap();
110//!
111//! // Load optimizer state
112//! let mut loaded_optimizer = Adam::load("optimizer.json", Format::Json).unwrap();
113//! ```
114//!
115//! # Performance Characteristics
116//!
117//! ## SIMD Optimization
118//! - **AVX2 support**: Vectorized operations on x86_64 with AVX2 support
119//! - **Fallback paths**: Optimized scalar implementations for non-SIMD hardware
120//! - **Automatic detection**: Runtime CPU feature detection for optimal performance
121//! - **Memory alignment**: Proper memory alignment for vectorized operations
122//!
123//! ## Memory Efficiency
124//! - **Minimal overhead**: Optimized state storage with minimal memory footprint
125//! - **Lazy allocation**: State allocated only when parameters are linked
126//! - **Memory reuse**: Efficient memory reuse patterns to minimize allocations
127//! - **Cache optimization**: Memory access patterns optimized for CPU cache
128//!
129//! ## Scalability
130//! - **Large models**: Efficient handling of models with millions of parameters
131//! - **Batch processing**: Optimized for typical machine learning batch sizes
132//! - **Concurrent training**: Thread-safe operations for parallel training
133//! - **Memory scaling**: Linear memory scaling with parameter count
134//!
135//! # Thread Safety
136//!
137//! All optimizers in this module are designed to be thread-safe:
138//!
139//! - **Exclusive access**: Parameter updates require mutable references
140//! - **State isolation**: Each optimizer instance maintains independent state
141//! - **Concurrent safe**: Multiple optimizers can run concurrently on different parameters
142//! - **Atomic operations**: Thread-safe operations where required for correctness
143//!
144//! # Integration with GradTrack
145//!
146//! The optimizers integrate seamlessly with the GradTrack automatic differentiation system:
147//!
148//! - **Gradient access**: Automatic access to computed gradients from tensors
149//! - **Gradient clearing**: Efficient gradient clearing before backward passes
150//! - **Computation graph**: Proper integration with the computation graph system
151//! - **Memory management**: Efficient gradient memory management during optimization
152
153mod adam;
154
155pub use adam::{Adam, AdamConfig};
156
157/// Universal trait for parameter optimization algorithms
158///
159/// This trait provides a unified interface for all optimization algorithms in the Train Station
160/// library, ensuring consistent behavior and API compatibility across different optimizers.
161/// The trait follows PyTorch conventions for familiar usage patterns while providing
162/// high-performance implementations optimized for the Train Station ecosystem.
163///
164/// # Design Principles
165///
166/// The Optimizer trait is designed around several key principles:
167///
168/// ## Type Safety
169/// - **Parameter linking**: Explicit parameter registration prevents runtime errors
170/// - **Mutable references**: Parameter updates require exclusive access for thread safety
171/// - **Compile-time guarantees**: Type system ensures correct usage patterns
172/// - **Memory safety**: All operations are memory-safe with proper lifetime management
173///
174/// ## Performance
175/// - **Zero-cost abstractions**: Trait methods compile to direct function calls
176/// - **SIMD optimization**: Implementations use vectorized operations when available
177/// - **Memory efficiency**: Minimal overhead with optimized state management
178/// - **Cache-friendly**: Memory access patterns optimized for CPU cache performance
179///
180/// ## PyTorch Compatibility
181/// - **Familiar methods**: Method names and semantics match PyTorch conventions
182/// - **Parameter management**: Similar parameter linking and state management
183/// - **Learning rate control**: Dynamic learning rate adjustment support
184/// - **Training workflows**: Compatible with standard training loop patterns
185///
186/// # Required Methods
187///
188/// All optimizers must implement these core methods:
189///
190/// * `step()` - Perform parameter updates based on current gradients
191/// * `zero_grad()` - Clear accumulated gradients before backward pass
192/// * `learning_rate()` - Get current learning rate for monitoring
193/// * `set_learning_rate()` - Update learning rate for scheduling
194///
195/// # Usage Patterns
196///
197/// ## Basic Usage
198/// ```
199/// use train_station::{Tensor, optimizers::{Adam, Optimizer}};
200///
201/// // Create parameters and optimizer
202/// let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
203/// let mut optimizer = Adam::new();
204/// optimizer.add_parameter(¶m);
205///
206/// // Training step
207/// optimizer.zero_grad(&mut [&mut param]);
208/// // ... forward pass and loss computation ...
209/// // loss.backward(None);
210/// optimizer.step(&mut [&mut param]);
211/// ```
212///
213/// ## Learning Rate Scheduling
214/// ```
215/// use train_station::optimizers::{Adam, Optimizer};
216///
217/// let mut optimizer = Adam::new();
218/// // ... parameter setup ...
219///
220/// for epoch in 0..100 {
221/// // Decay learning rate every 10 epochs
222/// if epoch % 10 == 0 {
223/// let current_lr = optimizer.learning_rate();
224/// optimizer.set_learning_rate(current_lr * 0.9);
225/// }
226///
227/// // Training step
228/// // ... training logic ...
229/// }
230/// ```
231///
232/// # Thread Safety
233///
234/// All optimizer implementations are required to be thread-safe:
235///
236/// - **Send + Sync**: Optimizers can be moved between threads and shared safely
237/// - **Exclusive access**: Parameter updates require mutable references
238/// - **State isolation**: Each optimizer instance maintains independent state
239/// - **Concurrent training**: Multiple optimizers can run concurrently
240///
241/// # Performance Characteristics
242///
243/// Optimizer implementations are expected to provide:
244///
245/// - **O(n) complexity**: Linear time complexity with parameter count
246/// - **Minimal allocations**: Avoid memory allocations in hot paths
247/// - **SIMD optimization**: Use vectorized operations when available
248/// - **Cache efficiency**: Optimize memory access patterns for CPU cache
249///
250/// # Implementors
251///
252/// Current optimizer implementations:
253///
254/// * `Adam` - Adaptive Moment Estimation with momentum and bias correction
255///
256/// Future implementations may include:
257/// * SGD - Stochastic Gradient Descent with momentum
258/// * RMSprop - Root Mean Square Propagation
259/// * AdamW - Adam with decoupled weight decay
260pub trait Optimizer {
261 /// Perform a single optimization step to update parameters
262 ///
263 /// This method performs the core optimization algorithm, updating all provided parameters
264 /// based on their current gradients. The specific update rule depends on the optimizer
265 /// implementation (Adam, SGD, etc.). Parameters must be linked to the optimizer before
266 /// calling this method to ensure proper state management.
267 ///
268 /// # Arguments
269 ///
270 /// * `parameters` - Mutable slice of parameter tensor references to update
271 ///
272 /// # Behavior
273 ///
274 /// The method performs these operations:
275 /// 1. **Gradient validation**: Ensures all parameters have computed gradients
276 /// 2. **State update**: Updates internal optimizer state (momentum, velocity, etc.)
277 /// 3. **Parameter update**: Applies the optimization algorithm to update parameter values
278 /// 4. **Bias correction**: Applies bias correction if required by the algorithm
279 ///
280 /// # Requirements
281 ///
282 /// - **Parameter linking**: All parameters must be linked via `add_parameter()`
283 /// - **Gradient computation**: Parameters must have gradients from `backward()` call
284 /// - **Exclusive access**: Requires mutable references for thread-safe updates
285 /// - **Consistent state**: Optimizer state must be consistent with parameter count
286 ///
287 /// # Performance
288 ///
289 /// - **SIMD optimization**: Uses vectorized operations when available
290 /// - **Memory efficiency**: Minimizes memory allocations during updates
291 /// - **Cache-friendly**: Optimized memory access patterns for performance
292 /// - **Linear complexity**: O(n) time complexity with parameter count
293 ///
294 /// # Examples
295 ///
296 /// ```
297 /// use train_station::{Tensor, optimizers::{Adam, Optimizer}};
298 ///
299 /// let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
300 /// let mut optimizer = Adam::new();
301 /// optimizer.add_parameter(¶m);
302 ///
303 /// // After forward pass and backward pass
304 /// optimizer.step(&mut [&mut param]);
305 /// ```
306 #[track_caller]
307 fn step(&mut self, parameters: &mut [&mut crate::tensor::core::Tensor]);
308
309 /// Clear accumulated gradients for all parameters
310 ///
311 /// This method resets all parameter gradients to zero, preparing for a new backward pass.
312 /// It should be called before each backward pass to prevent gradient accumulation across
313 /// multiple forward/backward cycles. This is essential for correct training behavior as
314 /// gradients accumulate by default in the GradTrack system.
315 ///
316 /// # Arguments
317 ///
318 /// * `parameters` - Mutable slice of parameter tensor references to clear gradients for
319 ///
320 /// # Behavior
321 ///
322 /// The method performs these operations:
323 /// 1. **Gradient clearing**: Sets all parameter gradients to zero
324 /// 2. **Memory management**: Efficiently manages gradient memory allocation
325 /// 3. **State consistency**: Maintains consistent gradient state across parameters
326 /// 4. **GradTrack integration**: Properly integrates with the automatic differentiation system
327 ///
328 /// # Usage Pattern
329 ///
330 /// This method should be called at the beginning of each training iteration:
331 /// 1. **Clear gradients**: Call `zero_grad()` to reset gradients
332 /// 2. **Forward pass**: Compute model output and loss
333 /// 3. **Backward pass**: Call `loss.backward()` to compute gradients
334 /// 4. **Parameter update**: Call `step()` to update parameters
335 ///
336 /// # Performance
337 ///
338 /// - **Efficient clearing**: Optimized gradient clearing with minimal overhead
339 /// - **Memory reuse**: Reuses existing gradient memory when possible
340 /// - **SIMD optimization**: Uses vectorized operations for large parameter tensors
341 /// - **Linear complexity**: O(n) time complexity with total parameter count
342 ///
343 /// # Examples
344 ///
345 /// ```
346 /// use train_station::{Tensor, optimizers::{Adam, Optimizer}};
347 ///
348 /// let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
349 /// let mut optimizer = Adam::new();
350 /// optimizer.add_parameter(¶m);
351 ///
352 /// // Training iteration
353 /// optimizer.zero_grad(&mut [&mut param]); // Clear gradients
354 /// // ... forward pass and loss computation ...
355 /// // loss.backward(None); // Compute gradients
356 /// optimizer.step(&mut [&mut param]); // Update parameters
357 /// ```
358 ///
359 /// # Integration with GradTrack
360 ///
361 /// The method integrates seamlessly with the GradTrack automatic differentiation system:
362 /// - **Gradient storage**: Clears gradients stored in tensor gradient fields
363 /// - **Computation graph**: Maintains proper computation graph state
364 /// - **Memory efficiency**: Efficiently manages gradient memory allocation
365 #[track_caller]
366 fn zero_grad(&mut self, parameters: &mut [&mut crate::tensor::core::Tensor]);
367
368 /// Get the current learning rate for monitoring and scheduling
369 ///
370 /// This method returns the current learning rate used by the optimizer for parameter
371 /// updates. For optimizers with adaptive learning rates, this returns the base learning
372 /// rate that is modified by the adaptive algorithm. This method is essential for
373 /// learning rate monitoring and implementing learning rate scheduling strategies.
374 ///
375 /// # Returns
376 ///
377 /// The current learning rate as a 32-bit floating-point value
378 ///
379 /// # Behavior
380 ///
381 /// The returned value represents:
382 /// - **Base learning rate**: The configured learning rate for the optimizer
383 /// - **Current rate**: The learning rate currently being used for updates
384 /// - **Scheduling support**: The rate that can be modified by learning rate schedulers
385 /// - **Monitoring value**: The rate that should be logged for training monitoring
386 ///
387 /// # Usage Patterns
388 ///
389 /// ## Learning Rate Monitoring
390 /// ```
391 /// use train_station::optimizers::{Adam, Optimizer};
392 ///
393 /// let optimizer = Adam::new();
394 /// println!("Current learning rate: {}", optimizer.learning_rate());
395 /// ```
396 ///
397 /// ## Learning Rate Scheduling
398 /// ```
399 /// use train_station::optimizers::{Adam, Optimizer};
400 ///
401 /// let mut optimizer = Adam::new();
402 ///
403 /// for epoch in 0..100 {
404 /// // Exponential decay every 10 epochs
405 /// if epoch % 10 == 0 && epoch > 0 {
406 /// let current_lr = optimizer.learning_rate();
407 /// optimizer.set_learning_rate(current_lr * 0.9);
408 /// }
409 /// }
410 /// ```
411 ///
412 /// ## Training Loop Integration
413 /// ```
414 /// use train_station::optimizers::{Adam, Optimizer};
415 ///
416 /// let mut optimizer = Adam::new();
417 ///
418 /// // Training loop with learning rate logging
419 /// for epoch in 0..100 {
420 /// let lr = optimizer.learning_rate();
421 /// println!("Epoch {}: Learning rate = {:.6}", epoch, lr);
422 ///
423 /// // ... training logic ...
424 /// }
425 /// ```
426 ///
427 /// # Performance
428 ///
429 /// - **Constant time**: O(1) time complexity for learning rate retrieval
430 /// - **No allocations**: No memory allocations during learning rate access
431 /// - **Minimal overhead**: Negligible performance impact for monitoring
432 ///
433 /// # Thread Safety
434 ///
435 /// This method is thread-safe and can be called concurrently with other read operations.
436 /// It does not modify optimizer state and can be safely used for monitoring in
437 /// multi-threaded training scenarios.
438 #[track_caller]
439 fn learning_rate(&self) -> f32;
440
441 /// Update the learning rate for dynamic scheduling and adjustment
442 ///
443 /// This method updates the learning rate used by the optimizer for parameter updates.
444 /// It enables dynamic learning rate adjustment during training, which is essential
445 /// for implementing learning rate scheduling strategies, adaptive training, and
446 /// fine-tuning workflows. The new learning rate takes effect immediately for
447 /// subsequent parameter updates.
448 ///
449 /// # Arguments
450 ///
451 /// * `lr` - The new learning rate value (must be positive for meaningful optimization)
452 ///
453 /// # Behavior
454 ///
455 /// The method performs these operations:
456 /// 1. **Rate validation**: Ensures the learning rate is a valid positive value
457 /// 2. **State update**: Updates internal optimizer configuration with new rate
458 /// 3. **Immediate effect**: New rate applies to subsequent `step()` calls
459 /// 4. **Consistency**: Maintains optimizer state consistency across all parameters
460 ///
461 /// # Learning Rate Scheduling
462 ///
463 /// Common scheduling patterns supported:
464 /// - **Exponential decay**: Multiply by decay factor periodically
465 /// - **Step decay**: Reduce by fixed amount at specific epochs
466 /// - **Cosine annealing**: Smooth cosine-based learning rate schedule
467 /// - **Adaptive adjustment**: Dynamic adjustment based on training metrics
468 ///
469 /// # Usage Patterns
470 ///
471 /// ## Exponential Decay Scheduling
472 /// ```
473 /// use train_station::optimizers::{Adam, Optimizer};
474 ///
475 /// let mut optimizer = Adam::new();
476 ///
477 /// for epoch in 0..100 {
478 /// // Decay learning rate every 10 epochs
479 /// if epoch % 10 == 0 && epoch > 0 {
480 /// let current_lr = optimizer.learning_rate();
481 /// optimizer.set_learning_rate(current_lr * 0.95);
482 /// }
483 ///
484 /// // ... training logic ...
485 /// }
486 /// ```
487 ///
488 /// ## Step-based Scheduling
489 /// ```
490 /// use train_station::optimizers::{Adam, Optimizer};
491 ///
492 /// let mut optimizer = Adam::new();
493 /// let initial_lr = 0.001;
494 ///
495 /// for epoch in 0..100 {
496 /// // Step decay at specific epochs
497 /// let lr = match epoch {
498 /// 0..=29 => initial_lr,
499 /// 30..=59 => initial_lr * 0.1,
500 /// 60..=89 => initial_lr * 0.01,
501 /// _ => initial_lr * 0.001,
502 /// };
503 /// optimizer.set_learning_rate(lr);
504 ///
505 /// // ... training logic ...
506 /// }
507 /// ```
508 ///
509 /// ## Adaptive Adjustment
510 /// ```
511 /// use train_station::optimizers::{Adam, Optimizer};
512 ///
513 /// let mut optimizer = Adam::new();
514 /// let mut best_loss = f32::INFINITY;
515 /// let mut patience = 0;
516 ///
517 /// for epoch in 0..100 {
518 /// // ... training and validation ...
519 /// let current_loss = 0.5; // Example validation loss
520 ///
521 /// if current_loss < best_loss {
522 /// best_loss = current_loss;
523 /// patience = 0;
524 /// } else {
525 /// patience += 1;
526 /// if patience >= 5 {
527 /// // Reduce learning rate when loss plateaus
528 /// let current_lr = optimizer.learning_rate();
529 /// optimizer.set_learning_rate(current_lr * 0.5);
530 /// patience = 0;
531 /// }
532 /// }
533 /// }
534 /// ```
535 ///
536 /// # Performance
537 ///
538 /// - **Constant time**: O(1) time complexity for learning rate updates
539 /// - **No allocations**: No memory allocations during rate updates
540 /// - **Immediate effect**: Changes take effect for next parameter update
541 /// - **Minimal overhead**: Negligible performance impact on training
542 ///
543 /// # Thread Safety
544 ///
545 /// This method requires exclusive access (`&mut self`) and is thread-safe when
546 /// used with proper synchronization. Multiple threads should not modify the
547 /// learning rate concurrently without external synchronization.
548 ///
549 /// # Validation
550 ///
551 /// While the trait does not enforce validation, implementations should:
552 /// - Accept positive learning rates for normal optimization
553 /// - Handle zero learning rate (effectively disables updates)
554 /// - Consider very large rates that may cause numerical instability
555 #[track_caller]
556 fn set_learning_rate(&mut self, lr: f32);
557}