Skip to main content

zyx_optim/
sgd.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4use zyx::Tensor;
5use zyx_derive::Module;
6
7/// # Stochastic gradient descent optimizer
8#[derive(Module)]
9#[cfg_attr(feature = "py", pyo3::pyclass)]
10pub struct SGD {
11    /// learning rate (default: 0.001)
12    pub learning_rate: f32,
13    /// momentum factor (default: 0.0)
14    pub momentum: f32,
15    /// weight decay (L2 penalty) (default: 0.0)
16    pub weight_decay: f32,
17    /// dampening for momentum (default: 0.0)
18    pub dampening: f32,
19    /// enables Nesterov momentum (default: false)
20    pub nesterov: bool,
21    /// maximize the objective with respect to the params, instead of minimizing (default: false)
22    pub maximize: bool,
23    /// stores momentum, starts empty and will be initialized on demand
24    pub bias: Vec<Tensor>,
25}
26
27impl Default for SGD {
28    fn default() -> Self {
29        Self {
30            learning_rate: 0.001,
31            momentum: 0.0,
32            weight_decay: 0.0,
33            dampening: 0.0,
34            nesterov: false,
35            maximize: false,
36            bias: Vec::new(),
37        }
38    }
39}
40
41impl SGD {
42    /// Updates parameters with gradients.
43    /// Number of parameters must be the same as number of gradients.
44    /// Gradients can be None, those are simply skipped.
45    pub fn update<'a>(
46        &mut self,
47        parameters: impl IntoIterator<Item = &'a mut Tensor>,
48        gradients: impl IntoIterator<Item = Option<Tensor>>,
49    ) {
50        let params: Vec<&mut Tensor> = parameters.into_iter().collect();
51        let grads: Vec<Option<Tensor>> = gradients.into_iter().collect();
52
53        assert_eq!(
54            params.len(),
55            grads.len(),
56            "Number of parameters != number of gradients."
57        );
58
59        for (i, (param, grad)) in params.into_iter().zip(grads).enumerate() {
60            if let Some(mut grad) = grad {
61                if self.weight_decay != 0.0 {
62                    grad = grad + param.clone() * self.weight_decay;
63                }
64                if self.momentum != 0.0 {
65                    if let Some(bias) = self.bias.get_mut(i) {
66                        *bias =
67                            bias.clone() * self.momentum + grad.clone() * (1.0 - self.dampening);
68                    } else {
69                        self.bias.push(grad.clone());
70                    }
71                    if self.nesterov {
72                        grad = grad + self.bias[i].clone() * self.momentum;
73                    } else {
74                        grad = self.bias[i].clone();
75                    }
76                }
77                if self.maximize {
78                    // Cast since learning_rate is f32, but parameters can have different precision.
79                    // Can this cast be somehow avoided? Is it better to always work with original dtype?
80                    *param = (&*param + grad * self.learning_rate).cast(param.dtype());
81                } else {
82                    *param = (&*param - grad * self.learning_rate).cast(param.dtype());
83                }
84            }
85        }
86    }
87}