trustformers_optim/traits.rs
1//! Advanced optimizer trait hierarchy for TrustformeRS.
2//!
3//! This module extends the base `Optimizer` trait from `trustformers-core` with
4//! additional specialized traits for different categories of optimizers, providing
5//! better organization and extensibility.
6//!
7//! # Trait Hierarchy
8//!
9//! ```text
10//! Optimizer (from trustformers-core)
11//! │
12//! ├── StatefulOptimizer
13//! │ ├── MomentumOptimizer
14//! │ │ ├── AdaptiveMomentumOptimizer (Adam, AdamW, etc.)
15//! │ │ └── ClassicalMomentumOptimizer (SGD with momentum)
16//! │ └── SecondOrderOptimizer (L-BFGS, Newton-CG, etc.)
17//! │
18//! ├── DistributedOptimizer
19//! │ ├── GradientCompressionOptimizer
20//! │ ├── FederatedOptimizer
21//! │ └── AsyncOptimizer
22//! │
23//! ├── HardwareOptimizer
24//! │ ├── SIMDOptimizer
25//! │ ├── GPUOptimizer
26//! │ └── EdgeOptimizer
27//! │
28//! └── MetaOptimizer
29//! ├── LookaheadOptimizer
30//! ├── ScheduledOptimizer
31//! └── CompositeOptimizer
32//! ```
33
34use crate::common::StateMemoryStats;
35use std::collections::HashMap;
36use trustformers_core::errors::Result;
37use trustformers_core::tensor::Tensor;
38use trustformers_core::traits::Optimizer;
39
40/// Extended optimizer trait with state management capabilities.
41///
42/// This trait builds on the base `Optimizer` trait to provide standardized
43/// state management, serialization, and configuration access.
44pub trait StatefulOptimizer: Optimizer {
45 /// The configuration type for this optimizer.
46 type Config: Clone + Send + Sync;
47
48 /// The state type used by this optimizer.
49 type State: Send + Sync;
50
51 /// Gets a reference to the optimizer's configuration.
52 fn config(&self) -> &Self::Config;
53
54 /// Gets a reference to the optimizer's internal state.
55 fn state(&self) -> &Self::State;
56
57 /// Gets a mutable reference to the optimizer's internal state.
58 fn state_mut(&mut self) -> &mut Self::State;
59
60 /// Saves the optimizer state to a dictionary for checkpointing.
61 fn state_dict(&self) -> Result<HashMap<String, Tensor>>;
62
63 /// Loads optimizer state from a dictionary during checkpoint restoration.
64 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()>;
65
66 /// Gets memory usage statistics for this optimizer.
67 fn memory_usage(&self) -> StateMemoryStats;
68
69 /// Resets the optimizer state (useful for training restarts).
70 fn reset_state(&mut self);
71
72 /// Returns the number of parameters being optimized.
73 fn num_parameters(&self) -> usize;
74}
75
76/// Trait for optimizers that use momentum-based updates.
77///
78/// This includes both classical momentum (SGD) and adaptive momentum (Adam family).
79pub trait MomentumOptimizer: StatefulOptimizer {
80 /// Gets the momentum decay coefficient (β1 in Adam, momentum in SGD).
81 fn momentum_coeff(&self) -> f32;
82
83 /// Sets the momentum decay coefficient.
84 fn set_momentum_coeff(&mut self, coeff: f32);
85
86 /// Gets the current momentum buffers (for debugging/analysis).
87 fn momentum_buffers(&self) -> &HashMap<String, Vec<f32>>;
88
89 /// Clears all momentum buffers (useful for fine-tuning).
90 fn clear_momentum(&mut self);
91}
92
93/// Trait for adaptive momentum optimizers (Adam, AdamW, RAdam, etc.).
94///
95/// These optimizers maintain both first and second moment estimates.
96pub trait AdaptiveMomentumOptimizer: MomentumOptimizer {
97 /// Gets the second moment decay coefficient (β2 in Adam).
98 fn variance_coeff(&self) -> f32;
99
100 /// Sets the second moment decay coefficient.
101 fn set_variance_coeff(&mut self, coeff: f32);
102
103 /// Gets the epsilon value for numerical stability.
104 fn epsilon(&self) -> f32;
105
106 /// Sets the epsilon value.
107 fn set_epsilon(&mut self, eps: f32);
108
109 /// Gets the current variance buffers (for debugging/analysis).
110 fn variance_buffers(&self) -> &HashMap<String, Vec<f32>>;
111
112 /// Clears variance buffers.
113 fn clear_variance(&mut self);
114
115 /// Applies bias correction to momentum and variance estimates.
116 fn apply_bias_correction(&self, momentum: f32, variance: f32, step: usize) -> (f32, f32);
117}
118
119/// Trait for classical momentum optimizers (SGD variants).
120pub trait ClassicalMomentumOptimizer: MomentumOptimizer {
121 /// Gets the dampening factor.
122 fn dampening(&self) -> f32;
123
124 /// Sets the dampening factor.
125 fn set_dampening(&mut self, dampening: f32);
126
127 /// Whether Nesterov momentum is enabled.
128 fn nesterov(&self) -> bool;
129
130 /// Enables or disables Nesterov momentum.
131 fn set_nesterov(&mut self, nesterov: bool);
132}
133
134/// Trait for second-order optimization methods.
135///
136/// These optimizers use curvature information (Hessian approximations).
137pub trait SecondOrderOptimizer: StatefulOptimizer {
138 /// The type used to represent curvature information.
139 type CurvatureInfo;
140
141 /// Updates the curvature approximation with new gradient information.
142 fn update_curvature(&mut self, gradients: &[Tensor]) -> Result<()>;
143
144 /// Gets the current curvature approximation.
145 fn curvature_info(&self) -> &Self::CurvatureInfo;
146
147 /// Applies the inverse Hessian approximation to compute search direction.
148 fn apply_inverse_hessian(&self, gradient: &Tensor) -> Result<Tensor>;
149
150 /// Gets the maximum number of curvature pairs stored (for L-BFGS).
151 fn history_size(&self) -> usize;
152}
153
154/// Trait for distributed optimization capabilities.
155///
156/// Provides interfaces for gradient synchronization and distributed training.
157pub trait DistributedOptimizer: Optimizer {
158 /// The communicator type used for distributed operations.
159 type Communicator;
160
161 /// Performs all-reduce operation on gradients.
162 fn all_reduce_gradients(&mut self, gradients: &mut [Tensor]) -> Result<()>;
163
164 /// Broadcasts parameters from rank 0 to all other ranks.
165 fn broadcast_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()>;
166
167 /// Gets the current rank in the distributed group.
168 fn rank(&self) -> usize;
169
170 /// Gets the total number of ranks in the distributed group.
171 fn world_size(&self) -> usize;
172
173 /// Synchronizes optimizer state across all ranks.
174 fn sync_state(&mut self) -> Result<()>;
175}
176
177/// Trait for optimizers with gradient compression capabilities.
178pub trait GradientCompressionOptimizer: DistributedOptimizer {
179 /// The compression method used.
180 type CompressionMethod;
181
182 /// Compresses gradients before communication.
183 fn compress_gradients(&self, gradients: &[Tensor]) -> Result<Vec<u8>>;
184
185 /// Decompresses received gradient data.
186 fn decompress_gradients(&self, data: &[u8]) -> Result<Vec<Tensor>>;
187
188 /// Gets the compression ratio achieved.
189 fn compression_ratio(&self) -> f32;
190
191 /// Sets the compression parameters.
192 fn set_compression_config(&mut self, config: Self::CompressionMethod);
193}
194
195/// Trait for federated learning optimizers.
196pub trait FederatedOptimizer: DistributedOptimizer {
197 /// Client information type.
198 type ClientInfo;
199
200 /// Aggregates model updates from multiple clients.
201 fn aggregate_updates(
202 &mut self,
203 updates: &[Tensor],
204 clients: &[Self::ClientInfo],
205 ) -> Result<Tensor>;
206
207 /// Selects clients for the next round of training.
208 fn select_clients(
209 &self,
210 available_clients: &[Self::ClientInfo],
211 num_clients: usize,
212 ) -> Vec<usize>;
213
214 /// Applies differential privacy to updates.
215 fn apply_differential_privacy(&mut self, update: &mut Tensor) -> Result<()>;
216}
217
218/// Trait for asynchronous optimization methods.
219pub trait AsyncOptimizer: DistributedOptimizer {
220 /// Applies delayed gradients with staleness compensation.
221 fn apply_delayed_gradients(&mut self, gradients: &[Tensor], staleness: usize) -> Result<()>;
222
223 /// Gets the maximum allowed staleness.
224 fn max_staleness(&self) -> usize;
225
226 /// Sets the staleness compensation method.
227 fn set_staleness_compensation(&mut self, method: StalenessCompensation);
228}
229
230/// Staleness compensation methods for asynchronous optimization.
231#[derive(Debug, Clone, Copy)]
232pub enum StalenessCompensation {
233 /// No compensation for staleness.
234 None,
235 /// Linear scaling by staleness factor.
236 Linear,
237 /// Exponential decay based on staleness.
238 Exponential,
239 /// Polynomial scaling with configurable degree.
240 Polynomial(f32),
241}
242
243/// Trait for hardware-specific optimizer optimizations.
244pub trait HardwareOptimizer: Optimizer {
245 /// The target hardware type.
246 type HardwareTarget;
247
248 /// Optimizes the optimizer for specific hardware.
249 fn optimize_for_hardware(&mut self, target: Self::HardwareTarget) -> Result<()>;
250
251 /// Gets hardware utilization statistics.
252 fn hardware_utilization(&self) -> HardwareStats;
253
254 /// Checks if the optimizer is compatible with the current hardware.
255 fn is_hardware_compatible(&self) -> bool;
256}
257
258/// Hardware utilization statistics.
259#[derive(Debug, Clone)]
260pub struct HardwareStats {
261 /// Memory bandwidth utilization (0.0 to 1.0).
262 pub memory_bandwidth_utilization: f32,
263 /// Compute utilization (0.0 to 1.0).
264 pub compute_utilization: f32,
265 /// Cache hit rate (0.0 to 1.0).
266 pub cache_hit_rate: f32,
267 /// FLOPS per second achieved.
268 pub flops_per_second: f64,
269}
270
271/// Trait for SIMD-optimized operations.
272pub trait SIMDOptimizer: HardwareOptimizer {
273 /// The SIMD instruction set being used.
274 type SIMDType;
275
276 /// Checks if SIMD operations are available.
277 fn simd_available(&self) -> bool;
278
279 /// Gets the SIMD vector width.
280 fn vector_width(&self) -> usize;
281
282 /// Applies SIMD-optimized parameter updates.
283 fn simd_update(&mut self, parameters: &mut [Tensor], gradients: &[Tensor]) -> Result<()>;
284}
285
286/// Trait for GPU-accelerated optimizers.
287pub trait GPUOptimizer: HardwareOptimizer {
288 /// The GPU compute capability.
289 type ComputeCapability;
290
291 /// Transfers optimizer state to GPU.
292 fn to_gpu(&mut self) -> Result<()>;
293
294 /// Transfers optimizer state to CPU.
295 fn to_cpu(&mut self) -> Result<()>;
296
297 /// Launches GPU kernels for parameter updates.
298 fn gpu_update(&mut self, parameters: &mut [Tensor], gradients: &[Tensor]) -> Result<()>;
299
300 /// Gets GPU memory usage.
301 fn gpu_memory_usage(&self) -> GPUMemoryStats;
302}
303
304/// GPU memory usage statistics.
305#[derive(Debug, Clone)]
306pub struct GPUMemoryStats {
307 /// Total GPU memory in bytes.
308 pub total_memory: usize,
309 /// Used GPU memory in bytes.
310 pub used_memory: usize,
311 /// Available GPU memory in bytes.
312 pub available_memory: usize,
313 /// Memory usage by optimizer state.
314 pub optimizer_memory: usize,
315}
316
317/// Trait for edge device optimized optimizers.
318pub trait EdgeOptimizer: HardwareOptimizer {
319 /// Power consumption statistics.
320 type PowerStats;
321
322 /// Optimizes for low power consumption.
323 fn optimize_for_power(&mut self) -> Result<()>;
324
325 /// Gets current power consumption statistics.
326 fn power_stats(&self) -> Self::PowerStats;
327
328 /// Reduces precision to save memory and power.
329 fn reduce_precision(&mut self, bits: u8) -> Result<()>;
330}
331
332/// Trait for meta-optimizers that wrap other optimizers.
333pub trait MetaOptimizer: Optimizer {
334 /// The base optimizer type being wrapped.
335 type BaseOptimizer: Optimizer;
336
337 /// Gets a reference to the base optimizer.
338 fn base_optimizer(&self) -> &Self::BaseOptimizer;
339
340 /// Gets a mutable reference to the base optimizer.
341 fn base_optimizer_mut(&mut self) -> &mut Self::BaseOptimizer;
342
343 /// Applies the meta-optimization strategy.
344 fn apply_meta_strategy(
345 &mut self,
346 parameters: &mut [Tensor],
347 gradients: &[Tensor],
348 ) -> Result<()>;
349}
350
351/// Trait for lookahead meta-optimizers.
352pub trait LookaheadOptimizer: MetaOptimizer {
353 /// Gets the lookahead step size (α).
354 fn lookahead_alpha(&self) -> f32;
355
356 /// Sets the lookahead step size.
357 fn set_lookahead_alpha(&mut self, alpha: f32);
358
359 /// Gets the lookahead update frequency (k).
360 fn lookahead_k(&self) -> usize;
361
362 /// Sets the lookahead update frequency.
363 fn set_lookahead_k(&mut self, k: usize);
364
365 /// Gets the slow weights (for debugging).
366 fn slow_weights(&self) -> &HashMap<String, Vec<f32>>;
367}
368
369/// Trait for scheduled optimizers with learning rate scheduling.
370pub trait ScheduledOptimizer: Optimizer {
371 /// The scheduler type.
372 type Scheduler;
373
374 /// Gets a reference to the scheduler.
375 fn scheduler(&self) -> &Self::Scheduler;
376
377 /// Gets a mutable reference to the scheduler.
378 fn scheduler_mut(&mut self) -> &mut Self::Scheduler;
379
380 /// Updates the learning rate based on the scheduler.
381 fn update_lr(&mut self) -> Result<()>;
382
383 /// Gets the current scheduled learning rate.
384 fn current_lr(&self) -> f32;
385}
386
387/// Trait for composite optimizers that combine multiple optimization strategies.
388pub trait CompositeOptimizer: Optimizer {
389 /// The component optimizer types.
390 type Components;
391
392 /// Gets references to all component optimizers.
393 fn components(&self) -> &Self::Components;
394
395 /// Gets mutable references to all component optimizers.
396 fn components_mut(&mut self) -> &mut Self::Components;
397
398 /// Applies updates from all component optimizers.
399 fn apply_composite_update(
400 &mut self,
401 parameters: &mut [Tensor],
402 gradients: &[Tensor],
403 ) -> Result<()>;
404
405 /// Gets the weight assigned to each component.
406 fn component_weights(&self) -> Vec<f32>;
407
408 /// Sets the weights for each component.
409 fn set_component_weights(&mut self, weights: Vec<f32>) -> Result<()>;
410}
411
412/// Optimizer factory trait for creating optimizers with different configurations.
413pub trait OptimizerFactory {
414 /// The optimizer type produced by this factory.
415 type Optimizer: Optimizer;
416
417 /// The configuration type for the optimizer.
418 type Config;
419
420 /// Creates a new optimizer with the given configuration.
421 fn create(&self, config: Self::Config) -> Result<Self::Optimizer>;
422
423 /// Lists all available optimizer variants.
424 fn available_variants(&self) -> Vec<&'static str>;
425
426 /// Creates an optimizer by name with default configuration.
427 fn create_by_name(&self, name: &str) -> Result<Self::Optimizer>;
428}
429
430/// Trait for optimizers that can be serialized and restored.
431pub trait SerializableOptimizer: Optimizer {
432 /// Serializes the optimizer to bytes.
433 fn serialize(&self) -> Result<Vec<u8>>;
434
435 /// Deserializes an optimizer from bytes.
436 fn deserialize(data: &[u8]) -> Result<Self>
437 where
438 Self: Sized;
439
440 /// Gets the serialization format version.
441 fn version(&self) -> u32;
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 #[test]
449 fn test_staleness_compensation() {
450 let compensation = StalenessCompensation::Linear;
451 match compensation {
452 StalenessCompensation::Linear => assert!(true),
453 _ => assert!(false),
454 }
455 }
456
457 #[test]
458 fn test_hardware_stats() {
459 let stats = HardwareStats {
460 memory_bandwidth_utilization: 0.8,
461 compute_utilization: 0.9,
462 cache_hit_rate: 0.95,
463 flops_per_second: 1e12,
464 };
465
466 assert_eq!(stats.memory_bandwidth_utilization, 0.8);
467 assert_eq!(stats.compute_utilization, 0.9);
468 assert_eq!(stats.cache_hit_rate, 0.95);
469 assert_eq!(stats.flops_per_second, 1e12);
470 }
471
472 #[test]
473 fn test_gpu_memory_stats() {
474 let stats = GPUMemoryStats {
475 total_memory: 16 * 1024 * 1024 * 1024, // 16 GB
476 used_memory: 8 * 1024 * 1024 * 1024, // 8 GB
477 available_memory: 8 * 1024 * 1024 * 1024, // 8 GB
478 optimizer_memory: 1 * 1024 * 1024 * 1024, // 1 GB
479 };
480
481 assert_eq!(stats.total_memory, 16 * 1024 * 1024 * 1024);
482 assert_eq!(
483 stats.used_memory + stats.available_memory,
484 stats.total_memory
485 );
486 assert!(stats.optimizer_memory <= stats.used_memory);
487 }
488}