1use crate::error::TrainResult;
34use crate::optimizer::{GradClipMode, Optimizer};
35use scirs2_core::ndarray::Array2;
36use scirs2_core::random::StdRng;
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct ProdigyConfig {
43 pub d0: f64,
46
47 pub d_coef: f64,
50
51 pub lr: f64,
54
55 pub beta1: f64,
57
58 pub beta2: f64,
60
61 pub eps: f64,
63
64 pub weight_decay: f64,
67
68 pub grad_clip: Option<f64>,
70
71 pub grad_clip_mode: GradClipMode,
73
74 pub bias_correction: bool,
76
77 pub d_growth_rate: f64,
79}
80
81impl Default for ProdigyConfig {
82 fn default() -> Self {
83 Self {
84 d0: 1e-6,
85 d_coef: 1.0,
86 lr: 1.0,
87 beta1: 0.9,
88 beta2: 0.999,
89 eps: 1e-8,
90 weight_decay: 0.0,
91 grad_clip: None,
92 grad_clip_mode: GradClipMode::Norm,
93 bias_correction: true,
94 d_growth_rate: f64::INFINITY,
95 }
96 }
97}
98
99impl ProdigyConfig {
100 pub fn new() -> Self {
102 Self::default()
103 }
104
105 pub fn with_d0(mut self, d0: f64) -> Self {
107 self.d0 = d0;
108 self
109 }
110
111 pub fn with_d_coef(mut self, d_coef: f64) -> Self {
113 self.d_coef = d_coef;
114 self
115 }
116
117 pub fn with_lr(mut self, lr: f64) -> Self {
119 self.lr = lr;
120 self
121 }
122
123 pub fn with_beta1(mut self, beta1: f64) -> Self {
125 self.beta1 = beta1;
126 self
127 }
128
129 pub fn with_beta2(mut self, beta2: f64) -> Self {
131 self.beta2 = beta2;
132 self
133 }
134
135 pub fn with_eps(mut self, eps: f64) -> Self {
137 self.eps = eps;
138 self
139 }
140
141 pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
143 self.weight_decay = weight_decay;
144 self
145 }
146
147 pub fn with_grad_clip(mut self, grad_clip: f64) -> Self {
149 self.grad_clip = Some(grad_clip);
150 self
151 }
152
153 pub fn with_grad_clip_mode(mut self, mode: GradClipMode) -> Self {
155 self.grad_clip_mode = mode;
156 self
157 }
158
159 pub fn with_bias_correction(mut self, bias_correction: bool) -> Self {
161 self.bias_correction = bias_correction;
162 self
163 }
164
165 pub fn with_d_growth_rate(mut self, rate: f64) -> Self {
167 self.d_growth_rate = rate;
168 self
169 }
170}
171
172pub struct ProdigyOptimizer {
183 config: ProdigyConfig,
184 first_moments: HashMap<String, Array2<f64>>,
186 second_moments: HashMap<String, Array2<f64>>,
188 initial_params: HashMap<String, Array2<f64>>,
190 step: usize,
192 d: f64,
194 sum_grad_norm: f64,
196}
197
198impl ProdigyOptimizer {
199 pub fn new(config: ProdigyConfig) -> Self {
201 Self {
202 config,
203 first_moments: HashMap::new(),
204 second_moments: HashMap::new(),
205 initial_params: HashMap::new(),
206 step: 0,
207 d: 0.0, sum_grad_norm: 0.0,
209 }
210 }
211
212 pub fn get_d(&self) -> f64 {
214 self.d
215 }
216
217 pub fn get_step(&self) -> usize {
219 self.step
220 }
221
222 fn compute_distance(&self, parameters: &HashMap<String, Array2<f64>>) -> f64 {
224 let mut distance_sq = 0.0;
225
226 for (name, param) in parameters {
227 if let Some(init_param) = self.initial_params.get(name) {
228 let diff = param - init_param;
229 distance_sq += diff.mapv(|x| x * x).sum();
230 }
231 }
232
233 distance_sq.sqrt()
234 }
235
236 fn update_d(&mut self, parameters: &HashMap<String, Array2<f64>>, grad_norm: f64) {
238 if self.step == 1 {
240 self.d = self.config.d0;
241 return;
242 }
243
244 self.sum_grad_norm += grad_norm;
246
247 let param_distance = self.compute_distance(parameters);
249
250 if self.sum_grad_norm > 0.0 {
252 let d_estimate = self.config.d_coef * param_distance / self.sum_grad_norm;
253
254 if self.config.d_growth_rate.is_finite() {
256 let max_d = self.d * (1.0 + self.config.d_growth_rate);
257 self.d = d_estimate.min(max_d).max(self.config.d0);
258 } else {
259 self.d = d_estimate.max(self.config.d0);
260 }
261 }
262 }
263
264 fn compute_gradient_norm(&self, gradients: &HashMap<String, Array2<f64>>) -> f64 {
266 let mut norm_sq = 0.0;
267 for grad in gradients.values() {
268 norm_sq += grad.mapv(|x| x * x).sum();
269 }
270 norm_sq.sqrt()
271 }
272
273 fn clip_gradients(
275 &self,
276 gradients: &mut HashMap<String, Array2<f64>>,
277 _rng: Option<&mut StdRng>,
278 ) -> TrainResult<()> {
279 if let Some(max_val) = self.config.grad_clip {
280 match self.config.grad_clip_mode {
281 GradClipMode::Value => {
282 for grad in gradients.values_mut() {
284 grad.mapv_inplace(|x| x.max(-max_val).min(max_val));
285 }
286 }
287 GradClipMode::Norm => {
288 let total_norm = self.compute_gradient_norm(gradients);
290 if total_norm > max_val {
291 let scale = max_val / (total_norm + self.config.eps);
292 for grad in gradients.values_mut() {
293 grad.mapv_inplace(|x| x * scale);
294 }
295 }
296 }
297 }
298 }
299 Ok(())
300 }
301}
302
303impl Optimizer for ProdigyOptimizer {
304 fn zero_grad(&mut self) {
305 }
307
308 fn get_lr(&self) -> f64 {
309 self.config.lr
310 }
311
312 fn set_lr(&mut self, lr: f64) {
313 self.config.lr = lr;
314 }
315
316 fn step(
317 &mut self,
318 parameters: &mut HashMap<String, Array2<f64>>,
319 gradients: &HashMap<String, Array2<f64>>,
320 ) -> TrainResult<()> {
321 self.step += 1;
323
324 if self.step == 1 {
326 for (name, param) in parameters.iter() {
327 self.initial_params.insert(name.clone(), param.clone());
328 }
329 }
330
331 let gradients = if self.config.grad_clip.is_some() {
333 let mut clipped = HashMap::new();
334 for (name, grad) in gradients.iter() {
335 clipped.insert(name.clone(), grad.clone());
336 }
337 self.clip_gradients(&mut clipped, None)?;
338 clipped
339 } else {
340 gradients.clone()
341 };
342
343 let grad_norm = self.compute_gradient_norm(&gradients);
345
346 self.update_d(parameters, grad_norm);
348
349 let effective_lr = self.config.lr * self.d;
351
352 let bias_correction1 = if self.config.bias_correction {
354 1.0 - self.config.beta1.powi(self.step as i32)
355 } else {
356 1.0
357 };
358 let bias_correction2 = if self.config.bias_correction {
359 1.0 - self.config.beta2.powi(self.step as i32)
360 } else {
361 1.0
362 };
363
364 for (name, param) in parameters.iter_mut() {
366 let grad = match gradients.get(name) {
367 Some(g) => g,
368 None => continue,
369 };
370
371 let m = self
373 .first_moments
374 .entry(name.clone())
375 .or_insert_with(|| Array2::zeros(grad.raw_dim()));
376 let v = self
377 .second_moments
378 .entry(name.clone())
379 .or_insert_with(|| Array2::zeros(grad.raw_dim()));
380
381 *m = &*m * self.config.beta1 + grad * (1.0 - self.config.beta1);
383
384 let grad_sq = grad.mapv(|x| x * x);
386 *v = &*v * self.config.beta2 + &grad_sq * (1.0 - self.config.beta2);
387
388 let m_hat = m.mapv(|x| x / bias_correction1);
390 let v_hat = v.mapv(|x| x / bias_correction2);
391
392 let update = &m_hat / &v_hat.mapv(|x| x.sqrt() + self.config.eps);
394
395 if self.config.weight_decay > 0.0 {
397 param.mapv_inplace(|x| x * (1.0 - effective_lr * self.config.weight_decay));
398 }
399
400 *param = &*param - &update * effective_lr;
402 }
403
404 Ok(())
405 }
406
407 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
408 let mut state = HashMap::new();
409
410 state.insert("step".to_string(), vec![self.step as f64]);
412 state.insert("d".to_string(), vec![self.d]);
413 state.insert("sum_grad_norm".to_string(), vec![self.sum_grad_norm]);
414
415 state.insert("d0".to_string(), vec![self.config.d0]);
417 state.insert("d_coef".to_string(), vec![self.config.d_coef]);
418 state.insert("lr".to_string(), vec![self.config.lr]);
419 state.insert("beta1".to_string(), vec![self.config.beta1]);
420 state.insert("beta2".to_string(), vec![self.config.beta2]);
421 state.insert("eps".to_string(), vec![self.config.eps]);
422 state.insert("weight_decay".to_string(), vec![self.config.weight_decay]);
423
424 state
425 }
426
427 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
428 if let Some(v) = state.get("step") {
430 if !v.is_empty() {
431 self.step = v[0] as usize;
432 }
433 }
434 if let Some(v) = state.get("d") {
435 if !v.is_empty() {
436 self.d = v[0];
437 }
438 }
439 if let Some(v) = state.get("sum_grad_norm") {
440 if !v.is_empty() {
441 self.sum_grad_norm = v[0];
442 }
443 }
444
445 if let Some(v) = state.get("d0") {
447 if !v.is_empty() {
448 self.config.d0 = v[0];
449 }
450 }
451 if let Some(v) = state.get("d_coef") {
452 if !v.is_empty() {
453 self.config.d_coef = v[0];
454 }
455 }
456 if let Some(v) = state.get("lr") {
457 if !v.is_empty() {
458 self.config.lr = v[0];
459 }
460 }
461 if let Some(v) = state.get("beta1") {
462 if !v.is_empty() {
463 self.config.beta1 = v[0];
464 }
465 }
466 if let Some(v) = state.get("beta2") {
467 if !v.is_empty() {
468 self.config.beta2 = v[0];
469 }
470 }
471 if let Some(v) = state.get("eps") {
472 if !v.is_empty() {
473 self.config.eps = v[0];
474 }
475 }
476 if let Some(v) = state.get("weight_decay") {
477 if !v.is_empty() {
478 self.config.weight_decay = v[0];
479 }
480 }
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_prodigy_config_default() {
490 let config = ProdigyConfig::default();
491 assert_eq!(config.d0, 1e-6);
492 assert_eq!(config.d_coef, 1.0);
493 assert_eq!(config.lr, 1.0);
494 assert_eq!(config.beta1, 0.9);
495 assert_eq!(config.beta2, 0.999);
496 assert_eq!(config.eps, 1e-8);
497 assert_eq!(config.weight_decay, 0.0);
498 }
499
500 #[test]
501 fn test_prodigy_config_builder() {
502 let config = ProdigyConfig::default()
503 .with_d0(1e-5)
504 .with_d_coef(2.0)
505 .with_lr(0.5)
506 .with_beta1(0.95)
507 .with_beta2(0.9999)
508 .with_eps(1e-7)
509 .with_weight_decay(0.01)
510 .with_grad_clip(1.0)
511 .with_bias_correction(false)
512 .with_d_growth_rate(0.1);
513
514 assert_eq!(config.d0, 1e-5);
515 assert_eq!(config.d_coef, 2.0);
516 assert_eq!(config.lr, 0.5);
517 assert_eq!(config.beta1, 0.95);
518 assert_eq!(config.beta2, 0.9999);
519 assert_eq!(config.eps, 1e-7);
520 assert_eq!(config.weight_decay, 0.01);
521 assert_eq!(config.grad_clip, Some(1.0));
522 assert!(!config.bias_correction);
523 assert_eq!(config.d_growth_rate, 0.1);
524 }
525
526 #[test]
527 fn test_prodigy_initialization() {
528 let config = ProdigyConfig::default();
529 let optimizer = ProdigyOptimizer::new(config);
530
531 assert_eq!(optimizer.get_step(), 0);
532 assert_eq!(optimizer.get_d(), 0.0);
533 }
534
535 #[test]
536 fn test_prodigy_first_step() {
537 let config = ProdigyConfig::default().with_d0(1e-6);
538 let mut optimizer = ProdigyOptimizer::new(config);
539
540 let mut params = HashMap::new();
541 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
542
543 let mut grads = HashMap::new();
544 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
545
546 optimizer.step(&mut params, &grads).unwrap();
547
548 assert_eq!(optimizer.get_step(), 1);
549 assert_eq!(optimizer.get_d(), 1e-6); }
551
552 #[test]
553 fn test_prodigy_d_adaptation() {
554 let config = ProdigyConfig::default().with_d0(1e-6).with_d_coef(1.0);
555 let mut optimizer = ProdigyOptimizer::new(config);
556
557 let mut params = HashMap::new();
558 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
559
560 let mut grads = HashMap::new();
562 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
563 optimizer.step(&mut params, &grads).unwrap();
564
565 let d_after_step1 = optimizer.get_d();
566 assert_eq!(d_after_step1, 1e-6);
567
568 optimizer.step(&mut params, &grads).unwrap();
570
571 let d_after_step2 = optimizer.get_d();
572 assert!(d_after_step2 >= 1e-6); }
574
575 #[test]
576 fn test_prodigy_parameter_update() {
577 let config = ProdigyConfig::default();
578 let mut optimizer = ProdigyOptimizer::new(config);
579
580 let mut params = HashMap::new();
581 let initial_value = 1.0;
582 params.insert("w".to_string(), Array2::from_elem((2, 2), initial_value));
583
584 let mut grads = HashMap::new();
585 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.5));
586
587 optimizer.step(&mut params, &grads).unwrap();
588
589 let w = params.get("w").unwrap();
591 assert!(w[[0, 0]] < initial_value);
592 }
593
594 #[test]
595 fn test_prodigy_weight_decay() {
596 let config = ProdigyConfig::default().with_weight_decay(0.01);
597 let mut optimizer = ProdigyOptimizer::new(config);
598
599 let mut params = HashMap::new();
600 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
601
602 let mut grads = HashMap::new();
603 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.0));
604
605 let initial_sum: f64 = params.get("w").unwrap().sum();
607 optimizer.step(&mut params, &grads).unwrap();
608 let final_sum: f64 = params.get("w").unwrap().sum();
609
610 assert!(final_sum < initial_sum);
611 }
612
613 #[test]
614 fn test_prodigy_gradient_clipping_by_norm() {
615 let config = ProdigyConfig::default().with_grad_clip(0.1);
616 let mut optimizer = ProdigyOptimizer::new(config);
617
618 let mut params = HashMap::new();
619 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
620
621 let mut grads = HashMap::new();
622 grads.insert("w".to_string(), Array2::from_elem((2, 2), 10.0)); optimizer.step(&mut params, &grads).unwrap();
626
627 let w = params.get("w").unwrap();
629 assert!(w[[0, 0]] < 1.0);
630 }
631
632 #[test]
633 fn test_prodigy_state_dict() {
634 let config = ProdigyConfig::default();
635 let mut optimizer = ProdigyOptimizer::new(config);
636
637 let mut params = HashMap::new();
638 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
639
640 let mut grads = HashMap::new();
641 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
642
643 for _ in 0..3 {
645 optimizer.step(&mut params, &grads).unwrap();
646 }
647
648 let state = optimizer.state_dict();
649 assert!(state.contains_key("step"));
650 assert!(state.contains_key("d"));
651 assert!(state.contains_key("sum_grad_norm"));
652 }
653
654 #[test]
655 fn test_prodigy_load_state_dict() {
656 let config = ProdigyConfig::default();
657 let mut optimizer1 = ProdigyOptimizer::new(config.clone());
658
659 let mut params = HashMap::new();
660 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
661
662 let mut grads = HashMap::new();
663 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
664
665 for _ in 0..3 {
667 optimizer1.step(&mut params, &grads).unwrap();
668 }
669
670 let state = optimizer1.state_dict();
671
672 let mut optimizer2 = ProdigyOptimizer::new(config);
674 optimizer2.load_state_dict(state);
675
676 assert_eq!(optimizer1.get_step(), optimizer2.get_step());
677 assert_eq!(optimizer1.get_d(), optimizer2.get_d());
678 }
679
680 #[test]
681 fn test_prodigy_bias_correction() {
682 let config_with_bc = ProdigyConfig::default().with_bias_correction(true);
683 let config_without_bc = ProdigyConfig::default().with_bias_correction(false);
684
685 let mut opt_with_bc = ProdigyOptimizer::new(config_with_bc);
686 let mut opt_without_bc = ProdigyOptimizer::new(config_without_bc);
687
688 let mut params1 = HashMap::new();
689 params1.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
690
691 let mut params2 = params1.clone();
692
693 let mut grads = HashMap::new();
694 grads.insert("w".to_string(), Array2::from_elem((2, 2), 0.1));
695
696 opt_with_bc.step(&mut params1, &grads).unwrap();
697 opt_without_bc.step(&mut params2, &grads).unwrap();
698
699 let w1 = params1.get("w").unwrap();
701 let w2 = params2.get("w").unwrap();
702
703 let diff = (w1[[0, 0]] - w2[[0, 0]]).abs();
705 assert!(diff > 1e-10);
706 }
707
708 #[test]
709 fn test_prodigy_d_growth_rate_limit() {
710 let config = ProdigyConfig::default()
711 .with_d0(1e-6)
712 .with_d_growth_rate(0.1); let mut optimizer = ProdigyOptimizer::new(config);
715
716 let mut params = HashMap::new();
717 params.insert("w".to_string(), Array2::from_elem((2, 2), 1.0));
718
719 let mut grads = HashMap::new();
720 grads.insert("w".to_string(), Array2::from_elem((2, 2), 1.0)); optimizer.step(&mut params, &grads).unwrap();
724 let d1 = optimizer.get_d();
725
726 optimizer.step(&mut params, &grads).unwrap();
728 let d2 = optimizer.get_d();
729
730 if d2 > d1 {
732 let growth_ratio = d2 / d1;
733 assert!(growth_ratio <= 1.11); }
735 }
736}