tensorlogic_train/optimizers/
adamp.rs1use crate::optimizer::{GradClipMode, Optimizer, OptimizerConfig};
10use crate::TrainResult;
11use scirs2_core::ndarray::{Array, Ix2};
12use std::collections::HashMap;
13
14#[derive(Clone)]
40pub struct AdamPOptimizer {
41 config: OptimizerConfig,
43 m: HashMap<String, Array<f64, Ix2>>,
45 v: HashMap<String, Array<f64, Ix2>>,
47 t: usize,
49 nesterov: f64,
51 delta: f64,
53 wd_ratio: f64,
55}
56
57impl AdamPOptimizer {
58 pub fn new(config: OptimizerConfig) -> Self {
60 Self {
61 config,
62 m: HashMap::new(),
63 v: HashMap::new(),
64 t: 0,
65 nesterov: 0.9,
66 delta: 0.1,
67 wd_ratio: 1.0,
68 }
69 }
70
71 pub fn with_params(config: OptimizerConfig, nesterov: f64, delta: f64, wd_ratio: f64) -> Self {
79 Self {
80 config,
81 m: HashMap::new(),
82 v: HashMap::new(),
83 t: 0,
84 nesterov,
85 delta,
86 wd_ratio,
87 }
88 }
89
90 fn projection(
94 &self,
95 _param: &Array<f64, Ix2>,
96 grad: &Array<f64, Ix2>,
97 perturb: &Array<f64, Ix2>,
98 delta: f64,
99 wd_ratio: f64,
100 ) -> Array<f64, Ix2> {
101 let grad_norm = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
103 if grad_norm < 1e-12 {
104 return perturb.clone();
105 }
106
107 let perturb_norm = perturb.iter().map(|&x| x * x).sum::<f64>().sqrt();
109 if perturb_norm < 1e-12 {
110 return perturb.clone();
111 }
112
113 let dot_product: f64 = grad.iter().zip(perturb.iter()).map(|(&g, &p)| g * p).sum();
115 let cosine = dot_product / (grad_norm * perturb_norm + 1e-12);
116
117 if cosine.abs() > delta {
119 let scale = dot_product / (grad_norm * grad_norm + 1e-12);
121 let projection = grad.mapv(|x| x * scale);
122 let mut result = perturb - &projection;
123
124 let result_norm = result.iter().map(|&x| x * x).sum::<f64>().sqrt();
126 if result_norm > 1e-12 {
127 result = result.mapv(|x| x * perturb_norm / result_norm * wd_ratio);
128 }
129
130 result
131 } else {
132 perturb.mapv(|x| x * wd_ratio)
134 }
135 }
136}
137
138impl Optimizer for AdamPOptimizer {
139 fn step(
140 &mut self,
141 parameters: &mut HashMap<String, Array<f64, Ix2>>,
142 gradients: &HashMap<String, Array<f64, Ix2>>,
143 ) -> TrainResult<()> {
144 self.t += 1;
145
146 let beta1 = self.config.beta1;
147 let beta2 = self.config.beta2;
148 let epsilon = self.config.epsilon;
149 let lr = self.config.learning_rate;
150 let weight_decay = self.config.weight_decay;
151
152 for (name, param) in parameters.iter_mut() {
153 let grad = gradients.get(name).ok_or_else(|| {
154 crate::TrainError::OptimizerError(format!("No gradient for parameter {}", name))
155 })?;
156
157 let grad = if let Some(clip_value) = self.config.grad_clip {
159 let mut clipped = grad.clone();
160 match self.config.grad_clip_mode {
161 GradClipMode::Value => {
162 clipped.mapv_inplace(|x| x.max(-clip_value).min(clip_value));
163 }
164 GradClipMode::Norm => {
165 let norm = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
166 if norm > clip_value {
167 let scale = clip_value / norm;
168 clipped.mapv_inplace(|x| x * scale);
169 }
170 }
171 }
172 clipped
173 } else {
174 grad.clone()
175 };
176
177 let m = self
179 .m
180 .entry(name.clone())
181 .or_insert_with(|| Array::zeros(param.raw_dim()));
182 let v = self
183 .v
184 .entry(name.clone())
185 .or_insert_with(|| Array::zeros(param.raw_dim()));
186
187 *m = m.mapv(|x| x * beta1) + grad.mapv(|x| x * (1.0 - beta1));
189
190 *v = v.mapv(|x| x * beta2) + grad.mapv(|x| x * x * (1.0 - beta2));
192
193 let m_hat = m.mapv(|x| x / (1.0 - beta1.powi(self.t as i32)));
195 let v_hat = v.mapv(|x| x / (1.0 - beta2.powi(self.t as i32)));
196
197 let update = &m_hat / &v_hat.mapv(|x| x.sqrt() + epsilon);
199
200 let perturb = if self.nesterov > 0.0 {
202 let nesterov_m = m.mapv(|x| x * beta1) + grad.mapv(|x| x * (1.0 - beta1));
203 let nesterov_m_hat =
204 nesterov_m.mapv(|x| x / (1.0 - beta1.powi((self.t + 1) as i32)));
205 &nesterov_m_hat / &v_hat.mapv(|x| x.sqrt() + epsilon)
206 } else {
207 update.clone()
208 };
209
210 if weight_decay > 0.0 {
212 let wd_perturb = param.mapv(|x| -x * weight_decay);
213 let projected_wd =
214 self.projection(param, &grad, &wd_perturb, self.delta, self.wd_ratio);
215
216 *param = param.clone() - perturb.mapv(|x| x * lr) + projected_wd;
218 } else {
219 param.scaled_add(-lr, &perturb);
221 }
222 }
223
224 Ok(())
225 }
226
227 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
228 let mut state = HashMap::new();
229
230 state.insert("t".to_string(), vec![self.t as f64]);
232
233 for (name, m_val) in &self.m {
235 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
236 }
237 for (name, v_val) in &self.v {
238 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
239 }
240
241 state
242 }
243
244 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
245 if let Some(t_vec) = state.get("t") {
247 self.t = t_vec[0] as usize;
248 }
249
250 for (key, value) in &state {
252 if let Some(name) = key.strip_prefix("m_") {
253 if let Some(m) = self.m.get(name) {
254 let shape = m.raw_dim();
255 if let Ok(array) = Array::from_shape_vec(shape, value.clone()) {
256 self.m.insert(name.to_string(), array);
257 }
258 }
259 } else if let Some(name) = key.strip_prefix("v_") {
260 if let Some(v) = self.v.get(name) {
261 let shape = v.raw_dim();
262 if let Ok(array) = Array::from_shape_vec(shape, value.clone()) {
263 self.v.insert(name.to_string(), array);
264 }
265 }
266 }
267 }
268 }
269
270 fn get_lr(&self) -> f64 {
271 self.config.learning_rate
272 }
273
274 fn set_lr(&mut self, lr: f64) {
275 self.config.learning_rate = lr;
276 }
277
278 fn zero_grad(&mut self) {
279 }
282}
283
284impl AdamPOptimizer {
285 pub fn reset(&mut self) {
287 self.m.clear();
288 self.v.clear();
289 self.t = 0;
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use scirs2_core::ndarray::array;
297
298 #[test]
299 fn test_adamp_basic() {
300 let config = OptimizerConfig {
301 learning_rate: 0.01,
302 beta1: 0.9,
303 beta2: 0.999,
304 ..Default::default()
305 };
306
307 let mut optimizer = AdamPOptimizer::new(config);
308
309 let mut parameters = HashMap::new();
310 parameters.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
311
312 let mut gradients = HashMap::new();
313 gradients.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
314
315 optimizer.step(&mut parameters, &gradients).expect("unwrap");
317
318 assert_ne!(parameters["w"][[0, 0]], 1.0);
320 assert_ne!(parameters["w"][[1, 1]], 4.0);
321
322 assert!(parameters["w"][[0, 0]] < 1.0);
324 assert!(parameters["w"][[1, 1]] < 4.0);
325 }
326
327 #[test]
328 fn test_adamp_with_weight_decay() {
329 let config = OptimizerConfig {
330 learning_rate: 0.01,
331 weight_decay: 0.1,
332 ..Default::default()
333 };
334
335 let mut optimizer = AdamPOptimizer::new(config);
336
337 let mut parameters = HashMap::new();
338 parameters.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
339
340 let mut gradients = HashMap::new();
341 gradients.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
342
343 let initial_param = parameters["w"].clone();
344
345 optimizer.step(&mut parameters, &gradients).expect("unwrap");
346
347 assert_ne!(parameters["w"], initial_param);
350 }
351
352 #[test]
353 fn test_adamp_state_dict() {
354 let config = OptimizerConfig {
355 learning_rate: 0.01,
356 ..Default::default()
357 };
358
359 let mut optimizer = AdamPOptimizer::new(config);
360
361 let mut parameters = HashMap::new();
362 parameters.insert("w".to_string(), array![[1.0, 2.0]]);
363
364 let mut gradients = HashMap::new();
365 gradients.insert("w".to_string(), array![[0.1, 0.2]]);
366
367 for _ in 0..5 {
369 optimizer.step(&mut parameters, &gradients).expect("unwrap");
370 }
371
372 let state = optimizer.state_dict();
374 assert!(state.contains_key("t"));
375 assert!(state.contains_key("m_w"));
376 assert!(state.contains_key("v_w"));
377
378 let mut new_optimizer = AdamPOptimizer::new(OptimizerConfig {
380 learning_rate: 0.01,
381 ..Default::default()
382 });
383
384 new_optimizer
386 .step(&mut parameters, &gradients)
387 .expect("unwrap");
388
389 new_optimizer.load_state_dict(state);
391
392 assert_eq!(new_optimizer.t, 5);
393 }
394
395 #[test]
396 fn test_adamp_convergence() {
397 let config = OptimizerConfig {
398 learning_rate: 0.1,
399 ..Default::default()
400 };
401
402 let mut optimizer = AdamPOptimizer::new(config);
403
404 let mut parameters = HashMap::new();
405 parameters.insert("w".to_string(), array![[5.0, 5.0]]);
406
407 for _ in 0..100 {
409 let grad = parameters["w"].mapv(|x| x * 0.1); let mut gradients = HashMap::new();
411 gradients.insert("w".to_string(), grad);
412
413 optimizer.step(&mut parameters, &gradients).expect("unwrap");
414 }
415
416 assert!(parameters["w"][[0, 0]].abs() < 1.0);
418 assert!(parameters["w"][[0, 1]].abs() < 1.0);
419 }
420
421 #[test]
422 fn test_adamp_projection() {
423 let config = OptimizerConfig {
424 learning_rate: 0.01,
425 weight_decay: 0.1,
426 ..Default::default()
427 };
428
429 let optimizer = AdamPOptimizer::with_params(config, 0.9, 0.1, 1.0);
430
431 let param = array![[1.0, 2.0], [3.0, 4.0]];
432 let grad = array![[0.1, 0.2], [0.3, 0.4]];
433 let perturb = array![[-0.1, -0.2], [-0.3, -0.4]];
434
435 let projected = optimizer.projection(¶m, &grad, &perturb, 0.1, 1.0);
436
437 assert_eq!(projected.shape(), perturb.shape());
439 }
440
441 #[test]
442 fn test_adamp_nesterov() {
443 let config = OptimizerConfig {
444 learning_rate: 0.01,
445 ..Default::default()
446 };
447
448 let mut opt_nesterov = AdamPOptimizer::with_params(config.clone(), 0.9, 0.1, 1.0);
450
451 let mut opt_standard = AdamPOptimizer::with_params(config, 0.0, 0.1, 1.0);
453
454 let mut params1 = HashMap::new();
455 params1.insert("w".to_string(), array![[1.0, 2.0]]);
456
457 let mut params2 = params1.clone();
458
459 let mut gradients = HashMap::new();
460 gradients.insert("w".to_string(), array![[0.1, 0.2]]);
461
462 opt_nesterov.step(&mut params1, &gradients).expect("unwrap");
463 opt_standard.step(&mut params2, &gradients).expect("unwrap");
464
465 assert!(
468 params1["w"][[0, 0]] != params2["w"][[0, 0]]
469 || (params1["w"][[0, 0]] - params2["w"][[0, 0]]).abs() < 1e-10
470 );
471 }
472}