1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2
3use zenu_autograd::{creator::zeros::zeros_like, Variable};
4use zenu_layer::Parameters;
5use zenu_matrix::{device::Device, num::Num};
6
7use crate::Optimizer;
8
9pub struct AdamW<T: Num, D: Device> {
10 learning_rate: T,
11 beta1: T,
12 beta2: T,
13 epsilon: T,
14 weight_decay: T,
15 step: Rc<RefCell<T>>,
16 m: HashMap<String, Variable<T, D>>,
17 v: HashMap<String, Variable<T, D>>,
18}
19
20impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for AdamW<T, D> {
21 fn update(&self, parameters: &P) {
22 let step = *self.step.borrow() + T::one();
23 *self.step.borrow_mut() = step;
24
25 let beta1_t = self.beta1.powf(step);
26 let beta2_t = self.beta2.powf(step);
27
28 let weight_keys: Vec<_> = parameters.weights().keys().cloned().collect();
29
30 let params = parameters
31 .parameters()
32 .iter()
33 .filter_map(|(key, value)| {
34 value
35 .get_grad()
36 .map(|grad| (key.clone(), (value.clone(), grad.clone())))
37 })
38 .collect::<Vec<_>>();
39
40 for (k, (data, grad)) in params {
41 let m = self.m.get(&k).unwrap();
42 let v = self.v.get(&k).unwrap();
43 let mut m = m.get_as_mut();
44 let mut v = v.get_as_mut();
45 let grad = grad.get_as_ref();
46
47 m *= self.beta1;
49 m += grad.to_ref() * (T::one() - self.beta1);
50
51 v *= self.beta2;
52 v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2);
53
54 let m_hat = m.clone() / (T::one() - beta1_t);
55 let v_hat = v.clone() / (T::one() - beta2_t);
56
57 let denom = v_hat.sqrt() + self.epsilon;
58 let step_size = self.learning_rate;
59 let update = m_hat / denom;
60
61 if weight_keys.contains(&k) {
62 data.get_as_mut().sub_assign(
63 &(data.get_as_ref() * self.learning_rate * self.weight_decay).to_ref(),
64 );
65 }
66
67 data.get_as_mut().sub_assign(&(update * step_size).to_ref());
68 }
69 }
70}
71impl<T: Num, D: Device> AdamW<T, D> {
72 pub fn new(
73 learning_rate: T,
74 beta1: T,
75 beta2: T,
76 epsilon: T,
77 weight_decay: T,
78 model: &impl Parameters<T, D>,
79 ) -> Self {
80 let m = model
81 .parameters()
82 .iter()
83 .map(|(key, value)| (key.clone(), zeros_like(value)))
84 .collect();
85 let v = model
86 .parameters()
87 .iter()
88 .map(|(key, value)| (key.clone(), zeros_like(value)))
89 .collect();
90 Self {
91 learning_rate,
92 beta1,
93 beta2,
94 epsilon,
95 weight_decay,
96 step: Rc::new(RefCell::new(T::zero())),
97 m,
98 v,
99 }
100 }
101}