Skip to main content

trustformers_optim/
traits.rs

1//! Advanced optimizer trait hierarchy for TrustformeRS.
2//!
3//! This module extends the base `Optimizer` trait from `trustformers-core` with
4//! additional specialized traits for different categories of optimizers, providing
5//! better organization and extensibility.
6//!
7//! # Trait Hierarchy
8//!
9//! ```text
10//! Optimizer (from trustformers-core)
11//!     │
12//!     ├── StatefulOptimizer
13//!     │   ├── MomentumOptimizer
14//!     │   │   ├── AdaptiveMomentumOptimizer  (Adam, AdamW, etc.)
15//!     │   │   └── ClassicalMomentumOptimizer (SGD with momentum)
16//!     │   └── SecondOrderOptimizer (L-BFGS, Newton-CG, etc.)
17//!     │
18//!     ├── DistributedOptimizer
19//!     │   ├── GradientCompressionOptimizer
20//!     │   ├── FederatedOptimizer
21//!     │   └── AsyncOptimizer
22//!     │
23//!     ├── HardwareOptimizer
24//!     │   ├── SIMDOptimizer
25//!     │   ├── GPUOptimizer
26//!     │   └── EdgeOptimizer
27//!     │
28//!     └── MetaOptimizer
29//!         ├── LookaheadOptimizer
30//!         ├── ScheduledOptimizer
31//!         └── CompositeOptimizer
32//! ```
33
34use crate::common::StateMemoryStats;
35use std::collections::HashMap;
36use trustformers_core::errors::Result;
37use trustformers_core::tensor::Tensor;
38use trustformers_core::traits::Optimizer;
39
40/// Extended optimizer trait with state management capabilities.
41///
42/// This trait builds on the base `Optimizer` trait to provide standardized
43/// state management, serialization, and configuration access.
44pub trait StatefulOptimizer: Optimizer {
45    /// The configuration type for this optimizer.
46    type Config: Clone + Send + Sync;
47
48    /// The state type used by this optimizer.
49    type State: Send + Sync;
50
51    /// Gets a reference to the optimizer's configuration.
52    fn config(&self) -> &Self::Config;
53
54    /// Gets a reference to the optimizer's internal state.
55    fn state(&self) -> &Self::State;
56
57    /// Gets a mutable reference to the optimizer's internal state.
58    fn state_mut(&mut self) -> &mut Self::State;
59
60    /// Saves the optimizer state to a dictionary for checkpointing.
61    fn state_dict(&self) -> Result<HashMap<String, Tensor>>;
62
63    /// Loads optimizer state from a dictionary during checkpoint restoration.
64    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()>;
65
66    /// Gets memory usage statistics for this optimizer.
67    fn memory_usage(&self) -> StateMemoryStats;
68
69    /// Resets the optimizer state (useful for training restarts).
70    fn reset_state(&mut self);
71
72    /// Returns the number of parameters being optimized.
73    fn num_parameters(&self) -> usize;
74}
75
76/// Trait for optimizers that use momentum-based updates.
77///
78/// This includes both classical momentum (SGD) and adaptive momentum (Adam family).
79pub trait MomentumOptimizer: StatefulOptimizer {
80    /// Gets the momentum decay coefficient (β1 in Adam, momentum in SGD).
81    fn momentum_coeff(&self) -> f32;
82
83    /// Sets the momentum decay coefficient.
84    fn set_momentum_coeff(&mut self, coeff: f32);
85
86    /// Gets the current momentum buffers (for debugging/analysis).
87    fn momentum_buffers(&self) -> &HashMap<String, Vec<f32>>;
88
89    /// Clears all momentum buffers (useful for fine-tuning).
90    fn clear_momentum(&mut self);
91}
92
93/// Trait for adaptive momentum optimizers (Adam, AdamW, RAdam, etc.).
94///
95/// These optimizers maintain both first and second moment estimates.
96pub trait AdaptiveMomentumOptimizer: MomentumOptimizer {
97    /// Gets the second moment decay coefficient (β2 in Adam).
98    fn variance_coeff(&self) -> f32;
99
100    /// Sets the second moment decay coefficient.
101    fn set_variance_coeff(&mut self, coeff: f32);
102
103    /// Gets the epsilon value for numerical stability.
104    fn epsilon(&self) -> f32;
105
106    /// Sets the epsilon value.
107    fn set_epsilon(&mut self, eps: f32);
108
109    /// Gets the current variance buffers (for debugging/analysis).
110    fn variance_buffers(&self) -> &HashMap<String, Vec<f32>>;
111
112    /// Clears variance buffers.
113    fn clear_variance(&mut self);
114
115    /// Applies bias correction to momentum and variance estimates.
116    fn apply_bias_correction(&self, momentum: f32, variance: f32, step: usize) -> (f32, f32);
117}
118
119/// Trait for classical momentum optimizers (SGD variants).
120pub trait ClassicalMomentumOptimizer: MomentumOptimizer {
121    /// Gets the dampening factor.
122    fn dampening(&self) -> f32;
123
124    /// Sets the dampening factor.
125    fn set_dampening(&mut self, dampening: f32);
126
127    /// Whether Nesterov momentum is enabled.
128    fn nesterov(&self) -> bool;
129
130    /// Enables or disables Nesterov momentum.
131    fn set_nesterov(&mut self, nesterov: bool);
132}
133
134/// Trait for second-order optimization methods.
135///
136/// These optimizers use curvature information (Hessian approximations).
137pub trait SecondOrderOptimizer: StatefulOptimizer {
138    /// The type used to represent curvature information.
139    type CurvatureInfo;
140
141    /// Updates the curvature approximation with new gradient information.
142    fn update_curvature(&mut self, gradients: &[Tensor]) -> Result<()>;
143
144    /// Gets the current curvature approximation.
145    fn curvature_info(&self) -> &Self::CurvatureInfo;
146
147    /// Applies the inverse Hessian approximation to compute search direction.
148    fn apply_inverse_hessian(&self, gradient: &Tensor) -> Result<Tensor>;
149
150    /// Gets the maximum number of curvature pairs stored (for L-BFGS).
151    fn history_size(&self) -> usize;
152}
153
154/// Trait for distributed optimization capabilities.
155///
156/// Provides interfaces for gradient synchronization and distributed training.
157pub trait DistributedOptimizer: Optimizer {
158    /// The communicator type used for distributed operations.
159    type Communicator;
160
161    /// Performs all-reduce operation on gradients.
162    fn all_reduce_gradients(&mut self, gradients: &mut [Tensor]) -> Result<()>;
163
164    /// Broadcasts parameters from rank 0 to all other ranks.
165    fn broadcast_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()>;
166
167    /// Gets the current rank in the distributed group.
168    fn rank(&self) -> usize;
169
170    /// Gets the total number of ranks in the distributed group.
171    fn world_size(&self) -> usize;
172
173    /// Synchronizes optimizer state across all ranks.
174    fn sync_state(&mut self) -> Result<()>;
175}
176
177/// Trait for optimizers with gradient compression capabilities.
178pub trait GradientCompressionOptimizer: DistributedOptimizer {
179    /// The compression method used.
180    type CompressionMethod;
181
182    /// Compresses gradients before communication.
183    fn compress_gradients(&self, gradients: &[Tensor]) -> Result<Vec<u8>>;
184
185    /// Decompresses received gradient data.
186    fn decompress_gradients(&self, data: &[u8]) -> Result<Vec<Tensor>>;
187
188    /// Gets the compression ratio achieved.
189    fn compression_ratio(&self) -> f32;
190
191    /// Sets the compression parameters.
192    fn set_compression_config(&mut self, config: Self::CompressionMethod);
193}
194
195/// Trait for federated learning optimizers.
196pub trait FederatedOptimizer: DistributedOptimizer {
197    /// Client information type.
198    type ClientInfo;
199
200    /// Aggregates model updates from multiple clients.
201    fn aggregate_updates(
202        &mut self,
203        updates: &[Tensor],
204        clients: &[Self::ClientInfo],
205    ) -> Result<Tensor>;
206
207    /// Selects clients for the next round of training.
208    fn select_clients(
209        &self,
210        available_clients: &[Self::ClientInfo],
211        num_clients: usize,
212    ) -> Vec<usize>;
213
214    /// Applies differential privacy to updates.
215    fn apply_differential_privacy(&mut self, update: &mut Tensor) -> Result<()>;
216}
217
218/// Trait for asynchronous optimization methods.
219pub trait AsyncOptimizer: DistributedOptimizer {
220    /// Applies delayed gradients with staleness compensation.
221    fn apply_delayed_gradients(&mut self, gradients: &[Tensor], staleness: usize) -> Result<()>;
222
223    /// Gets the maximum allowed staleness.
224    fn max_staleness(&self) -> usize;
225
226    /// Sets the staleness compensation method.
227    fn set_staleness_compensation(&mut self, method: StalenessCompensation);
228}
229
230/// Staleness compensation methods for asynchronous optimization.
231#[derive(Debug, Clone, Copy)]
232pub enum StalenessCompensation {
233    /// No compensation for staleness.
234    None,
235    /// Linear scaling by staleness factor.
236    Linear,
237    /// Exponential decay based on staleness.
238    Exponential,
239    /// Polynomial scaling with configurable degree.
240    Polynomial(f32),
241}
242
243/// Trait for hardware-specific optimizer optimizations.
244pub trait HardwareOptimizer: Optimizer {
245    /// The target hardware type.
246    type HardwareTarget;
247
248    /// Optimizes the optimizer for specific hardware.
249    fn optimize_for_hardware(&mut self, target: Self::HardwareTarget) -> Result<()>;
250
251    /// Gets hardware utilization statistics.
252    fn hardware_utilization(&self) -> HardwareStats;
253
254    /// Checks if the optimizer is compatible with the current hardware.
255    fn is_hardware_compatible(&self) -> bool;
256}
257
258/// Hardware utilization statistics.
259#[derive(Debug, Clone)]
260pub struct HardwareStats {
261    /// Memory bandwidth utilization (0.0 to 1.0).
262    pub memory_bandwidth_utilization: f32,
263    /// Compute utilization (0.0 to 1.0).
264    pub compute_utilization: f32,
265    /// Cache hit rate (0.0 to 1.0).
266    pub cache_hit_rate: f32,
267    /// FLOPS per second achieved.
268    pub flops_per_second: f64,
269}
270
271/// Trait for SIMD-optimized operations.
272pub trait SIMDOptimizer: HardwareOptimizer {
273    /// The SIMD instruction set being used.
274    type SIMDType;
275
276    /// Checks if SIMD operations are available.
277    fn simd_available(&self) -> bool;
278
279    /// Gets the SIMD vector width.
280    fn vector_width(&self) -> usize;
281
282    /// Applies SIMD-optimized parameter updates.
283    fn simd_update(&mut self, parameters: &mut [Tensor], gradients: &[Tensor]) -> Result<()>;
284}
285
286/// Trait for GPU-accelerated optimizers.
287pub trait GPUOptimizer: HardwareOptimizer {
288    /// The GPU compute capability.
289    type ComputeCapability;
290
291    /// Transfers optimizer state to GPU.
292    fn to_gpu(&mut self) -> Result<()>;
293
294    /// Transfers optimizer state to CPU.
295    fn to_cpu(&mut self) -> Result<()>;
296
297    /// Launches GPU kernels for parameter updates.
298    fn gpu_update(&mut self, parameters: &mut [Tensor], gradients: &[Tensor]) -> Result<()>;
299
300    /// Gets GPU memory usage.
301    fn gpu_memory_usage(&self) -> GPUMemoryStats;
302}
303
304/// GPU memory usage statistics.
305#[derive(Debug, Clone)]
306pub struct GPUMemoryStats {
307    /// Total GPU memory in bytes.
308    pub total_memory: usize,
309    /// Used GPU memory in bytes.
310    pub used_memory: usize,
311    /// Available GPU memory in bytes.
312    pub available_memory: usize,
313    /// Memory usage by optimizer state.
314    pub optimizer_memory: usize,
315}
316
317/// Trait for edge device optimized optimizers.
318pub trait EdgeOptimizer: HardwareOptimizer {
319    /// Power consumption statistics.
320    type PowerStats;
321
322    /// Optimizes for low power consumption.
323    fn optimize_for_power(&mut self) -> Result<()>;
324
325    /// Gets current power consumption statistics.
326    fn power_stats(&self) -> Self::PowerStats;
327
328    /// Reduces precision to save memory and power.
329    fn reduce_precision(&mut self, bits: u8) -> Result<()>;
330}
331
332/// Trait for meta-optimizers that wrap other optimizers.
333pub trait MetaOptimizer: Optimizer {
334    /// The base optimizer type being wrapped.
335    type BaseOptimizer: Optimizer;
336
337    /// Gets a reference to the base optimizer.
338    fn base_optimizer(&self) -> &Self::BaseOptimizer;
339
340    /// Gets a mutable reference to the base optimizer.
341    fn base_optimizer_mut(&mut self) -> &mut Self::BaseOptimizer;
342
343    /// Applies the meta-optimization strategy.
344    fn apply_meta_strategy(
345        &mut self,
346        parameters: &mut [Tensor],
347        gradients: &[Tensor],
348    ) -> Result<()>;
349}
350
351/// Trait for lookahead meta-optimizers.
352pub trait LookaheadOptimizer: MetaOptimizer {
353    /// Gets the lookahead step size (α).
354    fn lookahead_alpha(&self) -> f32;
355
356    /// Sets the lookahead step size.
357    fn set_lookahead_alpha(&mut self, alpha: f32);
358
359    /// Gets the lookahead update frequency (k).
360    fn lookahead_k(&self) -> usize;
361
362    /// Sets the lookahead update frequency.
363    fn set_lookahead_k(&mut self, k: usize);
364
365    /// Gets the slow weights (for debugging).
366    fn slow_weights(&self) -> &HashMap<String, Vec<f32>>;
367}
368
369/// Trait for scheduled optimizers with learning rate scheduling.
370pub trait ScheduledOptimizer: Optimizer {
371    /// The scheduler type.
372    type Scheduler;
373
374    /// Gets a reference to the scheduler.
375    fn scheduler(&self) -> &Self::Scheduler;
376
377    /// Gets a mutable reference to the scheduler.
378    fn scheduler_mut(&mut self) -> &mut Self::Scheduler;
379
380    /// Updates the learning rate based on the scheduler.
381    fn update_lr(&mut self) -> Result<()>;
382
383    /// Gets the current scheduled learning rate.
384    fn current_lr(&self) -> f32;
385}
386
387/// Trait for composite optimizers that combine multiple optimization strategies.
388pub trait CompositeOptimizer: Optimizer {
389    /// The component optimizer types.
390    type Components;
391
392    /// Gets references to all component optimizers.
393    fn components(&self) -> &Self::Components;
394
395    /// Gets mutable references to all component optimizers.
396    fn components_mut(&mut self) -> &mut Self::Components;
397
398    /// Applies updates from all component optimizers.
399    fn apply_composite_update(
400        &mut self,
401        parameters: &mut [Tensor],
402        gradients: &[Tensor],
403    ) -> Result<()>;
404
405    /// Gets the weight assigned to each component.
406    fn component_weights(&self) -> Vec<f32>;
407
408    /// Sets the weights for each component.
409    fn set_component_weights(&mut self, weights: Vec<f32>) -> Result<()>;
410}
411
412/// Optimizer factory trait for creating optimizers with different configurations.
413pub trait OptimizerFactory {
414    /// The optimizer type produced by this factory.
415    type Optimizer: Optimizer;
416
417    /// The configuration type for the optimizer.
418    type Config;
419
420    /// Creates a new optimizer with the given configuration.
421    fn create(&self, config: Self::Config) -> Result<Self::Optimizer>;
422
423    /// Lists all available optimizer variants.
424    fn available_variants(&self) -> Vec<&'static str>;
425
426    /// Creates an optimizer by name with default configuration.
427    fn create_by_name(&self, name: &str) -> Result<Self::Optimizer>;
428}
429
430/// Trait for optimizers that can be serialized and restored.
431pub trait SerializableOptimizer: Optimizer {
432    /// Serializes the optimizer to bytes.
433    fn serialize(&self) -> Result<Vec<u8>>;
434
435    /// Deserializes an optimizer from bytes.
436    fn deserialize(data: &[u8]) -> Result<Self>
437    where
438        Self: Sized;
439
440    /// Gets the serialization format version.
441    fn version(&self) -> u32;
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_staleness_compensation() {
450        let compensation = StalenessCompensation::Linear;
451        match compensation {
452            StalenessCompensation::Linear => assert!(true),
453            _ => assert!(false),
454        }
455    }
456
457    #[test]
458    fn test_hardware_stats() {
459        let stats = HardwareStats {
460            memory_bandwidth_utilization: 0.8,
461            compute_utilization: 0.9,
462            cache_hit_rate: 0.95,
463            flops_per_second: 1e12,
464        };
465
466        assert_eq!(stats.memory_bandwidth_utilization, 0.8);
467        assert_eq!(stats.compute_utilization, 0.9);
468        assert_eq!(stats.cache_hit_rate, 0.95);
469        assert_eq!(stats.flops_per_second, 1e12);
470    }
471
472    #[test]
473    fn test_gpu_memory_stats() {
474        let stats = GPUMemoryStats {
475            total_memory: 16 * 1024 * 1024 * 1024,    // 16 GB
476            used_memory: 8 * 1024 * 1024 * 1024,      // 8 GB
477            available_memory: 8 * 1024 * 1024 * 1024, // 8 GB
478            optimizer_memory: 1 * 1024 * 1024 * 1024, // 1 GB
479        };
480
481        assert_eq!(stats.total_memory, 16 * 1024 * 1024 * 1024);
482        assert_eq!(
483            stats.used_memory + stats.available_memory,
484            stats.total_memory
485        );
486        assert!(stats.optimizer_memory <= stats.used_memory);
487    }
488}