1use super::base::{create_param_group, PyOptimizer};
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::{PyAny, PyBool};
7use std::collections::HashMap;
8use torsh_tensor::Tensor;
9
10#[pyclass(name = "SGD", extends = PyOptimizer)]
12pub struct PySGD {
13 parameters: Vec<Tensor<f32>>,
14 momentum_buffers: Vec<Option<Tensor<f32>>>,
15 param_groups: Vec<HashMap<String, Py<PyAny>>>,
16 lr: f32,
17 momentum: f32,
18 dampening: f32,
19 weight_decay: f32,
20 nesterov: bool,
21}
22
23#[pymethods]
24impl PySGD {
25 #[new]
26 fn new(
27 params: Vec<PyTensor>,
28 lr: f32,
29 momentum: Option<f32>,
30 dampening: Option<f32>,
31 weight_decay: Option<f32>,
32 nesterov: Option<bool>,
33 ) -> PyResult<(Self, PyOptimizer)> {
34 let momentum = momentum.unwrap_or(0.0);
35 let dampening = dampening.unwrap_or(0.0);
36 let weight_decay = weight_decay.unwrap_or(0.0);
37 let nesterov = nesterov.unwrap_or(false);
38
39 let parameters: Vec<Tensor<f32>> = params.iter().map(|p| p.tensor.clone()).collect();
41 let momentum_buffers = vec![None; parameters.len()];
42
43 let mut param_group_data = HashMap::new();
45 Python::attach(|py| {
46 param_group_data.insert(
47 "momentum".to_string(),
48 momentum
49 .into_pyobject(py)
50 .expect("Python object conversion should succeed")
51 .into_any()
52 .unbind(),
53 );
54 param_group_data.insert(
55 "dampening".to_string(),
56 dampening
57 .into_pyobject(py)
58 .expect("Python object conversion should succeed")
59 .into_any()
60 .unbind(),
61 );
62 param_group_data.insert(
63 "weight_decay".to_string(),
64 weight_decay
65 .into_pyobject(py)
66 .expect("Python object conversion should succeed")
67 .into_any()
68 .unbind(),
69 );
70 param_group_data.insert(
71 "nesterov".to_string(),
72 PyBool::new(py, nesterov).to_owned().into(),
73 );
74 });
75
76 let param_groups = vec![create_param_group(params, lr, param_group_data)?];
77
78 Ok((
79 Self {
80 parameters,
81 momentum_buffers,
82 param_groups,
83 lr,
84 momentum,
85 dampening,
86 weight_decay,
87 nesterov,
88 },
89 PyOptimizer {},
90 ))
91 }
92
93 fn step(&mut self) -> PyResult<()> {
95 for (i, param) in self.parameters.iter_mut().enumerate() {
96 if let Some(grad) = param.grad() {
97 let mut d_p = grad.clone();
98
99 if self.weight_decay != 0.0 {
101 let weight_decay_term = py_result!(param.mul_scalar(self.weight_decay))?;
102 d_p = py_result!(d_p.add(&weight_decay_term))?;
103 }
104
105 if self.momentum != 0.0 {
107 if let Some(ref mut buf) = self.momentum_buffers[i] {
108 let momentum_buf = py_result!(buf.mul_scalar(self.momentum))?;
110 *buf = py_result!(momentum_buf.add(&d_p))?;
111
112 if self.nesterov {
113 let momentum_term = py_result!(buf.mul_scalar(self.momentum))?;
114 d_p = py_result!(d_p.add(&momentum_term))?;
115 } else {
116 d_p = buf.clone();
117 }
118 } else {
119 self.momentum_buffers[i] = Some(d_p.clone());
121 if self.nesterov {
122 let momentum_term = py_result!(d_p.mul_scalar(self.momentum))?;
123 d_p = py_result!(d_p.add(&momentum_term))?;
124 }
125 }
126 }
127
128 let update = py_result!(d_p.mul_scalar(self.lr))?;
130 *param = py_result!(param.sub(&update))?;
131 }
132 }
133 Ok(())
134 }
135
136 fn zero_grad(&mut self, set_to_none: Option<bool>) {
138 let _set_to_none = set_to_none.unwrap_or(false);
139 for param in &mut self.parameters {
140 let _ = param.zero_grad();
141 }
142 }
143
144 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
146 Python::attach(|py| {
148 let cloned_groups = self
149 .param_groups
150 .iter()
151 .map(|group| {
152 group
153 .iter()
154 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
155 .collect()
156 })
157 .collect();
158 Ok(cloned_groups)
159 })
160 }
161
162 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
164 let mut state = HashMap::new();
166 Python::attach(|py| {
167 if self.momentum != 0.0 {
168 state.insert(
169 "momentum_buffer".to_string(),
170 "{}".into_pyobject(py)
171 .expect("Python object conversion should succeed")
172 .into_any()
173 .unbind(),
174 );
175 }
176 });
177 Ok(state)
178 }
179
180 fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
182 let mut state_dict = HashMap::new();
183 Python::attach(|py| {
184 state_dict.insert(
185 "state".to_string(),
186 self.state()
187 .expect("Python object conversion should succeed")
188 .into_pyobject(py)
189 .expect("Python object conversion should succeed")
190 .into_any()
191 .unbind(),
192 );
193 let param_groups_clone = self
194 .param_groups
195 .iter()
196 .map(|group| {
197 group
198 .iter()
199 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
200 .collect::<HashMap<String, Py<PyAny>>>()
201 })
202 .collect::<Vec<_>>();
203 state_dict.insert(
204 "param_groups".to_string(),
205 param_groups_clone
206 .into_pyobject(py)
207 .expect("Python object conversion should succeed")
208 .into_any()
209 .unbind(),
210 );
211 });
212 Ok(state_dict)
213 }
214
215 fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
217 let _state_dict = state_dict;
219 Ok(())
220 }
221
222 fn add_param_group(&mut self, mut param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
224 Python::attach(|py| {
226 if !param_group.contains_key("lr") {
227 param_group.insert(
228 "lr".to_string(),
229 self.lr
230 .into_pyobject(py)
231 .expect("Python object conversion should succeed")
232 .into_any()
233 .unbind(),
234 );
235 }
236 if !param_group.contains_key("momentum") {
237 param_group.insert(
238 "momentum".to_string(),
239 self.momentum
240 .into_pyobject(py)
241 .expect("Python object conversion should succeed")
242 .into_any()
243 .unbind(),
244 );
245 }
246 if !param_group.contains_key("dampening") {
247 param_group.insert(
248 "dampening".to_string(),
249 self.dampening
250 .into_pyobject(py)
251 .expect("Python object conversion should succeed")
252 .into_any()
253 .unbind(),
254 );
255 }
256 if !param_group.contains_key("weight_decay") {
257 param_group.insert(
258 "weight_decay".to_string(),
259 self.weight_decay
260 .into_pyobject(py)
261 .expect("Python object conversion should succeed")
262 .into_any()
263 .unbind(),
264 );
265 }
266 if !param_group.contains_key("nesterov") {
267 param_group.insert(
268 "nesterov".to_string(),
269 PyBool::new(py, self.nesterov).to_owned().into(),
270 );
271 }
272 });
273
274 self.param_groups.push(param_group);
275 Ok(())
276 }
277
278 fn __repr__(&self) -> String {
280 format!(
281 "SGD(lr={}, momentum={}, dampening={}, weight_decay={}, nesterov={})",
282 self.lr, self.momentum, self.dampening, self.weight_decay, self.nesterov
283 )
284 }
285
286 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
288 let mut defaults = HashMap::new();
289 Python::attach(|py| {
290 defaults.insert(
291 "lr".to_string(),
292 self.lr
293 .into_pyobject(py)
294 .expect("Python object conversion should succeed")
295 .into_any()
296 .unbind(),
297 );
298 defaults.insert(
299 "momentum".to_string(),
300 self.momentum
301 .into_pyobject(py)
302 .expect("Python object conversion should succeed")
303 .into_any()
304 .unbind(),
305 );
306 defaults.insert(
307 "dampening".to_string(),
308 self.dampening
309 .into_pyobject(py)
310 .expect("Python object conversion should succeed")
311 .into_any()
312 .unbind(),
313 );
314 defaults.insert(
315 "weight_decay".to_string(),
316 self.weight_decay
317 .into_pyobject(py)
318 .expect("Python object conversion should succeed")
319 .into_any()
320 .unbind(),
321 );
322 defaults.insert(
323 "nesterov".to_string(),
324 PyBool::new(py, self.nesterov).to_owned().into(),
325 );
326 });
327 Ok(defaults)
328 }
329
330 #[getter]
332 fn lr(&self) -> f32 {
333 self.lr
334 }
335
336 #[setter]
338 fn set_lr(&mut self, lr: f32) {
339 self.lr = lr;
340 Python::attach(|py| {
342 for param_group in &mut self.param_groups {
343 param_group.insert(
344 "lr".to_string(),
345 lr.into_pyobject(py)
346 .expect("Python object conversion should succeed")
347 .into_any()
348 .unbind(),
349 );
350 }
351 });
352 }
353
354 #[getter]
356 fn momentum(&self) -> f32 {
357 self.momentum
358 }
359
360 #[getter]
362 fn dampening(&self) -> f32 {
363 self.dampening
364 }
365
366 #[getter]
368 fn weight_decay(&self) -> f32 {
369 self.weight_decay
370 }
371
372 #[getter]
374 fn nesterov(&self) -> bool {
375 self.nesterov
376 }
377}