Skip to main content

tensorlogic_train/optimizers/
prodigy.rs

1//! Prodigy optimizer - Auto-tuning learning rate optimizer (2024)
2//!
3//! Reference: Mishchenko, K., & Defazio, A. (2024).
4//! "Prodigyopt: An Adaptive Learning Rate Method That Requires No Manual Tuning"
5//! <https://arxiv.org/abs/2306.06101>
6//!
7//! Key innovation: Automatically estimates the learning rate scale (D) without manual tuning.
8//! Uses distance from initialization to estimate appropriate step size.
9//!
10//! Benefits:
11//! - No manual LR tuning required
12//! - Works across different problem scales
13//! - Adaptive to problem difficulty
14//! - Combines benefits of Adam and D-Adaptation
15//!
16//! Usage:
17//! ```rust
18//! use tensorlogic_train::{ProdigyConfig, ProdigyOptimizer};
19//!
20//! // Create Prodigy optimizer with defaults
21//! let config = ProdigyConfig::default()
22//!     .with_d0(1e-6)      // Initial D estimate (small value)
23//!     .with_d_coef(1.0)   // Coefficient for D adaptation
24//!     .with_beta1(0.9)    // First moment decay
25//!     .with_beta2(0.999); // Second moment decay
26//!
27//! let mut optimizer = ProdigyOptimizer::new(config);
28//!
29//! // Prodigy automatically adapts the learning rate!
30//! // No need to manually tune LR or use schedulers
31//! ```
32
33use crate::error::TrainResult;
34use crate::optimizer::{GradClipMode, Optimizer};
35use scirs2_core::ndarray::Array2;
36use scirs2_core::random::StdRng;
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39
40/// Configuration for Prodigy optimizer
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProdigyConfig {
43    /// Initial D estimate (default: 1e-6)
44    /// D represents the distance scale from initialization
45    pub d0: f64,
46
47    /// Coefficient for D adaptation (default: 1.0)
48    /// Controls how aggressively D is adapted
49    pub d_coef: f64,
50
51    /// Learning rate (default: 1.0)
52    /// Note: Prodigy is relatively insensitive to this value
53    pub lr: f64,
54
55    /// First moment decay rate (default: 0.9)
56    pub beta1: f64,
57
58    /// Second moment decay rate (default: 0.999)
59    pub beta2: f64,
60
61    /// Small constant for numerical stability (default: 1e-8)
62    pub eps: f64,
63
64    /// Weight decay coefficient (default: 0.0)
65    /// Applied as decoupled weight decay (like AdamW)
66    pub weight_decay: f64,
67
68    /// Gradient clipping threshold (optional)
69    pub grad_clip: Option<f64>,
70
71    /// Gradient clipping mode
72    pub grad_clip_mode: GradClipMode,
73
74    /// Whether to use bias correction (default: true)
75    pub bias_correction: bool,
76
77    /// Growth rate for D (default: infinity, meaning no limit)
78    pub d_growth_rate: f64,
79}
80
81impl Default for ProdigyConfig {
82    fn default() -> Self {
83        Self {
84            d0: 1e-6,
85            d_coef: 1.0,
86            lr: 1.0,
87            beta1: 0.9,
88            beta2: 0.999,
89            eps: 1e-8,
90            weight_decay: 0.0,
91            grad_clip: None,
92            grad_clip_mode: GradClipMode::Norm,
93            bias_correction: true,
94            d_growth_rate: f64::INFINITY,
95        }
96    }
97}
98
99impl ProdigyConfig {
100    /// Create new Prodigy config with default values
101    pub fn new() -> Self {
102        Self::default()
103    }
104
105    /// Set initial D estimate
106    pub fn with_d0(mut self, d0: f64) -> Self {
107        self.d0 = d0;
108        self
109    }
110
111    /// Set D coefficient
112    pub fn with_d_coef(mut self, d_coef: f64) -> Self {
113        self.d_coef = d_coef;
114        self
115    }
116
117    /// Set learning rate
118    pub fn with_lr(mut self, lr: f64) -> Self {
119        self.lr = lr;
120        self
121    }
122
123    /// Set beta1 (first moment decay)
124    pub fn with_beta1(mut self, beta1: f64) -> Self {
125        self.beta1 = beta1;
126        self
127    }
128
129    /// Set beta2 (second moment decay)
130    pub fn with_beta2(mut self, beta2: f64) -> Self {
131        self.beta2 = beta2;
132        self
133    }
134
135    /// Set epsilon for numerical stability
136    pub fn with_eps(mut self, eps: f64) -> Self {
137        self.eps = eps;
138        self
139    }
140
141    /// Set weight decay
142    pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
143        self.weight_decay = weight_decay;
144        self
145    }
146
147    /// Set gradient clipping
148    pub fn with_grad_clip(mut self, grad_clip: f64) -> Self {
149        self.grad_clip = Some(grad_clip);
150        self
151    }
152
153    /// Set gradient clipping mode
154    pub fn with_grad_clip_mode(mut self, mode: GradClipMode) -> Self {
155        self.grad_clip_mode = mode;
156        self
157    }
158
159    /// Enable or disable bias correction
160    pub fn with_bias_correction(mut self, bias_correction: bool) -> Self {
161        self.bias_correction = bias_correction;
162        self
163    }
164
165    /// Set D growth rate limit
166    pub fn with_d_growth_rate(mut self, rate: f64) -> Self {
167        self.d_growth_rate = rate;
168        self
169    }
170}
171
172/// Prodigy optimizer
173///
174/// Auto-tuning adaptive learning rate optimizer that estimates the distance scale D
175/// from initialization to automatically set appropriate step sizes.
176///
177/// Key features:
178/// - No manual LR tuning needed
179/// - Adapts to problem scale automatically
180/// - Combines Adam-style updates with D-Adaptation
181/// - Maintains first and second moment estimates
182pub struct ProdigyOptimizer {
183    config: ProdigyConfig,
184    /// First moment estimates (momentum)
185    first_moments: HashMap<String, Array2<f64>>,
186    /// Second moment estimates (variance)
187    second_moments: HashMap<String, Array2<f64>>,
188    /// Initial parameters (for distance computation)
189    initial_params: HashMap<String, Array2<f64>>,
190    /// Current step count
191    step: usize,
192    /// Estimated distance scale D
193    d: f64,
194    /// Sum of gradient norms (for D estimation)
195    sum_grad_norm: f64,
196}
197
198impl ProdigyOptimizer {
199    /// Create new Prodigy optimizer with given config
200    pub fn new(config: ProdigyConfig) -> Self {
201        Self {
202            config,
203            first_moments: HashMap::new(),
204            second_moments: HashMap::new(),
205            initial_params: HashMap::new(),
206            step: 0,
207            d: 0.0, // Will be initialized to d0 on first step
208            sum_grad_norm: 0.0,
209        }
210    }
211
212    /// Get current D estimate
213    pub fn get_d(&self) -> f64 {
214        self.d
215    }
216
217    /// Get current step count
218    pub fn get_step(&self) -> usize {
219        self.step
220    }
221
222    /// Compute parameter distance from initialization
223    fn compute_distance(&self, parameters: &HashMap<String, Array2<f64>>) -> f64 {
224        let mut distance_sq = 0.0;
225
226        for (name, param) in parameters {
227            if let Some(init_param) = self.initial_params.get(name) {
228                let diff = param - init_param;
229                distance_sq += diff.mapv(|x| x * x).sum();
230            }
231        }
232
233        distance_sq.sqrt()
234    }
235
236    /// Update D estimate based on gradients and parameters
237    fn update_d(&mut self, parameters: &HashMap<String, Array2<f64>>, grad_norm: f64) {
238        // Initialize D on first step
239        if self.step == 1 {
240            self.d = self.config.d0;
241            return;
242        }
243
244        // Accumulate gradient norms
245        self.sum_grad_norm += grad_norm;
246
247        // Compute parameter distance from initialization
248        let param_distance = self.compute_distance(parameters);
249
250        // Estimate D based on ratio of distance to accumulated gradient norm
251        if self.sum_grad_norm > 0.0 {
252            let d_estimate = self.config.d_coef * param_distance / self.sum_grad_norm;
253
254            // Apply growth rate limit if specified
255            if self.config.d_growth_rate.is_finite() {
256                let max_d = self.d * (1.0 + self.config.d_growth_rate);
257                self.d = d_estimate.min(max_d).max(self.config.d0);
258            } else {
259                self.d = d_estimate.max(self.config.d0);
260            }
261        }
262    }
263
264    /// Compute total gradient norm
265    fn compute_gradient_norm(&self, gradients: &HashMap<String, Array2<f64>>) -> f64 {
266        let mut norm_sq = 0.0;
267        for grad in gradients.values() {
268            norm_sq += grad.mapv(|x| x * x).sum();
269        }
270        norm_sq.sqrt()
271    }
272
273    /// Apply gradient clipping
274    fn clip_gradients(
275        &self,
276        gradients: &mut HashMap<String, Array2<f64>>,
277        _rng: Option<&mut StdRng>,
278    ) -> TrainResult<()> {
279        if let Some(max_val) = self.config.grad_clip {
280            match self.config.grad_clip_mode {
281                GradClipMode::Value => {
282                    // Clip by value
283                    for grad in gradients.values_mut() {
284                        grad.mapv_inplace(|x| x.max(-max_val).min(max_val));
285                    }
286                }
287                GradClipMode::Norm => {
288                    // Clip by norm
289                    let total_norm = self.compute_gradient_norm(gradients);
290                    if total_norm > max_val {
291                        let scale = max_val / (total_norm + self.config.eps);
292                        for grad in gradients.values_mut() {
293                            grad.mapv_inplace(|x| x * scale);
294                        }
295                    }
296                }
297            }
298        }
299        Ok(())
300    }
301}
302
303impl Optimizer for ProdigyOptimizer {
304    fn zero_grad(&mut self) {
305        // Prodigy doesn't maintain gradients separately, so this is a no-op
306    }
307
308    fn get_lr(&self) -> f64 {
309        self.config.lr
310    }
311
312    fn set_lr(&mut self, lr: f64) {
313        self.config.lr = lr;
314    }
315
316    fn step(
317        &mut self,
318        parameters: &mut HashMap<String, Array2<f64>>,
319        gradients: &HashMap<String, Array2<f64>>,
320    ) -> TrainResult<()> {
321        // Increment step counter
322        self.step += 1;
323
324        // Save initial parameters on first step
325        if self.step == 1 {
326            for (name, param) in parameters.iter() {
327                self.initial_params.insert(name.clone(), param.clone());
328            }
329        }
330
331        // Clone gradients if clipping is needed
332        let gradients = if self.config.grad_clip.is_some() {
333            let mut clipped = HashMap::new();
334            for (name, grad) in gradients.iter() {
335                clipped.insert(name.clone(), grad.clone());
336            }
337            self.clip_gradients(&mut clipped, None)?;
338            clipped
339        } else {
340            gradients.clone()
341        };
342
343        // Compute gradient norm for D estimation
344        let grad_norm = self.compute_gradient_norm(&gradients);
345
346        // Update D estimate
347        self.update_d(parameters, grad_norm);
348
349        // Compute effective learning rate
350        let effective_lr = self.config.lr * self.d;
351
352        // Bias correction factors (if enabled)
353        let bias_correction1 = if self.config.bias_correction {
354            1.0 - self.config.beta1.powi(self.step as i32)
355        } else {
356            1.0
357        };
358        let bias_correction2 = if self.config.bias_correction {
359            1.0 - self.config.beta2.powi(self.step as i32)
360        } else {
361            1.0
362        };
363
364        // Update parameters
365        for (name, param) in parameters.iter_mut() {
366            let grad = match gradients.get(name) {
367                Some(g) => g,
368                None => continue,
369            };
370
371            // Initialize moments if needed
372            let m = self
373                .first_moments
374                .entry(name.clone())
375                .or_insert_with(|| Array2::zeros(grad.raw_dim()));
376            let v = self
377                .second_moments
378                .entry(name.clone())
379                .or_insert_with(|| Array2::zeros(grad.raw_dim()));
380
381            // Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * g_t
382            *m = &*m * self.config.beta1 + grad * (1.0 - self.config.beta1);
383
384            // Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * g_t^2
385            let grad_sq = grad.mapv(|x| x * x);
386            *v = &*v * self.config.beta2 + &grad_sq * (1.0 - self.config.beta2);
387
388            // Compute bias-corrected moments
389            let m_hat = m.mapv(|x| x / bias_correction1);
390            let v_hat = v.mapv(|x| x / bias_correction2);
391
392            // Compute update: delta = lr * D * m_hat / (sqrt(v_hat) + eps)
393            let update = &m_hat / &v_hat.mapv(|x| x.sqrt() + self.config.eps);
394
395            // Apply weight decay (decoupled, like AdamW)
396            if self.config.weight_decay > 0.0 {
397                param.mapv_inplace(|x| x * (1.0 - effective_lr * self.config.weight_decay));
398            }
399
400            // Update parameters: theta_t = theta_{t-1} - lr * D * update
401            *param = &*param - &update * effective_lr;
402        }
403
404        Ok(())
405    }
406
407    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
408        let mut state = HashMap::new();
409
410        // Save scalar values
411        state.insert("step".to_string(), vec![self.step as f64]);
412        state.insert("d".to_string(), vec![self.d]);
413        state.insert("sum_grad_norm".to_string(), vec![self.sum_grad_norm]);
414
415        // Save config values
416        state.insert("d0".to_string(), vec![self.config.d0]);
417        state.insert("d_coef".to_string(), vec![self.config.d_coef]);
418        state.insert("lr".to_string(), vec![self.config.lr]);
419        state.insert("beta1".to_string(), vec![self.config.beta1]);
420        state.insert("beta2".to_string(), vec![self.config.beta2]);
421        state.insert("eps".to_string(), vec![self.config.eps]);
422        state.insert("weight_decay".to_string(), vec![self.config.weight_decay]);
423
424        state
425    }
426
427    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
428        // Load scalar values
429        if let Some(v) = state.get("step") {
430            if !v.is_empty() {
431                self.step = v[0] as usize;
432            }
433        }
434        if let Some(v) = state.get("d") {
435            if !v.is_empty() {
436                self.d = v[0];
437            }
438        }
439        if let Some(v) = state.get("sum_grad_norm") {
440            if !v.is_empty() {
441                self.sum_grad_norm = v[0];
442            }
443        }
444
445        // Load config values
446        if let Some(v) = state.get("d0") {
447            if !v.is_empty() {
448                self.config.d0 = v[0];
449            }
450        }
451        if let Some(v) = state.get("d_coef") {
452            if !v.is_empty() {
453                self.config.d_coef = v[0];
454            }
455        }
456        if let Some(v) = state.get("lr") {
457            if !v.is_empty() {
458                self.config.lr = v[0];
459            }
460        }
461        if let Some(v) = state.get("beta1") {
462            if !v.is_empty() {
463                self.config.beta1 = v[0];
464            }
465        }
466        if let Some(v) = state.get("beta2") {
467            if !v.is_empty() {
468                self.config.beta2 = v[0];
469            }
470        }
471        if let Some(v) = state.get("eps") {
472            if !v.is_empty() {
473                self.config.eps = v[0];
474            }
475        }
476        if let Some(v) = state.get("weight_decay") {
477            if !v.is_empty() {
478                self.config.weight_decay = v[0];
479            }
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_prodigy_config_default() {
490        let config = ProdigyConfig::default();
491        assert_eq!(config.d0, 1e-6);
492        assert_eq!(config.d_coef, 1.0);
493        assert_eq!(config.lr, 1.0);
494        assert_eq!(config.beta1, 0.9);
495        assert_eq!(config.beta2, 0.999);
496        assert_eq!(config.eps, 1e-8);
497        assert_eq!(config.weight_decay, 0.0);
498    }
499
500    #[test]
501    fn test_prodigy_config_builder() {
502        let config = ProdigyConfig::default()
503            .with_d0(1e-5)
504            .with_d_coef(2.0)
505            .with_lr(0.5)
506            .with_beta1(0.95)
507            .with_beta2(0.9999)
508            .with_eps(1e-7)
509            .with_weight_decay(0.01)
510            .with_grad_clip(1.0)
511            .with_bias_correction(false)
512            .with_d_growth_rate(0.1);
513
514        assert_eq!(config.d0, 1e-5);
515        assert_eq!(config.d_coef, 2.0);
516        assert_eq!(config.lr, 0.5);
517        assert_eq!(config.beta1, 0.95);
518        assert_eq!(config.beta2, 0.9999);
519        assert_eq!(config.eps, 1e-7);
520        assert_eq!(config.weight_decay, 0.01);
521        assert_eq!(config.grad_clip, Some(1.0));
522        assert!(!config.bias_correction);
523        assert_eq!(config.d_growth_rate, 0.1);
524    }
525
526    #[test]
527    fn test_prodigy_initialization() {
528        let config = ProdigyConfig::default();
529        let optimizer = ProdigyOptimizer::new(config);
530
531        assert_eq!(optimizer.get_step(), 0);
532        assert_eq!(optimizer.get_d(), 0.0);
533    }
534
535    #[test]
536    fn test_prodigy_first_step() {
537        let config = ProdigyConfig::default().with_d0(1e-6);
538        let mut optimizer = ProdigyOptimizer::new(config);
539
540        let mut params = HashMap::new();
541        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
542
543        let mut grads = HashMap::new();
544        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
545
546        optimizer.step(&mut params, &grads).unwrap();
547
548        assert_eq!(optimizer.get_step(), 1);
549        assert_eq!(optimizer.get_d(), 1e-6); // D initialized to d0
550    }
551
552    #[test]
553    fn test_prodigy_d_adaptation() {
554        let config = ProdigyConfig::default().with_d0(1e-6).with_d_coef(1.0);
555        let mut optimizer = ProdigyOptimizer::new(config);
556
557        let mut params = HashMap::new();
558        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
559
560        // First step
561        let mut grads = HashMap::new();
562        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
563        optimizer.step(&mut params, &grads).unwrap();
564
565        let d_after_step1 = optimizer.get_d();
566        assert_eq!(d_after_step1, 1e-6);
567
568        // Second step - D should adapt
569        optimizer.step(&mut params, &grads).unwrap();
570
571        let d_after_step2 = optimizer.get_d();
572        assert!(d_after_step2 >= 1e-6); // D should increase or stay same
573    }
574
575    #[test]
576    fn test_prodigy_parameter_update() {
577        let config = ProdigyConfig::default();
578        let mut optimizer = ProdigyOptimizer::new(config);
579
580        let mut params = HashMap::new();
581        let initial_value = 1.0;
582        params.insert("w".to_string(), Array2::from_elem((2, 2), initial_value));
583
584        let mut grads = HashMap::new();
585        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.5));
586
587        optimizer.step(&mut params, &grads).unwrap();
588
589        // Parameters should be updated (decreased since gradient is positive)
590        let w = params.get("w").unwrap();
591        assert!(w[[0, 0]] < initial_value);
592    }
593
594    #[test]
595    fn test_prodigy_weight_decay() {
596        let config = ProdigyConfig::default().with_weight_decay(0.01);
597        let mut optimizer = ProdigyOptimizer::new(config);
598
599        let mut params = HashMap::new();
600        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
601
602        let mut grads = HashMap::new();
603        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.0));
604
605        // With weight decay and zero gradient, parameters should decay
606        let initial_sum: f64 = params.get("w").unwrap().sum();
607        optimizer.step(&mut params, &grads).unwrap();
608        let final_sum: f64 = params.get("w").unwrap().sum();
609
610        assert!(final_sum < initial_sum);
611    }
612
613    #[test]
614    fn test_prodigy_gradient_clipping_by_norm() {
615        let config = ProdigyConfig::default().with_grad_clip(0.1);
616        let mut optimizer = ProdigyOptimizer::new(config);
617
618        let mut params = HashMap::new();
619        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
620
621        let mut grads = HashMap::new();
622        grads.insert("w".to_string(), Array2::from_elem((2, 2), 10.0)); // Large gradient
623
624        // Should not panic, gradients should be clipped
625        optimizer.step(&mut params, &grads).unwrap();
626
627        // Parameters should still be updated, but with clipped gradients
628        let w = params.get("w").unwrap();
629        assert!(w[[0, 0]] < 1.0);
630    }
631
632    #[test]
633    fn test_prodigy_state_dict() {
634        let config = ProdigyConfig::default();
635        let mut optimizer = ProdigyOptimizer::new(config);
636
637        let mut params = HashMap::new();
638        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
639
640        let mut grads = HashMap::new();
641        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
642
643        // Take a few steps
644        for _ in 0..3 {
645            optimizer.step(&mut params, &grads).unwrap();
646        }
647
648        let state = optimizer.state_dict();
649        assert!(state.contains_key("step"));
650        assert!(state.contains_key("d"));
651        assert!(state.contains_key("sum_grad_norm"));
652    }
653
654    #[test]
655    fn test_prodigy_load_state_dict() {
656        let config = ProdigyConfig::default();
657        let mut optimizer1 = ProdigyOptimizer::new(config.clone());
658
659        let mut params = HashMap::new();
660        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
661
662        let mut grads = HashMap::new();
663        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
664
665        // Take a few steps with optimizer1
666        for _ in 0..3 {
667            optimizer1.step(&mut params, &grads).unwrap();
668        }
669
670        let state = optimizer1.state_dict();
671
672        // Create new optimizer and load state
673        let mut optimizer2 = ProdigyOptimizer::new(config);
674        optimizer2.load_state_dict(state);
675
676        assert_eq!(optimizer1.get_step(), optimizer2.get_step());
677        assert_eq!(optimizer1.get_d(), optimizer2.get_d());
678    }
679
680    #[test]
681    fn test_prodigy_bias_correction() {
682        let config_with_bc = ProdigyConfig::default().with_bias_correction(true);
683        let config_without_bc = ProdigyConfig::default().with_bias_correction(false);
684
685        let mut opt_with_bc = ProdigyOptimizer::new(config_with_bc);
686        let mut opt_without_bc = ProdigyOptimizer::new(config_without_bc);
687
688        let mut params1 = HashMap::new();
689        params1.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
690
691        let mut params2 = params1.clone();
692
693        let mut grads = HashMap::new();
694        grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
695
696        opt_with_bc.step(&mut params1, &grads).unwrap();
697        opt_without_bc.step(&mut params2, &grads).unwrap();
698
699        // Results should be different due to bias correction
700        let w1 = params1.get("w").unwrap();
701        let w2 = params2.get("w").unwrap();
702
703        // They should be different (bias correction affects the updates)
704        let diff = (w1[[0, 0]] - w2[[0, 0]]).abs();
705        assert!(diff > 1e-10);
706    }
707
708    #[test]
709    fn test_prodigy_d_growth_rate_limit() {
710        let config = ProdigyConfig::default()
711            .with_d0(1e-6)
712            .with_d_growth_rate(0.1); // 10% max growth per step
713
714        let mut optimizer = ProdigyOptimizer::new(config);
715
716        let mut params = HashMap::new();
717        params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
718
719        let mut grads = HashMap::new();
720        grads.insert("w".to_string(), Array2::from_elem((2, 2), 1.0)); // Large gradient
721
722        // First step
723        optimizer.step(&mut params, &grads).unwrap();
724        let d1 = optimizer.get_d();
725
726        // Second step
727        optimizer.step(&mut params, &grads).unwrap();
728        let d2 = optimizer.get_d();
729
730        // D should not grow more than 10% per step
731        if d2 > d1 {
732            let growth_ratio = d2 / d1;
733            assert!(growth_ratio <= 1.11); // Allow small numerical error
734        }
735    }
736}