Skip to main content

trustformers_optim/
eva.rs

1//! # EVA Optimizer
2//!
3//! EVA (Exponential Moving Average with Variance Adaptation) is a state-of-the-art optimizer
4//! that adapts the learning rate based on the variance of gradient estimates.
5//!
6//! ## Key Features
7//!
8//! - **Adaptive Learning Rate**: Uses gradient variance to adapt learning rate
9//! - **Exponential Moving Averages**: Maintains momentum and variance estimates
10//! - **Robustness**: More stable than Adam in certain scenarios
11//! - **Computational Efficiency**: Low overhead compared to second-order methods
12//!
13//! ## Algorithm
14//!
15//! EVA updates parameters using:
16//! 1. Exponential moving average of gradients (momentum)
17//! 2. Exponential moving average of squared gradients (variance)
18//! 3. Variance-adapted learning rate scaling
19//! 4. Optional bias correction
20//!
21//! ## Usage Example
22//!
23//! ```rust,no_run
24//! use trustformers_optim::EVA;
25//!
26//! let mut optimizer = EVA::new(
27//!     1e-3,   // learning_rate
28//!     0.9,    // beta1
29//!     0.999,  // beta2
30//!     1e-8,   // epsilon
31//!     0.01,   // weight_decay
32//!     true,   // variance_adaptation
33//! );
34//! ```
35
36use crate::common::{OptimizerState, StateMemoryStats};
37use crate::traits::StatefulOptimizer;
38use std::collections::HashMap;
39use trustformers_core::errors::{Result, TrustformersError};
40use trustformers_core::tensor::Tensor;
41use trustformers_core::traits::Optimizer;
42
43/// Configuration for EVA optimizer.
44#[derive(Debug, Clone)]
45pub struct EVAConfig {
46    /// Learning rate
47    pub lr: f32,
48    /// First moment coefficient
49    pub beta1: f32,
50    /// Second moment coefficient
51    pub beta2: f32,
52    /// Term added for numerical stability
53    pub eps: f32,
54    /// Weight decay (L2 penalty)
55    pub weight_decay: f32,
56    /// Whether to use variance adaptation
57    pub variance_adaptation: bool,
58    /// Whether to use bias correction
59    pub bias_correction: bool,
60    /// Variance adaptation strength
61    pub adaptation_strength: f32,
62}
63
64impl Default for EVAConfig {
65    fn default() -> Self {
66        Self {
67            lr: 1e-3,
68            beta1: 0.9,
69            beta2: 0.999,
70            eps: 1e-8,
71            weight_decay: 0.01,
72            variance_adaptation: true,
73            bias_correction: true,
74            adaptation_strength: 1.0,
75        }
76    }
77}
78
79/// EVA (Exponential Moving Average with Variance Adaptation) optimizer.
80#[derive(Debug)]
81pub struct EVA {
82    config: EVAConfig,
83    state: OptimizerState,
84    exp_avg: HashMap<String, Vec<f32>>,
85    exp_avg_sq: HashMap<String, Vec<f32>>,
86    var_adaptation: HashMap<String, Vec<f32>>,
87    step_count: usize,
88}
89
90impl EVA {
91    /// Creates a new EVA optimizer with default configuration.
92    pub fn new(
93        lr: f32,
94        beta1: f32,
95        beta2: f32,
96        eps: f32,
97        weight_decay: f32,
98        variance_adaptation: bool,
99    ) -> Self {
100        let config = EVAConfig {
101            lr,
102            beta1,
103            beta2,
104            eps,
105            weight_decay,
106            variance_adaptation,
107            bias_correction: true,
108            adaptation_strength: 1.0,
109        };
110
111        Self::with_config(config)
112    }
113
114    /// Creates a new EVA optimizer with custom configuration.
115    pub fn with_config(config: EVAConfig) -> Self {
116        Self {
117            config,
118            state: OptimizerState::new(),
119            exp_avg: HashMap::new(),
120            exp_avg_sq: HashMap::new(),
121            var_adaptation: HashMap::new(),
122            step_count: 0,
123        }
124    }
125
126    /// Convenience constructor for EVA with AdamW-like settings.
127    pub fn adamw_like(lr: f32, weight_decay: f32) -> Self {
128        Self::new(lr, 0.9, 0.999, 1e-8, weight_decay, true)
129    }
130
131    /// Convenience constructor for EVA with variance adaptation disabled.
132    pub fn no_variance_adaptation(lr: f32, beta1: f32, beta2: f32, eps: f32) -> Self {
133        Self::new(lr, beta1, beta2, eps, 0.0, false)
134    }
135
136    /// Gets the current learning rate.
137    pub fn get_lr(&self) -> f32 {
138        self.config.lr
139    }
140
141    /// Sets the learning rate.
142    pub fn set_lr(&mut self, lr: f32) {
143        self.config.lr = lr;
144    }
145
146    /// Gets current optimizer configuration.
147    pub fn config(&self) -> &EVAConfig {
148        &self.config
149    }
150
151    /// Gets memory statistics for the optimizer state.
152    pub fn memory_stats(&self) -> StateMemoryStats {
153        let mut total_parameters = 0;
154        #[allow(dead_code)]
155        let mut _total_buffers = 0;
156        #[allow(unused_assignments)]
157        for buffer in self.exp_avg.values() {
158            total_parameters += buffer.len();
159            _total_buffers += 1;
160        }
161
162        for buffer in self.exp_avg_sq.values() {
163            total_parameters += buffer.len();
164            _total_buffers += 1;
165        }
166
167        if self.config.variance_adaptation {
168            for buffer in self.var_adaptation.values() {
169                total_parameters += buffer.len();
170                _total_buffers += 1;
171            }
172        }
173
174        StateMemoryStats {
175            momentum_elements: total_parameters,
176            variance_elements: total_parameters,
177            third_moment_elements: if self.config.variance_adaptation {
178                total_parameters
179            } else {
180                0
181            },
182            total_bytes: total_parameters * 4, // f32 = 4 bytes
183            num_parameters: total_parameters,
184        }
185    }
186
187    /// Computes variance adaptation factor.
188    #[allow(dead_code)]
189    fn compute_variance_adaptation(&self, grad_var: f32, step: usize) -> f32 {
190        if !self.config.variance_adaptation || step == 0 {
191            return 1.0;
192        }
193
194        let adaptation = (grad_var + self.config.eps).sqrt();
195        let strength = self.config.adaptation_strength;
196
197        // Apply strength and clamp to reasonable range
198        let factor = 1.0 / (1.0 + strength * adaptation);
199        factor.clamp(0.1, 2.0)
200    }
201}
202
203impl Optimizer for EVA {
204    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
205        self.step_count += 1;
206
207        match (parameter, grad) {
208            (Tensor::F32(param), Tensor::F32(grad_data)) => {
209                let param_id = format!("{:p}", param.as_ptr());
210                let size = grad_data.len();
211
212                // Initialize state if needed
213                let exp_avg =
214                    self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
215                let exp_avg_sq =
216                    self.exp_avg_sq.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
217                let mut var_adapt = if self.config.variance_adaptation {
218                    Some(
219                        self.var_adaptation
220                            .entry(param_id.clone())
221                            .or_insert_with(|| vec![0.0; size]),
222                    )
223                } else {
224                    None
225                };
226
227                // Check buffer sizes
228                if exp_avg.len() != size || exp_avg_sq.len() != size {
229                    return Err(TrustformersError::tensor_op_error(
230                        "EVA buffer size mismatch",
231                        "EVA::update",
232                    ));
233                }
234
235                if let Some(ref va) = var_adapt {
236                    if va.len() != size {
237                        return Err(TrustformersError::tensor_op_error(
238                            "EVA variance adaptation buffer size mismatch",
239                            "EVA::update",
240                        ));
241                    }
242                }
243
244                // Compute bias correction factors
245                let bias_correction1 = if self.config.bias_correction {
246                    1.0 - self.config.beta1.powi(self.step_count as i32)
247                } else {
248                    1.0
249                };
250
251                let bias_correction2 = if self.config.bias_correction {
252                    1.0 - self.config.beta2.powi(self.step_count as i32)
253                } else {
254                    1.0
255                };
256
257                // Compute gradient variance for adaptation
258                let grad_var = if self.config.variance_adaptation {
259                    let mean_grad = grad_data.iter().sum::<f32>() / size as f32;
260                    grad_data.iter().map(|&g| (g - mean_grad).powi(2)).sum::<f32>() / size as f32
261                } else {
262                    0.0
263                };
264
265                let variance_factor = if self.config.variance_adaptation && self.step_count > 0 {
266                    let adaptation = (grad_var + self.config.eps).sqrt();
267                    let strength = self.config.adaptation_strength;
268                    let factor = 1.0 / (1.0 + strength * adaptation);
269                    factor.clamp(0.1, 2.0)
270                } else {
271                    1.0
272                };
273
274                // Update parameters
275                for (i, ((&g, p), (m, v))) in grad_data
276                    .iter()
277                    .zip(param.iter_mut())
278                    .zip(exp_avg.iter_mut().zip(exp_avg_sq.iter_mut()))
279                    .enumerate()
280                {
281                    // Apply weight decay
282                    let grad_with_decay = if self.config.weight_decay > 0.0 {
283                        g + self.config.weight_decay * (*p)
284                    } else {
285                        g
286                    };
287
288                    // Update biased first moment estimate
289                    *m = self.config.beta1 * (*m) + (1.0 - self.config.beta1) * grad_with_decay;
290
291                    // Update biased second moment estimate
292                    *v = self.config.beta2 * (*v)
293                        + (1.0 - self.config.beta2) * grad_with_decay * grad_with_decay;
294
295                    // Update variance adaptation if enabled
296                    if let Some(ref mut va) = var_adapt {
297                        va[i] = 0.9 * va[i] + 0.1 * grad_with_decay.abs();
298                    }
299
300                    // Compute bias-corrected estimates
301                    let m_hat = *m / bias_correction1;
302                    let v_hat = *v / bias_correction2;
303
304                    // Apply variance adaptation
305                    let adapted_lr = self.config.lr * variance_factor;
306
307                    // Update parameter
308                    *p -= adapted_lr * m_hat / (v_hat.sqrt() + self.config.eps);
309                }
310
311                Ok(())
312            },
313            _ => Err(TrustformersError::tensor_op_error(
314                "EVA optimizer only supports F32 tensors",
315                "EVA::update",
316            )),
317        }
318    }
319
320    fn zero_grad(&mut self) {
321        // EVA doesn't accumulate gradients, so this is a no-op
322    }
323
324    fn step(&mut self) {
325        // Update is called per parameter, so step is a no-op
326    }
327
328    fn get_lr(&self) -> f32 {
329        self.config.lr
330    }
331
332    fn set_lr(&mut self, lr: f32) {
333        self.config.lr = lr;
334    }
335}
336
337impl StatefulOptimizer for EVA {
338    type Config = EVAConfig;
339    type State = OptimizerState;
340
341    fn state(&self) -> &OptimizerState {
342        &self.state
343    }
344
345    fn state_mut(&mut self) -> &mut OptimizerState {
346        &mut self.state
347    }
348
349    fn config(&self) -> &Self::Config {
350        &self.config
351    }
352
353    fn memory_usage(&self) -> StateMemoryStats {
354        self.memory_stats()
355    }
356
357    fn reset_state(&mut self) {
358        self.exp_avg.clear();
359        self.exp_avg_sq.clear();
360        self.var_adaptation.clear();
361        self.step_count = 0;
362        self.state = OptimizerState::new();
363    }
364
365    fn num_parameters(&self) -> usize {
366        self.exp_avg.values().map(|v| v.len()).sum()
367    }
368
369    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
370        let mut dict = HashMap::new();
371
372        for (key, value) in &self.exp_avg {
373            dict.insert(format!("exp_avg_{}", key), Tensor::new(value.clone())?);
374        }
375
376        for (key, value) in &self.exp_avg_sq {
377            dict.insert(format!("exp_avg_sq_{}", key), Tensor::new(value.clone())?);
378        }
379
380        if self.config.variance_adaptation {
381            for (key, value) in &self.var_adaptation {
382                dict.insert(
383                    format!("var_adaptation_{}", key),
384                    Tensor::new(value.clone())?,
385                );
386            }
387        }
388
389        dict.insert(
390            "step_count".to_string(),
391            Tensor::new(vec![self.step_count as f32])?,
392        );
393
394        Ok(dict)
395    }
396
397    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
398        // Load step count
399        if let Some(Tensor::F32(data)) = state_dict.get("step_count") {
400            if !data.is_empty() {
401                self.step_count = data[0] as usize;
402            }
403        }
404
405        // Load exp_avg
406        for (key, value) in &state_dict {
407            if let Some(param_key) = key.strip_prefix("exp_avg_") {
408                if let Tensor::F32(data) = value {
409                    self.exp_avg.insert(param_key.to_string(), data.as_slice().unwrap().to_vec());
410                }
411            }
412        }
413
414        // Load exp_avg_sq
415        for (key, value) in &state_dict {
416            if let Some(param_key) = key.strip_prefix("exp_avg_sq_") {
417                if let Tensor::F32(data) = value {
418                    self.exp_avg_sq
419                        .insert(param_key.to_string(), data.as_slice().unwrap().to_vec());
420                }
421            }
422        }
423
424        // Load variance adaptation
425        if self.config.variance_adaptation {
426            for (key, value) in &state_dict {
427                if let Some(param_key) = key.strip_prefix("var_adaptation_") {
428                    if let Tensor::F32(data) = value {
429                        self.var_adaptation
430                            .insert(param_key.to_string(), data.as_slice().unwrap().to_vec());
431                    }
432                }
433            }
434        }
435
436        Ok(())
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use trustformers_core::tensor::Tensor;
444
445    #[test]
446    fn test_eva_creation() {
447        let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
448        assert_eq!(optimizer.get_lr(), 1e-3);
449        assert_eq!(optimizer.config().beta1, 0.9);
450        assert_eq!(optimizer.config().beta2, 0.999);
451        assert_eq!(optimizer.config().eps, 1e-8);
452        assert_eq!(optimizer.config().weight_decay, 0.01);
453        assert!(optimizer.config().variance_adaptation);
454    }
455
456    #[test]
457    fn test_eva_adamw_like() {
458        let optimizer = EVA::adamw_like(1e-3, 0.01);
459        assert_eq!(optimizer.get_lr(), 1e-3);
460        assert_eq!(optimizer.config().weight_decay, 0.01);
461        assert!(optimizer.config().variance_adaptation);
462    }
463
464    #[test]
465    fn test_eva_no_variance_adaptation() {
466        let optimizer = EVA::no_variance_adaptation(1e-3, 0.9, 0.999, 1e-8);
467        assert_eq!(optimizer.get_lr(), 1e-3);
468        assert!(!optimizer.config().variance_adaptation);
469    }
470
471    #[test]
472    fn test_eva_lr_setter() {
473        let mut optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
474        optimizer.set_lr(2e-3);
475        assert_eq!(optimizer.get_lr(), 2e-3);
476    }
477
478    #[test]
479    fn test_eva_memory_stats() {
480        let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
481        let stats = optimizer.memory_stats();
482        assert_eq!(stats.num_parameters, 0);
483        assert_eq!(stats.total_bytes, 0);
484    }
485
486    #[test]
487    fn test_eva_variance_adaptation() {
488        let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
489        let factor = optimizer.compute_variance_adaptation(0.1, 1);
490        assert!(factor > 0.1 && factor < 2.0);
491    }
492
493    #[test]
494    fn test_eva_state_dict() {
495        let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
496        let state_dict = optimizer.state_dict();
497        assert!(state_dict.unwrap().contains_key("step_count"));
498    }
499
500    #[test]
501    fn test_eva_load_state_dict() {
502        let mut optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
503        let mut state_dict = HashMap::new();
504        state_dict.insert("step_count".to_string(), Tensor::new(vec![10.0]).unwrap());
505
506        optimizer.load_state_dict(state_dict).unwrap();
507        assert_eq!(optimizer.step_count, 10);
508    }
509}