tensorlogic_train/optimizers/
adamp.rs1use crate::optimizer::{GradClipMode, Optimizer, OptimizerConfig};
10use crate::TrainResult;
11use scirs2_core::ndarray::{Array, Ix2};
12use std::collections::HashMap;
13
14#[derive(Clone)]
40pub struct AdamPOptimizer {
41 config: OptimizerConfig,
43 m: HashMap<String, Array<f64, Ix2>>,
45 v: HashMap<String, Array<f64, Ix2>>,
47 t: usize,
49 nesterov: f64,
51 delta: f64,
53 wd_ratio: f64,
55}
56
57impl AdamPOptimizer {
58 pub fn new(config: OptimizerConfig) -> Self {
60 Self {
61 config,
62 m: HashMap::new(),
63 v: HashMap::new(),
64 t: 0,
65 nesterov: 0.9,
66 delta: 0.1,
67 wd_ratio: 1.0,
68 }
69 }
70
71 pub fn with_params(config: OptimizerConfig, nesterov: f64, delta: f64, wd_ratio: f64) -> Self {
79 Self {
80 config,
81 m: HashMap::new(),
82 v: HashMap::new(),
83 t: 0,
84 nesterov,
85 delta,
86 wd_ratio,
87 }
88 }
89
90 fn projection(
94 &self,
95 _param: &Array<f64, Ix2>,
96 grad: &Array<f64, Ix2>,
97 perturb: &Array<f64, Ix2>,
98 delta: f64,
99 wd_ratio: f64,
100 ) -> Array<f64, Ix2> {
101 let grad_norm = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
103 if grad_norm < 1e-12 {
104 return perturb.clone();
105 }
106
107 let perturb_norm = perturb.iter().map(|&x| x * x).sum::<f64>().sqrt();
109 if perturb_norm < 1e-12 {
110 return perturb.clone();
111 }
112
113 let dot_product: f64 = grad.iter().zip(perturb.iter()).map(|(&g, &p)| g * p).sum();
115 let cosine = dot_product / (grad_norm * perturb_norm + 1e-12);
116
117 if cosine.abs() > delta {
119 let scale = dot_product / (grad_norm * grad_norm + 1e-12);
121 let projection = grad.mapv(|x| x * scale);
122 let mut result = perturb - &projection;
123
124 let result_norm = result.iter().map(|&x| x * x).sum::<f64>().sqrt();
126 if result_norm > 1e-12 {
127 result = result.mapv(|x| x * perturb_norm / result_norm * wd_ratio);
128 }
129
130 result
131 } else {
132 perturb.mapv(|x| x * wd_ratio)
134 }
135 }
136}
137
138impl Optimizer for AdamPOptimizer {
139 fn step(
140 &mut self,
141 parameters: &mut HashMap<String, Array<f64, Ix2>>,
142 gradients: &HashMap<String, Array<f64, Ix2>>,
143 ) -> TrainResult<()> {
144 self.t += 1;
145
146 let beta1 = self.config.beta1;
147 let beta2 = self.config.beta2;
148 let epsilon = self.config.epsilon;
149 let lr = self.config.learning_rate;
150 let weight_decay = self.config.weight_decay;
151
152 for (name, param) in parameters.iter_mut() {
153 let grad = gradients.get(name).ok_or_else(|| {
154 crate::TrainError::OptimizerError(format!("No gradient for parameter {}", name))
155 })?;
156
157 let grad = if let Some(clip_value) = self.config.grad_clip {
159 let mut clipped = grad.clone();
160 match self.config.grad_clip_mode {
161 GradClipMode::Value => {
162 clipped.mapv_inplace(|x| x.max(-clip_value).min(clip_value));
163 }
164 GradClipMode::Norm => {
165 let norm = grad.iter().map(|&x| x * x).sum::<f64>().sqrt();
166 if norm > clip_value {
167 let scale = clip_value / norm;
168 clipped.mapv_inplace(|x| x * scale);
169 }
170 }
171 }
172 clipped
173 } else {
174 grad.clone()
175 };
176
177 let m = self
179 .m
180 .entry(name.clone())
181 .or_insert_with(|| Array::zeros(param.raw_dim()));
182 let v = self
183 .v
184 .entry(name.clone())
185 .or_insert_with(|| Array::zeros(param.raw_dim()));
186
187 *m = m.mapv(|x| x * beta1) + grad.mapv(|x| x * (1.0 - beta1));
189
190 *v = v.mapv(|x| x * beta2) + grad.mapv(|x| x * x * (1.0 - beta2));
192
193 let m_hat = m.mapv(|x| x / (1.0 - beta1.powi(self.t as i32)));
195 let v_hat = v.mapv(|x| x / (1.0 - beta2.powi(self.t as i32)));
196
197 let update = &m_hat / &v_hat.mapv(|x| x.sqrt() + epsilon);
199
200 let perturb = if self.nesterov > 0.0 {
202 let nesterov_m = m.mapv(|x| x * beta1) + grad.mapv(|x| x * (1.0 - beta1));
203 let nesterov_m_hat =
204 nesterov_m.mapv(|x| x / (1.0 - beta1.powi((self.t + 1) as i32)));
205 &nesterov_m_hat / &v_hat.mapv(|x| x.sqrt() + epsilon)
206 } else {
207 update.clone()
208 };
209
210 if weight_decay > 0.0 {
212 let wd_perturb = param.mapv(|x| -x * weight_decay);
213 let projected_wd =
214 self.projection(param, &grad, &wd_perturb, self.delta, self.wd_ratio);
215
216 *param = param.clone() - perturb.mapv(|x| x * lr) + projected_wd;
218 } else {
219 param.scaled_add(-lr, &perturb);
221 }
222 }
223
224 Ok(())
225 }
226
227 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
228 let mut state = HashMap::new();
229
230 state.insert("t".to_string(), vec![self.t as f64]);
232
233 for (name, m_val) in &self.m {
235 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
236 }
237 for (name, v_val) in &self.v {
238 state.insert(format!("v_{}", name), v_val.iter().copied().collect());
239 }
240
241 state
242 }
243
244 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
245 if let Some(t_vec) = state.get("t") {
247 self.t = t_vec[0] as usize;
248 }
249
250 for (key, value) in &state {
252 if let Some(name) = key.strip_prefix("m_") {
253 if let Some(m) = self.m.get(name) {
254 let shape = m.raw_dim();
255 if let Ok(array) = Array::from_shape_vec(shape, value.clone()) {
256 self.m.insert(name.to_string(), array);
257 }
258 }
259 } else if let Some(name) = key.strip_prefix("v_") {
260 if let Some(v) = self.v.get(name) {
261 let shape = v.raw_dim();
262 if let Ok(array) = Array::from_shape_vec(shape, value.clone()) {
263 self.v.insert(name.to_string(), array);
264 }
265 }
266 }
267 }
268 }
269
270 fn get_lr(&self) -> f64 {
271 self.config.learning_rate
272 }
273
274 fn set_lr(&mut self, lr: f64) {
275 self.config.learning_rate = lr;
276 }
277
278 fn zero_grad(&mut self) {
279 }
282}
283
284impl AdamPOptimizer {
285 pub fn reset(&mut self) {
287 self.m.clear();
288 self.v.clear();
289 self.t = 0;
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use scirs2_core::ndarray::array;
297
298 #[test]
299 fn test_adamp_basic() {
300 let config = OptimizerConfig {
301 learning_rate: 0.01,
302 beta1: 0.9,
303 beta2: 0.999,
304 ..Default::default()
305 };
306
307 let mut optimizer = AdamPOptimizer::new(config);
308
309 let mut parameters = HashMap::new();
310 parameters.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
311
312 let mut gradients = HashMap::new();
313 gradients.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
314
315 optimizer.step(&mut parameters, &gradients).unwrap();
317
318 assert_ne!(parameters["w"][[0, 0]], 1.0);
320 assert_ne!(parameters["w"][[1, 1]], 4.0);
321
322 assert!(parameters["w"][[0, 0]] < 1.0);
324 assert!(parameters["w"][[1, 1]] < 4.0);
325 }
326
327 #[test]
328 fn test_adamp_with_weight_decay() {
329 let config = OptimizerConfig {
330 learning_rate: 0.01,
331 weight_decay: 0.1,
332 ..Default::default()
333 };
334
335 let mut optimizer = AdamPOptimizer::new(config);
336
337 let mut parameters = HashMap::new();
338 parameters.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
339
340 let mut gradients = HashMap::new();
341 gradients.insert("w".to_string(), array![[0.1, 0.2], [0.3, 0.4]]);
342
343 let initial_param = parameters["w"].clone();
344
345 optimizer.step(&mut parameters, &gradients).unwrap();
346
347 assert_ne!(parameters["w"], initial_param);
350 }
351
352 #[test]
353 fn test_adamp_state_dict() {
354 let config = OptimizerConfig {
355 learning_rate: 0.01,
356 ..Default::default()
357 };
358
359 let mut optimizer = AdamPOptimizer::new(config);
360
361 let mut parameters = HashMap::new();
362 parameters.insert("w".to_string(), array![[1.0, 2.0]]);
363
364 let mut gradients = HashMap::new();
365 gradients.insert("w".to_string(), array![[0.1, 0.2]]);
366
367 for _ in 0..5 {
369 optimizer.step(&mut parameters, &gradients).unwrap();
370 }
371
372 let state = optimizer.state_dict();
374 assert!(state.contains_key("t"));
375 assert!(state.contains_key("m_w"));
376 assert!(state.contains_key("v_w"));
377
378 let mut new_optimizer = AdamPOptimizer::new(OptimizerConfig {
380 learning_rate: 0.01,
381 ..Default::default()
382 });
383
384 new_optimizer.step(&mut parameters, &gradients).unwrap();
386
387 new_optimizer.load_state_dict(state);
389
390 assert_eq!(new_optimizer.t, 5);
391 }
392
393 #[test]
394 fn test_adamp_convergence() {
395 let config = OptimizerConfig {
396 learning_rate: 0.1,
397 ..Default::default()
398 };
399
400 let mut optimizer = AdamPOptimizer::new(config);
401
402 let mut parameters = HashMap::new();
403 parameters.insert("w".to_string(), array![[5.0, 5.0]]);
404
405 for _ in 0..100 {
407 let grad = parameters["w"].mapv(|x| x * 0.1); let mut gradients = HashMap::new();
409 gradients.insert("w".to_string(), grad);
410
411 optimizer.step(&mut parameters, &gradients).unwrap();
412 }
413
414 assert!(parameters["w"][[0, 0]].abs() < 1.0);
416 assert!(parameters["w"][[0, 1]].abs() < 1.0);
417 }
418
419 #[test]
420 fn test_adamp_projection() {
421 let config = OptimizerConfig {
422 learning_rate: 0.01,
423 weight_decay: 0.1,
424 ..Default::default()
425 };
426
427 let optimizer = AdamPOptimizer::with_params(config, 0.9, 0.1, 1.0);
428
429 let param = array![[1.0, 2.0], [3.0, 4.0]];
430 let grad = array![[0.1, 0.2], [0.3, 0.4]];
431 let perturb = array![[-0.1, -0.2], [-0.3, -0.4]];
432
433 let projected = optimizer.projection(¶m, &grad, &perturb, 0.1, 1.0);
434
435 assert_eq!(projected.shape(), perturb.shape());
437 }
438
439 #[test]
440 fn test_adamp_nesterov() {
441 let config = OptimizerConfig {
442 learning_rate: 0.01,
443 ..Default::default()
444 };
445
446 let mut opt_nesterov = AdamPOptimizer::with_params(config.clone(), 0.9, 0.1, 1.0);
448
449 let mut opt_standard = AdamPOptimizer::with_params(config, 0.0, 0.1, 1.0);
451
452 let mut params1 = HashMap::new();
453 params1.insert("w".to_string(), array![[1.0, 2.0]]);
454
455 let mut params2 = params1.clone();
456
457 let mut gradients = HashMap::new();
458 gradients.insert("w".to_string(), array![[0.1, 0.2]]);
459
460 opt_nesterov.step(&mut params1, &gradients).unwrap();
461 opt_standard.step(&mut params2, &gradients).unwrap();
462
463 assert!(
466 params1["w"][[0, 0]] != params2["w"][[0, 0]]
467 || (params1["w"][[0, 0]] - params2["w"][[0, 0]]).abs() < 1e-10
468 );
469 }
470}