Skip to main content

yscv_optim/
rmsprop.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3
4use yscv_autograd::{Graph, NodeId};
5use yscv_tensor::Tensor;
6
7use super::validate::{validate_epsilon, validate_lr, validate_momentum, validate_rmsprop_alpha};
8use super::{LearningRate, OptimError};
9
10#[derive(Debug, Clone)]
11struct RmsPropState {
12    square_avg: Tensor,
13    grad_avg: Tensor,
14    momentum_buffer: Tensor,
15}
16
17impl RmsPropState {
18    fn new(shape: &[usize]) -> Result<Self, OptimError> {
19        Ok(Self {
20            square_avg: Tensor::zeros(shape.to_vec())?,
21            grad_avg: Tensor::zeros(shape.to_vec())?,
22            momentum_buffer: Tensor::zeros(shape.to_vec())?,
23        })
24    }
25
26    fn reset(&mut self, shape: &[usize]) -> Result<(), OptimError> {
27        *self = Self::new(shape)?;
28        Ok(())
29    }
30}
31
32/// RMSProp optimizer with optional momentum, weight decay, and centered variance.
33#[derive(Debug, Clone)]
34pub struct RmsProp {
35    lr: f32,
36    alpha: f32,
37    epsilon: f32,
38    weight_decay: f32,
39    momentum: f32,
40    centered: bool,
41    state: HashMap<u64, RmsPropState>,
42}
43
44impl RmsProp {
45    /// Creates RMSProp with required learning rate.
46    pub fn new(lr: f32) -> Result<Self, OptimError> {
47        validate_lr(lr)?;
48        Ok(Self {
49            lr,
50            alpha: 0.99,
51            epsilon: 1e-8,
52            weight_decay: 0.0,
53            momentum: 0.0,
54            centered: false,
55            state: HashMap::new(),
56        })
57    }
58
59    /// Sets RMSProp smoothing factor in `[0, 1)`.
60    pub fn with_alpha(mut self, alpha: f32) -> Result<Self, OptimError> {
61        validate_rmsprop_alpha(alpha)?;
62        self.alpha = alpha;
63        Ok(self)
64    }
65
66    /// Sets epsilon value, must be finite and `> 0`.
67    pub fn with_epsilon(mut self, epsilon: f32) -> Result<Self, OptimError> {
68        validate_epsilon(epsilon)?;
69        self.epsilon = epsilon;
70        Ok(self)
71    }
72
73    /// Sets L2 weight decay factor in `[0, +inf)`.
74    pub fn with_weight_decay(mut self, weight_decay: f32) -> Result<Self, OptimError> {
75        if !weight_decay.is_finite() || weight_decay < 0.0 {
76            return Err(OptimError::InvalidWeightDecay { weight_decay });
77        }
78        self.weight_decay = weight_decay;
79        Ok(self)
80    }
81
82    /// Sets momentum factor in `[0, 1)`.
83    pub fn with_momentum(mut self, momentum: f32) -> Result<Self, OptimError> {
84        validate_momentum(momentum)?;
85        self.momentum = momentum;
86        Ok(self)
87    }
88
89    /// Enables/disables centered RMSProp variant.
90    pub fn with_centered(mut self, centered: bool) -> Self {
91        self.centered = centered;
92        self
93    }
94
95    /// Drops optimizer state (for example when restarting training).
96    pub fn clear_state(&mut self) {
97        self.state.clear();
98    }
99
100    /// Returns current learning rate.
101    pub fn learning_rate(&self) -> f32 {
102        self.lr
103    }
104
105    /// Overrides current learning rate.
106    pub fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
107        validate_lr(lr)?;
108        self.lr = lr;
109        Ok(())
110    }
111
112    /// Applies one update to raw tensor weights.
113    pub fn step(
114        &mut self,
115        parameter_id: u64,
116        weights: &mut Tensor,
117        grad: &Tensor,
118    ) -> Result<(), OptimError> {
119        if weights.shape() != grad.shape() {
120            return Err(OptimError::ShapeMismatch {
121                weights: weights.shape().to_vec(),
122                grad: grad.shape().to_vec(),
123            });
124        }
125
126        let state = match self.state.entry(parameter_id) {
127            Entry::Occupied(entry) => entry.into_mut(),
128            Entry::Vacant(entry) => entry.insert(RmsPropState::new(weights.shape())?),
129        };
130        if state.square_avg.shape() != weights.shape() {
131            state.reset(weights.shape())?;
132        }
133
134        let grad_values = grad.data();
135        let weights_data = weights.data_mut();
136        let square_avg = state.square_avg.data_mut();
137        let grad_avg = state.grad_avg.data_mut();
138        let momentum_buffer = state.momentum_buffer.data_mut();
139
140        let alpha = self.alpha;
141        let one_minus_alpha = 1.0 - self.alpha;
142
143        for index in 0..weights_data.len() {
144            let grad_value = grad_values[index] + self.weight_decay * weights_data[index];
145            square_avg[index] =
146                alpha * square_avg[index] + one_minus_alpha * grad_value * grad_value;
147
148            let avg = if self.centered {
149                grad_avg[index] = alpha * grad_avg[index] + one_minus_alpha * grad_value;
150                (square_avg[index] - grad_avg[index] * grad_avg[index]).max(0.0)
151            } else {
152                square_avg[index]
153            };
154
155            let denom = avg.sqrt() + self.epsilon;
156            let normalized = grad_value / denom;
157            let update = if self.momentum != 0.0 {
158                let next = self.momentum * momentum_buffer[index] + normalized;
159                momentum_buffer[index] = next;
160                next
161            } else {
162                normalized
163            };
164            weights_data[index] -= self.lr * update;
165        }
166
167        Ok(())
168    }
169
170    /// Applies one update to a trainable graph node by its `NodeId`.
171    pub fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), OptimError> {
172        if !graph.requires_grad(node)? {
173            return Ok(());
174        }
175
176        let grad = match graph.grad(node)? {
177            Some(grad) => grad.clone(),
178            None => return Err(OptimError::MissingGradient { node: node.0 }),
179        };
180        let weights = graph.value_mut(node)?;
181        self.step(node.0 as u64, weights, &grad)
182    }
183}
184
185impl LearningRate for RmsProp {
186    fn learning_rate(&self) -> f32 {
187        RmsProp::learning_rate(self)
188    }
189
190    fn set_learning_rate(&mut self, lr: f32) -> Result<(), OptimError> {
191        RmsProp::set_learning_rate(self, lr)
192    }
193}