1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
use alloc::vec::Vec;
use zyx_core::backend::Backend;
use zyx_core::tensor::Tensor;
/// # Adaptive momentum estimation optimizer
pub struct Adam<B: Backend> {
/// learning rate (default: 1e-3)
pub learning_rate: f32,
/// coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
pub betas: (f32, f32),
/// term added to the denominator to improve numerical stability (default: 1e-8)
pub eps: f32,
/// weight decay (L2 penalty) (default: 0)
pub weight_decay: f32,
/// whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond (default: false)
pub amsgrad: bool,
/// maximize the objective with respect to the params, instead of minimizing (default: false)
pub maximize: bool,
/// m
pub m: Vec<Tensor<B>>,
/// v
pub v: Vec<Tensor<B>>,
/// vm
pub vm: Vec<Tensor<B>>,
/// t
pub t: usize,
}
impl<B: Backend> Default for Adam<B> {
fn default() -> Self {
Self {
learning_rate: 0.001,
betas: (0.9, 0.999),
eps: 1e-8,
weight_decay: 0.0,
amsgrad: false,
maximize: false,
m: Vec::new(),
v: Vec::new(),
vm: Vec::new(),
t: 0,
}
}
}
impl<B: Backend> Adam<B> {
/// Updates parameters with gradients.
/// Number of parameters must be the same as number of gradients.
/// Gradients can be None, those are simply skipped.
pub fn update<'a>(
&mut self,
parameters: impl IntoIterator<Item = &'a mut Tensor<B>>,
gradients: impl IntoIterator<Item = Option<Tensor<B>>>,
) where
B: 'a,
{
let params: Vec<&mut Tensor<B>> = parameters.into_iter().collect();
let grads: Vec<Option<Tensor<B>>> = gradients.into_iter().collect();
assert_eq!(
params.len(),
grads.len(),
"Number of parameters != number of gradients."
);
for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
if let Some(mut grad) = grad {
if self.maximize {
grad = -grad;
}
if self.weight_decay != 0.0 {
grad = grad + &*param * self.weight_decay;
}
if let Some(m) = self.m.get_mut(i) {
*m = &*m * self.betas.0 + &grad * (1.0 - self.betas.0);
} else {
self.m.push(&grad * (1.0 - self.betas.0));
}
if let Some(v) = self.m.get_mut(i) {
*v = &*v * self.betas.1 + &grad * &grad * (1.0 - self.betas.1);
} else {
self.v.push(&grad * &grad * (1.0 - self.betas.1));
}
let mh = &self.m[i] / (1.0 - self.betas.0.powi(self.t as i32));
let vh = &self.v[i] / (1.0 - self.betas.1.powi(self.t as i32));
if self.amsgrad {
if let Some(vm) = self.vm.get_mut(i) {
*vm = vm.cmplt(&vh).where_(vh, &*vm);
} else {
self.vm.push(vh);
}
// Cast since learning_rate is f32, but parameters can have different precision.
// Can this cast be somehow avoided? Is it better to always work with original dtype?
*param = (&*param - mh / ((self.vm[i].sqrt() + self.eps) * self.learning_rate))
.cast(param.dtype());
} else {
*param = (&*param - mh / ((vh.sqrt() + self.eps) * self.learning_rate))
.cast(param.dtype());
}
}
}
}
}