torsh_python/optim/
adam.rs

1//! Adam and AdamW optimizers
2
3use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{Adam, AdamW, Optimizer};
11
12/// Adam optimizer - Adaptive Moment Estimation
13#[pyclass(name = "Adam", extends = PyOptimizer)]
14pub struct PyAdam {
15    adam: Adam,
16    param_groups: Vec<HashMap<String, Py<PyAny>>>,
17    lr: f32,
18    betas: (f32, f32),
19    eps: f32,
20    weight_decay: f32,
21    amsgrad: bool,
22}
23
24#[pymethods]
25impl PyAdam {
26    #[new]
27    fn new(
28        params: Vec<PyTensor>,
29        lr: Option<f32>,
30        betas: Option<(f32, f32)>,
31        eps: Option<f32>,
32        weight_decay: Option<f32>,
33        amsgrad: Option<bool>,
34    ) -> (Self, PyOptimizer) {
35        let lr = lr.unwrap_or(0.001);
36        let betas = betas.unwrap_or((0.9, 0.999));
37        let eps = eps.unwrap_or(1e-8);
38        let weight_decay = weight_decay.unwrap_or(0.0);
39        let amsgrad = amsgrad.unwrap_or(false);
40
41        // Extract tensor parameters and wrap in Arc<RwLock>
42        let tensor_params = extract_parameters(params.clone()).unwrap();
43        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
44            .into_iter()
45            .map(|tensor| Arc::new(RwLock::new(tensor)))
46            .collect();
47        let adam = Adam::new(
48            wrapped_params,
49            Some(lr),
50            Some(betas),
51            Some(eps),
52            Some(weight_decay),
53            amsgrad,
54        );
55
56        // Create parameter groups
57        let mut param_group_data = HashMap::new();
58        Python::attach(|py| {
59            param_group_data.insert(
60                "betas".to_string(),
61                betas.into_pyobject(py).unwrap().into_any().unbind(),
62            );
63            param_group_data.insert(
64                "eps".to_string(),
65                eps.into_pyobject(py).unwrap().into_any().unbind(),
66            );
67            param_group_data.insert(
68                "weight_decay".to_string(),
69                weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
70            );
71            param_group_data.insert(
72                "amsgrad".to_string(),
73                PyBool::new(py, amsgrad).to_owned().into(),
74            );
75        });
76
77        let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
78
79        (
80            Self {
81                adam,
82                param_groups,
83                lr,
84                betas,
85                eps,
86                weight_decay,
87                amsgrad,
88            },
89            PyOptimizer {},
90        )
91    }
92
93    /// Perform a single optimization step
94    fn step(&mut self) -> PyResult<()> {
95        self.adam.step().map_err(|e| {
96            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
97                "Optimizer step failed: {}",
98                e
99            ))
100        })?;
101        Ok(())
102    }
103
104    /// Zero out gradients of all parameters
105    fn zero_grad(&mut self, set_to_none: Option<bool>) {
106        let _set_to_none = set_to_none.unwrap_or(false);
107        self.adam.zero_grad();
108    }
109
110    /// Get parameter groups
111    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
112        // Manual clone since Py<PyAny> doesn't implement Clone
113        Python::attach(|py| {
114            let cloned_groups = self
115                .param_groups
116                .iter()
117                .map(|group| {
118                    group
119                        .iter()
120                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
121                        .collect()
122                })
123                .collect();
124            Ok(cloned_groups)
125        })
126    }
127
128    /// Get current state
129    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
130        let mut state = HashMap::new();
131        Python::attach(|py| {
132            state.insert(
133                "step".to_string(),
134                0i64.into_pyobject(py).unwrap().into_any().unbind(),
135            );
136            state.insert(
137                "exp_avg".to_string(),
138                "{}".into_pyobject(py).unwrap().into_any().unbind(),
139            );
140            state.insert(
141                "exp_avg_sq".to_string(),
142                "{}".into_pyobject(py).unwrap().into_any().unbind(),
143            );
144            if self.amsgrad {
145                state.insert(
146                    "max_exp_avg_sq".to_string(),
147                    "{}".into_pyobject(py).unwrap().into_any().unbind(),
148                );
149            }
150        });
151        Ok(state)
152    }
153
154    /// String representation
155    fn __repr__(&self) -> String {
156        format!(
157            "Adam(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
158            self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
159        )
160    }
161
162    /// Get defaults
163    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
164        let mut defaults = HashMap::new();
165        Python::attach(|py| {
166            defaults.insert(
167                "lr".to_string(),
168                self.lr.into_pyobject(py).unwrap().into_any().unbind(),
169            );
170            defaults.insert(
171                "betas".to_string(),
172                self.betas.into_pyobject(py).unwrap().into_any().unbind(),
173            );
174            defaults.insert(
175                "eps".to_string(),
176                self.eps.into_pyobject(py).unwrap().into_any().unbind(),
177            );
178            defaults.insert(
179                "weight_decay".to_string(),
180                self.weight_decay
181                    .into_pyobject(py)
182                    .unwrap()
183                    .into_any()
184                    .unbind(),
185            );
186            defaults.insert(
187                "amsgrad".to_string(),
188                PyBool::new(py, self.amsgrad).to_owned().into(),
189            );
190        });
191        Ok(defaults)
192    }
193
194    /// Get learning rate
195    #[getter]
196    fn lr(&self) -> f32 {
197        self.lr
198    }
199
200    /// Set learning rate
201    #[setter]
202    fn set_lr(&mut self, lr: f32) {
203        self.lr = lr;
204        Python::attach(|py| {
205            for param_group in &mut self.param_groups {
206                param_group.insert(
207                    "lr".to_string(),
208                    lr.into_pyobject(py).unwrap().into_any().unbind(),
209                );
210            }
211        });
212    }
213
214    /// Get betas
215    #[getter]
216    fn betas(&self) -> (f32, f32) {
217        self.betas
218    }
219
220    /// Get eps
221    #[getter]
222    fn eps(&self) -> f32 {
223        self.eps
224    }
225
226    /// Get weight decay
227    #[getter]
228    fn weight_decay(&self) -> f32 {
229        self.weight_decay
230    }
231
232    /// Get amsgrad flag
233    #[getter]
234    fn amsgrad(&self) -> bool {
235        self.amsgrad
236    }
237}
238
239/// AdamW optimizer - Adam with decoupled weight decay
240#[pyclass(name = "AdamW", extends = PyOptimizer)]
241pub struct PyAdamW {
242    adamw: AdamW,
243    param_groups: Vec<HashMap<String, Py<PyAny>>>,
244    lr: f32,
245    betas: (f32, f32),
246    eps: f32,
247    weight_decay: f32,
248    amsgrad: bool,
249}
250
251#[pymethods]
252impl PyAdamW {
253    #[new]
254    fn new(
255        params: Vec<PyTensor>,
256        lr: Option<f32>,
257        betas: Option<(f32, f32)>,
258        eps: Option<f32>,
259        weight_decay: Option<f32>,
260        amsgrad: Option<bool>,
261    ) -> (Self, PyOptimizer) {
262        let lr = lr.unwrap_or(0.001);
263        let betas = betas.unwrap_or((0.9, 0.999));
264        let eps = eps.unwrap_or(1e-8);
265        let weight_decay = weight_decay.unwrap_or(0.01);
266        let amsgrad = amsgrad.unwrap_or(false);
267
268        // Extract tensor parameters and wrap in Arc<RwLock>
269        let tensor_params = extract_parameters(params.clone()).unwrap();
270        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
271            .into_iter()
272            .map(|tensor| Arc::new(RwLock::new(tensor)))
273            .collect();
274        let adamw = AdamW::new(
275            wrapped_params,
276            Some(lr),
277            Some(betas),
278            Some(eps),
279            Some(weight_decay),
280            amsgrad,
281        );
282
283        // Create parameter groups
284        let mut param_group_data = HashMap::new();
285        Python::attach(|py| {
286            param_group_data.insert(
287                "betas".to_string(),
288                betas.into_pyobject(py).unwrap().into_any().unbind(),
289            );
290            param_group_data.insert(
291                "eps".to_string(),
292                eps.into_pyobject(py).unwrap().into_any().unbind(),
293            );
294            param_group_data.insert(
295                "weight_decay".to_string(),
296                weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
297            );
298            param_group_data.insert(
299                "amsgrad".to_string(),
300                PyBool::new(py, amsgrad).to_owned().into(),
301            );
302        });
303
304        let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
305
306        (
307            Self {
308                adamw,
309                param_groups,
310                lr,
311                betas,
312                eps,
313                weight_decay,
314                amsgrad,
315            },
316            PyOptimizer {},
317        )
318    }
319
320    /// Perform a single optimization step
321    fn step(&mut self) -> PyResult<()> {
322        self.adamw.step().map_err(|e| {
323            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
324                "Optimizer step failed: {}",
325                e
326            ))
327        })?;
328        Ok(())
329    }
330
331    /// Zero out gradients of all parameters
332    fn zero_grad(&mut self, set_to_none: Option<bool>) {
333        let _set_to_none = set_to_none.unwrap_or(false);
334        self.adamw.zero_grad();
335    }
336
337    /// Get parameter groups
338    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
339        // Manual clone since Py<PyAny> doesn't implement Clone
340        Python::attach(|py| {
341            let cloned_groups = self
342                .param_groups
343                .iter()
344                .map(|group| {
345                    group
346                        .iter()
347                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
348                        .collect()
349                })
350                .collect();
351            Ok(cloned_groups)
352        })
353    }
354
355    /// Get current state
356    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
357        let mut state = HashMap::new();
358        Python::attach(|py| {
359            state.insert(
360                "step".to_string(),
361                0i64.into_pyobject(py).unwrap().into_any().unbind(),
362            );
363            state.insert(
364                "exp_avg".to_string(),
365                "{}".into_pyobject(py).unwrap().into_any().unbind(),
366            );
367            state.insert(
368                "exp_avg_sq".to_string(),
369                "{}".into_pyobject(py).unwrap().into_any().unbind(),
370            );
371            if self.amsgrad {
372                state.insert(
373                    "max_exp_avg_sq".to_string(),
374                    "{}".into_pyobject(py).unwrap().into_any().unbind(),
375                );
376            }
377        });
378        Ok(state)
379    }
380
381    /// String representation
382    fn __repr__(&self) -> String {
383        format!(
384            "AdamW(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
385            self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
386        )
387    }
388
389    /// Get defaults
390    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
391        let mut defaults = HashMap::new();
392        Python::attach(|py| {
393            defaults.insert(
394                "lr".to_string(),
395                self.lr.into_pyobject(py).unwrap().into_any().unbind(),
396            );
397            defaults.insert(
398                "betas".to_string(),
399                self.betas.into_pyobject(py).unwrap().into_any().unbind(),
400            );
401            defaults.insert(
402                "eps".to_string(),
403                self.eps.into_pyobject(py).unwrap().into_any().unbind(),
404            );
405            defaults.insert(
406                "weight_decay".to_string(),
407                self.weight_decay
408                    .into_pyobject(py)
409                    .unwrap()
410                    .into_any()
411                    .unbind(),
412            );
413            defaults.insert(
414                "amsgrad".to_string(),
415                PyBool::new(py, self.amsgrad).to_owned().into(),
416            );
417        });
418        Ok(defaults)
419    }
420
421    /// Get learning rate
422    #[getter]
423    fn lr(&self) -> f32 {
424        self.lr
425    }
426
427    /// Set learning rate
428    #[setter]
429    fn set_lr(&mut self, lr: f32) {
430        self.lr = lr;
431        Python::attach(|py| {
432            for param_group in &mut self.param_groups {
433                param_group.insert(
434                    "lr".to_string(),
435                    lr.into_pyobject(py).unwrap().into_any().unbind(),
436                );
437            }
438        });
439    }
440
441    /// Get betas
442    #[getter]
443    fn betas(&self) -> (f32, f32) {
444        self.betas
445    }
446
447    /// Get eps
448    #[getter]
449    fn eps(&self) -> f32 {
450        self.eps
451    }
452
453    /// Get weight decay
454    #[getter]
455    fn weight_decay(&self) -> f32 {
456        self.weight_decay
457    }
458
459    /// Get amsgrad flag
460    #[getter]
461    fn amsgrad(&self) -> bool {
462        self.amsgrad
463    }
464}