1use zyx::Tensor;
5use zyx_derive::Module;
6
7#[derive(Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct AdamW {
11 pub learning_rate: f32,
13 pub betas: (f32, f32),
15 pub eps: f32,
17 pub weight_decay: f32,
19 pub amsgrad: bool,
21 pub m: Vec<Tensor>,
23 pub v: Vec<Tensor>,
25 pub vm: Vec<Tensor>,
27 pub t: usize,
29}
30
31impl Default for AdamW {
32 fn default() -> Self {
33 Self {
34 learning_rate: 0.001,
35 betas: (0.9, 0.999),
36 eps: 1e-8,
37 weight_decay: 0.0,
38 amsgrad: false,
39 m: Vec::new(),
40 v: Vec::new(),
41 vm: Vec::new(),
42 t: 0,
43 }
44 }
45}
46
47impl AdamW {
48 pub fn update<'a>(
52 &mut self,
53 parameters: impl IntoIterator<Item = &'a mut Tensor>,
54 gradients: impl IntoIterator<Item = Option<Tensor>>,
55 ) {
56 use zyx::Scalar;
57 self.t += 1;
58 for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
59 let Some(grad) = grad else {
60 if self.m.len() <= i {
62 self.m.push(Tensor::zeros_like(&*param));
63 self.v.push(Tensor::zeros_like(&*param));
64 if self.amsgrad {
65 self.vm.push(Tensor::zeros_like(&*param));
66 }
67 }
68 continue;
69 };
70
71 if let Some(m) = self.m.get_mut(i) {
73 *m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
74 } else {
75 self.m.push(&grad * (1.0 - self.betas.0));
76 }
77
78 if let Some(v) = self.v.get_mut(i) {
80 *v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
81 } else {
82 self.v.push(&grad * &grad * (1.0 - self.betas.1));
83 }
84
85 let mh = &self.m[i] / (1.0 - self.betas.0.pow(self.t as f32));
87 let vh = &self.v[i] / (1.0 - self.betas.1.pow(self.t as f32));
88
89 if self.amsgrad {
90 if let Some(vm) = self.vm.get_mut(i) {
91 *vm = vm.cmplt(&vh).unwrap().where_(vh, &*vm).unwrap();
92 } else {
93 self.vm.push(vh);
94 }
95 *param = (&*param - mh / ((self.vm[i].sqrt() + self.eps) * self.learning_rate))
97 .cast(param.dtype());
98 } else {
99 *param = (&*param - mh / ((vh.sqrt() + self.eps) * self.learning_rate))
101 .cast(param.dtype());
102 }
103
104 if self.weight_decay != 0.0 {
106 *param = &*param * (1.0 - self.learning_rate * self.weight_decay);
107 }
108 }
109 }
110}