1use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{rmsprop::RMSprop, Optimizer};
11
12#[pyclass(name = "RMSprop", extends = PyOptimizer)]
14pub struct PyRMSprop {
15 rmsprop: RMSprop,
16 param_groups: Vec<HashMap<String, Py<PyAny>>>,
17 lr: f32,
18 alpha: f32,
19 eps: f32,
20 weight_decay: f32,
21 momentum: f32,
22 centered: bool,
23}
24
25#[pymethods]
26impl PyRMSprop {
27 #[new]
28 fn new(
29 params: Vec<PyTensor>,
30 lr: Option<f32>,
31 alpha: Option<f32>,
32 eps: Option<f32>,
33 weight_decay: Option<f32>,
34 momentum: Option<f32>,
35 centered: Option<bool>,
36 ) -> (Self, PyOptimizer) {
37 let lr = lr.unwrap_or(0.01);
38 let alpha = alpha.unwrap_or(0.99);
39 let eps = eps.unwrap_or(1e-8);
40 let weight_decay = weight_decay.unwrap_or(0.0);
41 let momentum = momentum.unwrap_or(0.0);
42 let centered = centered.unwrap_or(false);
43
44 let tensor_params = extract_parameters(params.clone()).unwrap();
46 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
47 .into_iter()
48 .map(|tensor| Arc::new(RwLock::new(tensor)))
49 .collect();
50 let rmsprop = RMSprop::new(
51 wrapped_params,
52 Some(lr),
53 Some(alpha),
54 Some(eps),
55 Some(weight_decay),
56 Some(momentum),
57 centered,
58 );
59
60 let mut param_group_data = HashMap::new();
62 Python::attach(|py| {
63 param_group_data.insert(
64 "alpha".to_string(),
65 alpha.into_pyobject(py).unwrap().into_any().unbind(),
66 );
67 param_group_data.insert(
68 "eps".to_string(),
69 eps.into_pyobject(py).unwrap().into_any().unbind(),
70 );
71 param_group_data.insert(
72 "weight_decay".to_string(),
73 weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
74 );
75 param_group_data.insert(
76 "momentum".to_string(),
77 momentum.into_pyobject(py).unwrap().into_any().unbind(),
78 );
79 param_group_data.insert(
80 "centered".to_string(),
81 PyBool::new(py, centered).to_owned().into(),
82 );
83 });
84
85 let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
86
87 (
88 Self {
89 rmsprop,
90 param_groups,
91 lr,
92 alpha,
93 eps,
94 weight_decay,
95 momentum,
96 centered,
97 },
98 PyOptimizer {},
99 )
100 }
101
102 fn step(&mut self) -> PyResult<()> {
104 self.rmsprop.step().map_err(|e| {
105 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
106 "Optimizer step failed: {}",
107 e
108 ))
109 })?;
110 Ok(())
111 }
112
113 fn zero_grad(&mut self, set_to_none: Option<bool>) {
115 let _set_to_none = set_to_none.unwrap_or(false);
116 self.rmsprop.zero_grad();
117 }
118
119 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
121 Python::attach(|py| {
123 let cloned_groups = self
124 .param_groups
125 .iter()
126 .map(|group| {
127 group
128 .iter()
129 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
130 .collect()
131 })
132 .collect();
133 Ok(cloned_groups)
134 })
135 }
136
137 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
139 let mut state = HashMap::new();
140 Python::attach(|py| {
141 state.insert(
142 "step".to_string(),
143 0i64.into_pyobject(py).unwrap().into_any().unbind(),
144 );
145 state.insert(
146 "square_avg".to_string(),
147 "{}".into_pyobject(py).unwrap().into_any().unbind(),
148 );
149 if self.momentum > 0.0 {
150 state.insert(
151 "momentum_buffer".to_string(),
152 "{}".into_pyobject(py).unwrap().into_any().unbind(),
153 );
154 }
155 if self.centered {
156 state.insert(
157 "grad_avg".to_string(),
158 "{}".into_pyobject(py).unwrap().into_any().unbind(),
159 );
160 }
161 });
162 Ok(state)
163 }
164
165 fn __repr__(&self) -> String {
167 format!(
168 "RMSprop(lr={}, alpha={}, eps={}, weight_decay={}, momentum={}, centered={})",
169 self.lr, self.alpha, self.eps, self.weight_decay, self.momentum, self.centered
170 )
171 }
172
173 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
175 let mut defaults = HashMap::new();
176 Python::attach(|py| {
177 defaults.insert(
178 "lr".to_string(),
179 self.lr.into_pyobject(py).unwrap().into_any().unbind(),
180 );
181 defaults.insert(
182 "alpha".to_string(),
183 self.alpha.into_pyobject(py).unwrap().into_any().unbind(),
184 );
185 defaults.insert(
186 "eps".to_string(),
187 self.eps.into_pyobject(py).unwrap().into_any().unbind(),
188 );
189 defaults.insert(
190 "weight_decay".to_string(),
191 self.weight_decay
192 .into_pyobject(py)
193 .unwrap()
194 .into_any()
195 .unbind(),
196 );
197 defaults.insert(
198 "momentum".to_string(),
199 self.momentum.into_pyobject(py).unwrap().into_any().unbind(),
200 );
201 defaults.insert(
202 "centered".to_string(),
203 PyBool::new(py, self.centered).to_owned().into(),
204 );
205 });
206 Ok(defaults)
207 }
208
209 #[getter]
211 fn lr(&self) -> f32 {
212 self.lr
213 }
214
215 #[setter]
217 fn set_lr(&mut self, lr: f32) {
218 self.lr = lr;
219 Python::attach(|py| {
220 for param_group in &mut self.param_groups {
221 param_group.insert(
222 "lr".to_string(),
223 lr.into_pyobject(py).unwrap().into_any().unbind(),
224 );
225 }
226 });
227 }
228
229 #[getter]
231 fn alpha(&self) -> f32 {
232 self.alpha
233 }
234
235 #[getter]
237 fn eps(&self) -> f32 {
238 self.eps
239 }
240
241 #[getter]
243 fn weight_decay(&self) -> f32 {
244 self.weight_decay
245 }
246
247 #[getter]
249 fn momentum(&self) -> f32 {
250 self.momentum
251 }
252
253 #[getter]
255 fn centered(&self) -> bool {
256 self.centered
257 }
258}