Skip to main content

torsh_python/optim/
rmsprop.rs

1//! RMSprop optimizer
2
3use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{rmsprop::RMSprop, Optimizer};
11
12/// RMSprop optimizer - Root Mean Square Propagation
13#[pyclass(name = "RMSprop", extends = PyOptimizer)]
14pub struct PyRMSprop {
15    rmsprop: RMSprop,
16    param_groups: Vec<HashMap<String, Py<PyAny>>>,
17    lr: f32,
18    alpha: f32,
19    eps: f32,
20    weight_decay: f32,
21    momentum: f32,
22    centered: bool,
23}
24
25#[pymethods]
26impl PyRMSprop {
27    #[new]
28    fn new(
29        params: Vec<PyTensor>,
30        lr: Option<f32>,
31        alpha: Option<f32>,
32        eps: Option<f32>,
33        weight_decay: Option<f32>,
34        momentum: Option<f32>,
35        centered: Option<bool>,
36    ) -> (Self, PyOptimizer) {
37        let lr = lr.unwrap_or(0.01);
38        let alpha = alpha.unwrap_or(0.99);
39        let eps = eps.unwrap_or(1e-8);
40        let weight_decay = weight_decay.unwrap_or(0.0);
41        let momentum = momentum.unwrap_or(0.0);
42        let centered = centered.unwrap_or(false);
43
44        // Extract tensor parameters and wrap in Arc<RwLock>
45        let tensor_params =
46            extract_parameters(params.clone()).expect("parameter extraction should succeed");
47        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
48            .into_iter()
49            .map(|tensor| Arc::new(RwLock::new(tensor)))
50            .collect();
51        let rmsprop = RMSprop::new(
52            wrapped_params,
53            Some(lr),
54            Some(alpha),
55            Some(eps),
56            Some(weight_decay),
57            Some(momentum),
58            centered,
59        );
60
61        // Create parameter groups
62        let mut param_group_data = HashMap::new();
63        Python::attach(|py| {
64            param_group_data.insert(
65                "alpha".to_string(),
66                alpha
67                    .into_pyobject(py)
68                    .expect("Python object conversion should succeed")
69                    .into_any()
70                    .unbind(),
71            );
72            param_group_data.insert(
73                "eps".to_string(),
74                eps.into_pyobject(py)
75                    .expect("Python object conversion should succeed")
76                    .into_any()
77                    .unbind(),
78            );
79            param_group_data.insert(
80                "weight_decay".to_string(),
81                weight_decay
82                    .into_pyobject(py)
83                    .expect("Python object conversion should succeed")
84                    .into_any()
85                    .unbind(),
86            );
87            param_group_data.insert(
88                "momentum".to_string(),
89                momentum
90                    .into_pyobject(py)
91                    .expect("Python object conversion should succeed")
92                    .into_any()
93                    .unbind(),
94            );
95            param_group_data.insert(
96                "centered".to_string(),
97                PyBool::new(py, centered).to_owned().into(),
98            );
99        });
100
101        let param_groups = vec![create_param_group(params, lr, param_group_data)
102            .expect("param group creation should succeed")];
103
104        (
105            Self {
106                rmsprop,
107                param_groups,
108                lr,
109                alpha,
110                eps,
111                weight_decay,
112                momentum,
113                centered,
114            },
115            PyOptimizer {},
116        )
117    }
118
119    /// Perform a single optimization step
120    fn step(&mut self) -> PyResult<()> {
121        self.rmsprop.step().map_err(|e| {
122            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
123                "Optimizer step failed: {}",
124                e
125            ))
126        })?;
127        Ok(())
128    }
129
130    /// Zero out gradients of all parameters
131    fn zero_grad(&mut self, set_to_none: Option<bool>) {
132        let _set_to_none = set_to_none.unwrap_or(false);
133        self.rmsprop.zero_grad();
134    }
135
136    /// Get parameter groups
137    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
138        // Manual clone since Py<PyAny> doesn't implement Clone
139        Python::attach(|py| {
140            let cloned_groups = self
141                .param_groups
142                .iter()
143                .map(|group| {
144                    group
145                        .iter()
146                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
147                        .collect()
148                })
149                .collect();
150            Ok(cloned_groups)
151        })
152    }
153
154    /// Get current state
155    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
156        let mut state = HashMap::new();
157        Python::attach(|py| {
158            state.insert(
159                "step".to_string(),
160                0i64.into_pyobject(py)
161                    .expect("Python object conversion should succeed")
162                    .into_any()
163                    .unbind(),
164            );
165            state.insert(
166                "square_avg".to_string(),
167                "{}".into_pyobject(py)
168                    .expect("Python object conversion should succeed")
169                    .into_any()
170                    .unbind(),
171            );
172            if self.momentum > 0.0 {
173                state.insert(
174                    "momentum_buffer".to_string(),
175                    "{}".into_pyobject(py)
176                        .expect("Python object conversion should succeed")
177                        .into_any()
178                        .unbind(),
179                );
180            }
181            if self.centered {
182                state.insert(
183                    "grad_avg".to_string(),
184                    "{}".into_pyobject(py)
185                        .expect("Python object conversion should succeed")
186                        .into_any()
187                        .unbind(),
188                );
189            }
190        });
191        Ok(state)
192    }
193
194    /// String representation
195    fn __repr__(&self) -> String {
196        format!(
197            "RMSprop(lr={}, alpha={}, eps={}, weight_decay={}, momentum={}, centered={})",
198            self.lr, self.alpha, self.eps, self.weight_decay, self.momentum, self.centered
199        )
200    }
201
202    /// Get defaults
203    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
204        let mut defaults = HashMap::new();
205        Python::attach(|py| {
206            defaults.insert(
207                "lr".to_string(),
208                self.lr
209                    .into_pyobject(py)
210                    .expect("Python object conversion should succeed")
211                    .into_any()
212                    .unbind(),
213            );
214            defaults.insert(
215                "alpha".to_string(),
216                self.alpha
217                    .into_pyobject(py)
218                    .expect("Python object conversion should succeed")
219                    .into_any()
220                    .unbind(),
221            );
222            defaults.insert(
223                "eps".to_string(),
224                self.eps
225                    .into_pyobject(py)
226                    .expect("Python object conversion should succeed")
227                    .into_any()
228                    .unbind(),
229            );
230            defaults.insert(
231                "weight_decay".to_string(),
232                self.weight_decay
233                    .into_pyobject(py)
234                    .expect("Python object conversion should succeed")
235                    .into_any()
236                    .unbind(),
237            );
238            defaults.insert(
239                "momentum".to_string(),
240                self.momentum
241                    .into_pyobject(py)
242                    .expect("Python object conversion should succeed")
243                    .into_any()
244                    .unbind(),
245            );
246            defaults.insert(
247                "centered".to_string(),
248                PyBool::new(py, self.centered).to_owned().into(),
249            );
250        });
251        Ok(defaults)
252    }
253
254    /// Get learning rate
255    #[getter]
256    fn lr(&self) -> f32 {
257        self.lr
258    }
259
260    /// Set learning rate
261    #[setter]
262    fn set_lr(&mut self, lr: f32) {
263        self.lr = lr;
264        Python::attach(|py| {
265            for param_group in &mut self.param_groups {
266                param_group.insert(
267                    "lr".to_string(),
268                    lr.into_pyobject(py)
269                        .expect("Python object conversion should succeed")
270                        .into_any()
271                        .unbind(),
272                );
273            }
274        });
275    }
276
277    /// Get alpha (smoothing constant)
278    #[getter]
279    fn alpha(&self) -> f32 {
280        self.alpha
281    }
282
283    /// Get epsilon
284    #[getter]
285    fn eps(&self) -> f32 {
286        self.eps
287    }
288
289    /// Get weight decay
290    #[getter]
291    fn weight_decay(&self) -> f32 {
292        self.weight_decay
293    }
294
295    /// Get momentum
296    #[getter]
297    fn momentum(&self) -> f32 {
298        self.momentum
299    }
300
301    /// Get centered flag
302    #[getter]
303    fn centered(&self) -> bool {
304        self.centered
305    }
306}