1use alloc::vec::Vec;
2use zyx_core::backend::Backend;
3use zyx_core::tensor::Tensor;
4
5pub struct SGD<B: Backend> {
7 pub learning_rate: f32,
9 pub momentum: f32,
11 pub weight_decay: f32,
13 pub dampening: f32,
15 pub nesterov: bool,
17 pub maximize: bool,
19 pub bias: Vec<Tensor<B>>,
21}
22
23impl<B: Backend> Default for SGD<B> {
24 fn default() -> Self {
25 Self {
26 learning_rate: 0.001,
27 momentum: 0.0,
28 weight_decay: 0.0,
29 dampening: 0.0,
30 nesterov: false,
31 maximize: false,
32 bias: Vec::new(),
33 }
34 }
35}
36
37impl<B: Backend> SGD<B> {
38 pub fn update<'a>(
42 &mut self,
43 parameters: impl IntoIterator<Item = &'a mut Tensor<B>>,
44 gradients: impl IntoIterator<Item = Option<Tensor<B>>>,
45 ) where
46 B: 'a,
47 {
48 let params: Vec<&mut Tensor<B>> = parameters.into_iter().collect();
49 let grads: Vec<Option<Tensor<B>>> = gradients.into_iter().collect();
50
51 assert_eq!(
52 params.len(),
53 grads.len(),
54 "Number of parameters != number of gradients."
55 );
56
57 for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
58 if let Some(mut grad) = grad {
59 if self.weight_decay != 0.0 {
60 grad = grad + &*param * self.weight_decay;
61 }
62 if self.momentum != 0.0 {
63 if let Some(bias) = self.bias.get_mut(i) {
64 *bias = &*bias * self.momentum + &grad * (1.0 - self.dampening);
65 } else {
66 self.bias.push(grad.clone());
67 }
68 if self.nesterov {
69 grad = grad + &self.bias[i] * self.momentum;
70 } else {
71 grad = self.bias[i].clone();
72 }
73 }
74 if self.maximize {
75 *param = (&*param + grad * self.learning_rate).cast(param.dtype());
78 } else {
79 *param = (&*param - grad * self.learning_rate).cast(param.dtype());
80 }
81 }
82 }
83 }
84}