zenu_optimizer/
adamw.rs

1use std::{cell::RefCell, collections::HashMap, rc::Rc};
2
3use zenu_autograd::{creator::zeros::zeros_like, Variable};
4use zenu_layer::Parameters;
5use zenu_matrix::{device::Device, num::Num};
6
7use crate::Optimizer;
8
9pub struct AdamW<T: Num, D: Device> {
10    learning_rate: T,
11    beta1: T,
12    beta2: T,
13    epsilon: T,
14    weight_decay: T,
15    step: Rc<RefCell<T>>,
16    m: HashMap<String, Variable<T, D>>,
17    v: HashMap<String, Variable<T, D>>,
18}
19
20impl<T: Num, D: Device, P: Parameters<T, D>> Optimizer<T, D, P> for AdamW<T, D> {
21    fn update(&self, parameters: &P) {
22        let step = *self.step.borrow() + T::one();
23        *self.step.borrow_mut() = step;
24
25        let beta1_t = self.beta1.powf(step);
26        let beta2_t = self.beta2.powf(step);
27
28        let weight_keys: Vec<_> = parameters.weights().keys().cloned().collect();
29
30        let params = parameters
31            .parameters()
32            .iter()
33            .filter_map(|(key, value)| {
34                value
35                    .get_grad()
36                    .map(|grad| (key.clone(), (value.clone(), grad.clone())))
37            })
38            .collect::<Vec<_>>();
39
40        for (k, (data, grad)) in params {
41            let m = self.m.get(&k).unwrap();
42            let v = self.v.get(&k).unwrap();
43            let mut m = m.get_as_mut();
44            let mut v = v.get_as_mut();
45            let grad = grad.get_as_ref();
46
47            // Update m and v
48            m *= self.beta1;
49            m += grad.to_ref() * (T::one() - self.beta1);
50
51            v *= self.beta2;
52            v += grad.to_ref() * grad.to_ref() * (T::one() - self.beta2);
53
54            let m_hat = m.clone() / (T::one() - beta1_t);
55            let v_hat = v.clone() / (T::one() - beta2_t);
56
57            let denom = v_hat.sqrt() + self.epsilon;
58            let step_size = self.learning_rate;
59            let update = m_hat / denom;
60
61            if weight_keys.contains(&k) {
62                data.get_as_mut().sub_assign(
63                    &(data.get_as_ref() * self.learning_rate * self.weight_decay).to_ref(),
64                );
65            }
66
67            data.get_as_mut().sub_assign(&(update * step_size).to_ref());
68        }
69    }
70}
71impl<T: Num, D: Device> AdamW<T, D> {
72    pub fn new(
73        learning_rate: T,
74        beta1: T,
75        beta2: T,
76        epsilon: T,
77        weight_decay: T,
78        model: &impl Parameters<T, D>,
79    ) -> Self {
80        let m = model
81            .parameters()
82            .iter()
83            .map(|(key, value)| (key.clone(), zeros_like(value)))
84            .collect();
85        let v = model
86            .parameters()
87            .iter()
88            .map(|(key, value)| (key.clone(), zeros_like(value)))
89            .collect();
90        Self {
91            learning_rate,
92            beta1,
93            beta2,
94            epsilon,
95            weight_decay,
96            step: Rc::new(RefCell::new(T::zero())),
97            m,
98            v,
99        }
100    }
101}