zyx_optim/
sgd.rs

1use alloc::vec::Vec;
2use zyx_core::backend::Backend;
3use zyx_core::tensor::Tensor;
4
5/// # Stochastic gradient descent optimizer
6pub struct SGD<B: Backend> {
7    /// learning rate (default: 0.001)
8    pub learning_rate: f32,
9    /// momentum factor (default: 0.0)
10    pub momentum: f32,
11    /// weight decay (L2 penalty) (default: 0.0)
12    pub weight_decay: f32,
13    /// dampening for momentum (default: 0.0)
14    pub dampening: f32,
15    /// enables Nesterov momentum (default: false)
16    pub nesterov: bool,
17    /// maximize the objective with respect to the params, instead of minimizing (default: false)
18    pub maximize: bool,
19    /// stores momentum, starts empty and will be initialized on demand
20    pub bias: Vec<Tensor<B>>,
21}
22
23impl<B: Backend> Default for SGD<B> {
24    fn default() -> Self {
25        Self {
26            learning_rate: 0.001,
27            momentum: 0.0,
28            weight_decay: 0.0,
29            dampening: 0.0,
30            nesterov: false,
31            maximize: false,
32            bias: Vec::new(),
33        }
34    }
35}
36
37impl<B: Backend> SGD<B> {
38    /// Updates parameters with gradients.
39    /// Number of parameters must be the same as number of gradients.
40    /// Gradients can be None, those are simply skipped.
41    pub fn update<'a>(
42        &mut self,
43        parameters: impl IntoIterator<Item = &'a mut Tensor<B>>,
44        gradients: impl IntoIterator<Item = Option<Tensor<B>>>,
45    ) where
46        B: 'a,
47    {
48        let params: Vec<&mut Tensor<B>> = parameters.into_iter().collect();
49        let grads: Vec<Option<Tensor<B>>> = gradients.into_iter().collect();
50
51        assert_eq!(
52            params.len(),
53            grads.len(),
54            "Number of parameters != number of gradients."
55        );
56
57        for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
58            if let Some(mut grad) = grad {
59                if self.weight_decay != 0.0 {
60                    grad = grad + &*param * self.weight_decay;
61                }
62                if self.momentum != 0.0 {
63                    if let Some(bias) = self.bias.get_mut(i) {
64                        *bias = &*bias * self.momentum + &grad * (1.0 - self.dampening);
65                    } else {
66                        self.bias.push(grad.clone());
67                    }
68                    if self.nesterov {
69                        grad = grad + &self.bias[i] * self.momentum;
70                    } else {
71                        grad = self.bias[i].clone();
72                    }
73                }
74                if self.maximize {
75                    // Cast since learning_rate is f32, but parameters can have different precision.
76                    // Can this cast be somehow avoided? Is it better to always work with original dtype?
77                    *param = (&*param + grad * self.learning_rate).cast(param.dtype());
78                } else {
79                    *param = (&*param - grad * self.learning_rate).cast(param.dtype());
80                }
81            }
82        }
83    }
84}