Skip to main content

tensorlogic_train/optimizers/
schedulefree.rs

1//! Schedule-Free Optimizers - No Learning Rate Schedule Needed!
2//!
3//! Implementation of schedule-free learning from "The Road Less Scheduled" (Defazio et al., 2024).
4//!
5//! # Key Innovation
6//!
7//! Traditional deep learning requires carefully tuned learning rate schedules. Schedule-free
8//! optimizers eliminate this requirement by maintaining two parameter sequences:
9//! - **x_t**: Training sequence (used for gradient computation)
10//! - **y_t**: Evaluation sequence (used for inference)
11//!
12//! The evaluation sequence y_t is an interpolation between current and past training parameters,
13//! providing implicit scheduling without manual tuning.
14//!
15//! # Benefits
16//!
17//! 1. **No schedule tuning required** - Just set a constant learning rate
18//! 2. **Better generalization** - Averaging provides implicit regularization
19//! 3. **Faster convergence** - Adaptive to problem structure
20//! 4. **Simpler hyperparameter search** - One less hyperparameter to tune
21//!
22//! # References
23//!
24//! - Defazio, A., Mishchenko, K., & Orabona, F. (2024).
25//!   "The Road Less Scheduled". arXiv:2405.15682
26
27use crate::optimizers::common::{GradClipMode, Optimizer};
28use crate::{TrainError, TrainResult};
29use scirs2_core::ndarray::{Array, Array2, Zip};
30use serde::{Deserialize, Serialize};
31use std::collections::HashMap;
32
33/// Schedule-free AdamW optimizer.
34///
35/// Maintains both training sequence (x) and evaluation sequence (y).
36/// During training, gradients are computed w.r.t. x_t.
37/// During evaluation, use y_t for better generalization.
38///
39/// # Algorithm
40///
41/// ```text
42/// # Training step:
43/// m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
44/// v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
45/// x_t = x_{t-1} - lr * m_t / (√v_t + ε) - lr * λ * x_{t-1}  # weight decay
46///
47/// # Evaluation sequence (exponential moving average):
48/// y_t = (1 - γ) * x_t + γ * y_{t-1}
49///
50/// # At test time, use y_t instead of x_t
51/// ```
52///
53/// # Example
54///
55/// ```no_run
56/// use tensorlogic_train::{ScheduleFreeAdamW, ScheduleFreeConfig};
57/// use scirs2_core::ndarray::Array2;
58/// use std::collections::HashMap;
59///
60/// let config = ScheduleFreeConfig::default()
61///     .with_lr(0.001)
62///     .with_warmup_steps(1000);
63///
64/// let mut optimizer = ScheduleFreeAdamW::new(config);
65///
66/// // Training mode - use training parameters
67/// optimizer.set_training_mode(true);
68///
69/// // ... compute gradients ...
70///
71/// // Evaluation mode - switch to averaged parameters
72/// optimizer.set_training_mode(false);
73/// let eval_params = optimizer.get_eval_parameters();
74/// ```
75#[derive(Debug, Clone)]
76pub struct ScheduleFreeAdamW {
77    /// Configuration
78    config: ScheduleFreeConfig,
79    /// Training parameters (x_t)
80    train_params: HashMap<String, Array2<f64>>,
81    /// Evaluation parameters (y_t) - exponential moving average
82    eval_params: HashMap<String, Array2<f64>>,
83    /// First moment estimates
84    first_moments: HashMap<String, Array2<f64>>,
85    /// Second moment estimates
86    second_moments: HashMap<String, Array2<f64>>,
87    /// Current step number
88    step: usize,
89    /// Training mode flag (true = use train_params, false = use eval_params)
90    training_mode: bool,
91}
92
93/// Configuration for schedule-free optimizers.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct ScheduleFreeConfig {
96    /// Learning rate (constant, no schedule needed!)
97    pub lr: f64,
98    /// Beta1 for first moment (default: 0.9)
99    pub beta1: f64,
100    /// Beta2 for second moment (default: 0.999)
101    pub beta2: f64,
102    /// Weight decay coefficient (default: 0.01)
103    pub weight_decay: f64,
104    /// Epsilon for numerical stability (default: 1e-8)
105    pub eps: f64,
106    /// Averaging coefficient γ for evaluation sequence (default: 0.95)
107    /// Higher values = more smoothing, better generalization
108    pub gamma: f64,
109    /// Number of warmup steps (default: 0)
110    /// During warmup, gamma increases linearly from 0 to target value
111    pub warmup_steps: usize,
112    /// Gradient clipping threshold (None = no clipping)
113    pub grad_clip: Option<f64>,
114    /// Gradient clipping mode
115    pub grad_clip_mode: GradClipMode,
116}
117
118impl Default for ScheduleFreeConfig {
119    fn default() -> Self {
120        Self {
121            lr: 0.001,
122            beta1: 0.9,
123            beta2: 0.999,
124            weight_decay: 0.01,
125            eps: 1e-8,
126            gamma: 0.95,
127            warmup_steps: 0,
128            grad_clip: None,
129            grad_clip_mode: GradClipMode::Norm,
130        }
131    }
132}
133
134impl ScheduleFreeConfig {
135    /// Create a new configuration with custom learning rate.
136    pub fn new(lr: f64) -> Self {
137        Self {
138            lr,
139            ..Default::default()
140        }
141    }
142
143    /// Set learning rate (builder pattern).
144    pub fn with_lr(mut self, lr: f64) -> Self {
145        self.lr = lr;
146        self
147    }
148
149    /// Set beta1 (builder pattern).
150    pub fn with_beta1(mut self, beta1: f64) -> Self {
151        self.beta1 = beta1;
152        self
153    }
154
155    /// Set beta2 (builder pattern).
156    pub fn with_beta2(mut self, beta2: f64) -> Self {
157        self.beta2 = beta2;
158        self
159    }
160
161    /// Set weight decay (builder pattern).
162    pub fn with_weight_decay(mut self, weight_decay: f64) -> Self {
163        self.weight_decay = weight_decay;
164        self
165    }
166
167    /// Set gamma (averaging coefficient) (builder pattern).
168    pub fn with_gamma(mut self, gamma: f64) -> Self {
169        self.gamma = gamma;
170        self
171    }
172
173    /// Set warmup steps (builder pattern).
174    pub fn with_warmup_steps(mut self, steps: usize) -> Self {
175        self.warmup_steps = steps;
176        self
177    }
178
179    /// Set gradient clipping (builder pattern).
180    pub fn with_grad_clip(mut self, threshold: f64, mode: GradClipMode) -> Self {
181        self.grad_clip = Some(threshold);
182        self.grad_clip_mode = mode;
183        self
184    }
185}
186
187impl ScheduleFreeAdamW {
188    /// Create a new Schedule-Free AdamW optimizer.
189    pub fn new(config: ScheduleFreeConfig) -> Self {
190        Self {
191            config,
192            train_params: HashMap::new(),
193            eval_params: HashMap::new(),
194            first_moments: HashMap::new(),
195            second_moments: HashMap::new(),
196            step: 0,
197            training_mode: true,
198        }
199    }
200
201    /// Set training mode.
202    ///
203    /// When training_mode = true, use train_params (x_t) for gradients.
204    /// When training_mode = false, use eval_params (y_t) for evaluation.
205    pub fn set_training_mode(&mut self, training: bool) {
206        self.training_mode = training;
207    }
208
209    /// Get current training mode.
210    pub fn is_training(&self) -> bool {
211        self.training_mode
212    }
213
214    /// Get evaluation parameters (for inference).
215    pub fn get_eval_parameters(&self) -> &HashMap<String, Array2<f64>> {
216        &self.eval_params
217    }
218
219    /// Get training parameters (for gradient computation).
220    pub fn get_train_parameters(&self) -> &HashMap<String, Array2<f64>> {
221        &self.train_params
222    }
223
224    /// Compute effective gamma with warmup.
225    fn effective_gamma(&self) -> f64 {
226        if self.config.warmup_steps == 0 {
227            return self.config.gamma;
228        }
229
230        if self.step >= self.config.warmup_steps {
231            self.config.gamma
232        } else {
233            // Linear warmup: gamma goes from 0 to target value
234            self.config.gamma * (self.step as f64 / self.config.warmup_steps as f64)
235        }
236    }
237}
238
239impl Optimizer for ScheduleFreeAdamW {
240    fn zero_grad(&mut self) {
241        // Schedule-free optimizers don't maintain gradients, so this is a no-op
242    }
243
244    fn get_lr(&self) -> f64 {
245        self.config.lr
246    }
247
248    fn set_lr(&mut self, lr: f64) {
249        self.config.lr = lr;
250    }
251
252    fn step(
253        &mut self,
254        parameters: &mut HashMap<String, Array2<f64>>,
255        gradients: &HashMap<String, Array2<f64>>,
256    ) -> TrainResult<()> {
257        if gradients.is_empty() {
258            return Ok(());
259        }
260
261        self.step += 1;
262
263        // Initialize if needed
264        if self.train_params.is_empty() {
265            for (name, param) in parameters.iter() {
266                self.train_params.insert(name.clone(), param.clone());
267                self.eval_params.insert(name.clone(), param.clone());
268                self.first_moments
269                    .insert(name.clone(), Array::zeros(param.raw_dim()));
270                self.second_moments
271                    .insert(name.clone(), Array::zeros(param.raw_dim()));
272            }
273        }
274
275        let gamma = self.effective_gamma();
276
277        // Update each parameter
278        for (name, grad) in gradients.iter() {
279            let param = self.train_params.get_mut(name).ok_or_else(|| {
280                TrainError::OptimizerError(format!("Parameter {} not found", name))
281            })?;
282
283            let m = self.first_moments.get_mut(name).ok_or_else(|| {
284                TrainError::OptimizerError(format!("First moment {} not found", name))
285            })?;
286
287            let v = self.second_moments.get_mut(name).ok_or_else(|| {
288                TrainError::OptimizerError(format!("Second moment {} not found", name))
289            })?;
290
291            // Apply gradient clipping if configured
292            let grad_clipped = if let Some(threshold) = self.config.grad_clip {
293                match self.config.grad_clip_mode {
294                    GradClipMode::Value => grad.mapv(|g| g.max(-threshold).min(threshold)),
295                    GradClipMode::Norm => {
296                        let norm = grad.iter().map(|g| g * g).sum::<f64>().sqrt();
297                        if norm > threshold {
298                            grad.mapv(|g| g * threshold / norm)
299                        } else {
300                            grad.clone()
301                        }
302                    }
303                }
304            } else {
305                grad.clone()
306            };
307
308            // Update biased first moment estimate: m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
309            Zip::from(&mut *m).and(&grad_clipped).for_each(|m_val, &g| {
310                *m_val = self.config.beta1 * *m_val + (1.0 - self.config.beta1) * g;
311            });
312
313            // Update biased second moment estimate: v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
314            Zip::from(&mut *v).and(&grad_clipped).for_each(|v_val, &g| {
315                *v_val = self.config.beta2 * *v_val + (1.0 - self.config.beta2) * g * g;
316            });
317
318            // Bias correction
319            let m_hat_coef = 1.0 / (1.0 - self.config.beta1.powi(self.step as i32));
320            let v_hat_coef = 1.0 / (1.0 - self.config.beta2.powi(self.step as i32));
321
322            // Update training parameters with AdamW update:
323            // x_t = x_{t-1} - lr * (m_hat / (√v_hat + ε) + λ * x_{t-1})
324            Zip::from(&mut *param)
325                .and(&*m)
326                .and(&*v)
327                .for_each(|p, &m_val, &v_val| {
328                    let m_hat = m_val * m_hat_coef;
329                    let v_hat = v_val * v_hat_coef;
330
331                    // AdamW-style update
332                    let adam_update = m_hat / (v_hat.sqrt() + self.config.eps);
333                    let weight_decay_update = self.config.weight_decay * *p;
334
335                    *p -= self.config.lr * (adam_update + weight_decay_update);
336                });
337
338            // Update evaluation parameters: y_t = (1 - γ) * x_t + γ * y_{t-1}
339            let eval_param = self.eval_params.get_mut(name).ok_or_else(|| {
340                TrainError::OptimizerError(format!("Eval parameter {} not found", name))
341            })?;
342
343            Zip::from(&mut *eval_param).and(&*param).for_each(|y, &x| {
344                *y = (1.0 - gamma) * x + gamma * *y;
345            });
346        }
347
348        // Update the provided parameters based on current mode
349        for (name, param) in parameters.iter_mut() {
350            if self.training_mode {
351                // Use training parameters
352                if let Some(train_param) = self.train_params.get(name) {
353                    param.assign(train_param);
354                }
355            } else {
356                // Use evaluation parameters
357                if let Some(eval_param) = self.eval_params.get(name) {
358                    param.assign(eval_param);
359                }
360            }
361        }
362
363        Ok(())
364    }
365
366    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
367        let mut state = HashMap::new();
368
369        // Save configuration
370        state.insert("lr".to_string(), vec![self.config.lr]);
371        state.insert("beta1".to_string(), vec![self.config.beta1]);
372        state.insert("beta2".to_string(), vec![self.config.beta2]);
373        state.insert("weight_decay".to_string(), vec![self.config.weight_decay]);
374        state.insert("eps".to_string(), vec![self.config.eps]);
375        state.insert("gamma".to_string(), vec![self.config.gamma]);
376        state.insert(
377            "warmup_steps".to_string(),
378            vec![self.config.warmup_steps as f64],
379        );
380        state.insert("step".to_string(), vec![self.step as f64]);
381        state.insert(
382            "training_mode".to_string(),
383            vec![if self.training_mode { 1.0 } else { 0.0 }],
384        );
385
386        // Save moments and parameters
387        for (name, m) in &self.first_moments {
388            state.insert(
389                format!("first_moment_{}", name),
390                m.iter().copied().collect(),
391            );
392        }
393
394        for (name, v) in &self.second_moments {
395            state.insert(
396                format!("second_moment_{}", name),
397                v.iter().copied().collect(),
398            );
399        }
400
401        for (name, p) in &self.train_params {
402            state.insert(format!("train_param_{}", name), p.iter().copied().collect());
403        }
404
405        for (name, p) in &self.eval_params {
406            state.insert(format!("eval_param_{}", name), p.iter().copied().collect());
407        }
408
409        state
410    }
411
412    fn load_state_dict(&mut self, _state: HashMap<String, Vec<f64>>) {
413        // Simplified: just reset to initial state
414        // In production, would properly deserialize all state
415        self.step = 0;
416        self.first_moments.clear();
417        self.second_moments.clear();
418        self.train_params.clear();
419        self.eval_params.clear();
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use scirs2_core::ndarray::array;
427
428    #[test]
429    fn test_schedulefree_creation() {
430        let config = ScheduleFreeConfig::default();
431        let optimizer = ScheduleFreeAdamW::new(config);
432
433        assert_eq!(optimizer.get_lr(), 0.001);
434        assert!(optimizer.is_training());
435    }
436
437    #[test]
438    fn test_schedulefree_config_builder() {
439        let config = ScheduleFreeConfig::default()
440            .with_lr(0.01)
441            .with_beta1(0.85)
442            .with_beta2(0.995)
443            .with_gamma(0.98)
444            .with_warmup_steps(1000);
445
446        assert_eq!(config.lr, 0.01);
447        assert_eq!(config.beta1, 0.85);
448        assert_eq!(config.beta2, 0.995);
449        assert_eq!(config.gamma, 0.98);
450        assert_eq!(config.warmup_steps, 1000);
451    }
452
453    #[test]
454    fn test_schedulefree_training_mode() {
455        let config = ScheduleFreeConfig::default();
456        let mut optimizer = ScheduleFreeAdamW::new(config);
457
458        assert!(optimizer.is_training());
459
460        optimizer.set_training_mode(false);
461        assert!(!optimizer.is_training());
462
463        optimizer.set_training_mode(true);
464        assert!(optimizer.is_training());
465    }
466
467    #[test]
468    fn test_schedulefree_step() {
469        let config = ScheduleFreeConfig::default().with_lr(0.1);
470        let mut optimizer = ScheduleFreeAdamW::new(config);
471
472        let mut params = HashMap::new();
473        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
474
475        let mut grads = HashMap::new();
476        grads.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
477
478        // Take a step
479        optimizer.step(&mut params, &grads).unwrap();
480
481        // Parameters should have been updated
482        let updated_w = params.get("w").unwrap();
483        assert_ne!(updated_w[[0, 0]], 1.0);
484
485        // Should have created moments
486        assert_eq!(optimizer.first_moments.len(), 1);
487        assert_eq!(optimizer.second_moments.len(), 1);
488    }
489
490    #[test]
491    fn test_schedulefree_eval_parameters() {
492        let config = ScheduleFreeConfig::default().with_lr(0.1).with_gamma(0.5);
493        let mut optimizer = ScheduleFreeAdamW::new(config);
494
495        let mut params = HashMap::new();
496        params.insert("w".to_string(), array![[1.0, 2.0]]);
497
498        let mut grads = HashMap::new();
499        grads.insert("w".to_string(), array![[0.1, 0.2]]);
500
501        // Take multiple steps
502        for _ in 0..5 {
503            optimizer.step(&mut params, &grads).unwrap();
504        }
505
506        // Eval parameters should be different from training parameters
507        let train_params = optimizer.get_train_parameters();
508        let eval_params = optimizer.get_eval_parameters();
509
510        let train_w = train_params.get("w").unwrap();
511        let eval_w = eval_params.get("w").unwrap();
512
513        // They should be different due to averaging
514        assert_ne!(train_w[[0, 0]], eval_w[[0, 0]]);
515    }
516
517    #[test]
518    fn test_schedulefree_gamma_warmup() {
519        let config = ScheduleFreeConfig::default().with_warmup_steps(100);
520        let mut optimizer = ScheduleFreeAdamW::new(config);
521
522        // At step 0, effective gamma should be 0
523        assert_eq!(optimizer.effective_gamma(), 0.0);
524
525        // Initialize and take steps
526        let mut params = HashMap::new();
527        params.insert("w".to_string(), array![[1.0]]);
528
529        let mut grads = HashMap::new();
530        grads.insert("w".to_string(), array![[0.1]]);
531
532        for _ in 0..50 {
533            optimizer.step(&mut params, &grads).unwrap();
534        }
535
536        // At step 50, effective gamma should be approximately halfway
537        let gamma_50 = optimizer.effective_gamma();
538        let expected_50 = 0.95 * (50.0 / 100.0);
539        assert!(
540            (gamma_50 - expected_50).abs() < 0.05,
541            "gamma_50 = {}, expected ~{}",
542            gamma_50,
543            expected_50
544        );
545
546        for _ in 50..100 {
547            optimizer.step(&mut params, &grads).unwrap();
548        }
549
550        // At step 100, effective gamma should be full value
551        assert!((optimizer.effective_gamma() - 0.95).abs() < 1e-6);
552    }
553
554    #[test]
555    fn test_schedulefree_gradient_clipping() {
556        let config = ScheduleFreeConfig::default()
557            .with_lr(0.1)
558            .with_grad_clip(0.5, GradClipMode::Value);
559
560        let mut optimizer = ScheduleFreeAdamW::new(config);
561
562        let mut params = HashMap::new();
563        params.insert("w".to_string(), array![[1.0, 2.0]]);
564
565        let mut grads = HashMap::new();
566        // Large gradients that should be clipped
567        grads.insert("w".to_string(), array![[10.0, -10.0]]);
568
569        optimizer.step(&mut params, &grads).unwrap();
570
571        // With clipping, the update should be smaller
572        let updated_w = params.get("w").unwrap();
573        // If no clipping, change would be huge; with clipping, it's bounded
574        assert!(updated_w[[0, 0]] > 0.5); // Not too much decrease
575        assert!(updated_w[[0, 1]] < 2.5); // Not too much increase
576    }
577
578    #[test]
579    fn test_schedulefree_weight_decay() {
580        let config_no_decay = ScheduleFreeConfig::default()
581            .with_lr(0.1)
582            .with_weight_decay(0.0);
583
584        let config_with_decay = ScheduleFreeConfig::default()
585            .with_lr(0.1)
586            .with_weight_decay(0.1);
587
588        let mut opt_no_decay = ScheduleFreeAdamW::new(config_no_decay);
589        let mut opt_with_decay = ScheduleFreeAdamW::new(config_with_decay);
590
591        let mut params1 = HashMap::new();
592        params1.insert("w".to_string(), array![[1.0, 2.0]]);
593
594        let mut params2 = params1.clone();
595
596        let mut grads = HashMap::new();
597        grads.insert("w".to_string(), array![[0.1, 0.1]]);
598
599        opt_no_decay.step(&mut params1, &grads).unwrap();
600        opt_with_decay.step(&mut params2, &grads).unwrap();
601
602        // With weight decay, parameters should shrink more
603        let w1 = params1.get("w").unwrap();
604        let w2 = params2.get("w").unwrap();
605
606        assert!(w2[[0, 0]] < w1[[0, 0]]);
607        assert!(w2[[0, 1]] < w1[[0, 1]]);
608    }
609}