1use alloc::vec::Vec;
2use zyx_core::backend::Backend;
3use zyx_core::tensor::Tensor;
4
5pub struct Adam<B: Backend> {
7 pub learning_rate: f32,
9 pub betas: (f32, f32),
11 pub eps: f32,
13 pub weight_decay: f32,
15 pub amsgrad: bool,
17 pub maximize: bool,
19 pub m: Vec<Tensor<B>>,
21 pub v: Vec<Tensor<B>>,
23 pub vm: Vec<Tensor<B>>,
25 pub t: usize,
27}
28
29impl<B: Backend> Default for Adam<B> {
30 fn default() -> Self {
31 Self {
32 learning_rate: 0.001,
33 betas: (0.9, 0.999),
34 eps: 1e-8,
35 weight_decay: 0.0,
36 amsgrad: false,
37 maximize: false,
38 m: Vec::new(),
39 v: Vec::new(),
40 vm: Vec::new(),
41 t: 0,
42 }
43 }
44}
45
46impl<B: Backend> Adam<B> {
47 pub fn update<'a>(
51 &mut self,
52 parameters: impl IntoIterator<Item = &'a mut Tensor<B>>,
53 gradients: impl IntoIterator<Item = Option<Tensor<B>>>,
54 ) where
55 B: 'a,
56 {
57 let params: Vec<&mut Tensor<B>> = parameters.into_iter().collect();
58 let grads: Vec<Option<Tensor<B>>> = gradients.into_iter().collect();
59
60 assert_eq!(
61 params.len(),
62 grads.len(),
63 "Number of parameters != number of gradients."
64 );
65
66 for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
67 if let Some(mut grad) = grad {
68 if self.maximize {
69 grad = -grad;
70 }
71 if self.weight_decay != 0.0 {
72 grad = grad + &*param * self.weight_decay;
73 }
74 if let Some(m) = self.m.get_mut(i) {
75 *m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
76 } else {
77 self.m.push(&grad * (1.0 - self.betas.0));
78 }
79 if let Some(v) = self.m.get_mut(i) {
80 *v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
81 } else {
82 self.v.push(&grad * &grad * (1.0 - self.betas.1));
83 }
84 let mh = &self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
85 let vh = &self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
86 if self.amsgrad {
87 if let Some(vm) = self.vm.get_mut(i) {
88 *vm = vm.cmplt(&vh).where_(vh, &*vm);
89 } else {
90 self.vm.push(vh);
91 }
92 *param = (&*param - mh / ((self.vm[i].sqrt() + self.eps) * self.learning_rate))
95 .cast(param.dtype());
96 } else {
97 *param = (&*param - mh / ((vh.sqrt() + self.eps) * self.learning_rate))
98 .cast(param.dtype());
99 }
100 }
101 }
102 }
103}