torsh_python/optim/
sgd.rs

1//! SGD (Stochastic Gradient Descent) optimizer
2
3use super::base::{create_param_group, PyOptimizer};
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::{PyAny, PyBool};
7use std::collections::HashMap;
8use torsh_tensor::Tensor;
9
10/// SGD optimizer - Stochastic Gradient Descent
11#[pyclass(name = "SGD", extends = PyOptimizer)]
12pub struct PySGD {
13    parameters: Vec<Tensor<f32>>,
14    momentum_buffers: Vec<Option<Tensor<f32>>>,
15    param_groups: Vec<HashMap<String, Py<PyAny>>>,
16    lr: f32,
17    momentum: f32,
18    dampening: f32,
19    weight_decay: f32,
20    nesterov: bool,
21}
22
23#[pymethods]
24impl PySGD {
25    #[new]
26    fn new(
27        params: Vec<PyTensor>,
28        lr: f32,
29        momentum: Option<f32>,
30        dampening: Option<f32>,
31        weight_decay: Option<f32>,
32        nesterov: Option<bool>,
33    ) -> PyResult<(Self, PyOptimizer)> {
34        let momentum = momentum.unwrap_or(0.0);
35        let dampening = dampening.unwrap_or(0.0);
36        let weight_decay = weight_decay.unwrap_or(0.0);
37        let nesterov = nesterov.unwrap_or(false);
38
39        // Extract tensor parameters
40        let parameters: Vec<Tensor<f32>> = params.iter().map(|p| p.tensor.clone()).collect();
41        let momentum_buffers = vec![None; parameters.len()];
42
43        // Create parameter groups
44        let mut param_group_data = HashMap::new();
45        Python::attach(|py| {
46            param_group_data.insert(
47                "momentum".to_string(),
48                momentum.into_pyobject(py).unwrap().into_any().unbind(),
49            );
50            param_group_data.insert(
51                "dampening".to_string(),
52                dampening.into_pyobject(py).unwrap().into_any().unbind(),
53            );
54            param_group_data.insert(
55                "weight_decay".to_string(),
56                weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
57            );
58            param_group_data.insert(
59                "nesterov".to_string(),
60                PyBool::new(py, nesterov).to_owned().into(),
61            );
62        });
63
64        let param_groups = vec![create_param_group(params, lr, param_group_data)?];
65
66        Ok((
67            Self {
68                parameters,
69                momentum_buffers,
70                param_groups,
71                lr,
72                momentum,
73                dampening,
74                weight_decay,
75                nesterov,
76            },
77            PyOptimizer {},
78        ))
79    }
80
81    /// Perform a single optimization step
82    fn step(&mut self) -> PyResult<()> {
83        for (i, param) in self.parameters.iter_mut().enumerate() {
84            if let Some(grad) = param.grad() {
85                let mut d_p = grad.clone();
86
87                // Apply weight decay if specified
88                if self.weight_decay != 0.0 {
89                    let weight_decay_term = py_result!(param.mul_scalar(self.weight_decay))?;
90                    d_p = py_result!(d_p.add(&weight_decay_term))?;
91                }
92
93                // Apply momentum if specified
94                if self.momentum != 0.0 {
95                    if let Some(ref mut buf) = self.momentum_buffers[i] {
96                        // buf = momentum * buf + d_p
97                        let momentum_buf = py_result!(buf.mul_scalar(self.momentum))?;
98                        *buf = py_result!(momentum_buf.add(&d_p))?;
99
100                        if self.nesterov {
101                            let momentum_term = py_result!(buf.mul_scalar(self.momentum))?;
102                            d_p = py_result!(d_p.add(&momentum_term))?;
103                        } else {
104                            d_p = buf.clone();
105                        }
106                    } else {
107                        // Initialize momentum buffer
108                        self.momentum_buffers[i] = Some(d_p.clone());
109                        if self.nesterov {
110                            let momentum_term = py_result!(d_p.mul_scalar(self.momentum))?;
111                            d_p = py_result!(d_p.add(&momentum_term))?;
112                        }
113                    }
114                }
115
116                // Update parameter: param = param - lr * d_p
117                let update = py_result!(d_p.mul_scalar(self.lr))?;
118                *param = py_result!(param.sub(&update))?;
119            }
120        }
121        Ok(())
122    }
123
124    /// Zero out gradients of all parameters
125    fn zero_grad(&mut self, set_to_none: Option<bool>) {
126        let _set_to_none = set_to_none.unwrap_or(false);
127        for param in &mut self.parameters {
128            let _ = param.zero_grad();
129        }
130    }
131
132    /// Get parameter groups
133    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
134        // Manual clone since Py<PyAny> doesn't implement Clone
135        Python::attach(|py| {
136            let cloned_groups = self
137                .param_groups
138                .iter()
139                .map(|group| {
140                    group
141                        .iter()
142                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
143                        .collect()
144                })
145                .collect();
146            Ok(cloned_groups)
147        })
148    }
149
150    /// Get current state
151    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
152        // For SGD, state includes momentum buffers
153        let mut state = HashMap::new();
154        Python::attach(|py| {
155            if self.momentum != 0.0 {
156                state.insert(
157                    "momentum_buffer".to_string(),
158                    "{}".into_pyobject(py).unwrap().into_any().unbind(),
159                );
160            }
161        });
162        Ok(state)
163    }
164
165    /// Get state dictionary
166    fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
167        let mut state_dict = HashMap::new();
168        Python::attach(|py| {
169            state_dict.insert(
170                "state".to_string(),
171                self.state()
172                    .unwrap()
173                    .into_pyobject(py)
174                    .unwrap()
175                    .into_any()
176                    .unbind(),
177            );
178            let param_groups_clone = self
179                .param_groups
180                .iter()
181                .map(|group| {
182                    group
183                        .iter()
184                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
185                        .collect::<HashMap<String, Py<PyAny>>>()
186                })
187                .collect::<Vec<_>>();
188            state_dict.insert(
189                "param_groups".to_string(),
190                param_groups_clone
191                    .into_pyobject(py)
192                    .unwrap()
193                    .into_any()
194                    .unbind(),
195            );
196        });
197        Ok(state_dict)
198    }
199
200    /// Load state dictionary
201    fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
202        // Implementation for loading state dict
203        let _state_dict = state_dict;
204        Ok(())
205    }
206
207    /// Add a new parameter group
208    fn add_param_group(&mut self, mut param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
209        // Set default values if not provided
210        Python::attach(|py| {
211            if !param_group.contains_key("lr") {
212                param_group.insert(
213                    "lr".to_string(),
214                    self.lr.into_pyobject(py).unwrap().into_any().unbind(),
215                );
216            }
217            if !param_group.contains_key("momentum") {
218                param_group.insert(
219                    "momentum".to_string(),
220                    self.momentum.into_pyobject(py).unwrap().into_any().unbind(),
221                );
222            }
223            if !param_group.contains_key("dampening") {
224                param_group.insert(
225                    "dampening".to_string(),
226                    self.dampening
227                        .into_pyobject(py)
228                        .unwrap()
229                        .into_any()
230                        .unbind(),
231                );
232            }
233            if !param_group.contains_key("weight_decay") {
234                param_group.insert(
235                    "weight_decay".to_string(),
236                    self.weight_decay
237                        .into_pyobject(py)
238                        .unwrap()
239                        .into_any()
240                        .unbind(),
241                );
242            }
243            if !param_group.contains_key("nesterov") {
244                param_group.insert(
245                    "nesterov".to_string(),
246                    PyBool::new(py, self.nesterov).to_owned().into(),
247                );
248            }
249        });
250
251        self.param_groups.push(param_group);
252        Ok(())
253    }
254
255    /// String representation
256    fn __repr__(&self) -> String {
257        format!(
258            "SGD(lr={}, momentum={}, dampening={}, weight_decay={}, nesterov={})",
259            self.lr, self.momentum, self.dampening, self.weight_decay, self.nesterov
260        )
261    }
262
263    /// Get defaults (default hyperparameters)
264    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
265        let mut defaults = HashMap::new();
266        Python::attach(|py| {
267            defaults.insert(
268                "lr".to_string(),
269                self.lr.into_pyobject(py).unwrap().into_any().unbind(),
270            );
271            defaults.insert(
272                "momentum".to_string(),
273                self.momentum.into_pyobject(py).unwrap().into_any().unbind(),
274            );
275            defaults.insert(
276                "dampening".to_string(),
277                self.dampening
278                    .into_pyobject(py)
279                    .unwrap()
280                    .into_any()
281                    .unbind(),
282            );
283            defaults.insert(
284                "weight_decay".to_string(),
285                self.weight_decay
286                    .into_pyobject(py)
287                    .unwrap()
288                    .into_any()
289                    .unbind(),
290            );
291            defaults.insert(
292                "nesterov".to_string(),
293                PyBool::new(py, self.nesterov).to_owned().into(),
294            );
295        });
296        Ok(defaults)
297    }
298
299    /// Get learning rate
300    #[getter]
301    fn lr(&self) -> f32 {
302        self.lr
303    }
304
305    /// Set learning rate
306    #[setter]
307    fn set_lr(&mut self, lr: f32) {
308        self.lr = lr;
309        // Update all parameter groups
310        Python::attach(|py| {
311            for param_group in &mut self.param_groups {
312                param_group.insert(
313                    "lr".to_string(),
314                    lr.into_pyobject(py).unwrap().into_any().unbind(),
315                );
316            }
317        });
318    }
319
320    /// Get momentum
321    #[getter]
322    fn momentum(&self) -> f32 {
323        self.momentum
324    }
325
326    /// Get dampening
327    #[getter]
328    fn dampening(&self) -> f32 {
329        self.dampening
330    }
331
332    /// Get weight decay
333    #[getter]
334    fn weight_decay(&self) -> f32 {
335        self.weight_decay
336    }
337
338    /// Get nesterov flag
339    #[getter]
340    fn nesterov(&self) -> bool {
341        self.nesterov
342    }
343}