1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2
3use zenu_autograd::{creator::zeros::zeros_like, Variable};
4use zenu_layer::Parameters;
5use zenu_matrix::{device::Device, num::Num};
6
7use crate::Optimizer;
8
9pub struct Adam<T: Num, D: Device> {
10 learning_rate: T,
11 beta1: T,
12 beta2: T,
13 epsilon: T,
14 step: Rc<RefCell<usize>>,
15 pub m: HashMap<String, Variable<T, D>>,
16 pub v: HashMap<String, Variable<T, D>>,
17}
18
19impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for Adam<T, D> {
20 fn update(&self, parameters: &P) {
21 *self.step.borrow_mut() += 1;
22 let step = T::from_usize(*self.step.borrow());
23
24 let beta1_t = self.beta1.powf(step);
25 let beta2_t = self.beta2.powf(step);
26
27 let parameters = parameters
28 .parameters()
29 .iter()
30 .filter_map(|(key, value)| {
31 value
32 .get_grad()
33 .map(|grad| (key.clone(), (value.clone(), grad.clone())))
34 })
35 .collect::<Vec<_>>();
36
37 for (k, (data, grad)) in ¶meters {
38 let v = self.v.get(k).unwrap();
39 let m = self.m.get(k).unwrap();
40 let mut v = v.get_as_mut();
41 let mut m = m.get_as_mut();
42 let grad = grad.get_as_ref();
43
44 m *= self.beta1;
45 m += grad.to_ref() * (T::one() - self.beta1);
46
47 v *= self.beta2;
48 v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2);
49
50 let m_hat = m.clone() / (T::one() - beta1_t);
51 let v_hat = v.clone() / (T::one() - beta2_t);
52
53 let m_v_hat = m_hat / (v_hat.sqrt() + self.epsilon);
54 let lr_mv_hat = m_v_hat * self.learning_rate;
55
56 data.get_as_mut().sub_assign(&lr_mv_hat.to_ref());
57 }
58 }
59}
60
61impl<T: Num, D: Device> Adam<T, D> {
62 pub fn new(
63 learning_rate: T,
64 beta1: T,
65 beta2: T,
66 epsilon: T,
67 model: &impl Parameters<T, D>,
68 ) -> Self {
69 let m = model
70 .parameters()
71 .iter()
72 .map(|(key, value)| (key.clone(), zeros_like(value)))
73 .collect();
74 let v = model
75 .parameters()
76 .iter()
77 .map(|(key, value)| (key.clone(), zeros_like(value)))
78 .collect();
79 Self {
80 learning_rate,
81 beta1,
82 beta2,
83 epsilon,
84 step: Rc::new(RefCell::new(0)),
85 m,
86 v,
87 }
88 }
89}