1use crate::{
22 common::{BiasCorrection, OptimizerState, ParameterUpdate, StateMemoryStats},
23 traits::StatefulOptimizer,
24};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct NovoGradConfig {
32 pub learning_rate: f32,
34 pub beta1: f32,
36 pub beta2: f32,
38 pub epsilon: f32,
40 pub weight_decay: f32,
42 pub grad_clipping: Option<f32>,
44 pub bias_correction: bool,
46 pub adaptive_weight_decay: bool,
48 pub memory_factor: f32,
50 pub layer_wise_adaptation: bool,
52}
53
54impl Default for NovoGradConfig {
55 fn default() -> Self {
56 Self {
57 learning_rate: 1e-3,
58 beta1: 0.95, beta2: 0.98, epsilon: 1e-8,
61 weight_decay: 0.0,
62 grad_clipping: Some(1.0),
63 bias_correction: true,
64 adaptive_weight_decay: true,
65 memory_factor: 0.8,
66 layer_wise_adaptation: true,
67 }
68 }
69}
70
71impl NovoGradConfig {
72 pub fn for_large_language_models() -> Self {
74 Self {
75 learning_rate: 1e-3,
76 beta1: 0.95,
77 beta2: 0.999, epsilon: 1e-6, weight_decay: 1e-2,
80 grad_clipping: Some(1.0),
81 bias_correction: true,
82 adaptive_weight_decay: true,
83 memory_factor: 0.9, layer_wise_adaptation: true,
85 }
86 }
87
88 pub fn for_vision_models() -> Self {
90 Self {
91 learning_rate: 1e-3,
92 beta1: 0.9, beta2: 0.999,
94 epsilon: 1e-8,
95 weight_decay: 1e-4,
96 grad_clipping: Some(2.0), bias_correction: true,
98 adaptive_weight_decay: false, memory_factor: 0.7,
100 layer_wise_adaptation: false, }
102 }
103
104 pub fn for_memory_constrained() -> Self {
106 Self {
107 learning_rate: 1e-3,
108 beta1: 0.95,
109 beta2: 0.98,
110 epsilon: 1e-8,
111 weight_decay: 0.0,
112 grad_clipping: Some(1.0),
113 bias_correction: false, adaptive_weight_decay: false,
115 memory_factor: 1.0, layer_wise_adaptation: false,
117 }
118 }
119
120 pub fn for_scientific_computing() -> Self {
122 Self {
123 learning_rate: 1e-4, beta1: 0.99, beta2: 0.999,
126 epsilon: 1e-10, weight_decay: 1e-6, grad_clipping: Some(0.5), bias_correction: true,
130 adaptive_weight_decay: true,
131 memory_factor: 0.8,
132 layer_wise_adaptation: true,
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct NovoGrad {
140 config: NovoGradConfig,
141 state: OptimizerState,
142 layer_second_moments: HashMap<String, f32>,
144 layer_grad_norms: HashMap<String, f32>,
146 layer_lr_factors: HashMap<String, f32>,
148 current_step: usize,
150 total_parameters: usize,
152}
153
154impl NovoGrad {
155 pub fn new(config: NovoGradConfig) -> Self {
157 Self {
158 config,
159 state: OptimizerState::new(),
160 layer_second_moments: HashMap::new(),
161 layer_grad_norms: HashMap::new(),
162 layer_lr_factors: HashMap::new(),
163 current_step: 0,
164 total_parameters: 0,
165 }
166 }
167
168 pub fn for_large_language_models() -> Self {
170 Self::new(NovoGradConfig::for_large_language_models())
171 }
172
173 pub fn for_vision_models() -> Self {
175 Self::new(NovoGradConfig::for_vision_models())
176 }
177
178 pub fn for_memory_constrained() -> Self {
180 Self::new(NovoGradConfig::for_memory_constrained())
181 }
182
183 pub fn for_scientific_computing() -> Self {
185 Self::new(NovoGradConfig::for_scientific_computing())
186 }
187
188 fn compute_layer_grad_norm(&self, gradient: &[f32]) -> f32 {
190 let grad_norm_squared: f32 = gradient.iter().map(|g| g * g).sum();
191 grad_norm_squared.sqrt()
192 }
193
194 fn compute_adaptive_lr(&mut self, layer_id: &str, grad_norm: f32) -> f32 {
196 if !self.config.layer_wise_adaptation {
197 return self.config.learning_rate;
198 }
199
200 let base_lr = self.config.learning_rate;
202 let prev_norm = self.layer_grad_norms.get(layer_id).copied().unwrap_or(1.0);
203
204 let norm_ratio = if prev_norm > 1e-8 { grad_norm / prev_norm } else { 1.0 };
206
207 let adaptation_factor = if norm_ratio > 1.2 {
209 0.8 } else if norm_ratio < 0.8 {
211 1.1 } else {
213 1.0 };
215
216 let current_factor = self.layer_lr_factors.get(layer_id).copied().unwrap_or(1.0);
218 let new_factor = 0.9 * current_factor + 0.1 * adaptation_factor;
219 self.layer_lr_factors.insert(layer_id.to_string(), new_factor);
220
221 base_lr * new_factor
222 }
223
224 fn compute_adaptive_weight_decay(&self, layer_size: usize) -> f32 {
226 if !self.config.adaptive_weight_decay {
227 return self.config.weight_decay;
228 }
229
230 let size_factor = (layer_size as f32).sqrt();
232 let adapted_wd = self.config.weight_decay / (1.0 + size_factor * 0.001);
233 adapted_wd.max(self.config.weight_decay * 0.1) }
235
236 pub fn memory_efficiency(&self) -> MemoryEfficiencyStats {
238 let traditional_adam_memory = self.total_parameters * 2 * std::mem::size_of::<f32>(); let novograd_memory = self.state.momentum.values().map(|v| v.len()).sum::<usize>()
240 * std::mem::size_of::<f32>()
241 + self.layer_second_moments.len() * std::mem::size_of::<f32>();
242
243 let memory_savings = if traditional_adam_memory > 0 {
244 1.0 - (novograd_memory as f32) / (traditional_adam_memory as f32)
245 } else {
246 0.0
247 };
248
249 MemoryEfficiencyStats {
250 traditional_adam_memory_bytes: traditional_adam_memory,
251 novograd_memory_bytes: novograd_memory,
252 memory_savings_ratio: memory_savings,
253 layer_count: self.layer_second_moments.len(),
254 average_layer_size: if !self.layer_second_moments.is_empty() {
255 self.total_parameters / self.layer_second_moments.len()
256 } else {
257 0
258 },
259 }
260 }
261
262 pub fn learning_rate(&self) -> f32 {
264 self.config.learning_rate
265 }
266
267 pub fn set_learning_rate(&mut self, lr: f32) {
269 self.config.learning_rate = lr;
270 }
271}
272
273#[derive(Debug, Clone)]
275pub struct MemoryEfficiencyStats {
276 pub traditional_adam_memory_bytes: usize,
277 pub novograd_memory_bytes: usize,
278 pub memory_savings_ratio: f32,
279 pub layer_count: usize,
280 pub average_layer_size: usize,
281}
282
283impl Optimizer for NovoGrad {
284 fn update(&mut self, _parameter: &mut Tensor, _gradient: &Tensor) -> Result<()> {
285 Ok(())
288 }
289
290 fn step(&mut self) {
291 self.current_step += 1;
293 self.state.step();
294 }
295
296 fn zero_grad(&mut self) {
297 }
299
300 fn get_lr(&self) -> f32 {
301 self.config.learning_rate
302 }
303
304 fn set_lr(&mut self, lr: f32) {
305 self.config.learning_rate = lr;
306 }
307}
308
309impl NovoGrad {
311 pub fn step_batch(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
313 self.current_step += 1;
314
315 for (param_name, gradient) in gradients.iter() {
316 let grad_data = gradient.data()?;
317 if grad_data.is_empty() {
318 continue;
319 }
320
321 let param_size = grad_data.len();
322 self.total_parameters = self
323 .total_parameters
324 .max(self.state.momentum.values().map(|v| v.len()).sum::<usize>() + param_size);
325
326 let mut clipped_grad = grad_data.clone();
328 if let Some(clip_value) = self.config.grad_clipping {
329 let grad_norm = self.compute_layer_grad_norm(&clipped_grad);
330 if grad_norm > clip_value {
331 let scale = clip_value / grad_norm;
332 for g in clipped_grad.iter_mut() {
333 *g *= scale;
334 }
335 }
336 }
337
338 let grad_norm = self.compute_layer_grad_norm(&clipped_grad);
340 self.layer_grad_norms.insert(param_name.clone(), grad_norm);
341
342 let prev_layer_v = self.layer_second_moments.get(param_name).copied().unwrap_or(0.0);
344 let layer_v = self.config.beta2 * prev_layer_v
345 + (1.0 - self.config.beta2) * grad_norm * grad_norm;
346
347 let momentum = {
349 let momentum = self.state.get_or_create_momentum(param_name.clone(), param_size);
350 momentum.clone()
351 };
352
353 let (bias_correction1, bias_correction2) = if self.config.bias_correction {
355 BiasCorrection::compute_adam_corrections(
356 self.config.beta1,
357 self.config.beta2,
358 self.current_step,
359 )
360 } else {
361 (1.0, 1.0)
362 };
363
364 let mut updated_momentum = momentum;
366 for i in 0..param_size {
367 ParameterUpdate::update_ema(
368 &mut updated_momentum[i],
369 clipped_grad[i],
370 self.config.beta1,
371 );
372 }
373
374 let adaptive_lr = self.compute_adaptive_lr(param_name, grad_norm);
376
377 let adaptive_wd = self.compute_adaptive_weight_decay(param_size);
379
380 let v_hat = layer_v / bias_correction2;
382 let layer_lr_scale = adaptive_lr / (v_hat.sqrt() + self.config.epsilon);
383
384 for i in 0..param_size {
386 let m_hat = updated_momentum[i] / bias_correction1;
387
388 let grad_with_wd = if adaptive_wd > 0.0 {
390 clipped_grad[i] + adaptive_wd * 0.0 } else {
393 clipped_grad[i]
394 };
395
396 let _update = layer_lr_scale * (m_hat + self.config.memory_factor * grad_with_wd);
398 }
401
402 self.state.momentum.insert(param_name.clone(), updated_momentum);
404 self.layer_second_moments.insert(param_name.clone(), layer_v);
405 }
406
407 Ok(())
408 }
409}
410
411impl StatefulOptimizer for NovoGrad {
412 type Config = NovoGradConfig;
413 type State = OptimizerState;
414
415 fn config(&self) -> &Self::Config {
416 &self.config
417 }
418
419 fn state(&self) -> &Self::State {
420 &self.state
421 }
422
423 fn state_mut(&mut self) -> &mut Self::State {
424 &mut self.state
425 }
426
427 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
428 let mut state = HashMap::new();
429
430 state.insert(
432 "step".to_string(),
433 Tensor::new(vec![self.current_step as f32])?,
434 );
435
436 for (name, momentum) in &self.state.momentum {
438 let shape = vec![momentum.len()];
439 state.insert(
440 format!("momentum_{}", name),
441 Tensor::from_vec(momentum.clone(), &shape)?,
442 );
443 }
444
445 for (name, v) in &self.layer_second_moments {
447 state.insert(format!("layer_v_{}", name), Tensor::new(vec![*v])?);
448 }
449
450 for (name, factor) in &self.layer_lr_factors {
452 state.insert(format!("lr_factor_{}", name), Tensor::new(vec![*factor])?);
453 }
454
455 Ok(state)
456 }
457
458 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
459 if let Some(step_tensor) = state.get("step") {
461 if let Ok(step_data) = step_tensor.data() {
462 if !step_data.is_empty() {
463 self.current_step = step_data[0] as usize;
464 self.state.step = self.current_step;
465 }
466 }
467 }
468
469 for (key, tensor) in &state {
471 if let Some(name) = key.strip_prefix("momentum_") {
472 if let Ok(data) = tensor.data() {
473 self.state.momentum.insert(name.to_string(), data);
474 }
475 } else if let Some(name) = key.strip_prefix("layer_v_") {
476 if let Ok(data) = tensor.data() {
477 if !data.is_empty() {
478 self.layer_second_moments.insert(name.to_string(), data[0]);
479 }
480 }
481 } else if let Some(name) = key.strip_prefix("lr_factor_") {
482 if let Ok(data) = tensor.data() {
483 if !data.is_empty() {
484 self.layer_lr_factors.insert(name.to_string(), data[0]);
485 }
486 }
487 }
488 }
489
490 Ok(())
491 }
492
493 fn memory_usage(&self) -> StateMemoryStats {
494 let momentum_elements: usize = self.state.momentum.values().map(|v| v.len()).sum();
495 let layer_elements = self.layer_second_moments.len() + self.layer_lr_factors.len();
496
497 StateMemoryStats {
498 momentum_elements,
499 variance_elements: 0, third_moment_elements: layer_elements, total_bytes: momentum_elements * std::mem::size_of::<f32>()
502 + layer_elements * std::mem::size_of::<f32>(),
503 num_parameters: self.state.momentum.len(),
504 }
505 }
506
507 fn reset_state(&mut self) {
508 self.state.clear();
509 self.layer_second_moments.clear();
510 self.layer_grad_norms.clear();
511 self.layer_lr_factors.clear();
512 self.current_step = 0;
513 self.total_parameters = 0;
514 }
515
516 fn num_parameters(&self) -> usize {
517 self.state.momentum.len()
518 }
519}
520
521#[derive(Debug, Clone)]
523pub struct NovoGradStats {
524 pub current_step: usize,
525 pub total_parameters: usize,
526 pub layer_count: usize,
527 pub average_grad_norm: f32,
528 pub max_grad_norm: f32,
529 pub min_grad_norm: f32,
530 pub memory_efficiency: MemoryEfficiencyStats,
531 pub adaptive_lr_range: (f32, f32), }
533
534impl NovoGrad {
535 pub fn reset(&mut self) {
537 self.reset_state();
538 }
539
540 pub fn get_stats(&self) -> NovoGradStats {
542 let grad_norms: Vec<f32> = self.layer_grad_norms.values().copied().collect();
543 let lr_factors: Vec<f32> = self.layer_lr_factors.values().copied().collect();
544
545 let avg_grad_norm = if !grad_norms.is_empty() {
546 grad_norms.iter().sum::<f32>() / grad_norms.len() as f32
547 } else {
548 0.0
549 };
550
551 let (min_grad_norm, max_grad_norm) = if !grad_norms.is_empty() {
552 let min = grad_norms.iter().fold(f32::INFINITY, |a, &b| a.min(b));
553 let max = grad_norms.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
554 (min, max)
555 } else {
556 (0.0, 0.0)
557 };
558
559 let adaptive_lr_range = if !lr_factors.is_empty() {
560 let min_factor = lr_factors.iter().fold(f32::INFINITY, |a, &b| a.min(b));
561 let max_factor = lr_factors.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
562 (
563 self.config.learning_rate * min_factor,
564 self.config.learning_rate * max_factor,
565 )
566 } else {
567 (self.config.learning_rate, self.config.learning_rate)
568 };
569
570 NovoGradStats {
571 current_step: self.current_step,
572 total_parameters: self.total_parameters,
573 layer_count: self.layer_second_moments.len(),
574 average_grad_norm: avg_grad_norm,
575 max_grad_norm,
576 min_grad_norm,
577 memory_efficiency: self.memory_efficiency(),
578 adaptive_lr_range,
579 }
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_novograd_creation() {
589 let optimizer = NovoGrad::new(NovoGradConfig::default());
590 assert_eq!(optimizer.learning_rate(), 1e-3);
591 assert_eq!(optimizer.config.beta1, 0.95);
592 assert_eq!(optimizer.config.beta2, 0.98);
593 }
594
595 #[test]
596 fn test_novograd_presets() {
597 let llm_opt = NovoGrad::for_large_language_models();
598 assert_eq!(llm_opt.config.beta2, 0.999);
599 assert_eq!(llm_opt.config.memory_factor, 0.9);
600
601 let vision_opt = NovoGrad::for_vision_models();
602 assert_eq!(vision_opt.config.beta1, 0.9);
603 assert!(!vision_opt.config.layer_wise_adaptation);
604
605 let memory_opt = NovoGrad::for_memory_constrained();
606 assert_eq!(memory_opt.config.memory_factor, 1.0);
607 assert!(!memory_opt.config.bias_correction);
608
609 let sci_opt = NovoGrad::for_scientific_computing();
610 assert_eq!(sci_opt.config.learning_rate, 1e-4);
611 assert_eq!(sci_opt.config.epsilon, 1e-10);
612 }
613
614 #[test]
615 fn test_layer_grad_norm_computation() {
616 let optimizer = NovoGrad::new(NovoGradConfig::default());
617 let gradient = vec![3.0, 4.0]; let norm = optimizer.compute_layer_grad_norm(&gradient);
619 assert!((norm - 5.0).abs() < 1e-6);
620 }
621
622 #[test]
623 fn test_adaptive_weight_decay() {
624 let optimizer = NovoGrad::new(NovoGradConfig {
625 adaptive_weight_decay: true,
626 weight_decay: 1e-4,
627 ..Default::default()
628 });
629
630 let small_layer_wd = optimizer.compute_adaptive_weight_decay(100);
631 let large_layer_wd = optimizer.compute_adaptive_weight_decay(10000);
632
633 assert!(large_layer_wd < small_layer_wd);
635 assert!(large_layer_wd >= 1e-5); }
637
638 #[test]
639 fn test_learning_rate_getter_setter() {
640 let mut optimizer = NovoGrad::new(NovoGradConfig::default());
641 assert_eq!(optimizer.learning_rate(), 1e-3);
642
643 optimizer.set_learning_rate(2e-3);
644 assert_eq!(optimizer.learning_rate(), 2e-3);
645 }
646
647 #[test]
648 fn test_memory_efficiency_tracking() {
649 let optimizer = NovoGrad::new(NovoGradConfig::default());
650 let efficiency = optimizer.memory_efficiency();
651
652 assert_eq!(efficiency.layer_count, 0);
653 assert_eq!(efficiency.average_layer_size, 0);
654 assert_eq!(efficiency.novograd_memory_bytes, 0);
655 }
656
657 #[test]
658 fn test_memory_usage_tracking() {
659 let optimizer = NovoGrad::new(NovoGradConfig::default());
660 let memory_stats = optimizer.memory_usage();
661
662 assert_eq!(memory_stats.momentum_elements, 0);
663 assert_eq!(memory_stats.variance_elements, 0); assert_eq!(memory_stats.num_parameters, 0);
665 }
666
667 #[test]
668 fn test_stats_generation() {
669 let optimizer = NovoGrad::new(NovoGradConfig::default());
670 let stats = optimizer.get_stats();
671
672 assert_eq!(stats.current_step, 0);
673 assert_eq!(stats.total_parameters, 0);
674 assert_eq!(stats.layer_count, 0);
675 assert_eq!(stats.average_grad_norm, 0.0);
676 }
677
678 #[test]
679 fn test_reset_functionality() {
680 let mut optimizer = NovoGrad::new(NovoGradConfig::default());
681 optimizer.current_step = 100;
682 optimizer.layer_second_moments.insert("test".to_string(), 0.5);
683
684 optimizer.reset();
685 assert_eq!(optimizer.current_step, 0);
686 assert!(optimizer.layer_second_moments.is_empty());
687 }
688
689 #[test]
690 fn test_state_dict_operations() {
691 let optimizer = NovoGrad::new(NovoGradConfig::default());
692 let state_dict = optimizer.state_dict();
693 assert!(state_dict.is_ok());
694
695 let state = state_dict.unwrap();
696 assert!(state.contains_key("step"));
697 }
698
699 #[test]
700 fn test_config_serialization() {
701 let config = NovoGradConfig::for_large_language_models();
702 let serialized = serde_json::to_string(&config);
703 assert!(serialized.is_ok());
704
705 let deserialized: std::result::Result<NovoGradConfig, _> =
706 serde_json::from_str(&serialized.unwrap());
707 assert!(deserialized.is_ok());
708 assert_eq!(deserialized.unwrap().beta2, 0.999);
709 }
710}