torsh_python/optim/
rmsprop.rs

1//! RMSprop optimizer
2
3use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{rmsprop::RMSprop, Optimizer};
11
12/// RMSprop optimizer - Root Mean Square Propagation
13#[pyclass(name = "RMSprop", extends = PyOptimizer)]
14pub struct PyRMSprop {
15    rmsprop: RMSprop,
16    param_groups: Vec<HashMap<String, Py<PyAny>>>,
17    lr: f32,
18    alpha: f32,
19    eps: f32,
20    weight_decay: f32,
21    momentum: f32,
22    centered: bool,
23}
24
25#[pymethods]
26impl PyRMSprop {
27    #[new]
28    fn new(
29        params: Vec<PyTensor>,
30        lr: Option<f32>,
31        alpha: Option<f32>,
32        eps: Option<f32>,
33        weight_decay: Option<f32>,
34        momentum: Option<f32>,
35        centered: Option<bool>,
36    ) -> (Self, PyOptimizer) {
37        let lr = lr.unwrap_or(0.01);
38        let alpha = alpha.unwrap_or(0.99);
39        let eps = eps.unwrap_or(1e-8);
40        let weight_decay = weight_decay.unwrap_or(0.0);
41        let momentum = momentum.unwrap_or(0.0);
42        let centered = centered.unwrap_or(false);
43
44        // Extract tensor parameters and wrap in Arc<RwLock>
45        let tensor_params = extract_parameters(params.clone()).unwrap();
46        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
47            .into_iter()
48            .map(|tensor| Arc::new(RwLock::new(tensor)))
49            .collect();
50        let rmsprop = RMSprop::new(
51            wrapped_params,
52            Some(lr),
53            Some(alpha),
54            Some(eps),
55            Some(weight_decay),
56            Some(momentum),
57            centered,
58        );
59
60        // Create parameter groups
61        let mut param_group_data = HashMap::new();
62        Python::attach(|py| {
63            param_group_data.insert(
64                "alpha".to_string(),
65                alpha.into_pyobject(py).unwrap().into_any().unbind(),
66            );
67            param_group_data.insert(
68                "eps".to_string(),
69                eps.into_pyobject(py).unwrap().into_any().unbind(),
70            );
71            param_group_data.insert(
72                "weight_decay".to_string(),
73                weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
74            );
75            param_group_data.insert(
76                "momentum".to_string(),
77                momentum.into_pyobject(py).unwrap().into_any().unbind(),
78            );
79            param_group_data.insert(
80                "centered".to_string(),
81                PyBool::new(py, centered).to_owned().into(),
82            );
83        });
84
85        let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
86
87        (
88            Self {
89                rmsprop,
90                param_groups,
91                lr,
92                alpha,
93                eps,
94                weight_decay,
95                momentum,
96                centered,
97            },
98            PyOptimizer {},
99        )
100    }
101
102    /// Perform a single optimization step
103    fn step(&mut self) -> PyResult<()> {
104        self.rmsprop.step().map_err(|e| {
105            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
106                "Optimizer step failed: {}",
107                e
108            ))
109        })?;
110        Ok(())
111    }
112
113    /// Zero out gradients of all parameters
114    fn zero_grad(&mut self, set_to_none: Option<bool>) {
115        let _set_to_none = set_to_none.unwrap_or(false);
116        self.rmsprop.zero_grad();
117    }
118
119    /// Get parameter groups
120    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
121        // Manual clone since Py<PyAny> doesn't implement Clone
122        Python::attach(|py| {
123            let cloned_groups = self
124                .param_groups
125                .iter()
126                .map(|group| {
127                    group
128                        .iter()
129                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
130                        .collect()
131                })
132                .collect();
133            Ok(cloned_groups)
134        })
135    }
136
137    /// Get current state
138    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
139        let mut state = HashMap::new();
140        Python::attach(|py| {
141            state.insert(
142                "step".to_string(),
143                0i64.into_pyobject(py).unwrap().into_any().unbind(),
144            );
145            state.insert(
146                "square_avg".to_string(),
147                "{}".into_pyobject(py).unwrap().into_any().unbind(),
148            );
149            if self.momentum > 0.0 {
150                state.insert(
151                    "momentum_buffer".to_string(),
152                    "{}".into_pyobject(py).unwrap().into_any().unbind(),
153                );
154            }
155            if self.centered {
156                state.insert(
157                    "grad_avg".to_string(),
158                    "{}".into_pyobject(py).unwrap().into_any().unbind(),
159                );
160            }
161        });
162        Ok(state)
163    }
164
165    /// String representation
166    fn __repr__(&self) -> String {
167        format!(
168            "RMSprop(lr={}, alpha={}, eps={}, weight_decay={}, momentum={}, centered={})",
169            self.lr, self.alpha, self.eps, self.weight_decay, self.momentum, self.centered
170        )
171    }
172
173    /// Get defaults
174    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
175        let mut defaults = HashMap::new();
176        Python::attach(|py| {
177            defaults.insert(
178                "lr".to_string(),
179                self.lr.into_pyobject(py).unwrap().into_any().unbind(),
180            );
181            defaults.insert(
182                "alpha".to_string(),
183                self.alpha.into_pyobject(py).unwrap().into_any().unbind(),
184            );
185            defaults.insert(
186                "eps".to_string(),
187                self.eps.into_pyobject(py).unwrap().into_any().unbind(),
188            );
189            defaults.insert(
190                "weight_decay".to_string(),
191                self.weight_decay
192                    .into_pyobject(py)
193                    .unwrap()
194                    .into_any()
195                    .unbind(),
196            );
197            defaults.insert(
198                "momentum".to_string(),
199                self.momentum.into_pyobject(py).unwrap().into_any().unbind(),
200            );
201            defaults.insert(
202                "centered".to_string(),
203                PyBool::new(py, self.centered).to_owned().into(),
204            );
205        });
206        Ok(defaults)
207    }
208
209    /// Get learning rate
210    #[getter]
211    fn lr(&self) -> f32 {
212        self.lr
213    }
214
215    /// Set learning rate
216    #[setter]
217    fn set_lr(&mut self, lr: f32) {
218        self.lr = lr;
219        Python::attach(|py| {
220            for param_group in &mut self.param_groups {
221                param_group.insert(
222                    "lr".to_string(),
223                    lr.into_pyobject(py).unwrap().into_any().unbind(),
224                );
225            }
226        });
227    }
228
229    /// Get alpha (smoothing constant)
230    #[getter]
231    fn alpha(&self) -> f32 {
232        self.alpha
233    }
234
235    /// Get epsilon
236    #[getter]
237    fn eps(&self) -> f32 {
238        self.eps
239    }
240
241    /// Get weight decay
242    #[getter]
243    fn weight_decay(&self) -> f32 {
244        self.weight_decay
245    }
246
247    /// Get momentum
248    #[getter]
249    fn momentum(&self) -> f32 {
250        self.momentum
251    }
252
253    /// Get centered flag
254    #[getter]
255    fn centered(&self) -> bool {
256        self.centered
257    }
258}