Expand description
High-performance optimization algorithms for machine learning training
This module provides a comprehensive suite of optimization algorithms designed for maximum performance and compatibility with modern machine learning workflows. All optimizers are implemented with zero external dependencies and feature SIMD-optimized parameter updates for optimal training performance.
§Purpose
The optimizer module serves as the core parameter optimization layer for the Train Station machine learning library, providing:
- High-performance implementations: SIMD-optimized parameter updates with AVX2 support
- PyTorch compatibility: Familiar interfaces and parameter semantics for easy migration
- GradTrack integration: Seamless integration with the automatic differentiation system
- Memory efficiency: Optimized state management with minimal memory overhead
- Thread safety: All optimizers are thread-safe and support concurrent training
- Serialization support: Complete state serialization for model checkpointing
§Supported Optimizers
§Adam Optimizer
- Adaptive learning rates: Per-parameter adaptive learning rate adjustment
- Momentum: First and second moment estimation for stable convergence
- Bias correction: Proper bias correction for early training stability
- AMSGrad variant: Optional AMSGrad variant for improved convergence
- Weight decay: L2 regularization support for model regularization
- SIMD optimization: AVX2-optimized parameter updates for maximum performance
§Design Philosophy
§Performance First
- SIMD optimization: All parameter updates use vectorized operations when available
- Memory efficiency: Minimal memory overhead with optimized state storage
- Zero allocations: Hot paths avoid memory allocations for maximum performance
- Cache-friendly: Memory access patterns optimized for CPU cache efficiency
§PyTorch Compatibility
- Familiar interfaces: Method names and semantics match PyTorch conventions
- Parameter linking: Explicit parameter registration for type safety
- Learning rate scheduling: Support for dynamic learning rate adjustment
- State management: Complete optimizer state serialization and restoration
§Thread Safety
- Concurrent training: All optimizers support multi-threaded parameter updates
- Exclusive access: Parameter updates require mutable references for safety
- State isolation: Each optimizer instance maintains independent state
- Atomic operations: Thread-safe operations where required
§Usage Patterns
§Basic Training Loop
use train_station::{Tensor, optimizers::{Adam, Optimizer}};
// Create model parameters
let mut weight = Tensor::randn(vec![10, 5], None).with_requires_grad();
let mut bias = Tensor::zeros(vec![10]).with_requires_grad();
// Create optimizer and link parameters
let mut optimizer = Adam::new();
optimizer.add_parameter(&weight);
optimizer.add_parameter(&bias);
// Training loop
for epoch in 0..100 {
// Forward pass (compute loss)
let input = Tensor::randn(vec![5, 3], None);
let output = weight.matmul(&input);
let output_with_bias = output + &bias.unsqueeze(1); // Broadcast bias to [10, 3]
let target = Tensor::randn(vec![10, 3], None);
let mut loss = (output_with_bias - &target).pow_scalar(2.0).sum();
// Backward pass
optimizer.zero_grad(&mut [&mut weight, &mut bias]);
loss.backward(None);
// Parameter update
optimizer.step(&mut [&mut weight, &mut bias]);
}§Custom Configuration
use train_station::optimizers::{Adam, AdamConfig, Optimizer};
// Create custom configuration
let config = AdamConfig {
learning_rate: 0.001,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
weight_decay: 0.01,
amsgrad: false,
};
// Create optimizer with custom configuration
let mut optimizer = Adam::with_config(config);§State Serialization
use train_station::optimizers::{Adam, Optimizer};
use train_station::serialization::{Serializable, Format};
let mut optimizer = Adam::new();
// ... training ...
// Save optimizer state
optimizer.save("optimizer.json", Format::Json).unwrap();
// Load optimizer state
let mut loaded_optimizer = Adam::load("optimizer.json", Format::Json).unwrap();§Performance Characteristics
§SIMD Optimization
- AVX2 support: Vectorized operations on x86_64 with AVX2 support
- Fallback paths: Optimized scalar implementations for non-SIMD hardware
- Automatic detection: Runtime CPU feature detection for optimal performance
- Memory alignment: Proper memory alignment for vectorized operations
§Memory Efficiency
- Minimal overhead: Optimized state storage with minimal memory footprint
- Lazy allocation: State allocated only when parameters are linked
- Memory reuse: Efficient memory reuse patterns to minimize allocations
- Cache optimization: Memory access patterns optimized for CPU cache
§Scalability
- Large models: Efficient handling of models with millions of parameters
- Batch processing: Optimized for typical machine learning batch sizes
- Concurrent training: Thread-safe operations for parallel training
- Memory scaling: Linear memory scaling with parameter count
§Thread Safety
All optimizers in this module are designed to be thread-safe:
- Exclusive access: Parameter updates require mutable references
- State isolation: Each optimizer instance maintains independent state
- Concurrent safe: Multiple optimizers can run concurrently on different parameters
- Atomic operations: Thread-safe operations where required for correctness
§Integration with GradTrack
The optimizers integrate seamlessly with the GradTrack automatic differentiation system:
- Gradient access: Automatic access to computed gradients from tensors
- Gradient clearing: Efficient gradient clearing before backward passes
- Computation graph: Proper integration with the computation graph system
- Memory management: Efficient gradient memory management during optimization
Structs§
- Adam
- Adam optimizer for neural network parameter optimization
- Adam
Config - Configuration for the Adam optimization algorithm
Traits§
- Optimizer
- Universal trait for parameter optimization algorithms