tensorlogic_train/optimizers/
schedulefree.rs1use crate::optimizers::common::{GradClipMode, Optimizer};
28use crate::{TrainError, TrainResult};
29use scirs2_core::ndarray::{Array, Array2, Zip};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32
33#[derive(Debug, Clone)]
76pub struct ScheduleFreeAdamW {
77 config: ScheduleFreeConfig,
79 train_params: HashMap<String, Array2<f64>>,
81 eval_params: HashMap<String, Array2<f64>>,
83 first_moments: HashMap<String, Array2<f64>>,
85 second_moments: HashMap<String, Array2<f64>>,
87 step: usize,
89 training_mode: bool,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ScheduleFreeConfig {
96 pub lr: f64,
98 pub beta1: f64,
100 pub beta2: f64,
102 pub weight_decay: f64,
104 pub eps: f64,
106 pub gamma: f64,
109 pub warmup_steps: usize,
112 pub grad_clip: Option<f64>,
114 pub grad_clip_mode: GradClipMode,
116}
117
118impl Default for ScheduleFreeConfig {
119 fn default() -> Self {
120 Self {
121 lr: 0.001,
122 beta1: 0.9,
123 beta2: 0.999,
124 weight_decay: 0.01,
125 eps: 1e-8,
126 gamma: 0.95,
127 warmup_steps: 0,
128 grad_clip: None,
129 grad_clip_mode: GradClipMode::Norm,
130 }
131 }
132}
133
134impl ScheduleFreeConfig {
135 pub fn new(lr: f64) -> Self {
137 Self {
138 lr,
139 ..Default::default()
140 }
141 }
142
143 pub fn with_lr(mut self, lr: f64) -> Self {
145 self.lr = lr;
146 self
147 }
148
149 pub fn with_beta1(mut self, beta1: f64) -> Self {
151 self.beta1 = beta1;
152 self
153 }
154
155 pub fn with_beta2(mut self, beta2: f64) -> Self {
157 self.beta2 = beta2;
158 self
159 }
160
161 pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
163 self.weight_decay = weight_decay;
164 self
165 }
166
167 pub fn with_gamma(mut self, gamma: f64) -> Self {
169 self.gamma = gamma;
170 self
171 }
172
173 pub fn with_warmup_steps(mut self, steps: usize) -> Self {
175 self.warmup_steps = steps;
176 self
177 }
178
179 pub fn with_grad_clip(mut self, threshold: f64, mode: GradClipMode) -> Self {
181 self.grad_clip = Some(threshold);
182 self.grad_clip_mode = mode;
183 self
184 }
185}
186
187impl ScheduleFreeAdamW {
188 pub fn new(config: ScheduleFreeConfig) -> Self {
190 Self {
191 config,
192 train_params: HashMap::new(),
193 eval_params: HashMap::new(),
194 first_moments: HashMap::new(),
195 second_moments: HashMap::new(),
196 step: 0,
197 training_mode: true,
198 }
199 }
200
201 pub fn set_training_mode(&mut self, training: bool) {
206 self.training_mode = training;
207 }
208
209 pub fn is_training(&self) -> bool {
211 self.training_mode
212 }
213
214 pub fn get_eval_parameters(&self) -> &HashMap<String, Array2<f64>> {
216 &self.eval_params
217 }
218
219 pub fn get_train_parameters(&self) -> &HashMap<String, Array2<f64>> {
221 &self.train_params
222 }
223
224 fn effective_gamma(&self) -> f64 {
226 if self.config.warmup_steps == 0 {
227 return self.config.gamma;
228 }
229
230 if self.step >= self.config.warmup_steps {
231 self.config.gamma
232 } else {
233 self.config.gamma * (self.step as f64 / self.config.warmup_steps as f64)
235 }
236 }
237}
238
239impl Optimizer for ScheduleFreeAdamW {
240 fn zero_grad(&mut self) {
241 }
243
244 fn get_lr(&self) -> f64 {
245 self.config.lr
246 }
247
248 fn set_lr(&mut self, lr: f64) {
249 self.config.lr = lr;
250 }
251
252 fn step(
253 &mut self,
254 parameters: &mut HashMap<String, Array2<f64>>,
255 gradients: &HashMap<String, Array2<f64>>,
256 ) -> TrainResult<()> {
257 if gradients.is_empty() {
258 return Ok(());
259 }
260
261 self.step += 1;
262
263 if self.train_params.is_empty() {
265 for (name, param) in parameters.iter() {
266 self.train_params.insert(name.clone(), param.clone());
267 self.eval_params.insert(name.clone(), param.clone());
268 self.first_moments
269 .insert(name.clone(), Array::zeros(param.raw_dim()));
270 self.second_moments
271 .insert(name.clone(), Array::zeros(param.raw_dim()));
272 }
273 }
274
275 let gamma = self.effective_gamma();
276
277 for (name, grad) in gradients.iter() {
279 let param = self.train_params.get_mut(name).ok_or_else(|| {
280 TrainError::OptimizerError(format!("Parameter {} not found", name))
281 })?;
282
283 let m = self.first_moments.get_mut(name).ok_or_else(|| {
284 TrainError::OptimizerError(format!("First moment {} not found", name))
285 })?;
286
287 let v = self.second_moments.get_mut(name).ok_or_else(|| {
288 TrainError::OptimizerError(format!("Second moment {} not found", name))
289 })?;
290
291 let grad_clipped = if let Some(threshold) = self.config.grad_clip {
293 match self.config.grad_clip_mode {
294 GradClipMode::Value => grad.mapv(|g| g.max(-threshold).min(threshold)),
295 GradClipMode::Norm => {
296 let norm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
297 if norm > threshold {
298 grad.mapv(|g| g * threshold / norm)
299 } else {
300 grad.clone()
301 }
302 }
303 }
304 } else {
305 grad.clone()
306 };
307
308 Zip::from(&mut *m).and(&grad_clipped).for_each(|m_val, &g| {
310 *m_val = self.config.beta1 * *m_val + (1.0 - self.config.beta1) * g;
311 });
312
313 Zip::from(&mut *v).and(&grad_clipped).for_each(|v_val, &g| {
315 *v_val = self.config.beta2 * *v_val + (1.0 - self.config.beta2) * g * g;
316 });
317
318 let m_hat_coef = 1.0 / (1.0 - self.config.beta1.powi(self.step as i32));
320 let v_hat_coef = 1.0 / (1.0 - self.config.beta2.powi(self.step as i32));
321
322 Zip::from(&mut *param)
325 .and(&*m)
326 .and(&*v)
327 .for_each(|p, &m_val, &v_val| {
328 let m_hat = m_val * m_hat_coef;
329 let v_hat = v_val * v_hat_coef;
330
331 let adam_update = m_hat / (v_hat.sqrt() + self.config.eps);
333 let weight_decay_update = self.config.weight_decay * *p;
334
335 *p -= self.config.lr * (adam_update + weight_decay_update);
336 });
337
338 let eval_param = self.eval_params.get_mut(name).ok_or_else(|| {
340 TrainError::OptimizerError(format!("Eval parameter {} not found", name))
341 })?;
342
343 Zip::from(&mut *eval_param).and(&*param).for_each(|y, &x| {
344 *y = (1.0 - gamma) * x + gamma * *y;
345 });
346 }
347
348 for (name, param) in parameters.iter_mut() {
350 if self.training_mode {
351 if let Some(train_param) = self.train_params.get(name) {
353 param.assign(train_param);
354 }
355 } else {
356 if let Some(eval_param) = self.eval_params.get(name) {
358 param.assign(eval_param);
359 }
360 }
361 }
362
363 Ok(())
364 }
365
366 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
367 let mut state = HashMap::new();
368
369 state.insert("lr".to_string(), vec![self.config.lr]);
371 state.insert("beta1".to_string(), vec![self.config.beta1]);
372 state.insert("beta2".to_string(), vec![self.config.beta2]);
373 state.insert("weight_decay".to_string(), vec![self.config.weight_decay]);
374 state.insert("eps".to_string(), vec![self.config.eps]);
375 state.insert("gamma".to_string(), vec![self.config.gamma]);
376 state.insert(
377 "warmup_steps".to_string(),
378 vec![self.config.warmup_steps as f64],
379 );
380 state.insert("step".to_string(), vec![self.step as f64]);
381 state.insert(
382 "training_mode".to_string(),
383 vec![if self.training_mode { 1.0 } else { 0.0 }],
384 );
385
386 for (name, m) in &self.first_moments {
388 state.insert(
389 format!("first_moment_{}", name),
390 m.iter().copied().collect(),
391 );
392 }
393
394 for (name, v) in &self.second_moments {
395 state.insert(
396 format!("second_moment_{}", name),
397 v.iter().copied().collect(),
398 );
399 }
400
401 for (name, p) in &self.train_params {
402 state.insert(format!("train_param_{}", name), p.iter().copied().collect());
403 }
404
405 for (name, p) in &self.eval_params {
406 state.insert(format!("eval_param_{}", name), p.iter().copied().collect());
407 }
408
409 state
410 }
411
412 fn load_state_dict(&mut self, _state: HashMap<String, Vec<f64>>) {
413 self.step = 0;
416 self.first_moments.clear();
417 self.second_moments.clear();
418 self.train_params.clear();
419 self.eval_params.clear();
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use scirs2_core::ndarray::array;
427
428 #[test]
429 fn test_schedulefree_creation() {
430 let config = ScheduleFreeConfig::default();
431 let optimizer = ScheduleFreeAdamW::new(config);
432
433 assert_eq!(optimizer.get_lr(), 0.001);
434 assert!(optimizer.is_training());
435 }
436
437 #[test]
438 fn test_schedulefree_config_builder() {
439 let config = ScheduleFreeConfig::default()
440 .with_lr(0.01)
441 .with_beta1(0.85)
442 .with_beta2(0.995)
443 .with_gamma(0.98)
444 .with_warmup_steps(1000);
445
446 assert_eq!(config.lr, 0.01);
447 assert_eq!(config.beta1, 0.85);
448 assert_eq!(config.beta2, 0.995);
449 assert_eq!(config.gamma, 0.98);
450 assert_eq!(config.warmup_steps, 1000);
451 }
452
453 #[test]
454 fn test_schedulefree_training_mode() {
455 let config = ScheduleFreeConfig::default();
456 let mut optimizer = ScheduleFreeAdamW::new(config);
457
458 assert!(optimizer.is_training());
459
460 optimizer.set_training_mode(false);
461 assert!(!optimizer.is_training());
462
463 optimizer.set_training_mode(true);
464 assert!(optimizer.is_training());
465 }
466
467 #[test]
468 fn test_schedulefree_step() {
469 let config = ScheduleFreeConfig::default().with_lr(0.1);
470 let mut optimizer = ScheduleFreeAdamW::new(config);
471
472 let mut params = HashMap::new();
473 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
474
475 let mut grads = HashMap::new();
476 grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
477
478 optimizer.step(&mut params, &grads).unwrap();
480
481 let updated_w = params.get("w").unwrap();
483 assert_ne!(updated_w[[0, 0]], 1.0);
484
485 assert_eq!(optimizer.first_moments.len(), 1);
487 assert_eq!(optimizer.second_moments.len(), 1);
488 }
489
490 #[test]
491 fn test_schedulefree_eval_parameters() {
492 let config = ScheduleFreeConfig::default().with_lr(0.1).with_gamma(0.5);
493 let mut optimizer = ScheduleFreeAdamW::new(config);
494
495 let mut params = HashMap::new();
496 params.insert("w".to_string(), array![[1.0, 2.0]]);
497
498 let mut grads = HashMap::new();
499 grads.insert("w".to_string(), array![[0.1, 0.2]]);
500
501 for _ in 0..5 {
503 optimizer.step(&mut params, &grads).unwrap();
504 }
505
506 let train_params = optimizer.get_train_parameters();
508 let eval_params = optimizer.get_eval_parameters();
509
510 let train_w = train_params.get("w").unwrap();
511 let eval_w = eval_params.get("w").unwrap();
512
513 assert_ne!(train_w[[0, 0]], eval_w[[0, 0]]);
515 }
516
517 #[test]
518 fn test_schedulefree_gamma_warmup() {
519 let config = ScheduleFreeConfig::default().with_warmup_steps(100);
520 let mut optimizer = ScheduleFreeAdamW::new(config);
521
522 assert_eq!(optimizer.effective_gamma(), 0.0);
524
525 let mut params = HashMap::new();
527 params.insert("w".to_string(), array![[1.0]]);
528
529 let mut grads = HashMap::new();
530 grads.insert("w".to_string(), array![[0.1]]);
531
532 for _ in 0..50 {
533 optimizer.step(&mut params, &grads).unwrap();
534 }
535
536 let gamma_50 = optimizer.effective_gamma();
538 let expected_50 = 0.95 * (50.0 / 100.0);
539 assert!(
540 (gamma_50 - expected_50).abs() < 0.05,
541 "gamma_50 = {}, expected ~{}",
542 gamma_50,
543 expected_50
544 );
545
546 for _ in 50..100 {
547 optimizer.step(&mut params, &grads).unwrap();
548 }
549
550 assert!((optimizer.effective_gamma() - 0.95).abs() < 1e-6);
552 }
553
554 #[test]
555 fn test_schedulefree_gradient_clipping() {
556 let config = ScheduleFreeConfig::default()
557 .with_lr(0.1)
558 .with_grad_clip(0.5, GradClipMode::Value);
559
560 let mut optimizer = ScheduleFreeAdamW::new(config);
561
562 let mut params = HashMap::new();
563 params.insert("w".to_string(), array![[1.0, 2.0]]);
564
565 let mut grads = HashMap::new();
566 grads.insert("w".to_string(), array![[10.0, -10.0]]);
568
569 optimizer.step(&mut params, &grads).unwrap();
570
571 let updated_w = params.get("w").unwrap();
573 assert!(updated_w[[0, 0]] > 0.5); assert!(updated_w[[0, 1]] < 2.5); }
577
578 #[test]
579 fn test_schedulefree_weight_decay() {
580 let config_no_decay = ScheduleFreeConfig::default()
581 .with_lr(0.1)
582 .with_weight_decay(0.0);
583
584 let config_with_decay = ScheduleFreeConfig::default()
585 .with_lr(0.1)
586 .with_weight_decay(0.1);
587
588 let mut opt_no_decay = ScheduleFreeAdamW::new(config_no_decay);
589 let mut opt_with_decay = ScheduleFreeAdamW::new(config_with_decay);
590
591 let mut params1 = HashMap::new();
592 params1.insert("w".to_string(), array![[1.0, 2.0]]);
593
594 let mut params2 = params1.clone();
595
596 let mut grads = HashMap::new();
597 grads.insert("w".to_string(), array![[0.1, 0.1]]);
598
599 opt_no_decay.step(&mut params1, &grads).unwrap();
600 opt_with_decay.step(&mut params2, &grads).unwrap();
601
602 let w1 = params1.get("w").unwrap();
604 let w2 = params2.get("w").unwrap();
605
606 assert!(w2[[0, 0]] < w1[[0, 0]]);
607 assert!(w2[[0, 1]] < w1[[0, 1]]);
608 }
609}