1use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{rmsprop::RMSprop, Optimizer};
11
12#[pyclass(name = "RMSprop", extends = PyOptimizer)]
14pub struct PyRMSprop {
15 rmsprop: RMSprop,
16 param_groups: Vec<HashMap<String, Py<PyAny>>>,
17 lr: f32,
18 alpha: f32,
19 eps: f32,
20 weight_decay: f32,
21 momentum: f32,
22 centered: bool,
23}
24
25#[pymethods]
26impl PyRMSprop {
27 #[new]
28 fn new(
29 params: Vec<PyTensor>,
30 lr: Option<f32>,
31 alpha: Option<f32>,
32 eps: Option<f32>,
33 weight_decay: Option<f32>,
34 momentum: Option<f32>,
35 centered: Option<bool>,
36 ) -> (Self, PyOptimizer) {
37 let lr = lr.unwrap_or(0.01);
38 let alpha = alpha.unwrap_or(0.99);
39 let eps = eps.unwrap_or(1e-8);
40 let weight_decay = weight_decay.unwrap_or(0.0);
41 let momentum = momentum.unwrap_or(0.0);
42 let centered = centered.unwrap_or(false);
43
44 let tensor_params =
46 extract_parameters(params.clone()).expect("parameter extraction should succeed");
47 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
48 .into_iter()
49 .map(|tensor| Arc::new(RwLock::new(tensor)))
50 .collect();
51 let rmsprop = RMSprop::new(
52 wrapped_params,
53 Some(lr),
54 Some(alpha),
55 Some(eps),
56 Some(weight_decay),
57 Some(momentum),
58 centered,
59 );
60
61 let mut param_group_data = HashMap::new();
63 Python::attach(|py| {
64 param_group_data.insert(
65 "alpha".to_string(),
66 alpha
67 .into_pyobject(py)
68 .expect("Python object conversion should succeed")
69 .into_any()
70 .unbind(),
71 );
72 param_group_data.insert(
73 "eps".to_string(),
74 eps.into_pyobject(py)
75 .expect("Python object conversion should succeed")
76 .into_any()
77 .unbind(),
78 );
79 param_group_data.insert(
80 "weight_decay".to_string(),
81 weight_decay
82 .into_pyobject(py)
83 .expect("Python object conversion should succeed")
84 .into_any()
85 .unbind(),
86 );
87 param_group_data.insert(
88 "momentum".to_string(),
89 momentum
90 .into_pyobject(py)
91 .expect("Python object conversion should succeed")
92 .into_any()
93 .unbind(),
94 );
95 param_group_data.insert(
96 "centered".to_string(),
97 PyBool::new(py, centered).to_owned().into(),
98 );
99 });
100
101 let param_groups = vec![create_param_group(params, lr, param_group_data)
102 .expect("param group creation should succeed")];
103
104 (
105 Self {
106 rmsprop,
107 param_groups,
108 lr,
109 alpha,
110 eps,
111 weight_decay,
112 momentum,
113 centered,
114 },
115 PyOptimizer {},
116 )
117 }
118
119 fn step(&mut self) -> PyResult<()> {
121 self.rmsprop.step().map_err(|e| {
122 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
123 "Optimizer step failed: {}",
124 e
125 ))
126 })?;
127 Ok(())
128 }
129
130 fn zero_grad(&mut self, set_to_none: Option<bool>) {
132 let _set_to_none = set_to_none.unwrap_or(false);
133 self.rmsprop.zero_grad();
134 }
135
136 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
138 Python::attach(|py| {
140 let cloned_groups = self
141 .param_groups
142 .iter()
143 .map(|group| {
144 group
145 .iter()
146 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
147 .collect()
148 })
149 .collect();
150 Ok(cloned_groups)
151 })
152 }
153
154 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
156 let mut state = HashMap::new();
157 Python::attach(|py| {
158 state.insert(
159 "step".to_string(),
160 0i64.into_pyobject(py)
161 .expect("Python object conversion should succeed")
162 .into_any()
163 .unbind(),
164 );
165 state.insert(
166 "square_avg".to_string(),
167 "{}".into_pyobject(py)
168 .expect("Python object conversion should succeed")
169 .into_any()
170 .unbind(),
171 );
172 if self.momentum > 0.0 {
173 state.insert(
174 "momentum_buffer".to_string(),
175 "{}".into_pyobject(py)
176 .expect("Python object conversion should succeed")
177 .into_any()
178 .unbind(),
179 );
180 }
181 if self.centered {
182 state.insert(
183 "grad_avg".to_string(),
184 "{}".into_pyobject(py)
185 .expect("Python object conversion should succeed")
186 .into_any()
187 .unbind(),
188 );
189 }
190 });
191 Ok(state)
192 }
193
194 fn __repr__(&self) -> String {
196 format!(
197 "RMSprop(lr={}, alpha={}, eps={}, weight_decay={}, momentum={}, centered={})",
198 self.lr, self.alpha, self.eps, self.weight_decay, self.momentum, self.centered
199 )
200 }
201
202 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
204 let mut defaults = HashMap::new();
205 Python::attach(|py| {
206 defaults.insert(
207 "lr".to_string(),
208 self.lr
209 .into_pyobject(py)
210 .expect("Python object conversion should succeed")
211 .into_any()
212 .unbind(),
213 );
214 defaults.insert(
215 "alpha".to_string(),
216 self.alpha
217 .into_pyobject(py)
218 .expect("Python object conversion should succeed")
219 .into_any()
220 .unbind(),
221 );
222 defaults.insert(
223 "eps".to_string(),
224 self.eps
225 .into_pyobject(py)
226 .expect("Python object conversion should succeed")
227 .into_any()
228 .unbind(),
229 );
230 defaults.insert(
231 "weight_decay".to_string(),
232 self.weight_decay
233 .into_pyobject(py)
234 .expect("Python object conversion should succeed")
235 .into_any()
236 .unbind(),
237 );
238 defaults.insert(
239 "momentum".to_string(),
240 self.momentum
241 .into_pyobject(py)
242 .expect("Python object conversion should succeed")
243 .into_any()
244 .unbind(),
245 );
246 defaults.insert(
247 "centered".to_string(),
248 PyBool::new(py, self.centered).to_owned().into(),
249 );
250 });
251 Ok(defaults)
252 }
253
254 #[getter]
256 fn lr(&self) -> f32 {
257 self.lr
258 }
259
260 #[setter]
262 fn set_lr(&mut self, lr: f32) {
263 self.lr = lr;
264 Python::attach(|py| {
265 for param_group in &mut self.param_groups {
266 param_group.insert(
267 "lr".to_string(),
268 lr.into_pyobject(py)
269 .expect("Python object conversion should succeed")
270 .into_any()
271 .unbind(),
272 );
273 }
274 });
275 }
276
277 #[getter]
279 fn alpha(&self) -> f32 {
280 self.alpha
281 }
282
283 #[getter]
285 fn eps(&self) -> f32 {
286 self.eps
287 }
288
289 #[getter]
291 fn weight_decay(&self) -> f32 {
292 self.weight_decay
293 }
294
295 #[getter]
297 fn momentum(&self) -> f32 {
298 self.momentum
299 }
300
301 #[getter]
303 fn centered(&self) -> bool {
304 self.centered
305 }
306}