Skip to main content

tensorlogic_train/optimizers/
rmsprop.rs

1//! RMSprop optimizer (Root Mean Square Propagation).
2//!
3//! RMSprop divides the learning rate by an exponentially decaying average
4//! of squared gradients. It's effective for non-stationary objectives.
5//!
6//! Reference: Tieleman & Hinton, "Lecture 6.5-rmsprop", COURSERA: Neural networks for machine learning
7
8use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
9use crate::{TrainError, TrainResult};
10use scirs2_core::ndarray::{Array, Ix2};
11use std::collections::HashMap;
12
13/// RMSprop optimizer (Root Mean Square Propagation).
14#[derive(Debug)]
15pub struct RMSpropOptimizer {
16    config: OptimizerConfig,
17    /// Moving average of squared gradients.
18    v: HashMap<String, Array<f64, Ix2>>,
19}
20
21impl RMSpropOptimizer {
22    /// Create a new RMSprop optimizer.
23    pub fn new(config: OptimizerConfig) -> Self {
24        Self {
25            config,
26            v: HashMap::new(),
27        }
28    }
29
30    /// Apply gradient clipping if configured.
31    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
32        if let Some(clip_value) = self.config.grad_clip {
33            match self.config.grad_clip_mode {
34                GradClipMode::Value => {
35                    for grad in gradients.values_mut() {
36                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
37                    }
38                }
39                GradClipMode::Norm => {
40                    let total_norm = compute_gradient_norm(gradients);
41                    if total_norm > clip_value {
42                        let scale = clip_value / total_norm;
43                        for grad in gradients.values_mut() {
44                            grad.mapv_inplace(|g| g * scale);
45                        }
46                    }
47                }
48            }
49        }
50    }
51}
52
53impl Optimizer for RMSpropOptimizer {
54    fn step(
55        &mut self,
56        parameters: &mut HashMap<String, Array<f64, Ix2>>,
57        gradients: &HashMap<String, Array<f64, Ix2>>,
58    ) -> TrainResult<()> {
59        let mut clipped_gradients = gradients.clone();
60        self.clip_gradients(&mut clipped_gradients);
61        let lr = self.config.learning_rate;
62        let alpha = self.config.beta2;
63        let eps = self.config.epsilon;
64        for (name, param) in parameters.iter_mut() {
65            let grad = clipped_gradients.get(name).ok_or_else(|| {
66                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
67            })?;
68            if !self.v.contains_key(name) {
69                self.v.insert(name.clone(), Array::zeros(param.raw_dim()));
70            }
71            let v = self.v.get_mut(name).unwrap();
72            let grad_squared = grad.mapv(|g| g * g);
73            *v = &*v * alpha + &(grad_squared * (1.0 - alpha));
74            let update = grad / &v.mapv(|v_val| v_val.sqrt() + eps);
75            *param = &*param - &(update * lr);
76        }
77        Ok(())
78    }
79
80    fn zero_grad(&mut self) {}
81
82    fn get_lr(&self) -> f64 {
83        self.config.learning_rate
84    }
85
86    fn set_lr(&mut self, lr: f64) {
87        self.config.learning_rate = lr;
88    }
89
90    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
91        let mut state = HashMap::new();
92        for (name, v_val) in &self.v {
93            state.insert(format!("v_{}", name), v_val.iter().copied().collect());
94        }
95        state
96    }
97
98    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
99        for (key, values) in state {
100            if let Some(name) = key.strip_prefix("v_") {
101                if let Some(v) = self.v.get(name) {
102                    let shape = v.raw_dim();
103                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
104                        self.v.insert(name.to_string(), arr);
105                    }
106                }
107            }
108        }
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use scirs2_core::ndarray::array;
116
117    #[test]
118    fn test_rmsprop_optimizer() {
119        let config = OptimizerConfig {
120            learning_rate: 0.01,
121            ..Default::default()
122        };
123        let mut optimizer = RMSpropOptimizer::new(config);
124        let mut params = HashMap::new();
125        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
126        let mut grads = HashMap::new();
127        grads.insert("w".to_string(), array![[0.1, 0.1], [0.1, 0.1]]);
128        optimizer.step(&mut params, &grads).unwrap();
129        let w = params.get("w").unwrap();
130        assert!(w[[0, 0]] < 1.0);
131    }
132}