1use super::base::{create_param_group, PyOptimizer};
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::{PyAny, PyBool};
7use std::collections::HashMap;
8use torsh_tensor::Tensor;
9
10#[pyclass(name = "SGD", extends = PyOptimizer)]
12pub struct PySGD {
13 parameters: Vec<Tensor<f32>>,
14 momentum_buffers: Vec<Option<Tensor<f32>>>,
15 param_groups: Vec<HashMap<String, Py<PyAny>>>,
16 lr: f32,
17 momentum: f32,
18 dampening: f32,
19 weight_decay: f32,
20 nesterov: bool,
21}
22
23#[pymethods]
24impl PySGD {
25 #[new]
26 fn new(
27 params: Vec<PyTensor>,
28 lr: f32,
29 momentum: Option<f32>,
30 dampening: Option<f32>,
31 weight_decay: Option<f32>,
32 nesterov: Option<bool>,
33 ) -> PyResult<(Self, PyOptimizer)> {
34 let momentum = momentum.unwrap_or(0.0);
35 let dampening = dampening.unwrap_or(0.0);
36 let weight_decay = weight_decay.unwrap_or(0.0);
37 let nesterov = nesterov.unwrap_or(false);
38
39 let parameters: Vec<Tensor<f32>> = params.iter().map(|p| p.tensor.clone()).collect();
41 let momentum_buffers = vec![None; parameters.len()];
42
43 let mut param_group_data = HashMap::new();
45 Python::attach(|py| {
46 param_group_data.insert(
47 "momentum".to_string(),
48 momentum.into_pyobject(py).unwrap().into_any().unbind(),
49 );
50 param_group_data.insert(
51 "dampening".to_string(),
52 dampening.into_pyobject(py).unwrap().into_any().unbind(),
53 );
54 param_group_data.insert(
55 "weight_decay".to_string(),
56 weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
57 );
58 param_group_data.insert(
59 "nesterov".to_string(),
60 PyBool::new(py, nesterov).to_owned().into(),
61 );
62 });
63
64 let param_groups = vec![create_param_group(params, lr, param_group_data)?];
65
66 Ok((
67 Self {
68 parameters,
69 momentum_buffers,
70 param_groups,
71 lr,
72 momentum,
73 dampening,
74 weight_decay,
75 nesterov,
76 },
77 PyOptimizer {},
78 ))
79 }
80
81 fn step(&mut self) -> PyResult<()> {
83 for (i, param) in self.parameters.iter_mut().enumerate() {
84 if let Some(grad) = param.grad() {
85 let mut d_p = grad.clone();
86
87 if self.weight_decay != 0.0 {
89 let weight_decay_term = py_result!(param.mul_scalar(self.weight_decay))?;
90 d_p = py_result!(d_p.add(&weight_decay_term))?;
91 }
92
93 if self.momentum != 0.0 {
95 if let Some(ref mut buf) = self.momentum_buffers[i] {
96 let momentum_buf = py_result!(buf.mul_scalar(self.momentum))?;
98 *buf = py_result!(momentum_buf.add(&d_p))?;
99
100 if self.nesterov {
101 let momentum_term = py_result!(buf.mul_scalar(self.momentum))?;
102 d_p = py_result!(d_p.add(&momentum_term))?;
103 } else {
104 d_p = buf.clone();
105 }
106 } else {
107 self.momentum_buffers[i] = Some(d_p.clone());
109 if self.nesterov {
110 let momentum_term = py_result!(d_p.mul_scalar(self.momentum))?;
111 d_p = py_result!(d_p.add(&momentum_term))?;
112 }
113 }
114 }
115
116 let update = py_result!(d_p.mul_scalar(self.lr))?;
118 *param = py_result!(param.sub(&update))?;
119 }
120 }
121 Ok(())
122 }
123
124 fn zero_grad(&mut self, set_to_none: Option<bool>) {
126 let _set_to_none = set_to_none.unwrap_or(false);
127 for param in &mut self.parameters {
128 let _ = param.zero_grad();
129 }
130 }
131
132 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
134 Python::attach(|py| {
136 let cloned_groups = self
137 .param_groups
138 .iter()
139 .map(|group| {
140 group
141 .iter()
142 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
143 .collect()
144 })
145 .collect();
146 Ok(cloned_groups)
147 })
148 }
149
150 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
152 let mut state = HashMap::new();
154 Python::attach(|py| {
155 if self.momentum != 0.0 {
156 state.insert(
157 "momentum_buffer".to_string(),
158 "{}".into_pyobject(py).unwrap().into_any().unbind(),
159 );
160 }
161 });
162 Ok(state)
163 }
164
165 fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
167 let mut state_dict = HashMap::new();
168 Python::attach(|py| {
169 state_dict.insert(
170 "state".to_string(),
171 self.state()
172 .unwrap()
173 .into_pyobject(py)
174 .unwrap()
175 .into_any()
176 .unbind(),
177 );
178 let param_groups_clone = self
179 .param_groups
180 .iter()
181 .map(|group| {
182 group
183 .iter()
184 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
185 .collect::<HashMap<String, Py<PyAny>>>()
186 })
187 .collect::<Vec<_>>();
188 state_dict.insert(
189 "param_groups".to_string(),
190 param_groups_clone
191 .into_pyobject(py)
192 .unwrap()
193 .into_any()
194 .unbind(),
195 );
196 });
197 Ok(state_dict)
198 }
199
200 fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
202 let _state_dict = state_dict;
204 Ok(())
205 }
206
207 fn add_param_group(&mut self, mut param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
209 Python::attach(|py| {
211 if !param_group.contains_key("lr") {
212 param_group.insert(
213 "lr".to_string(),
214 self.lr.into_pyobject(py).unwrap().into_any().unbind(),
215 );
216 }
217 if !param_group.contains_key("momentum") {
218 param_group.insert(
219 "momentum".to_string(),
220 self.momentum.into_pyobject(py).unwrap().into_any().unbind(),
221 );
222 }
223 if !param_group.contains_key("dampening") {
224 param_group.insert(
225 "dampening".to_string(),
226 self.dampening
227 .into_pyobject(py)
228 .unwrap()
229 .into_any()
230 .unbind(),
231 );
232 }
233 if !param_group.contains_key("weight_decay") {
234 param_group.insert(
235 "weight_decay".to_string(),
236 self.weight_decay
237 .into_pyobject(py)
238 .unwrap()
239 .into_any()
240 .unbind(),
241 );
242 }
243 if !param_group.contains_key("nesterov") {
244 param_group.insert(
245 "nesterov".to_string(),
246 PyBool::new(py, self.nesterov).to_owned().into(),
247 );
248 }
249 });
250
251 self.param_groups.push(param_group);
252 Ok(())
253 }
254
255 fn __repr__(&self) -> String {
257 format!(
258 "SGD(lr={}, momentum={}, dampening={}, weight_decay={}, nesterov={})",
259 self.lr, self.momentum, self.dampening, self.weight_decay, self.nesterov
260 )
261 }
262
263 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
265 let mut defaults = HashMap::new();
266 Python::attach(|py| {
267 defaults.insert(
268 "lr".to_string(),
269 self.lr.into_pyobject(py).unwrap().into_any().unbind(),
270 );
271 defaults.insert(
272 "momentum".to_string(),
273 self.momentum.into_pyobject(py).unwrap().into_any().unbind(),
274 );
275 defaults.insert(
276 "dampening".to_string(),
277 self.dampening
278 .into_pyobject(py)
279 .unwrap()
280 .into_any()
281 .unbind(),
282 );
283 defaults.insert(
284 "weight_decay".to_string(),
285 self.weight_decay
286 .into_pyobject(py)
287 .unwrap()
288 .into_any()
289 .unbind(),
290 );
291 defaults.insert(
292 "nesterov".to_string(),
293 PyBool::new(py, self.nesterov).to_owned().into(),
294 );
295 });
296 Ok(defaults)
297 }
298
299 #[getter]
301 fn lr(&self) -> f32 {
302 self.lr
303 }
304
305 #[setter]
307 fn set_lr(&mut self, lr: f32) {
308 self.lr = lr;
309 Python::attach(|py| {
311 for param_group in &mut self.param_groups {
312 param_group.insert(
313 "lr".to_string(),
314 lr.into_pyobject(py).unwrap().into_any().unbind(),
315 );
316 }
317 });
318 }
319
320 #[getter]
322 fn momentum(&self) -> f32 {
323 self.momentum
324 }
325
326 #[getter]
328 fn dampening(&self) -> f32 {
329 self.dampening
330 }
331
332 #[getter]
334 fn weight_decay(&self) -> f32 {
335 self.weight_decay
336 }
337
338 #[getter]
340 fn nesterov(&self) -> bool {
341 self.nesterov
342 }
343}