zyx_optim/
adam.rs

1use alloc::vec::Vec;
2use zyx_core::backend::Backend;
3use zyx_core::tensor::Tensor;
4
5/// # Adaptive momentum estimation optimizer
6pub struct Adam<B: Backend> {
7    /// learning rate (default: 1e-3)
8    pub learning_rate: f32,
9    /// coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
10    pub betas: (f32, f32),
11    /// term added to the denominator to improve numerical stability (default: 1e-8)
12    pub eps: f32,
13    /// weight decay (L2 penalty) (default: 0)
14    pub weight_decay: f32,
15    /// whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: false)
16    pub amsgrad: bool,
17    /// maximize the objective with respect to the params, instead of minimizing (default: false)
18    pub maximize: bool,
19    /// m
20    pub m: Vec<Tensor<B>>,
21    /// v
22    pub v: Vec<Tensor<B>>,
23    /// vm
24    pub vm: Vec<Tensor<B>>,
25    /// t
26    pub t: usize,
27}
28
29impl<B: Backend> Default for Adam<B> {
30    fn default() -> Self {
31        Self {
32            learning_rate: 0.001,
33            betas: (0.9, 0.999),
34            eps: 1e-8,
35            weight_decay: 0.0,
36            amsgrad: false,
37            maximize: false,
38            m: Vec::new(),
39            v: Vec::new(),
40            vm: Vec::new(),
41            t: 0,
42        }
43    }
44}
45
46impl<B: Backend> Adam<B> {
47    /// Updates parameters with gradients.
48    /// Number of parameters must be the same as number of gradients.
49    /// Gradients can be None, those are simply skipped.
50    pub fn update<'a>(
51        &mut self,
52        parameters: impl IntoIterator<Item = &'a mut Tensor<B>>,
53        gradients: impl IntoIterator<Item = Option<Tensor<B>>>,
54    ) where
55        B: 'a,
56    {
57        let params: Vec<&mut Tensor<B>> = parameters.into_iter().collect();
58        let grads: Vec<Option<Tensor<B>>> = gradients.into_iter().collect();
59
60        assert_eq!(
61            params.len(),
62            grads.len(),
63            "Number of parameters != number of gradients."
64        );
65
66        for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
67            if let Some(mut grad) = grad {
68                if self.maximize {
69                    grad = -grad;
70                }
71                if self.weight_decay != 0.0 {
72                    grad = grad + &*param * self.weight_decay;
73                }
74                if let Some(m) = self.m.get_mut(i) {
75                    *m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
76                } else {
77                    self.m.push(&grad * (1.0 - self.betas.0));
78                }
79                if let Some(v) = self.m.get_mut(i) {
80                    *v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
81                } else {
82                    self.v.push(&grad * &grad * (1.0 - self.betas.1));
83                }
84                let mh = &self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
85                let vh = &self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
86                if self.amsgrad {
87                    if let Some(vm) = self.vm.get_mut(i) {
88                        *vm = vm.cmplt(&vh).where_(vh, &*vm);
89                    } else {
90                        self.vm.push(vh);
91                    }
92                    // Cast since learning_rate is f32, but parameters can have different precision.
93                    // Can this cast be somehow avoided? Is it better to always work with original dtype?
94                    *param = (&*param - mh / ((self.vm[i].sqrt() + self.eps) * self.learning_rate))
95                        .cast(param.dtype());
96                } else {
97                    *param = (&*param - mh / ((vh.sqrt() + self.eps) * self.learning_rate))
98                        .cast(param.dtype());
99                }
100            }
101        }
102    }
103}