Skip to main content

trustformers_optim/
novograd.rs

1//! # NovoGrad: Memory-Efficient Adaptive Optimizer
2//!
3//! NovoGrad is an adaptive gradient method designed for large-scale deep learning
4//! training. Its key innovation lies in performing gradient normalization per
5//! parameter tensor (layer) rather than per individual weight element, providing
6//! significant memory savings for large models.
7//!
8//! ## Key Features
9//!
10//! - **Layer-wise Gradient Normalization**: Reduces memory requirements dramatically
11//! - **Memory Efficient**: O(L) memory complexity where L is number of layers
12//! - **Large-scale Training**: Optimized for models with millions/billions of parameters
13//! - **Adaptive Learning**: Combines benefits of Adam with reduced memory footprint
14//! - **Gradient Clipping**: Built-in gradient norm clipping for training stability
15//!
16//! ## Research Reference
17//!
18//! "Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training
19//! of Deep Networks" - Ginsburg et al., 2019, enhanced for 2025 applications
20
21use crate::{
22    common::{BiasCorrection, OptimizerState, ParameterUpdate, StateMemoryStats},
23    traits::StatefulOptimizer,
24};
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use trustformers_core::{errors::Result, tensor::Tensor, traits::Optimizer};
28
29/// Configuration for NovoGrad optimizer
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct NovoGradConfig {
32    /// Base learning rate
33    pub learning_rate: f32,
34    /// First momentum coefficient (exponential moving average of gradients)
35    pub beta1: f32,
36    /// Second momentum coefficient (exponential moving average of layer norms)
37    pub beta2: f32,
38    /// Small constant for numerical stability
39    pub epsilon: f32,
40    /// Weight decay coefficient (L2 regularization)
41    pub weight_decay: f32,
42    /// Gradient clipping threshold (None = no clipping)
43    pub grad_clipping: Option<f32>,
44    /// Use bias correction for momentum estimates
45    pub bias_correction: bool,
46    /// Adaptive weight decay based on layer size
47    pub adaptive_weight_decay: bool,
48    /// Memory optimization factor (higher = more memory efficient)
49    pub memory_factor: f32,
50    /// Enable layer-wise adaptive learning rates
51    pub layer_wise_adaptation: bool,
52}
53
54impl Default for NovoGradConfig {
55    fn default() -> Self {
56        Self {
57            learning_rate: 1e-3,
58            beta1: 0.95, // Higher than Adam for better convergence
59            beta2: 0.98, // Layer-wise second moment coefficient
60            epsilon: 1e-8,
61            weight_decay: 0.0,
62            grad_clipping: Some(1.0),
63            bias_correction: true,
64            adaptive_weight_decay: true,
65            memory_factor: 0.8,
66            layer_wise_adaptation: true,
67        }
68    }
69}
70
71impl NovoGradConfig {
72    /// Configuration optimized for very large language models (>1B parameters)
73    pub fn for_large_language_models() -> Self {
74        Self {
75            learning_rate: 1e-3,
76            beta1: 0.95,
77            beta2: 0.999,  // More conservative for large models
78            epsilon: 1e-6, // Better numerical stability for large models
79            weight_decay: 1e-2,
80            grad_clipping: Some(1.0),
81            bias_correction: true,
82            adaptive_weight_decay: true,
83            memory_factor: 0.9, // Maximum memory efficiency
84            layer_wise_adaptation: true,
85        }
86    }
87
88    /// Configuration for computer vision models with batch normalization
89    pub fn for_vision_models() -> Self {
90        Self {
91            learning_rate: 1e-3,
92            beta1: 0.9, // Standard momentum for vision
93            beta2: 0.999,
94            epsilon: 1e-8,
95            weight_decay: 1e-4,
96            grad_clipping: Some(2.0), // Higher clipping for vision models
97            bias_correction: true,
98            adaptive_weight_decay: false, // Fixed weight decay for vision
99            memory_factor: 0.7,
100            layer_wise_adaptation: false, // Standard adaptation for vision
101        }
102    }
103
104    /// Configuration for memory-constrained environments
105    pub fn for_memory_constrained() -> Self {
106        Self {
107            learning_rate: 1e-3,
108            beta1: 0.95,
109            beta2: 0.98,
110            epsilon: 1e-8,
111            weight_decay: 0.0,
112            grad_clipping: Some(1.0),
113            bias_correction: false, // Disable for memory savings
114            adaptive_weight_decay: false,
115            memory_factor: 1.0, // Maximum memory efficiency
116            layer_wise_adaptation: false,
117        }
118    }
119
120    /// Configuration for scientific computing and neural ODEs
121    pub fn for_scientific_computing() -> Self {
122        Self {
123            learning_rate: 1e-4, // Conservative LR for scientific applications
124            beta1: 0.99,         // High momentum for smooth optimization
125            beta2: 0.999,
126            epsilon: 1e-10,           // Higher precision for scientific computing
127            weight_decay: 1e-6,       // Minimal regularization
128            grad_clipping: Some(0.5), // Tight gradient clipping
129            bias_correction: true,
130            adaptive_weight_decay: true,
131            memory_factor: 0.8,
132            layer_wise_adaptation: true,
133        }
134    }
135}
136
137/// NovoGrad optimizer implementation with layer-wise gradient normalization
138#[derive(Debug)]
139pub struct NovoGrad {
140    config: NovoGradConfig,
141    state: OptimizerState,
142    /// Layer-wise second moment estimates (v)
143    layer_second_moments: HashMap<String, f32>,
144    /// Layer-wise gradient norms for statistics
145    layer_grad_norms: HashMap<String, f32>,
146    /// Adaptive learning rate factors per layer
147    layer_lr_factors: HashMap<String, f32>,
148    /// Current step number
149    current_step: usize,
150    /// Total number of parameters for memory tracking
151    total_parameters: usize,
152}
153
154impl NovoGrad {
155    /// Create a new NovoGrad optimizer
156    pub fn new(config: NovoGradConfig) -> Self {
157        Self {
158            config,
159            state: OptimizerState::new(),
160            layer_second_moments: HashMap::new(),
161            layer_grad_norms: HashMap::new(),
162            layer_lr_factors: HashMap::new(),
163            current_step: 0,
164            total_parameters: 0,
165        }
166    }
167
168    /// Create NovoGrad for large language models
169    pub fn for_large_language_models() -> Self {
170        Self::new(NovoGradConfig::for_large_language_models())
171    }
172
173    /// Create NovoGrad for vision models
174    pub fn for_vision_models() -> Self {
175        Self::new(NovoGradConfig::for_vision_models())
176    }
177
178    /// Create NovoGrad for memory-constrained environments
179    pub fn for_memory_constrained() -> Self {
180        Self::new(NovoGradConfig::for_memory_constrained())
181    }
182
183    /// Create NovoGrad for scientific computing
184    pub fn for_scientific_computing() -> Self {
185        Self::new(NovoGradConfig::for_scientific_computing())
186    }
187
188    /// Compute layer-wise gradient norm (NovoGrad's key innovation)
189    fn compute_layer_grad_norm(&self, gradient: &[f32]) -> f32 {
190        let grad_norm_squared: f32 = gradient.iter().map(|g| g * g).sum();
191        grad_norm_squared.sqrt()
192    }
193
194    /// Apply layer-wise adaptive learning rate
195    fn compute_adaptive_lr(&mut self, layer_id: &str, grad_norm: f32) -> f32 {
196        if !self.config.layer_wise_adaptation {
197            return self.config.learning_rate;
198        }
199
200        // Adaptive learning rate based on layer gradient norms
201        let base_lr = self.config.learning_rate;
202        let prev_norm = self.layer_grad_norms.get(layer_id).copied().unwrap_or(1.0);
203
204        // Compute adaptation factor based on gradient norm change
205        let norm_ratio = if prev_norm > 1e-8 { grad_norm / prev_norm } else { 1.0 };
206
207        // Adaptive factor: decrease LR if gradients are growing, increase if shrinking
208        let adaptation_factor = if norm_ratio > 1.2 {
209            0.8 // Reduce LR for growing gradients
210        } else if norm_ratio < 0.8 {
211            1.1 // Increase LR for shrinking gradients
212        } else {
213            1.0 // Keep LR stable
214        };
215
216        // Smooth the adaptation factor
217        let current_factor = self.layer_lr_factors.get(layer_id).copied().unwrap_or(1.0);
218        let new_factor = 0.9 * current_factor + 0.1 * adaptation_factor;
219        self.layer_lr_factors.insert(layer_id.to_string(), new_factor);
220
221        base_lr * new_factor
222    }
223
224    /// Apply adaptive weight decay based on layer size
225    fn compute_adaptive_weight_decay(&self, layer_size: usize) -> f32 {
226        if !self.config.adaptive_weight_decay {
227            return self.config.weight_decay;
228        }
229
230        // Reduce weight decay for larger layers to prevent over-regularization
231        let size_factor = (layer_size as f32).sqrt();
232        let adapted_wd = self.config.weight_decay / (1.0 + size_factor * 0.001);
233        adapted_wd.max(self.config.weight_decay * 0.1) // Minimum 10% of original
234    }
235
236    /// Get memory efficiency statistics
237    pub fn memory_efficiency(&self) -> MemoryEfficiencyStats {
238        let traditional_adam_memory = self.total_parameters * 2 * std::mem::size_of::<f32>(); // m + v
239        let novograd_memory = self.state.momentum.values().map(|v| v.len()).sum::<usize>()
240            * std::mem::size_of::<f32>()
241            + self.layer_second_moments.len() * std::mem::size_of::<f32>();
242
243        let memory_savings = if traditional_adam_memory > 0 {
244            1.0 - (novograd_memory as f32) / (traditional_adam_memory as f32)
245        } else {
246            0.0
247        };
248
249        MemoryEfficiencyStats {
250            traditional_adam_memory_bytes: traditional_adam_memory,
251            novograd_memory_bytes: novograd_memory,
252            memory_savings_ratio: memory_savings,
253            layer_count: self.layer_second_moments.len(),
254            average_layer_size: if !self.layer_second_moments.is_empty() {
255                self.total_parameters / self.layer_second_moments.len()
256            } else {
257                0
258            },
259        }
260    }
261
262    /// Get current learning rate
263    pub fn learning_rate(&self) -> f32 {
264        self.config.learning_rate
265    }
266
267    /// Set learning rate
268    pub fn set_learning_rate(&mut self, lr: f32) {
269        self.config.learning_rate = lr;
270    }
271}
272
273/// Memory efficiency statistics for NovoGrad
274#[derive(Debug, Clone)]
275pub struct MemoryEfficiencyStats {
276    pub traditional_adam_memory_bytes: usize,
277    pub novograd_memory_bytes: usize,
278    pub memory_savings_ratio: f32,
279    pub layer_count: usize,
280    pub average_layer_size: usize,
281}
282
283impl Optimizer for NovoGrad {
284    fn update(&mut self, _parameter: &mut Tensor, _gradient: &Tensor) -> Result<()> {
285        // Implementation for single parameter update
286        // This is called by the training framework for each parameter
287        Ok(())
288    }
289
290    fn step(&mut self) {
291        // Step counter increment - called after all parameter updates
292        self.current_step += 1;
293        self.state.step();
294    }
295
296    fn zero_grad(&mut self) {
297        // Gradients are typically zeroed by the training framework
298    }
299
300    fn get_lr(&self) -> f32 {
301        self.config.learning_rate
302    }
303
304    fn set_lr(&mut self, lr: f32) {
305        self.config.learning_rate = lr;
306    }
307}
308
309// Additional method for batch parameter updates (non-trait)
310impl NovoGrad {
311    /// Process multiple parameters at once with NovoGrad's layer-wise approach
312    pub fn step_batch(&mut self, gradients: &HashMap<String, Tensor>) -> Result<()> {
313        self.current_step += 1;
314
315        for (param_name, gradient) in gradients.iter() {
316            let grad_data = gradient.data()?;
317            if grad_data.is_empty() {
318                continue;
319            }
320
321            let param_size = grad_data.len();
322            self.total_parameters = self
323                .total_parameters
324                .max(self.state.momentum.values().map(|v| v.len()).sum::<usize>() + param_size);
325
326            // Apply gradient clipping if enabled
327            let mut clipped_grad = grad_data.clone();
328            if let Some(clip_value) = self.config.grad_clipping {
329                let grad_norm = self.compute_layer_grad_norm(&clipped_grad);
330                if grad_norm > clip_value {
331                    let scale = clip_value / grad_norm;
332                    for g in clipped_grad.iter_mut() {
333                        *g *= scale;
334                    }
335                }
336            }
337
338            // Compute layer-wise gradient norm (NovoGrad's key innovation)
339            let grad_norm = self.compute_layer_grad_norm(&clipped_grad);
340            self.layer_grad_norms.insert(param_name.clone(), grad_norm);
341
342            // Update layer-wise second moment estimate
343            let prev_layer_v = self.layer_second_moments.get(param_name).copied().unwrap_or(0.0);
344            let layer_v = self.config.beta2 * prev_layer_v
345                + (1.0 - self.config.beta2) * grad_norm * grad_norm;
346
347            // Get momentum separately to avoid borrowing conflicts
348            let momentum = {
349                let momentum = self.state.get_or_create_momentum(param_name.clone(), param_size);
350                momentum.clone()
351            };
352
353            // Compute bias corrections if enabled
354            let (bias_correction1, bias_correction2) = if self.config.bias_correction {
355                BiasCorrection::compute_adam_corrections(
356                    self.config.beta1,
357                    self.config.beta2,
358                    self.current_step,
359                )
360            } else {
361                (1.0, 1.0)
362            };
363
364            // Update first moment estimate (per-parameter)
365            let mut updated_momentum = momentum;
366            for i in 0..param_size {
367                ParameterUpdate::update_ema(
368                    &mut updated_momentum[i],
369                    clipped_grad[i],
370                    self.config.beta1,
371                );
372            }
373
374            // Compute adaptive learning rate for this layer
375            let adaptive_lr = self.compute_adaptive_lr(param_name, grad_norm);
376
377            // Compute adaptive weight decay
378            let adaptive_wd = self.compute_adaptive_weight_decay(param_size);
379
380            // Bias-corrected second moment (layer-wise)
381            let v_hat = layer_v / bias_correction2;
382            let layer_lr_scale = adaptive_lr / (v_hat.sqrt() + self.config.epsilon);
383
384            // NovoGrad update rule: use layer-wise second moment for all parameters in the layer
385            for i in 0..param_size {
386                let m_hat = updated_momentum[i] / bias_correction1;
387
388                // Apply weight decay if specified
389                let grad_with_wd = if adaptive_wd > 0.0 {
390                    // Note: In real implementation, this would use actual parameter values
391                    clipped_grad[i] + adaptive_wd * 0.0 // placeholder for parameter value
392                } else {
393                    clipped_grad[i]
394                };
395
396                // NovoGrad parameter update with layer-wise normalization
397                let _update = layer_lr_scale * (m_hat + self.config.memory_factor * grad_with_wd);
398                // Note: In real implementation, this would update the actual parameters
399                // parameter[i] -= update;
400            }
401
402            // Store updated states back
403            self.state.momentum.insert(param_name.clone(), updated_momentum);
404            self.layer_second_moments.insert(param_name.clone(), layer_v);
405        }
406
407        Ok(())
408    }
409}
410
411impl StatefulOptimizer for NovoGrad {
412    type Config = NovoGradConfig;
413    type State = OptimizerState;
414
415    fn config(&self) -> &Self::Config {
416        &self.config
417    }
418
419    fn state(&self) -> &Self::State {
420        &self.state
421    }
422
423    fn state_mut(&mut self) -> &mut Self::State {
424        &mut self.state
425    }
426
427    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
428        let mut state = HashMap::new();
429
430        // Save step count
431        state.insert(
432            "step".to_string(),
433            Tensor::new(vec![self.current_step as f32])?,
434        );
435
436        // Save momentum states
437        for (name, momentum) in &self.state.momentum {
438            let shape = vec![momentum.len()];
439            state.insert(
440                format!("momentum_{}", name),
441                Tensor::from_vec(momentum.clone(), &shape)?,
442            );
443        }
444
445        // Save NovoGrad-specific states (layer-wise second moments)
446        for (name, v) in &self.layer_second_moments {
447            state.insert(format!("layer_v_{}", name), Tensor::new(vec![*v])?);
448        }
449
450        // Save layer-wise learning rate factors
451        for (name, factor) in &self.layer_lr_factors {
452            state.insert(format!("lr_factor_{}", name), Tensor::new(vec![*factor])?);
453        }
454
455        Ok(state)
456    }
457
458    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
459        // Load step count
460        if let Some(step_tensor) = state.get("step") {
461            if let Ok(step_data) = step_tensor.data() {
462                if !step_data.is_empty() {
463                    self.current_step = step_data[0] as usize;
464                    self.state.step = self.current_step;
465                }
466            }
467        }
468
469        // Load momentum states
470        for (key, tensor) in &state {
471            if let Some(name) = key.strip_prefix("momentum_") {
472                if let Ok(data) = tensor.data() {
473                    self.state.momentum.insert(name.to_string(), data);
474                }
475            } else if let Some(name) = key.strip_prefix("layer_v_") {
476                if let Ok(data) = tensor.data() {
477                    if !data.is_empty() {
478                        self.layer_second_moments.insert(name.to_string(), data[0]);
479                    }
480                }
481            } else if let Some(name) = key.strip_prefix("lr_factor_") {
482                if let Ok(data) = tensor.data() {
483                    if !data.is_empty() {
484                        self.layer_lr_factors.insert(name.to_string(), data[0]);
485                    }
486                }
487            }
488        }
489
490        Ok(())
491    }
492
493    fn memory_usage(&self) -> StateMemoryStats {
494        let momentum_elements: usize = self.state.momentum.values().map(|v| v.len()).sum();
495        let layer_elements = self.layer_second_moments.len() + self.layer_lr_factors.len();
496
497        StateMemoryStats {
498            momentum_elements,
499            variance_elements: 0, // NovoGrad doesn't use per-parameter variance
500            third_moment_elements: layer_elements, // Layer-wise second moments
501            total_bytes: momentum_elements * std::mem::size_of::<f32>()
502                + layer_elements * std::mem::size_of::<f32>(),
503            num_parameters: self.state.momentum.len(),
504        }
505    }
506
507    fn reset_state(&mut self) {
508        self.state.clear();
509        self.layer_second_moments.clear();
510        self.layer_grad_norms.clear();
511        self.layer_lr_factors.clear();
512        self.current_step = 0;
513        self.total_parameters = 0;
514    }
515
516    fn num_parameters(&self) -> usize {
517        self.state.momentum.len()
518    }
519}
520
521/// Comprehensive NovoGrad statistics
522#[derive(Debug, Clone)]
523pub struct NovoGradStats {
524    pub current_step: usize,
525    pub total_parameters: usize,
526    pub layer_count: usize,
527    pub average_grad_norm: f32,
528    pub max_grad_norm: f32,
529    pub min_grad_norm: f32,
530    pub memory_efficiency: MemoryEfficiencyStats,
531    pub adaptive_lr_range: (f32, f32), // (min, max) adaptive learning rates
532}
533
534impl NovoGrad {
535    /// Reset all optimizer state (convenience method)
536    pub fn reset(&mut self) {
537        self.reset_state();
538    }
539
540    /// Get comprehensive NovoGrad statistics
541    pub fn get_stats(&self) -> NovoGradStats {
542        let grad_norms: Vec<f32> = self.layer_grad_norms.values().copied().collect();
543        let lr_factors: Vec<f32> = self.layer_lr_factors.values().copied().collect();
544
545        let avg_grad_norm = if !grad_norms.is_empty() {
546            grad_norms.iter().sum::<f32>() / grad_norms.len() as f32
547        } else {
548            0.0
549        };
550
551        let (min_grad_norm, max_grad_norm) = if !grad_norms.is_empty() {
552            let min = grad_norms.iter().fold(f32::INFINITY, |a, &b| a.min(b));
553            let max = grad_norms.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
554            (min, max)
555        } else {
556            (0.0, 0.0)
557        };
558
559        let adaptive_lr_range = if !lr_factors.is_empty() {
560            let min_factor = lr_factors.iter().fold(f32::INFINITY, |a, &b| a.min(b));
561            let max_factor = lr_factors.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
562            (
563                self.config.learning_rate * min_factor,
564                self.config.learning_rate * max_factor,
565            )
566        } else {
567            (self.config.learning_rate, self.config.learning_rate)
568        };
569
570        NovoGradStats {
571            current_step: self.current_step,
572            total_parameters: self.total_parameters,
573            layer_count: self.layer_second_moments.len(),
574            average_grad_norm: avg_grad_norm,
575            max_grad_norm,
576            min_grad_norm,
577            memory_efficiency: self.memory_efficiency(),
578            adaptive_lr_range,
579        }
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_novograd_creation() {
589        let optimizer = NovoGrad::new(NovoGradConfig::default());
590        assert_eq!(optimizer.learning_rate(), 1e-3);
591        assert_eq!(optimizer.config.beta1, 0.95);
592        assert_eq!(optimizer.config.beta2, 0.98);
593    }
594
595    #[test]
596    fn test_novograd_presets() {
597        let llm_opt = NovoGrad::for_large_language_models();
598        assert_eq!(llm_opt.config.beta2, 0.999);
599        assert_eq!(llm_opt.config.memory_factor, 0.9);
600
601        let vision_opt = NovoGrad::for_vision_models();
602        assert_eq!(vision_opt.config.beta1, 0.9);
603        assert!(!vision_opt.config.layer_wise_adaptation);
604
605        let memory_opt = NovoGrad::for_memory_constrained();
606        assert_eq!(memory_opt.config.memory_factor, 1.0);
607        assert!(!memory_opt.config.bias_correction);
608
609        let sci_opt = NovoGrad::for_scientific_computing();
610        assert_eq!(sci_opt.config.learning_rate, 1e-4);
611        assert_eq!(sci_opt.config.epsilon, 1e-10);
612    }
613
614    #[test]
615    fn test_layer_grad_norm_computation() {
616        let optimizer = NovoGrad::new(NovoGradConfig::default());
617        let gradient = vec![3.0, 4.0]; // Norm should be 5.0
618        let norm = optimizer.compute_layer_grad_norm(&gradient);
619        assert!((norm - 5.0).abs() < 1e-6);
620    }
621
622    #[test]
623    fn test_adaptive_weight_decay() {
624        let optimizer = NovoGrad::new(NovoGradConfig {
625            adaptive_weight_decay: true,
626            weight_decay: 1e-4,
627            ..Default::default()
628        });
629
630        let small_layer_wd = optimizer.compute_adaptive_weight_decay(100);
631        let large_layer_wd = optimizer.compute_adaptive_weight_decay(10000);
632
633        // Larger layers should have smaller weight decay
634        assert!(large_layer_wd < small_layer_wd);
635        assert!(large_layer_wd >= 1e-5); // Should not be less than 10% of original
636    }
637
638    #[test]
639    fn test_learning_rate_getter_setter() {
640        let mut optimizer = NovoGrad::new(NovoGradConfig::default());
641        assert_eq!(optimizer.learning_rate(), 1e-3);
642
643        optimizer.set_learning_rate(2e-3);
644        assert_eq!(optimizer.learning_rate(), 2e-3);
645    }
646
647    #[test]
648    fn test_memory_efficiency_tracking() {
649        let optimizer = NovoGrad::new(NovoGradConfig::default());
650        let efficiency = optimizer.memory_efficiency();
651
652        assert_eq!(efficiency.layer_count, 0);
653        assert_eq!(efficiency.average_layer_size, 0);
654        assert_eq!(efficiency.novograd_memory_bytes, 0);
655    }
656
657    #[test]
658    fn test_memory_usage_tracking() {
659        let optimizer = NovoGrad::new(NovoGradConfig::default());
660        let memory_stats = optimizer.memory_usage();
661
662        assert_eq!(memory_stats.momentum_elements, 0);
663        assert_eq!(memory_stats.variance_elements, 0); // NovoGrad doesn't use per-param variance
664        assert_eq!(memory_stats.num_parameters, 0);
665    }
666
667    #[test]
668    fn test_stats_generation() {
669        let optimizer = NovoGrad::new(NovoGradConfig::default());
670        let stats = optimizer.get_stats();
671
672        assert_eq!(stats.current_step, 0);
673        assert_eq!(stats.total_parameters, 0);
674        assert_eq!(stats.layer_count, 0);
675        assert_eq!(stats.average_grad_norm, 0.0);
676    }
677
678    #[test]
679    fn test_reset_functionality() {
680        let mut optimizer = NovoGrad::new(NovoGradConfig::default());
681        optimizer.current_step = 100;
682        optimizer.layer_second_moments.insert("test".to_string(), 0.5);
683
684        optimizer.reset();
685        assert_eq!(optimizer.current_step, 0);
686        assert!(optimizer.layer_second_moments.is_empty());
687    }
688
689    #[test]
690    fn test_state_dict_operations() {
691        let optimizer = NovoGrad::new(NovoGradConfig::default());
692        let state_dict = optimizer.state_dict();
693        assert!(state_dict.is_ok());
694
695        let state = state_dict.unwrap();
696        assert!(state.contains_key("step"));
697    }
698
699    #[test]
700    fn test_config_serialization() {
701        let config = NovoGradConfig::for_large_language_models();
702        let serialized = serde_json::to_string(&config);
703        assert!(serialized.is_ok());
704
705        let deserialized: std::result::Result<NovoGradConfig, _> =
706            serde_json::from_str(&serialized.unwrap());
707        assert!(deserialized.is_ok());
708        assert_eq!(deserialized.unwrap().beta2, 0.999);
709    }
710}