Skip to main content

trustformers_optim/
prodigy.rs

1//! # Prodigy Optimizer
2//!
3//! Implementation of the Prodigy optimizer, a cutting-edge 2024 optimization algorithm
4//! that adaptively estimates the distance to optimality and adjusts learning rates accordingly.
5//!
6//! Prodigy often outperforms Adam and other optimizers without requiring manual learning rate tuning.
7//!
8//! ## Key Features
9//!
10//! - **Adaptive Learning Rate**: Automatically estimates optimal learning rate without manual tuning
11//! - **Distance Estimation**: Estimates distance to optimality for better convergence
12//! - **Superior Performance**: Often outperforms Adam, AdamW, and other optimizers
13//! - **No LR Scheduling**: Eliminates need for learning rate schedules
14//! - **Robust Convergence**: Stable convergence across different problem types
15//!
16//! ## Research Foundation
17//!
18//! Based on "Prodigy: An Expeditiously Adaptive Parameter-Free Learner" and related research
19//! demonstrating superior convergence properties and automatic learning rate adaptation.
20//!
21//! ## Example Usage
22//!
23//! ```rust
24//! use trustformers_optim::prodigy::{Prodigy, ProdigyConfig};
25//!
26//! // Create with default configuration (no learning rate needed!)
27//! let optimizer = Prodigy::new();
28//!
29//! // Or customize configuration
30//! let config = ProdigyConfig {
31//!     d0: 1e-6,           // Initial distance estimate
32//!     beta1: 0.9,         // Momentum coefficient
33//!     beta2: 0.999,       // Variance coefficient
34//!     eps: 1e-8,          // Numerical stability
35//!     weight_decay: 0.01, // L2 regularization
36//!     growth_rate: 1.02,  // Distance growth rate
37//!     ..Default::default()
38//! };
39//! let optimizer = Prodigy::with_config(config);
40//! ```
41
42use crate::traits::StatefulOptimizer;
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45use trustformers_core::errors::{Result, TrustformersError};
46use trustformers_core::tensor::Tensor;
47use trustformers_core::traits::Optimizer;
48
49/// Configuration for the Prodigy optimizer.
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ProdigyConfig {
52    /// Initial distance estimate (d0)
53    pub d0: f64,
54    /// Momentum coefficient for first moment (β1)
55    pub beta1: f64,
56    /// Momentum coefficient for second moment (β2)
57    pub beta2: f64,
58    /// Numerical stability constant (ε)
59    pub eps: f64,
60    /// Weight decay coefficient
61    pub weight_decay: f64,
62    /// Growth rate for distance estimation
63    pub growth_rate: f64,
64    /// Warmup steps for stability
65    pub warmup_steps: usize,
66    /// Use bias correction
67    pub bias_correction: bool,
68    /// Safeguard bound for distance estimation
69    pub safeguard_bound: f64,
70}
71
72impl Default for ProdigyConfig {
73    fn default() -> Self {
74        Self {
75            d0: 1e-6,
76            beta1: 0.9,
77            beta2: 0.999,
78            eps: 1e-8,
79            weight_decay: 0.0,
80            growth_rate: 1.02,
81            warmup_steps: 0,
82            bias_correction: true,
83            safeguard_bound: 2.0,
84        }
85    }
86}
87
88impl ProdigyConfig {
89    /// Configuration optimized for language model training.
90    pub fn for_language_models() -> Self {
91        Self {
92            d0: 1e-6,
93            beta1: 0.9,
94            beta2: 0.999,
95            eps: 1e-8,
96            weight_decay: 0.1,
97            growth_rate: 1.02,
98            warmup_steps: 1000,
99            bias_correction: true,
100            safeguard_bound: 2.0,
101        }
102    }
103
104    /// Configuration optimized for computer vision tasks.
105    pub fn for_vision() -> Self {
106        Self {
107            d0: 1e-6,
108            beta1: 0.9,
109            beta2: 0.999,
110            eps: 1e-8,
111            weight_decay: 0.05,
112            growth_rate: 1.01,
113            warmup_steps: 100,
114            bias_correction: true,
115            safeguard_bound: 1.5,
116        }
117    }
118
119    /// Configuration for fast training with aggressive adaptation.
120    pub fn for_fast_training() -> Self {
121        Self {
122            d0: 1e-5,
123            beta1: 0.9,
124            beta2: 0.99,
125            eps: 1e-8,
126            weight_decay: 0.01,
127            growth_rate: 1.05,
128            warmup_steps: 10,
129            bias_correction: false,
130            safeguard_bound: 3.0,
131        }
132    }
133
134    /// Configuration for stable, conservative training.
135    pub fn for_stable_training() -> Self {
136        Self {
137            d0: 1e-7,
138            beta1: 0.95,
139            beta2: 0.9999,
140            eps: 1e-8,
141            weight_decay: 0.001,
142            growth_rate: 1.005,
143            warmup_steps: 2000,
144            bias_correction: true,
145            safeguard_bound: 1.2,
146        }
147    }
148}
149
150/// Optimizer state for individual parameters.
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ProdigyParameterState {
153    /// First moment estimate (momentum)
154    pub momentum: Vec<f32>,
155    /// Second moment estimate (variance)
156    pub variance: Vec<f32>,
157    /// Current distance estimate
158    pub distance: f64,
159    /// Step count for this parameter
160    pub step: usize,
161}
162
163impl ProdigyParameterState {
164    pub fn new(param_size: usize, initial_distance: f64) -> Self {
165        Self {
166            momentum: vec![0.0; param_size],
167            variance: vec![0.0; param_size],
168            distance: initial_distance,
169            step: 0,
170        }
171    }
172
173    /// Get memory usage statistics for this parameter state.
174    pub fn memory_usage(&self) -> ProdigyMemoryStats {
175        let momentum_bytes = self.momentum.len() * std::mem::size_of::<f32>();
176        let variance_bytes = self.variance.len() * std::mem::size_of::<f32>();
177        let metadata_bytes = std::mem::size_of::<f64>() + std::mem::size_of::<usize>();
178
179        ProdigyMemoryStats {
180            momentum_bytes,
181            variance_bytes,
182            metadata_bytes,
183            total_bytes: momentum_bytes + variance_bytes + metadata_bytes,
184        }
185    }
186}
187
188/// Memory usage statistics for Prodigy optimizer.
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct ProdigyMemoryStats {
191    pub momentum_bytes: usize,
192    pub variance_bytes: usize,
193    pub metadata_bytes: usize,
194    pub total_bytes: usize,
195}
196
197/// Global optimizer state containing all parameters.
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct ProdigyOptimizerState {
200    /// Per-parameter states
201    pub parameters: HashMap<String, ProdigyParameterState>,
202    /// Global step count
203    pub global_step: usize,
204    /// Global distance estimate
205    pub global_distance: f64,
206    /// Distance growth history for adaptive adjustment
207    pub distance_history: Vec<f64>,
208}
209
210impl Default for ProdigyOptimizerState {
211    fn default() -> Self {
212        Self {
213            parameters: HashMap::new(),
214            global_step: 0,
215            global_distance: 1e-6,
216            distance_history: Vec::new(),
217        }
218    }
219}
220
221impl ProdigyOptimizerState {
222    /// Clear all optimizer state.
223    pub fn clear(&mut self) {
224        self.parameters.clear();
225        self.global_step = 0;
226        self.global_distance = 1e-6;
227        self.distance_history.clear();
228    }
229
230    /// Get total memory usage across all parameters.
231    pub fn total_memory_usage(&self) -> ProdigyMemoryStats {
232        let mut total_momentum = 0;
233        let mut total_variance = 0;
234        let mut total_metadata = 0;
235
236        for param_state in self.parameters.values() {
237            let stats = param_state.memory_usage();
238            total_momentum += stats.momentum_bytes;
239            total_variance += stats.variance_bytes;
240            total_metadata += stats.metadata_bytes;
241        }
242
243        // Add global state memory
244        total_metadata += std::mem::size_of::<usize>()
245            + std::mem::size_of::<f64>()
246            + self.distance_history.len() * std::mem::size_of::<f64>();
247
248        ProdigyMemoryStats {
249            momentum_bytes: total_momentum,
250            variance_bytes: total_variance,
251            metadata_bytes: total_metadata,
252            total_bytes: total_momentum + total_variance + total_metadata,
253        }
254    }
255}
256
257/// Prodigy optimizer with adaptive learning rate estimation.
258pub struct Prodigy {
259    config: ProdigyConfig,
260    state: ProdigyOptimizerState,
261}
262
263impl Prodigy {
264    /// Create a new Prodigy optimizer with default configuration.
265    pub fn new() -> Self {
266        Self {
267            config: ProdigyConfig::default(),
268            state: ProdigyOptimizerState::default(),
269        }
270    }
271
272    /// Create Prodigy optimizer with custom configuration.
273    pub fn with_config(config: ProdigyConfig) -> Self {
274        let state = ProdigyOptimizerState {
275            global_distance: config.d0,
276            ..Default::default()
277        };
278
279        Self { config, state }
280    }
281
282    /// Create Prodigy optimizer optimized for language models.
283    pub fn for_language_models() -> Self {
284        Self::with_config(ProdigyConfig::for_language_models())
285    }
286
287    /// Create Prodigy optimizer optimized for computer vision.
288    pub fn for_vision() -> Self {
289        Self::with_config(ProdigyConfig::for_vision())
290    }
291
292    /// Create Prodigy optimizer for fast training.
293    pub fn for_fast_training() -> Self {
294        Self::with_config(ProdigyConfig::for_fast_training())
295    }
296
297    /// Create Prodigy optimizer for stable training.
298    pub fn for_stable_training() -> Self {
299        Self::with_config(ProdigyConfig::for_stable_training())
300    }
301
302    /// Get current global learning rate estimate.
303    pub fn get_lr(&self) -> f64 {
304        self.state.global_distance
305    }
306
307    /// Set global distance estimate (equivalent to learning rate).
308    pub fn set_lr(&mut self, distance: f64) {
309        self.state.global_distance = distance.max(1e-10);
310    }
311
312    /// Reset optimizer state.
313    pub fn reset(&mut self) {
314        self.state.clear();
315        self.state.global_distance = self.config.d0;
316    }
317
318    /// Get memory usage statistics.
319    pub fn memory_usage(&self) -> ProdigyMemoryStats {
320        self.state.total_memory_usage()
321    }
322
323    /// Update distance estimate based on gradient and parameter norms.
324    fn update_distance_estimate(&mut self, grad_norm: f64, param_norm: f64) {
325        if grad_norm > 0.0 && param_norm > 0.0 {
326            // Estimate distance to optimality using gradient and parameter norms
327            let distance_estimate = (param_norm / grad_norm).min(self.config.safeguard_bound);
328
329            // Apply exponential moving average for stability
330            let alpha = 0.01; // Smoothing factor
331            self.state.global_distance = (1.0 - alpha) * self.state.global_distance
332                + alpha * distance_estimate * self.config.growth_rate;
333
334            // Keep history for adaptive adjustment
335            self.state.distance_history.push(self.state.global_distance);
336            if self.state.distance_history.len() > 100 {
337                self.state.distance_history.remove(0);
338            }
339        }
340    }
341
342    /// Compute bias correction factors.
343    #[allow(dead_code)]
344    fn bias_correction(&self, step: usize) -> (f64, f64) {
345        if self.config.bias_correction && step > 0 {
346            let beta1_correction = 1.0 - self.config.beta1.powi(step as i32);
347            let beta2_correction = 1.0 - self.config.beta2.powi(step as i32);
348            (beta1_correction, beta2_correction)
349        } else {
350            (1.0, 1.0)
351        }
352    }
353
354    /// Apply warmup scaling to learning rate.
355    #[allow(dead_code)]
356    fn warmup_scaling(&self, step: usize) -> f64 {
357        if self.config.warmup_steps > 0 && step < self.config.warmup_steps {
358            (step as f64 + 1.0) / (self.config.warmup_steps as f64)
359        } else {
360            1.0
361        }
362    }
363
364    /// Updates a named parameter with its gradient.
365    pub fn update_parameter(
366        &mut self,
367        param_name: &str,
368        param: &mut Tensor,
369        grad: &Tensor,
370    ) -> Result<()> {
371        let mut param_data = param.data().map_err(|e| {
372            TrustformersError::tensor_op_error(
373                &format!("Failed to get parameter data: {}", e),
374                "prodigy_update",
375            )
376        })?;
377        let grad_data = grad.data().map_err(|e| {
378            TrustformersError::tensor_op_error(
379                &format!("Failed to get gradient data: {}", e),
380                "prodigy_update",
381            )
382        })?;
383
384        if param_data.len() != grad_data.len() {
385            return Err(TrustformersError::tensor_op_error(
386                "Parameter and gradient size mismatch",
387                "prodigy_update",
388            ));
389        }
390
391        // Get or create parameter state
392        let param_size = param_data.len();
393
394        // Compute gradient and parameter norms for distance estimation
395        let grad_norm: f64 = grad_data.iter().map(|&g| (g as f64).powi(2)).sum::<f64>().sqrt();
396        let param_norm: f64 = param_data.iter().map(|&p| (p as f64).powi(2)).sum::<f64>().sqrt();
397
398        // Update global distance estimate first (before borrowing param_state)
399        self.update_distance_estimate(grad_norm, param_norm);
400
401        // Now get or create parameter state
402        let param_state = self
403            .state
404            .parameters
405            .entry(param_name.to_string())
406            .or_insert_with(|| ProdigyParameterState::new(param_size, self.config.d0));
407
408        // Resize state if needed
409        if param_state.momentum.len() != param_size {
410            param_state.momentum.resize(param_size, 0.0);
411            param_state.variance.resize(param_size, 0.0);
412        }
413
414        param_state.step += 1;
415        let current_step = param_state.step;
416
417        // Apply warmup scaling (using local variable to avoid borrow issues)
418        let warmup_scale =
419            if self.config.warmup_steps > 0 && current_step < self.config.warmup_steps {
420                (current_step as f64 + 1.0) / (self.config.warmup_steps as f64)
421            } else {
422                1.0
423            };
424        let effective_distance = self.state.global_distance * warmup_scale;
425
426        // Bias correction (using local variables to avoid borrow issues)
427        let (beta1_correction, beta2_correction) =
428            if self.config.bias_correction && current_step > 0 {
429                let beta1_correction = 1.0 - self.config.beta1.powi(current_step as i32);
430                let beta2_correction = 1.0 - self.config.beta2.powi(current_step as i32);
431                (beta1_correction, beta2_correction)
432            } else {
433                (1.0, 1.0)
434            };
435
436        // Update momentum and variance
437        for i in 0..param_size {
438            let grad_val = grad_data[i] as f64;
439
440            // Apply weight decay to gradient if specified
441            let grad_with_decay = if self.config.weight_decay > 0.0 {
442                grad_val + self.config.weight_decay * (param_data[i] as f64)
443            } else {
444                grad_val
445            };
446
447            // Update biased first moment estimate
448            param_state.momentum[i] = (self.config.beta1 * param_state.momentum[i] as f64
449                + (1.0 - self.config.beta1) * grad_with_decay)
450                as f32;
451
452            // Update biased second moment estimate
453            param_state.variance[i] = (self.config.beta2 * param_state.variance[i] as f64
454                + (1.0 - self.config.beta2) * grad_with_decay.powi(2))
455                as f32;
456
457            // Bias-corrected moments
458            let m_hat = param_state.momentum[i] as f64 / beta1_correction;
459            let v_hat = param_state.variance[i] as f64 / beta2_correction;
460
461            // Compute parameter update using adaptive distance
462            let denominator = v_hat.sqrt() + self.config.eps;
463            let update = effective_distance * m_hat / denominator;
464
465            // Apply update
466            param_data[i] = (param_data[i] as f64 - update) as f32;
467        }
468
469        // Update parameter tensor with new data
470        *param = Tensor::new(param_data)?;
471
472        Ok(())
473    }
474}
475
476impl Default for Prodigy {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482impl Optimizer for Prodigy {
483    fn step(&mut self) {
484        self.state.global_step += 1;
485    }
486
487    fn zero_grad(&mut self) {
488        // Prodigy doesn't need to explicitly zero gradients
489        // as it processes them immediately during update
490    }
491
492    fn update(&mut self, param: &mut Tensor, grad: &Tensor) -> Result<()> {
493        // Use a default parameter name for the core update
494        self.update_parameter("default", param, grad)
495    }
496
497    fn get_lr(&self) -> f32 {
498        self.state.global_distance as f32
499    }
500
501    fn set_lr(&mut self, lr: f32) {
502        self.state.global_distance = (lr as f64).max(1e-10);
503    }
504}
505
506impl StatefulOptimizer for Prodigy {
507    type Config = ProdigyConfig;
508    type State = ProdigyOptimizerState;
509
510    fn config(&self) -> &Self::Config {
511        &self.config
512    }
513
514    fn state(&self) -> &Self::State {
515        &self.state
516    }
517
518    fn state_mut(&mut self) -> &mut Self::State {
519        &mut self.state
520    }
521
522    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
523        let mut state_dict = HashMap::new();
524
525        // Save configuration as tensors
526        state_dict.insert("lr".to_string(), Tensor::new(vec![self.config.d0 as f32])?);
527        state_dict.insert(
528            "beta1".to_string(),
529            Tensor::new(vec![self.config.beta1 as f32])?,
530        );
531        state_dict.insert(
532            "beta2".to_string(),
533            Tensor::new(vec![self.config.beta2 as f32])?,
534        );
535        state_dict.insert(
536            "eps".to_string(),
537            Tensor::new(vec![self.config.eps as f32])?,
538        );
539        state_dict.insert(
540            "weight_decay".to_string(),
541            Tensor::new(vec![self.config.weight_decay as f32])?,
542        );
543        state_dict.insert(
544            "growth_rate".to_string(),
545            Tensor::new(vec![self.config.growth_rate as f32])?,
546        );
547        state_dict.insert(
548            "warmup_steps".to_string(),
549            Tensor::new(vec![self.config.warmup_steps as f32])?,
550        );
551        state_dict.insert(
552            "global_step".to_string(),
553            Tensor::new(vec![self.state.global_step as f32])?,
554        );
555        state_dict.insert(
556            "global_distance".to_string(),
557            Tensor::new(vec![self.state.global_distance as f32])?,
558        );
559
560        // Save parameter states
561        for (param_name, param_state) in &self.state.parameters {
562            state_dict.insert(
563                format!("momentum_{}", param_name),
564                Tensor::new(param_state.momentum.clone())?,
565            );
566            state_dict.insert(
567                format!("variance_{}", param_name),
568                Tensor::new(param_state.variance.clone())?,
569            );
570            state_dict.insert(
571                format!("distance_{}", param_name),
572                Tensor::new(vec![param_state.distance as f32])?,
573            );
574            state_dict.insert(
575                format!("step_{}", param_name),
576                Tensor::new(vec![param_state.step as f32])?,
577            );
578        }
579
580        Ok(state_dict)
581    }
582
583    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
584        // Load configuration
585        if let Some(lr_tensor) = state_dict.get("lr") {
586            if let Ok(lr_vec) = lr_tensor.data() {
587                if !lr_vec.is_empty() {
588                    self.config.d0 = lr_vec[0] as f64;
589                }
590            }
591        }
592        if let Some(beta1_tensor) = state_dict.get("beta1") {
593            if let Ok(beta1_vec) = beta1_tensor.data() {
594                if !beta1_vec.is_empty() {
595                    self.config.beta1 = beta1_vec[0] as f64;
596                }
597            }
598        }
599        if let Some(beta2_tensor) = state_dict.get("beta2") {
600            if let Ok(beta2_vec) = beta2_tensor.data() {
601                if !beta2_vec.is_empty() {
602                    self.config.beta2 = beta2_vec[0] as f64;
603                }
604            }
605        }
606
607        // Load global state
608        if let Some(global_step_tensor) = state_dict.get("global_step") {
609            if let Ok(global_step_vec) = global_step_tensor.data() {
610                if !global_step_vec.is_empty() {
611                    self.state.global_step = global_step_vec[0] as usize;
612                }
613            }
614        }
615        if let Some(global_distance_tensor) = state_dict.get("global_distance") {
616            if let Ok(global_distance_vec) = global_distance_tensor.data() {
617                if !global_distance_vec.is_empty() {
618                    self.state.global_distance = global_distance_vec[0] as f64;
619                }
620            }
621        }
622
623        // Load parameter states (simplified for now)
624        // In a full implementation, we'd reconstruct all parameter states
625        // This would require iterating through the state_dict to find matching patterns
626
627        Ok(())
628    }
629
630    fn memory_usage(&self) -> crate::common::StateMemoryStats {
631        let total_momentum_elements: usize =
632            self.state.parameters.values().map(|p| p.momentum.len()).sum();
633        let total_variance_elements: usize =
634            self.state.parameters.values().map(|p| p.variance.len()).sum();
635
636        let momentum_bytes = total_momentum_elements * std::mem::size_of::<f32>();
637        let variance_bytes = total_variance_elements * std::mem::size_of::<f32>();
638        let metadata_bytes = self.state.parameters.len()
639            * (std::mem::size_of::<f64>() + std::mem::size_of::<usize>());
640
641        crate::common::StateMemoryStats {
642            momentum_elements: total_momentum_elements,
643            variance_elements: total_variance_elements,
644            third_moment_elements: 0,
645            total_bytes: momentum_bytes + variance_bytes + metadata_bytes,
646            num_parameters: self.state.parameters.len(),
647        }
648    }
649
650    fn reset_state(&mut self) {
651        self.reset();
652    }
653
654    fn num_parameters(&self) -> usize {
655        self.state.parameters.len()
656    }
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    #[test]
664    fn test_prodigy_creation() {
665        let optimizer = Prodigy::new();
666        assert_eq!(optimizer.config.d0, 1e-6);
667        assert_eq!(optimizer.config.beta1, 0.9);
668        assert_eq!(optimizer.config.beta2, 0.999);
669    }
670
671    #[test]
672    fn test_prodigy_with_config() {
673        let config = ProdigyConfig {
674            d0: 1e-5,
675            beta1: 0.95,
676            beta2: 0.99,
677            weight_decay: 0.1,
678            ..Default::default()
679        };
680        let optimizer = Prodigy::with_config(config.clone());
681        assert_eq!(optimizer.config.d0, config.d0);
682        assert_eq!(optimizer.config.beta1, config.beta1);
683        assert_eq!(optimizer.config.weight_decay, config.weight_decay);
684    }
685
686    #[test]
687    fn test_prodigy_presets() {
688        let lm_optimizer = Prodigy::for_language_models();
689        assert_eq!(lm_optimizer.config.warmup_steps, 1000);
690        assert_eq!(lm_optimizer.config.weight_decay, 0.1);
691
692        let vision_optimizer = Prodigy::for_vision();
693        assert_eq!(vision_optimizer.config.warmup_steps, 100);
694        assert_eq!(vision_optimizer.config.weight_decay, 0.05);
695
696        let fast_optimizer = Prodigy::for_fast_training();
697        assert_eq!(fast_optimizer.config.growth_rate, 1.05);
698        assert!(!fast_optimizer.config.bias_correction);
699
700        let stable_optimizer = Prodigy::for_stable_training();
701        assert_eq!(stable_optimizer.config.warmup_steps, 2000);
702        assert_eq!(stable_optimizer.config.safeguard_bound, 1.2);
703    }
704
705    #[test]
706    fn test_lr_getter_setter() {
707        let mut optimizer = Prodigy::new();
708        let initial_lr = optimizer.get_lr();
709        assert_eq!(initial_lr, 1e-6);
710
711        optimizer.set_lr(0.001);
712        assert_eq!(optimizer.get_lr(), 0.001);
713
714        // Test minimum bound
715        optimizer.set_lr(-1.0);
716        assert!(optimizer.get_lr() >= 1e-10);
717    }
718
719    #[test]
720    fn test_parameter_state_creation() {
721        let param_state = ProdigyParameterState::new(100, 1e-6);
722        assert_eq!(param_state.momentum.len(), 100);
723        assert_eq!(param_state.variance.len(), 100);
724        assert_eq!(param_state.distance, 1e-6);
725        assert_eq!(param_state.step, 0);
726        assert!(param_state.momentum.iter().all(|&x| x == 0.0));
727        assert!(param_state.variance.iter().all(|&x| x == 0.0));
728    }
729
730    #[test]
731    fn test_memory_usage_tracking() {
732        let param_state = ProdigyParameterState::new(1000, 1e-6);
733        let memory_stats = param_state.memory_usage();
734
735        assert_eq!(memory_stats.momentum_bytes, 1000 * 4); // f32 = 4 bytes
736        assert_eq!(memory_stats.variance_bytes, 1000 * 4);
737        assert!(memory_stats.metadata_bytes > 0);
738        assert_eq!(
739            memory_stats.total_bytes,
740            memory_stats.momentum_bytes + memory_stats.variance_bytes + memory_stats.metadata_bytes
741        );
742    }
743
744    #[test]
745    fn test_optimizer_state_operations() {
746        let mut state = ProdigyOptimizerState::default();
747        state
748            .parameters
749            .insert("param1".to_string(), ProdigyParameterState::new(100, 1e-6));
750        state
751            .parameters
752            .insert("param2".to_string(), ProdigyParameterState::new(200, 1e-6));
753        state.global_step = 10;
754
755        let memory_stats = state.total_memory_usage();
756        assert!(memory_stats.total_bytes > 0);
757        assert_eq!(memory_stats.momentum_bytes, (100 + 200) * 4);
758
759        state.clear();
760        assert_eq!(state.parameters.len(), 0);
761        assert_eq!(state.global_step, 0);
762        assert_eq!(state.global_distance, 1e-6);
763    }
764
765    #[test]
766    fn test_reset() {
767        let mut optimizer = Prodigy::new();
768        optimizer.state.global_step = 100;
769        optimizer
770            .state
771            .parameters
772            .insert("test".to_string(), ProdigyParameterState::new(10, 1e-6));
773
774        optimizer.reset();
775        assert_eq!(optimizer.state.global_step, 0);
776        assert_eq!(optimizer.state.parameters.len(), 0);
777        assert_eq!(optimizer.state.global_distance, optimizer.config.d0);
778    }
779
780    #[test]
781    fn test_config_serialization() {
782        let config = ProdigyConfig::for_language_models();
783        let serialized = serde_json::to_string(&config).unwrap();
784        let deserialized: ProdigyConfig = serde_json::from_str(&serialized).unwrap();
785
786        assert_eq!(config.d0, deserialized.d0);
787        assert_eq!(config.beta1, deserialized.beta1);
788        assert_eq!(config.warmup_steps, deserialized.warmup_steps);
789    }
790
791    #[test]
792    fn test_state_dict_operations() {
793        let mut optimizer = Prodigy::for_vision();
794        optimizer.state.global_step = 50;
795        optimizer.state.parameters.insert(
796            "test_param".to_string(),
797            ProdigyParameterState::new(5, 1e-5),
798        );
799
800        // Save state dict
801        let state_dict = optimizer.state_dict().unwrap();
802        assert!(state_dict.contains_key("lr"));
803        assert!(state_dict.contains_key("global_step"));
804
805        // Create new optimizer and load state
806        let mut new_optimizer = Prodigy::new();
807        new_optimizer.load_state_dict(state_dict).unwrap();
808
809        assert_eq!(new_optimizer.state.global_step, 50);
810        // Note: parameter states are not fully implemented in load_state_dict yet
811        // This test validates that basic config and global state are loaded correctly
812    }
813
814    #[test]
815    fn test_step_and_zero_grad() {
816        let mut optimizer = Prodigy::new();
817        assert_eq!(optimizer.state.global_step, 0);
818
819        optimizer.step();
820        assert_eq!(optimizer.state.global_step, 1);
821
822        optimizer.zero_grad(); // Should not error
823    }
824
825    #[test]
826    fn test_stateful_optimizer_trait() {
827        let optimizer = Prodigy::for_fast_training();
828
829        // Test config access
830        let config = optimizer.config();
831        assert_eq!(config.growth_rate, 1.05);
832
833        // Test state access
834        let state = optimizer.state();
835        assert_eq!(state.global_step, 0);
836    }
837
838    #[test]
839    fn test_distance_estimation_bounds() {
840        let mut optimizer = Prodigy::with_config(ProdigyConfig {
841            safeguard_bound: 2.0,
842            ..Default::default()
843        });
844
845        // Test that distance estimation respects safeguard bounds
846        optimizer.update_distance_estimate(1.0, 10.0); // Would give 10.0 without bound
847        assert!(optimizer.get_lr() <= 2.0);
848    }
849
850    #[test]
851    fn test_bias_correction() {
852        let optimizer = Prodigy::new();
853
854        // With bias correction enabled
855        let (bc1, bc2) = optimizer.bias_correction(1);
856        assert!(bc1 > 0.0 && bc1 < 1.0);
857        assert!(bc2 > 0.0 && bc2 < 1.0);
858
859        // After many steps, bias correction should be positive and less than 1.0
860        let (bc1_late, bc2_late) = optimizer.bias_correction(1000);
861        assert!(bc1_late > 0.9);
862        assert!(bc2_late > 0.6); // 1.0 - 0.999^1000 ≈ 0.63
863    }
864
865    #[test]
866    fn test_warmup_scaling() {
867        let optimizer = Prodigy::with_config(ProdigyConfig {
868            warmup_steps: 100,
869            ..Default::default()
870        });
871
872        // During warmup
873        let scale_early = optimizer.warmup_scaling(10);
874        assert!(scale_early < 1.0);
875        assert_eq!(scale_early, 11.0 / 100.0);
876
877        // After warmup
878        let scale_late = optimizer.warmup_scaling(200);
879        assert_eq!(scale_late, 1.0);
880    }
881}