Skip to main content

zyx_optim/
rmsprop.rs

1// Copyright (C) 2025 zk4x
2// SPDX-License-Identifier: LGPL-3.0-only
3
4// In your code base or inside zyx_optim crate
5use zyx::Tensor;
6use zyx_derive::Module;
7
8/// RMSProp optimizer for adaptive learning rate training.
9#[derive(Module)]
10#[cfg_attr(feature = "py", pyo3::pyclass)]
11pub struct RMSprop {
12    /// Step size multiplier
13    pub learning_rate: f32,
14    /// Controls how quickly the cache forgets old gradients
15    pub alpha: f32,
16    /// Small constant to avoid division by zero
17    pub eps: f32,
18    /// Momentum
19    pub momentum: f32,
20    /// Centered
21    pub centered: bool,
22    /// Weight decay
23    pub weight_decay: f32,
24    /// t
25    pub t: usize,
26    /// Squared grad avg
27    buffer: Vec<Tensor>,
28    /// Momentum buffer
29    momentum_buf: Vec<Tensor>,
30    /// Gradient average for centered variant
31    grad_avg: Vec<Tensor>,
32}
33
34impl Default for RMSprop {
35    fn default() -> Self {
36        Self {
37            learning_rate: 0.01,
38            alpha: 0.99,
39            eps: 1e-8,
40            momentum: 0.0,
41            centered: false,
42            weight_decay: 0.0,
43            t: 0,
44            buffer: Vec::new(),
45            momentum_buf: Vec::new(),
46            grad_avg: Vec::new(),
47        }
48    }
49}
50
51impl RMSprop {
52    /// update
53    pub fn update<'a>(
54        &mut self,
55        parameters: impl IntoIterator<Item = &'a mut Tensor>,
56        gradients: impl IntoIterator<Item = Option<Tensor>>,
57    ) {
58        for (i, (param, grad)) in parameters.into_iter().zip(gradients).enumerate() {
59            let Some(grad) = grad else {
60                // Lazy init for new parameters
61                if self.buffer.len() <= i {
62                    self.buffer.push(Tensor::zeros_like(&*param));
63                    self.momentum_buf.push(Tensor::zeros_like(&*param));
64                    if self.centered {
65                        self.grad_avg.push(Tensor::zeros_like(&*param));
66                    }
67                }
68                continue;
69            };
70
71            // Lazy init state if missing
72            if self.buffer.len() <= i {
73                self.buffer.push(&grad * &grad * (1.0 - self.alpha));
74                self.momentum_buf.push(Tensor::zeros_like(&*param));
75                if self.centered {
76                    self.grad_avg.push(&grad * (1.0 - self.alpha));
77                }
78            }
79
80            // Exponential moving average of squared gradients
81            self.buffer[i] = &self.buffer[i] * self.alpha + &grad * &grad * (1.0 - self.alpha);
82
83            let denom = if self.centered {
84                // Centered RMSProp: subtract moving avg of grad
85                self.grad_avg[i] = &self.grad_avg[i] * self.alpha + &grad * (1.0 - self.alpha);
86                let avg = &self.grad_avg[i];
87                (&self.buffer[i] - avg * avg).relu().sqrt() + self.eps
88            } else {
89                self.buffer[i].sqrt() + self.eps
90            };
91
92            let update = &grad / denom * self.learning_rate;
93
94            if self.momentum > 0.0 {
95                self.momentum_buf[i] = &self.momentum_buf[i] * self.momentum + &update;
96                *param = &*param - &self.momentum_buf[i];
97            } else {
98                *param = &*param - update;
99            }
100
101            if self.weight_decay > 0.0 {
102                *param = &*param * (1.0 - self.learning_rate * self.weight_decay);
103            }
104        }
105
106        self.t += 1;
107    }
108}