pub trait Optimizer {
// Required methods
fn step(&mut self, parameters: &mut [&mut Tensor]);
fn zero_grad(&mut self, parameters: &mut [&mut Tensor]);
fn learning_rate(&self) -> f32;
fn set_learning_rate(&mut self, lr: f32);
}Expand description
Universal trait for parameter optimization algorithms
This trait provides a unified interface for all optimization algorithms in the Train Station library, ensuring consistent behavior and API compatibility across different optimizers. The trait follows PyTorch conventions for familiar usage patterns while providing high-performance implementations optimized for the Train Station ecosystem.
§Design Principles
The Optimizer trait is designed around several key principles:
§Type Safety
- Parameter linking: Explicit parameter registration prevents runtime errors
- Mutable references: Parameter updates require exclusive access for thread safety
- Compile-time guarantees: Type system ensures correct usage patterns
- Memory safety: All operations are memory-safe with proper lifetime management
§Performance
- Zero-cost abstractions: Trait methods compile to direct function calls
- SIMD optimization: Implementations use vectorized operations when available
- Memory efficiency: Minimal overhead with optimized state management
- Cache-friendly: Memory access patterns optimized for CPU cache performance
§PyTorch Compatibility
- Familiar methods: Method names and semantics match PyTorch conventions
- Parameter management: Similar parameter linking and state management
- Learning rate control: Dynamic learning rate adjustment support
- Training workflows: Compatible with standard training loop patterns
§Required Methods
All optimizers must implement these core methods:
step()- Perform parameter updates based on current gradientszero_grad()- Clear accumulated gradients before backward passlearning_rate()- Get current learning rate for monitoringset_learning_rate()- Update learning rate for scheduling
§Usage Patterns
§Basic Usage
use train_station::{Tensor, optimizers::{Adam, Optimizer}};
// Create parameters and optimizer
let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(¶m);
// Training step
optimizer.zero_grad(&mut [&mut param]);
// ... forward pass and loss computation ...
// loss.backward(None);
optimizer.step(&mut [&mut param]);§Learning Rate Scheduling
use train_station::optimizers::{Adam, Optimizer};
let mut optimizer = Adam::new();
// ... parameter setup ...
for epoch in 0..100 {
// Decay learning rate every 10 epochs
if epoch % 10 == 0 {
let current_lr = optimizer.learning_rate();
optimizer.set_learning_rate(current_lr * 0.9);
}
// Training step
// ... training logic ...
}§Thread Safety
All optimizer implementations are required to be thread-safe:
- Send + Sync: Optimizers can be moved between threads and shared safely
- Exclusive access: Parameter updates require mutable references
- State isolation: Each optimizer instance maintains independent state
- Concurrent training: Multiple optimizers can run concurrently
§Performance Characteristics
Optimizer implementations are expected to provide:
- O(n) complexity: Linear time complexity with parameter count
- Minimal allocations: Avoid memory allocations in hot paths
- SIMD optimization: Use vectorized operations when available
- Cache efficiency: Optimize memory access patterns for CPU cache
§Implementors
Current optimizer implementations:
Adam- Adaptive Moment Estimation with momentum and bias correction
Future implementations may include:
- SGD - Stochastic Gradient Descent with momentum
- RMSprop - Root Mean Square Propagation
- AdamW - Adam with decoupled weight decay
Required Methods§
Sourcefn step(&mut self, parameters: &mut [&mut Tensor])
fn step(&mut self, parameters: &mut [&mut Tensor])
Perform a single optimization step to update parameters
This method performs the core optimization algorithm, updating all provided parameters based on their current gradients. The specific update rule depends on the optimizer implementation (Adam, SGD, etc.). Parameters must be linked to the optimizer before calling this method to ensure proper state management.
§Arguments
parameters- Mutable slice of parameter tensor references to update
§Behavior
The method performs these operations:
- Gradient validation: Ensures all parameters have computed gradients
- State update: Updates internal optimizer state (momentum, velocity, etc.)
- Parameter update: Applies the optimization algorithm to update parameter values
- Bias correction: Applies bias correction if required by the algorithm
§Requirements
- Parameter linking: All parameters must be linked via
add_parameter() - Gradient computation: Parameters must have gradients from
backward()call - Exclusive access: Requires mutable references for thread-safe updates
- Consistent state: Optimizer state must be consistent with parameter count
§Performance
- SIMD optimization: Uses vectorized operations when available
- Memory efficiency: Minimizes memory allocations during updates
- Cache-friendly: Optimized memory access patterns for performance
- Linear complexity: O(n) time complexity with parameter count
§Examples
use train_station::{Tensor, optimizers::{Adam, Optimizer}};
let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(¶m);
// After forward pass and backward pass
optimizer.step(&mut [&mut param]);Sourcefn zero_grad(&mut self, parameters: &mut [&mut Tensor])
fn zero_grad(&mut self, parameters: &mut [&mut Tensor])
Clear accumulated gradients for all parameters
This method resets all parameter gradients to zero, preparing for a new backward pass. It should be called before each backward pass to prevent gradient accumulation across multiple forward/backward cycles. This is essential for correct training behavior as gradients accumulate by default in the GradTrack system.
§Arguments
parameters- Mutable slice of parameter tensor references to clear gradients for
§Behavior
The method performs these operations:
- Gradient clearing: Sets all parameter gradients to zero
- Memory management: Efficiently manages gradient memory allocation
- State consistency: Maintains consistent gradient state across parameters
- GradTrack integration: Properly integrates with the automatic differentiation system
§Usage Pattern
This method should be called at the beginning of each training iteration:
- Clear gradients: Call
zero_grad()to reset gradients - Forward pass: Compute model output and loss
- Backward pass: Call
loss.backward()to compute gradients - Parameter update: Call
step()to update parameters
§Performance
- Efficient clearing: Optimized gradient clearing with minimal overhead
- Memory reuse: Reuses existing gradient memory when possible
- SIMD optimization: Uses vectorized operations for large parameter tensors
- Linear complexity: O(n) time complexity with total parameter count
§Examples
use train_station::{Tensor, optimizers::{Adam, Optimizer}};
let mut param = Tensor::randn(vec![10, 10], None).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(¶m);
// Training iteration
optimizer.zero_grad(&mut [&mut param]); // Clear gradients
// ... forward pass and loss computation ...
// loss.backward(None); // Compute gradients
optimizer.step(&mut [&mut param]); // Update parameters§Integration with GradTrack
The method integrates seamlessly with the GradTrack automatic differentiation system:
- Gradient storage: Clears gradients stored in tensor gradient fields
- Computation graph: Maintains proper computation graph state
- Memory efficiency: Efficiently manages gradient memory allocation
Sourcefn learning_rate(&self) -> f32
fn learning_rate(&self) -> f32
Get the current learning rate for monitoring and scheduling
This method returns the current learning rate used by the optimizer for parameter updates. For optimizers with adaptive learning rates, this returns the base learning rate that is modified by the adaptive algorithm. This method is essential for learning rate monitoring and implementing learning rate scheduling strategies.
§Returns
The current learning rate as a 32-bit floating-point value
§Behavior
The returned value represents:
- Base learning rate: The configured learning rate for the optimizer
- Current rate: The learning rate currently being used for updates
- Scheduling support: The rate that can be modified by learning rate schedulers
- Monitoring value: The rate that should be logged for training monitoring
§Usage Patterns
§Learning Rate Monitoring
use train_station::optimizers::{Adam, Optimizer};
let optimizer = Adam::new();
println!("Current learning rate: {}", optimizer.learning_rate());§Learning Rate Scheduling
use train_station::optimizers::{Adam, Optimizer};
let mut optimizer = Adam::new();
for epoch in 0..100 {
// Exponential decay every 10 epochs
if epoch % 10 == 0 && epoch > 0 {
let current_lr = optimizer.learning_rate();
optimizer.set_learning_rate(current_lr * 0.9);
}
}§Training Loop Integration
use train_station::optimizers::{Adam, Optimizer};
let mut optimizer = Adam::new();
// Training loop with learning rate logging
for epoch in 0..100 {
let lr = optimizer.learning_rate();
println!("Epoch {}: Learning rate = {:.6}", epoch, lr);
// ... training logic ...
}§Performance
- Constant time: O(1) time complexity for learning rate retrieval
- No allocations: No memory allocations during learning rate access
- Minimal overhead: Negligible performance impact for monitoring
§Thread Safety
This method is thread-safe and can be called concurrently with other read operations. It does not modify optimizer state and can be safely used for monitoring in multi-threaded training scenarios.
Sourcefn set_learning_rate(&mut self, lr: f32)
fn set_learning_rate(&mut self, lr: f32)
Update the learning rate for dynamic scheduling and adjustment
This method updates the learning rate used by the optimizer for parameter updates. It enables dynamic learning rate adjustment during training, which is essential for implementing learning rate scheduling strategies, adaptive training, and fine-tuning workflows. The new learning rate takes effect immediately for subsequent parameter updates.
§Arguments
lr- The new learning rate value (must be positive for meaningful optimization)
§Behavior
The method performs these operations:
- Rate validation: Ensures the learning rate is a valid positive value
- State update: Updates internal optimizer configuration with new rate
- Immediate effect: New rate applies to subsequent
step()calls - Consistency: Maintains optimizer state consistency across all parameters
§Learning Rate Scheduling
Common scheduling patterns supported:
- Exponential decay: Multiply by decay factor periodically
- Step decay: Reduce by fixed amount at specific epochs
- Cosine annealing: Smooth cosine-based learning rate schedule
- Adaptive adjustment: Dynamic adjustment based on training metrics
§Usage Patterns
§Exponential Decay Scheduling
use train_station::optimizers::{Adam, Optimizer};
let mut optimizer = Adam::new();
for epoch in 0..100 {
// Decay learning rate every 10 epochs
if epoch % 10 == 0 && epoch > 0 {
let current_lr = optimizer.learning_rate();
optimizer.set_learning_rate(current_lr * 0.95);
}
// ... training logic ...
}§Step-based Scheduling
use train_station::optimizers::{Adam, Optimizer};
let mut optimizer = Adam::new();
let initial_lr = 0.001;
for epoch in 0..100 {
// Step decay at specific epochs
let lr = match epoch {
0..=29 => initial_lr,
30..=59 => initial_lr * 0.1,
60..=89 => initial_lr * 0.01,
_ => initial_lr * 0.001,
};
optimizer.set_learning_rate(lr);
// ... training logic ...
}§Adaptive Adjustment
use train_station::optimizers::{Adam, Optimizer};
let mut optimizer = Adam::new();
let mut best_loss = f32::INFINITY;
let mut patience = 0;
for epoch in 0..100 {
// ... training and validation ...
let current_loss = 0.5; // Example validation loss
if current_loss < best_loss {
best_loss = current_loss;
patience = 0;
} else {
patience += 1;
if patience >= 5 {
// Reduce learning rate when loss plateaus
let current_lr = optimizer.learning_rate();
optimizer.set_learning_rate(current_lr * 0.5);
patience = 0;
}
}
}§Performance
- Constant time: O(1) time complexity for learning rate updates
- No allocations: No memory allocations during rate updates
- Immediate effect: Changes take effect for next parameter update
- Minimal overhead: Negligible performance impact on training
§Thread Safety
This method requires exclusive access (&mut self) and is thread-safe when
used with proper synchronization. Multiple threads should not modify the
learning rate concurrently without external synchronization.
§Validation
While the trait does not enforce validation, implementations should:
- Accept positive learning rates for normal optimization
- Handle zero learning rate (effectively disables updates)
- Consider very large rates that may cause numerical instability