Skip to main content

torsh_python/optim/
sgd.rs

1//! SGD (Stochastic Gradient Descent) optimizer
2
3use super::base::{create_param_group, PyOptimizer};
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::{PyAny, PyBool};
7use std::collections::HashMap;
8use torsh_tensor::Tensor;
9
10/// SGD optimizer - Stochastic Gradient Descent
11#[pyclass(name = "SGD", extends = PyOptimizer)]
12pub struct PySGD {
13    parameters: Vec<Tensor<f32>>,
14    momentum_buffers: Vec<Option<Tensor<f32>>>,
15    param_groups: Vec<HashMap<String, Py<PyAny>>>,
16    lr: f32,
17    momentum: f32,
18    dampening: f32,
19    weight_decay: f32,
20    nesterov: bool,
21}
22
23#[pymethods]
24impl PySGD {
25    #[new]
26    fn new(
27        params: Vec<PyTensor>,
28        lr: f32,
29        momentum: Option<f32>,
30        dampening: Option<f32>,
31        weight_decay: Option<f32>,
32        nesterov: Option<bool>,
33    ) -> PyResult<(Self, PyOptimizer)> {
34        let momentum = momentum.unwrap_or(0.0);
35        let dampening = dampening.unwrap_or(0.0);
36        let weight_decay = weight_decay.unwrap_or(0.0);
37        let nesterov = nesterov.unwrap_or(false);
38
39        // Extract tensor parameters
40        let parameters: Vec<Tensor<f32>> = params.iter().map(|p| p.tensor.clone()).collect();
41        let momentum_buffers = vec![None; parameters.len()];
42
43        // Create parameter groups
44        let mut param_group_data = HashMap::new();
45        Python::attach(|py| {
46            param_group_data.insert(
47                "momentum".to_string(),
48                momentum
49                    .into_pyobject(py)
50                    .expect("Python object conversion should succeed")
51                    .into_any()
52                    .unbind(),
53            );
54            param_group_data.insert(
55                "dampening".to_string(),
56                dampening
57                    .into_pyobject(py)
58                    .expect("Python object conversion should succeed")
59                    .into_any()
60                    .unbind(),
61            );
62            param_group_data.insert(
63                "weight_decay".to_string(),
64                weight_decay
65                    .into_pyobject(py)
66                    .expect("Python object conversion should succeed")
67                    .into_any()
68                    .unbind(),
69            );
70            param_group_data.insert(
71                "nesterov".to_string(),
72                PyBool::new(py, nesterov).to_owned().into(),
73            );
74        });
75
76        let param_groups = vec![create_param_group(params, lr, param_group_data)?];
77
78        Ok((
79            Self {
80                parameters,
81                momentum_buffers,
82                param_groups,
83                lr,
84                momentum,
85                dampening,
86                weight_decay,
87                nesterov,
88            },
89            PyOptimizer {},
90        ))
91    }
92
93    /// Perform a single optimization step
94    fn step(&mut self) -> PyResult<()> {
95        for (i, param) in self.parameters.iter_mut().enumerate() {
96            if let Some(grad) = param.grad() {
97                let mut d_p = grad.clone();
98
99                // Apply weight decay if specified
100                if self.weight_decay != 0.0 {
101                    let weight_decay_term = py_result!(param.mul_scalar(self.weight_decay))?;
102                    d_p = py_result!(d_p.add(&weight_decay_term))?;
103                }
104
105                // Apply momentum if specified
106                if self.momentum != 0.0 {
107                    if let Some(ref mut buf) = self.momentum_buffers[i] {
108                        // buf = momentum * buf + d_p
109                        let momentum_buf = py_result!(buf.mul_scalar(self.momentum))?;
110                        *buf = py_result!(momentum_buf.add(&d_p))?;
111
112                        if self.nesterov {
113                            let momentum_term = py_result!(buf.mul_scalar(self.momentum))?;
114                            d_p = py_result!(d_p.add(&momentum_term))?;
115                        } else {
116                            d_p = buf.clone();
117                        }
118                    } else {
119                        // Initialize momentum buffer
120                        self.momentum_buffers[i] = Some(d_p.clone());
121                        if self.nesterov {
122                            let momentum_term = py_result!(d_p.mul_scalar(self.momentum))?;
123                            d_p = py_result!(d_p.add(&momentum_term))?;
124                        }
125                    }
126                }
127
128                // Update parameter: param = param - lr * d_p
129                let update = py_result!(d_p.mul_scalar(self.lr))?;
130                *param = py_result!(param.sub(&update))?;
131            }
132        }
133        Ok(())
134    }
135
136    /// Zero out gradients of all parameters
137    fn zero_grad(&mut self, set_to_none: Option<bool>) {
138        let _set_to_none = set_to_none.unwrap_or(false);
139        for param in &mut self.parameters {
140            let _ = param.zero_grad();
141        }
142    }
143
144    /// Get parameter groups
145    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
146        // Manual clone since Py<PyAny> doesn't implement Clone
147        Python::attach(|py| {
148            let cloned_groups = self
149                .param_groups
150                .iter()
151                .map(|group| {
152                    group
153                        .iter()
154                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
155                        .collect()
156                })
157                .collect();
158            Ok(cloned_groups)
159        })
160    }
161
162    /// Get current state
163    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
164        // For SGD, state includes momentum buffers
165        let mut state = HashMap::new();
166        Python::attach(|py| {
167            if self.momentum != 0.0 {
168                state.insert(
169                    "momentum_buffer".to_string(),
170                    "{}".into_pyobject(py)
171                        .expect("Python object conversion should succeed")
172                        .into_any()
173                        .unbind(),
174                );
175            }
176        });
177        Ok(state)
178    }
179
180    /// Get state dictionary
181    fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
182        let mut state_dict = HashMap::new();
183        Python::attach(|py| {
184            state_dict.insert(
185                "state".to_string(),
186                self.state()
187                    .expect("Python object conversion should succeed")
188                    .into_pyobject(py)
189                    .expect("Python object conversion should succeed")
190                    .into_any()
191                    .unbind(),
192            );
193            let param_groups_clone = self
194                .param_groups
195                .iter()
196                .map(|group| {
197                    group
198                        .iter()
199                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
200                        .collect::<HashMap<String, Py<PyAny>>>()
201                })
202                .collect::<Vec<_>>();
203            state_dict.insert(
204                "param_groups".to_string(),
205                param_groups_clone
206                    .into_pyobject(py)
207                    .expect("Python object conversion should succeed")
208                    .into_any()
209                    .unbind(),
210            );
211        });
212        Ok(state_dict)
213    }
214
215    /// Load state dictionary
216    fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
217        // Implementation for loading state dict
218        let _state_dict = state_dict;
219        Ok(())
220    }
221
222    /// Add a new parameter group
223    fn add_param_group(&mut self, mut param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
224        // Set default values if not provided
225        Python::attach(|py| {
226            if !param_group.contains_key("lr") {
227                param_group.insert(
228                    "lr".to_string(),
229                    self.lr
230                        .into_pyobject(py)
231                        .expect("Python object conversion should succeed")
232                        .into_any()
233                        .unbind(),
234                );
235            }
236            if !param_group.contains_key("momentum") {
237                param_group.insert(
238                    "momentum".to_string(),
239                    self.momentum
240                        .into_pyobject(py)
241                        .expect("Python object conversion should succeed")
242                        .into_any()
243                        .unbind(),
244                );
245            }
246            if !param_group.contains_key("dampening") {
247                param_group.insert(
248                    "dampening".to_string(),
249                    self.dampening
250                        .into_pyobject(py)
251                        .expect("Python object conversion should succeed")
252                        .into_any()
253                        .unbind(),
254                );
255            }
256            if !param_group.contains_key("weight_decay") {
257                param_group.insert(
258                    "weight_decay".to_string(),
259                    self.weight_decay
260                        .into_pyobject(py)
261                        .expect("Python object conversion should succeed")
262                        .into_any()
263                        .unbind(),
264                );
265            }
266            if !param_group.contains_key("nesterov") {
267                param_group.insert(
268                    "nesterov".to_string(),
269                    PyBool::new(py, self.nesterov).to_owned().into(),
270                );
271            }
272        });
273
274        self.param_groups.push(param_group);
275        Ok(())
276    }
277
278    /// String representation
279    fn __repr__(&self) -> String {
280        format!(
281            "SGD(lr={}, momentum={}, dampening={}, weight_decay={}, nesterov={})",
282            self.lr, self.momentum, self.dampening, self.weight_decay, self.nesterov
283        )
284    }
285
286    /// Get defaults (default hyperparameters)
287    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
288        let mut defaults = HashMap::new();
289        Python::attach(|py| {
290            defaults.insert(
291                "lr".to_string(),
292                self.lr
293                    .into_pyobject(py)
294                    .expect("Python object conversion should succeed")
295                    .into_any()
296                    .unbind(),
297            );
298            defaults.insert(
299                "momentum".to_string(),
300                self.momentum
301                    .into_pyobject(py)
302                    .expect("Python object conversion should succeed")
303                    .into_any()
304                    .unbind(),
305            );
306            defaults.insert(
307                "dampening".to_string(),
308                self.dampening
309                    .into_pyobject(py)
310                    .expect("Python object conversion should succeed")
311                    .into_any()
312                    .unbind(),
313            );
314            defaults.insert(
315                "weight_decay".to_string(),
316                self.weight_decay
317                    .into_pyobject(py)
318                    .expect("Python object conversion should succeed")
319                    .into_any()
320                    .unbind(),
321            );
322            defaults.insert(
323                "nesterov".to_string(),
324                PyBool::new(py, self.nesterov).to_owned().into(),
325            );
326        });
327        Ok(defaults)
328    }
329
330    /// Get learning rate
331    #[getter]
332    fn lr(&self) -> f32 {
333        self.lr
334    }
335
336    /// Set learning rate
337    #[setter]
338    fn set_lr(&mut self, lr: f32) {
339        self.lr = lr;
340        // Update all parameter groups
341        Python::attach(|py| {
342            for param_group in &mut self.param_groups {
343                param_group.insert(
344                    "lr".to_string(),
345                    lr.into_pyobject(py)
346                        .expect("Python object conversion should succeed")
347                        .into_any()
348                        .unbind(),
349                );
350            }
351        });
352    }
353
354    /// Get momentum
355    #[getter]
356    fn momentum(&self) -> f32 {
357        self.momentum
358    }
359
360    /// Get dampening
361    #[getter]
362    fn dampening(&self) -> f32 {
363        self.dampening
364    }
365
366    /// Get weight decay
367    #[getter]
368    fn weight_decay(&self) -> f32 {
369        self.weight_decay
370    }
371
372    /// Get nesterov flag
373    #[getter]
374    fn nesterov(&self) -> bool {
375        self.nesterov
376    }
377}