Skip to main content

tensorlogic_train/optimizers/
adamp.rs

1//! AdamP optimizer (NeurIPS 2020).
2//!
3//! AdamP (Adaptive Momentum Projection) combines Adam's adaptive learning rates
4//! with projection-based weight decay, leading to better generalization.
5//!
6//! Reference: Heo et al. "AdamP: Slowing Down the Slowdown for Momentum Optimizers
7//! on Scale-invariant Weights" (NeurIPS 2020)
8
9use crate::optimizer::{GradClipMode, Optimizer, OptimizerConfig};
10use crate::TrainResult;
11use scirs2_core::ndarray::{Array, Ix2};
12use std::collections::HashMap;
13
14/// AdamP optimizer with projection-based weight decay.
15///
16/// AdamP addresses the issue that Adam's weight decay can slow down training
17/// on scale-invariant weights (e.g., batch normalization, layer normalization).
18/// It uses projection to apply weight decay only in the direction orthogonal
19/// to the gradient.
20///
21/// # Key Features
22/// - Projection-based weight decay
23/// - Maintains Adam's adaptive learning rates
24/// - Better generalization on deep networks
25/// - Particularly effective with normalization layers
26///
27/// # Example
28/// ```
29/// use tensorlogic_train::{AdamPOptimizer, OptimizerConfig};
30///
31/// let config = OptimizerConfig {
32///     learning_rate: 0.001,
33///     weight_decay: 0.01,  // L2 regularization
34///     ..Default::default()
35/// };
36///
37/// let optimizer = AdamPOptimizer::new(config);
38/// ```
39#[derive(Clone)]
40pub struct AdamPOptimizer {
41    /// Optimizer configuration.
42    config: OptimizerConfig,
43    /// First moment (mean) estimates.
44    m: HashMap<String, Array<f64, Ix2>>,
45    /// Second moment (variance) estimates.
46    v: HashMap<String, Array<f64, Ix2>>,
47    /// Time step for bias correction.
48    t: usize,
49    /// Nesterov momentum coefficient (default: 0.9).
50    nesterov: f64,
51    /// Delta parameter for projection (default: 0.1).
52    delta: f64,
53    /// Weight decay decoupling coefficient (default: 1.0).
54    wd_ratio: f64,
55}
56
57impl AdamPOptimizer {
58    /// Create a new AdamP optimizer with default parameters.
59    pub fn new(config: OptimizerConfig) -> Self {
60        Self {
61            config,
62            m: HashMap::new(),
63            v: HashMap::new(),
64            t: 0,
65            nesterov: 0.9,
66            delta: 0.1,
67            wd_ratio: 1.0,
68        }
69    }
70
71    /// Create AdamP with custom hyperparameters.
72    ///
73    /// # Arguments
74    /// * `config` - Base optimizer configuration
75    /// * `nesterov` - Nesterov momentum coefficient (default: 0.9)
76    /// * `delta` - Projection threshold (default: 0.1)
77    /// * `wd_ratio` - Weight decay ratio (default: 1.0)
78    pub fn with_params(config: OptimizerConfig, nesterov: f64, delta: f64, wd_ratio: f64) -> Self {
79        Self {
80            config,
81            m: HashMap::new(),
82            v: HashMap::new(),
83            t: 0,
84            nesterov,
85            delta,
86            wd_ratio,
87        }
88    }
89
90    /// Compute projection of weight decay.
91    ///
92    /// Projects weight decay into the space orthogonal to the gradient direction.
93    fn projection(
94        &self,
95        _param: &Array<f64, Ix2>,
96        grad: &Array<f64, Ix2>,
97        perturb: &Array<f64, Ix2>,
98        delta: f64,
99        wd_ratio: f64,
100    ) -> Array<f64, Ix2> {
101        // Compute gradient norm
102        let grad_norm = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
103        if grad_norm < 1e-12 {
104            return perturb.clone();
105        }
106
107        // Compute perturbation norm
108        let perturb_norm = perturb.iter().map(|&x| x * x).sum::<f64>().sqrt();
109        if perturb_norm < 1e-12 {
110            return perturb.clone();
111        }
112
113        // Compute cosine similarity
114        let dot_product: f64 = grad.iter().zip(perturb.iter()).map(|(&g, &p)| g * p).sum();
115        let cosine = dot_product / (grad_norm * perturb_norm + 1e-12);
116
117        // If cosine similarity is high, project perturbation
118        if cosine.abs() > delta {
119            // Project perturbation orthogonal to gradient
120            let scale = dot_product / (grad_norm * grad_norm + 1e-12);
121            let projection = grad.mapv(|x| x * scale);
122            let mut result = perturb - &projection;
123
124            // Scale the projection
125            let result_norm = result.iter().map(|&x| x * x).sum::<f64>().sqrt();
126            if result_norm > 1e-12 {
127                result = result.mapv(|x| x * perturb_norm / result_norm * wd_ratio);
128            }
129
130            result
131        } else {
132            // Use original perturbation
133            perturb.mapv(|x| x * wd_ratio)
134        }
135    }
136}
137
138impl Optimizer for AdamPOptimizer {
139    fn step(
140        &mut self,
141        parameters: &mut HashMap<String, Array<f64, Ix2>>,
142        gradients: &HashMap<String, Array<f64, Ix2>>,
143    ) -> TrainResult<()> {
144        self.t += 1;
145
146        let beta1 = self.config.beta1;
147        let beta2 = self.config.beta2;
148        let epsilon = self.config.epsilon;
149        let lr = self.config.learning_rate;
150        let weight_decay = self.config.weight_decay;
151
152        for (name, param) in parameters.iter_mut() {
153            let grad = gradients.get(name).ok_or_else(|| {
154                crate::TrainError::OptimizerError(format!("No gradient for parameter {}", name))
155            })?;
156
157            // Apply gradient clipping if configured
158            let grad = if let Some(clip_value) = self.config.grad_clip {
159                let mut clipped = grad.clone();
160                match self.config.grad_clip_mode {
161                    GradClipMode::Value => {
162                        clipped.mapv_inplace(|x| x.max(-clip_value).min(clip_value));
163                    }
164                    GradClipMode::Norm => {
165                        let norm = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
166                        if norm > clip_value {
167                            let scale = clip_value / norm;
168                            clipped.mapv_inplace(|x| x * scale);
169                        }
170                    }
171                }
172                clipped
173            } else {
174                grad.clone()
175            };
176
177            // Initialize moment estimates if needed
178            let m = self
179                .m
180                .entry(name.clone())
181                .or_insert_with(|| Array::zeros(param.raw_dim()));
182            let v = self
183                .v
184                .entry(name.clone())
185                .or_insert_with(|| Array::zeros(param.raw_dim()));
186
187            // Update biased first moment estimate
188            *m = m.mapv(|x| x * beta1) + grad.mapv(|x| x * (1.0 - beta1));
189
190            // Update biased second moment estimate
191            *v = v.mapv(|x| x * beta2) + grad.mapv(|x| x * x * (1.0 - beta2));
192
193            // Compute bias-corrected estimates
194            let m_hat = m.mapv(|x| x / (1.0 - beta1.powi(self.t as i32)));
195            let v_hat = v.mapv(|x| x / (1.0 - beta2.powi(self.t as i32)));
196
197            // Compute adaptive update
198            let update = &m_hat / &v_hat.mapv(|x| x.sqrt() + epsilon);
199
200            // Nesterov momentum
201            let perturb = if self.nesterov > 0.0 {
202                let nesterov_m = m.mapv(|x| x * beta1) + grad.mapv(|x| x * (1.0 - beta1));
203                let nesterov_m_hat =
204                    nesterov_m.mapv(|x| x / (1.0 - beta1.powi((self.t + 1) as i32)));
205                &nesterov_m_hat / &v_hat.mapv(|x| x.sqrt() + epsilon)
206            } else {
207                update.clone()
208            };
209
210            // Apply projection-based weight decay
211            if weight_decay > 0.0 {
212                let wd_perturb = param.mapv(|x| -x * weight_decay);
213                let projected_wd =
214                    self.projection(param, &grad, &wd_perturb, self.delta, self.wd_ratio);
215
216                // Update parameters with projected weight decay
217                *param = param.clone() - perturb.mapv(|x| x * lr) + projected_wd;
218            } else {
219                // Update parameters without weight decay
220                param.scaled_add(-lr, &perturb);
221            }
222        }
223
224        Ok(())
225    }
226
227    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
228        let mut state = HashMap::new();
229
230        // Save timestep
231        state.insert("t".to_string(), vec![self.t as f64]);
232
233        // Save moment estimates
234        for (name, m_val) in &self.m {
235            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
236        }
237        for (name, v_val) in &self.v {
238            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
239        }
240
241        state
242    }
243
244    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
245        // Load timestep
246        if let Some(t_vec) = state.get("t") {
247            self.t = t_vec[0] as usize;
248        }
249
250        // Load moment estimates
251        for (key, value) in &state {
252            if let Some(name) = key.strip_prefix("m_") {
253                if let Some(m) = self.m.get(name) {
254                    let shape = m.raw_dim();
255                    if let Ok(array) = Array::from_shape_vec(shape, value.clone()) {
256                        self.m.insert(name.to_string(), array);
257                    }
258                }
259            } else if let Some(name) = key.strip_prefix("v_") {
260                if let Some(v) = self.v.get(name) {
261                    let shape = v.raw_dim();
262                    if let Ok(array) = Array::from_shape_vec(shape, value.clone()) {
263                        self.v.insert(name.to_string(), array);
264                    }
265                }
266            }
267        }
268    }
269
270    fn get_lr(&self) -> f64 {
271        self.config.learning_rate
272    }
273
274    fn set_lr(&mut self, lr: f64) {
275        self.config.learning_rate = lr;
276    }
277
278    fn zero_grad(&mut self) {
279        // AdamP maintains state across steps
280        // Only reset on explicit request
281    }
282}
283
284impl AdamPOptimizer {
285    /// Reset optimizer state (clear momentum and timestep).
286    pub fn reset(&mut self) {
287        self.m.clear();
288        self.v.clear();
289        self.t = 0;
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use scirs2_core::ndarray::array;
297
298    #[test]
299    fn test_adamp_basic() {
300        let config = OptimizerConfig {
301            learning_rate: 0.01,
302            beta1: 0.9,
303            beta2: 0.999,
304            ..Default::default()
305        };
306
307        let mut optimizer = AdamPOptimizer::new(config);
308
309        let mut parameters = HashMap::new();
310        parameters.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
311
312        let mut gradients = HashMap::new();
313        gradients.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
314
315        // First step
316        optimizer.step(&mut parameters, &gradients).unwrap();
317
318        // Parameters should have changed
319        assert_ne!(parameters["w"][[0, 0]], 1.0);
320        assert_ne!(parameters["w"][[1, 1]], 4.0);
321
322        // Check that parameters decreased (gradient descent)
323        assert!(parameters["w"][[0, 0]] < 1.0);
324        assert!(parameters["w"][[1, 1]] < 4.0);
325    }
326
327    #[test]
328    fn test_adamp_with_weight_decay() {
329        let config = OptimizerConfig {
330            learning_rate: 0.01,
331            weight_decay: 0.1,
332            ..Default::default()
333        };
334
335        let mut optimizer = AdamPOptimizer::new(config);
336
337        let mut parameters = HashMap::new();
338        parameters.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
339
340        let mut gradients = HashMap::new();
341        gradients.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
342
343        let initial_param = parameters["w"].clone();
344
345        optimizer.step(&mut parameters, &gradients).unwrap();
346
347        // With projection-based weight decay, parameters should change differently
348        // than standard Adam
349        assert_ne!(parameters["w"], initial_param);
350    }
351
352    #[test]
353    fn test_adamp_state_dict() {
354        let config = OptimizerConfig {
355            learning_rate: 0.01,
356            ..Default::default()
357        };
358
359        let mut optimizer = AdamPOptimizer::new(config);
360
361        let mut parameters = HashMap::new();
362        parameters.insert("w".to_string(), array![[1.0, 2.0]]);
363
364        let mut gradients = HashMap::new();
365        gradients.insert("w".to_string(), array![[0.1, 0.2]]);
366
367        // Take a few steps
368        for _ in 0..5 {
369            optimizer.step(&mut parameters, &gradients).unwrap();
370        }
371
372        // Save state
373        let state = optimizer.state_dict();
374        assert!(state.contains_key("t"));
375        assert!(state.contains_key("m_w"));
376        assert!(state.contains_key("v_w"));
377
378        // Create new optimizer and load state
379        let mut new_optimizer = AdamPOptimizer::new(OptimizerConfig {
380            learning_rate: 0.01,
381            ..Default::default()
382        });
383
384        // Initialize with dummy step to create state
385        new_optimizer.step(&mut parameters, &gradients).unwrap();
386
387        // Load state
388        new_optimizer.load_state_dict(state);
389
390        assert_eq!(new_optimizer.t, 5);
391    }
392
393    #[test]
394    fn test_adamp_convergence() {
395        let config = OptimizerConfig {
396            learning_rate: 0.1,
397            ..Default::default()
398        };
399
400        let mut optimizer = AdamPOptimizer::new(config);
401
402        let mut parameters = HashMap::new();
403        parameters.insert("w".to_string(), array![[5.0, 5.0]]);
404
405        // Target is [0, 0], so gradient points toward origin
406        for _ in 0..100 {
407            let grad = parameters["w"].mapv(|x| x * 0.1); // Gradient proportional to distance
408            let mut gradients = HashMap::new();
409            gradients.insert("w".to_string(), grad);
410
411            optimizer.step(&mut parameters, &gradients).unwrap();
412        }
413
414        // Should converge toward zero
415        assert!(parameters["w"][[0, 0]].abs() < 1.0);
416        assert!(parameters["w"][[0, 1]].abs() < 1.0);
417    }
418
419    #[test]
420    fn test_adamp_projection() {
421        let config = OptimizerConfig {
422            learning_rate: 0.01,
423            weight_decay: 0.1,
424            ..Default::default()
425        };
426
427        let optimizer = AdamPOptimizer::with_params(config, 0.9, 0.1, 1.0);
428
429        let param = array![[1.0, 2.0], [3.0, 4.0]];
430        let grad = array![[0.1, 0.2], [0.3, 0.4]];
431        let perturb = array![[-0.1, -0.2], [-0.3, -0.4]];
432
433        let projected = optimizer.projection(&param, &grad, &perturb, 0.1, 1.0);
434
435        // Projection should produce a result
436        assert_eq!(projected.shape(), perturb.shape());
437    }
438
439    #[test]
440    fn test_adamp_nesterov() {
441        let config = OptimizerConfig {
442            learning_rate: 0.01,
443            ..Default::default()
444        };
445
446        // With Nesterov
447        let mut opt_nesterov = AdamPOptimizer::with_params(config.clone(), 0.9, 0.1, 1.0);
448
449        // Without Nesterov
450        let mut opt_standard = AdamPOptimizer::with_params(config, 0.0, 0.1, 1.0);
451
452        let mut params1 = HashMap::new();
453        params1.insert("w".to_string(), array![[1.0, 2.0]]);
454
455        let mut params2 = params1.clone();
456
457        let mut gradients = HashMap::new();
458        gradients.insert("w".to_string(), array![[0.1, 0.2]]);
459
460        opt_nesterov.step(&mut params1, &gradients).unwrap();
461        opt_standard.step(&mut params2, &gradients).unwrap();
462
463        // With Nesterov should produce different results
464        // (though they may be close for a single step)
465        assert!(
466            params1["w"][[0, 0]] != params2["w"][[0, 0]]
467                || (params1["w"][[0, 0]] - params2["w"][[0, 0]]).abs() < 1e-10
468        );
469    }
470}