1use crate::common::{OptimizerState, StateMemoryStats};
37use crate::traits::StatefulOptimizer;
38use std::collections::HashMap;
39use trustformers_core::errors::{Result, TrustformersError};
40use trustformers_core::tensor::Tensor;
41use trustformers_core::traits::Optimizer;
42
43#[derive(Debug, Clone)]
45pub struct EVAConfig {
46 pub lr: f32,
48 pub beta1: f32,
50 pub beta2: f32,
52 pub eps: f32,
54 pub weight_decay: f32,
56 pub variance_adaptation: bool,
58 pub bias_correction: bool,
60 pub adaptation_strength: f32,
62}
63
64impl Default for EVAConfig {
65 fn default() -> Self {
66 Self {
67 lr: 1e-3,
68 beta1: 0.9,
69 beta2: 0.999,
70 eps: 1e-8,
71 weight_decay: 0.01,
72 variance_adaptation: true,
73 bias_correction: true,
74 adaptation_strength: 1.0,
75 }
76 }
77}
78
79#[derive(Debug)]
81pub struct EVA {
82 config: EVAConfig,
83 state: OptimizerState,
84 exp_avg: HashMap<String, Vec<f32>>,
85 exp_avg_sq: HashMap<String, Vec<f32>>,
86 var_adaptation: HashMap<String, Vec<f32>>,
87 step_count: usize,
88}
89
90impl EVA {
91 pub fn new(
93 lr: f32,
94 beta1: f32,
95 beta2: f32,
96 eps: f32,
97 weight_decay: f32,
98 variance_adaptation: bool,
99 ) -> Self {
100 let config = EVAConfig {
101 lr,
102 beta1,
103 beta2,
104 eps,
105 weight_decay,
106 variance_adaptation,
107 bias_correction: true,
108 adaptation_strength: 1.0,
109 };
110
111 Self::with_config(config)
112 }
113
114 pub fn with_config(config: EVAConfig) -> Self {
116 Self {
117 config,
118 state: OptimizerState::new(),
119 exp_avg: HashMap::new(),
120 exp_avg_sq: HashMap::new(),
121 var_adaptation: HashMap::new(),
122 step_count: 0,
123 }
124 }
125
126 pub fn adamw_like(lr: f32, weight_decay: f32) -> Self {
128 Self::new(lr, 0.9, 0.999, 1e-8, weight_decay, true)
129 }
130
131 pub fn no_variance_adaptation(lr: f32, beta1: f32, beta2: f32, eps: f32) -> Self {
133 Self::new(lr, beta1, beta2, eps, 0.0, false)
134 }
135
136 pub fn get_lr(&self) -> f32 {
138 self.config.lr
139 }
140
141 pub fn set_lr(&mut self, lr: f32) {
143 self.config.lr = lr;
144 }
145
146 pub fn config(&self) -> &EVAConfig {
148 &self.config
149 }
150
151 pub fn memory_stats(&self) -> StateMemoryStats {
153 let mut total_parameters = 0;
154 #[allow(dead_code)]
155 let mut _total_buffers = 0;
156 #[allow(unused_assignments)]
157 for buffer in self.exp_avg.values() {
158 total_parameters += buffer.len();
159 _total_buffers += 1;
160 }
161
162 for buffer in self.exp_avg_sq.values() {
163 total_parameters += buffer.len();
164 _total_buffers += 1;
165 }
166
167 if self.config.variance_adaptation {
168 for buffer in self.var_adaptation.values() {
169 total_parameters += buffer.len();
170 _total_buffers += 1;
171 }
172 }
173
174 StateMemoryStats {
175 momentum_elements: total_parameters,
176 variance_elements: total_parameters,
177 third_moment_elements: if self.config.variance_adaptation {
178 total_parameters
179 } else {
180 0
181 },
182 total_bytes: total_parameters * 4, num_parameters: total_parameters,
184 }
185 }
186
187 #[allow(dead_code)]
189 fn compute_variance_adaptation(&self, grad_var: f32, step: usize) -> f32 {
190 if !self.config.variance_adaptation || step == 0 {
191 return 1.0;
192 }
193
194 let adaptation = (grad_var + self.config.eps).sqrt();
195 let strength = self.config.adaptation_strength;
196
197 let factor = 1.0 / (1.0 + strength * adaptation);
199 factor.clamp(0.1, 2.0)
200 }
201}
202
203impl Optimizer for EVA {
204 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
205 self.step_count += 1;
206
207 match (parameter, grad) {
208 (Tensor::F32(param), Tensor::F32(grad_data)) => {
209 let param_id = format!("{:p}", param.as_ptr());
210 let size = grad_data.len();
211
212 let exp_avg =
214 self.exp_avg.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
215 let exp_avg_sq =
216 self.exp_avg_sq.entry(param_id.clone()).or_insert_with(|| vec![0.0; size]);
217 let mut var_adapt = if self.config.variance_adaptation {
218 Some(
219 self.var_adaptation
220 .entry(param_id.clone())
221 .or_insert_with(|| vec![0.0; size]),
222 )
223 } else {
224 None
225 };
226
227 if exp_avg.len() != size || exp_avg_sq.len() != size {
229 return Err(TrustformersError::tensor_op_error(
230 "EVA buffer size mismatch",
231 "EVA::update",
232 ));
233 }
234
235 if let Some(ref va) = var_adapt {
236 if va.len() != size {
237 return Err(TrustformersError::tensor_op_error(
238 "EVA variance adaptation buffer size mismatch",
239 "EVA::update",
240 ));
241 }
242 }
243
244 let bias_correction1 = if self.config.bias_correction {
246 1.0 - self.config.beta1.powi(self.step_count as i32)
247 } else {
248 1.0
249 };
250
251 let bias_correction2 = if self.config.bias_correction {
252 1.0 - self.config.beta2.powi(self.step_count as i32)
253 } else {
254 1.0
255 };
256
257 let grad_var = if self.config.variance_adaptation {
259 let mean_grad = grad_data.iter().sum::<f32>() / size as f32;
260 grad_data.iter().map(|&g| (g - mean_grad).powi(2)).sum::<f32>() / size as f32
261 } else {
262 0.0
263 };
264
265 let variance_factor = if self.config.variance_adaptation && self.step_count > 0 {
266 let adaptation = (grad_var + self.config.eps).sqrt();
267 let strength = self.config.adaptation_strength;
268 let factor = 1.0 / (1.0 + strength * adaptation);
269 factor.clamp(0.1, 2.0)
270 } else {
271 1.0
272 };
273
274 for (i, ((&g, p), (m, v))) in grad_data
276 .iter()
277 .zip(param.iter_mut())
278 .zip(exp_avg.iter_mut().zip(exp_avg_sq.iter_mut()))
279 .enumerate()
280 {
281 let grad_with_decay = if self.config.weight_decay > 0.0 {
283 g + self.config.weight_decay * (*p)
284 } else {
285 g
286 };
287
288 *m = self.config.beta1 * (*m) + (1.0 - self.config.beta1) * grad_with_decay;
290
291 *v = self.config.beta2 * (*v)
293 + (1.0 - self.config.beta2) * grad_with_decay * grad_with_decay;
294
295 if let Some(ref mut va) = var_adapt {
297 va[i] = 0.9 * va[i] + 0.1 * grad_with_decay.abs();
298 }
299
300 let m_hat = *m / bias_correction1;
302 let v_hat = *v / bias_correction2;
303
304 let adapted_lr = self.config.lr * variance_factor;
306
307 *p -= adapted_lr * m_hat / (v_hat.sqrt() + self.config.eps);
309 }
310
311 Ok(())
312 },
313 _ => Err(TrustformersError::tensor_op_error(
314 "EVA optimizer only supports F32 tensors",
315 "EVA::update",
316 )),
317 }
318 }
319
320 fn zero_grad(&mut self) {
321 }
323
324 fn step(&mut self) {
325 }
327
328 fn get_lr(&self) -> f32 {
329 self.config.lr
330 }
331
332 fn set_lr(&mut self, lr: f32) {
333 self.config.lr = lr;
334 }
335}
336
337impl StatefulOptimizer for EVA {
338 type Config = EVAConfig;
339 type State = OptimizerState;
340
341 fn state(&self) -> &OptimizerState {
342 &self.state
343 }
344
345 fn state_mut(&mut self) -> &mut OptimizerState {
346 &mut self.state
347 }
348
349 fn config(&self) -> &Self::Config {
350 &self.config
351 }
352
353 fn memory_usage(&self) -> StateMemoryStats {
354 self.memory_stats()
355 }
356
357 fn reset_state(&mut self) {
358 self.exp_avg.clear();
359 self.exp_avg_sq.clear();
360 self.var_adaptation.clear();
361 self.step_count = 0;
362 self.state = OptimizerState::new();
363 }
364
365 fn num_parameters(&self) -> usize {
366 self.exp_avg.values().map(|v| v.len()).sum()
367 }
368
369 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
370 let mut dict = HashMap::new();
371
372 for (key, value) in &self.exp_avg {
373 dict.insert(format!("exp_avg_{}", key), Tensor::new(value.clone())?);
374 }
375
376 for (key, value) in &self.exp_avg_sq {
377 dict.insert(format!("exp_avg_sq_{}", key), Tensor::new(value.clone())?);
378 }
379
380 if self.config.variance_adaptation {
381 for (key, value) in &self.var_adaptation {
382 dict.insert(
383 format!("var_adaptation_{}", key),
384 Tensor::new(value.clone())?,
385 );
386 }
387 }
388
389 dict.insert(
390 "step_count".to_string(),
391 Tensor::new(vec![self.step_count as f32])?,
392 );
393
394 Ok(dict)
395 }
396
397 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
398 if let Some(Tensor::F32(data)) = state_dict.get("step_count") {
400 if !data.is_empty() {
401 self.step_count = data[0] as usize;
402 }
403 }
404
405 for (key, value) in &state_dict {
407 if let Some(param_key) = key.strip_prefix("exp_avg_") {
408 if let Tensor::F32(data) = value {
409 self.exp_avg.insert(param_key.to_string(), data.as_slice().unwrap().to_vec());
410 }
411 }
412 }
413
414 for (key, value) in &state_dict {
416 if let Some(param_key) = key.strip_prefix("exp_avg_sq_") {
417 if let Tensor::F32(data) = value {
418 self.exp_avg_sq
419 .insert(param_key.to_string(), data.as_slice().unwrap().to_vec());
420 }
421 }
422 }
423
424 if self.config.variance_adaptation {
426 for (key, value) in &state_dict {
427 if let Some(param_key) = key.strip_prefix("var_adaptation_") {
428 if let Tensor::F32(data) = value {
429 self.var_adaptation
430 .insert(param_key.to_string(), data.as_slice().unwrap().to_vec());
431 }
432 }
433 }
434 }
435
436 Ok(())
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use trustformers_core::tensor::Tensor;
444
445 #[test]
446 fn test_eva_creation() {
447 let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
448 assert_eq!(optimizer.get_lr(), 1e-3);
449 assert_eq!(optimizer.config().beta1, 0.9);
450 assert_eq!(optimizer.config().beta2, 0.999);
451 assert_eq!(optimizer.config().eps, 1e-8);
452 assert_eq!(optimizer.config().weight_decay, 0.01);
453 assert!(optimizer.config().variance_adaptation);
454 }
455
456 #[test]
457 fn test_eva_adamw_like() {
458 let optimizer = EVA::adamw_like(1e-3, 0.01);
459 assert_eq!(optimizer.get_lr(), 1e-3);
460 assert_eq!(optimizer.config().weight_decay, 0.01);
461 assert!(optimizer.config().variance_adaptation);
462 }
463
464 #[test]
465 fn test_eva_no_variance_adaptation() {
466 let optimizer = EVA::no_variance_adaptation(1e-3, 0.9, 0.999, 1e-8);
467 assert_eq!(optimizer.get_lr(), 1e-3);
468 assert!(!optimizer.config().variance_adaptation);
469 }
470
471 #[test]
472 fn test_eva_lr_setter() {
473 let mut optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
474 optimizer.set_lr(2e-3);
475 assert_eq!(optimizer.get_lr(), 2e-3);
476 }
477
478 #[test]
479 fn test_eva_memory_stats() {
480 let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
481 let stats = optimizer.memory_stats();
482 assert_eq!(stats.num_parameters, 0);
483 assert_eq!(stats.total_bytes, 0);
484 }
485
486 #[test]
487 fn test_eva_variance_adaptation() {
488 let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
489 let factor = optimizer.compute_variance_adaptation(0.1, 1);
490 assert!(factor > 0.1 && factor < 2.0);
491 }
492
493 #[test]
494 fn test_eva_state_dict() {
495 let optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
496 let state_dict = optimizer.state_dict();
497 assert!(state_dict.unwrap().contains_key("step_count"));
498 }
499
500 #[test]
501 fn test_eva_load_state_dict() {
502 let mut optimizer = EVA::new(1e-3, 0.9, 0.999, 1e-8, 0.01, true);
503 let mut state_dict = HashMap::new();
504 state_dict.insert("step_count".to_string(), Tensor::new(vec![10.0]).unwrap());
505
506 optimizer.load_state_dict(state_dict).unwrap();
507 assert_eq!(optimizer.step_count, 10);
508 }
509}