Skip to main content

torsh_python/optim/
adam.rs

1//! Adam and AdamW optimizers
2
3use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{Adam, AdamW, Optimizer};
11
12/// Adam optimizer - Adaptive Moment Estimation
13#[pyclass(name = "Adam", extends = PyOptimizer)]
14pub struct PyAdam {
15    adam: Adam,
16    param_groups: Vec<HashMap<String, Py<PyAny>>>,
17    lr: f32,
18    betas: (f32, f32),
19    eps: f32,
20    weight_decay: f32,
21    amsgrad: bool,
22}
23
24#[pymethods]
25impl PyAdam {
26    #[new]
27    fn new(
28        params: Vec<PyTensor>,
29        lr: Option<f32>,
30        betas: Option<(f32, f32)>,
31        eps: Option<f32>,
32        weight_decay: Option<f32>,
33        amsgrad: Option<bool>,
34    ) -> (Self, PyOptimizer) {
35        let lr = lr.unwrap_or(0.001);
36        let betas = betas.unwrap_or((0.9, 0.999));
37        let eps = eps.unwrap_or(1e-8);
38        let weight_decay = weight_decay.unwrap_or(0.0);
39        let amsgrad = amsgrad.unwrap_or(false);
40
41        // Extract tensor parameters and wrap in Arc<RwLock>
42        let tensor_params =
43            extract_parameters(params.clone()).expect("parameter extraction should succeed");
44        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
45            .into_iter()
46            .map(|tensor| Arc::new(RwLock::new(tensor)))
47            .collect();
48        let adam = Adam::new(
49            wrapped_params,
50            Some(lr),
51            Some(betas),
52            Some(eps),
53            Some(weight_decay),
54            amsgrad,
55        );
56
57        // Create parameter groups
58        let mut param_group_data = HashMap::new();
59        Python::attach(|py| {
60            param_group_data.insert(
61                "betas".to_string(),
62                betas
63                    .into_pyobject(py)
64                    .expect("Python object conversion should succeed")
65                    .into_any()
66                    .unbind(),
67            );
68            param_group_data.insert(
69                "eps".to_string(),
70                eps.into_pyobject(py)
71                    .expect("Python object conversion should succeed")
72                    .into_any()
73                    .unbind(),
74            );
75            param_group_data.insert(
76                "weight_decay".to_string(),
77                weight_decay
78                    .into_pyobject(py)
79                    .expect("Python object conversion should succeed")
80                    .into_any()
81                    .unbind(),
82            );
83            param_group_data.insert(
84                "amsgrad".to_string(),
85                PyBool::new(py, amsgrad).to_owned().into(),
86            );
87        });
88
89        let param_groups = vec![create_param_group(params, lr, param_group_data)
90            .expect("param group creation should succeed")];
91
92        (
93            Self {
94                adam,
95                param_groups,
96                lr,
97                betas,
98                eps,
99                weight_decay,
100                amsgrad,
101            },
102            PyOptimizer {},
103        )
104    }
105
106    /// Perform a single optimization step
107    fn step(&mut self) -> PyResult<()> {
108        self.adam.step().map_err(|e| {
109            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
110                "Optimizer step failed: {}",
111                e
112            ))
113        })?;
114        Ok(())
115    }
116
117    /// Zero out gradients of all parameters
118    fn zero_grad(&mut self, set_to_none: Option<bool>) {
119        let _set_to_none = set_to_none.unwrap_or(false);
120        self.adam.zero_grad();
121    }
122
123    /// Get parameter groups
124    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
125        // Manual clone since Py<PyAny> doesn't implement Clone
126        Python::attach(|py| {
127            let cloned_groups = self
128                .param_groups
129                .iter()
130                .map(|group| {
131                    group
132                        .iter()
133                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
134                        .collect()
135                })
136                .collect();
137            Ok(cloned_groups)
138        })
139    }
140
141    /// Get current state
142    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
143        let mut state = HashMap::new();
144        Python::attach(|py| {
145            state.insert(
146                "step".to_string(),
147                0i64.into_pyobject(py)
148                    .expect("Python object conversion should succeed")
149                    .into_any()
150                    .unbind(),
151            );
152            state.insert(
153                "exp_avg".to_string(),
154                "{}".into_pyobject(py)
155                    .expect("Python object conversion should succeed")
156                    .into_any()
157                    .unbind(),
158            );
159            state.insert(
160                "exp_avg_sq".to_string(),
161                "{}".into_pyobject(py)
162                    .expect("Python object conversion should succeed")
163                    .into_any()
164                    .unbind(),
165            );
166            if self.amsgrad {
167                state.insert(
168                    "max_exp_avg_sq".to_string(),
169                    "{}".into_pyobject(py)
170                        .expect("Python object conversion should succeed")
171                        .into_any()
172                        .unbind(),
173                );
174            }
175        });
176        Ok(state)
177    }
178
179    /// String representation
180    fn __repr__(&self) -> String {
181        format!(
182            "Adam(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
183            self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
184        )
185    }
186
187    /// Get defaults
188    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
189        let mut defaults = HashMap::new();
190        Python::attach(|py| {
191            defaults.insert(
192                "lr".to_string(),
193                self.lr
194                    .into_pyobject(py)
195                    .expect("Python object conversion should succeed")
196                    .into_any()
197                    .unbind(),
198            );
199            defaults.insert(
200                "betas".to_string(),
201                self.betas
202                    .into_pyobject(py)
203                    .expect("Python object conversion should succeed")
204                    .into_any()
205                    .unbind(),
206            );
207            defaults.insert(
208                "eps".to_string(),
209                self.eps
210                    .into_pyobject(py)
211                    .expect("Python object conversion should succeed")
212                    .into_any()
213                    .unbind(),
214            );
215            defaults.insert(
216                "weight_decay".to_string(),
217                self.weight_decay
218                    .into_pyobject(py)
219                    .expect("Python object conversion should succeed")
220                    .into_any()
221                    .unbind(),
222            );
223            defaults.insert(
224                "amsgrad".to_string(),
225                PyBool::new(py, self.amsgrad).to_owned().into(),
226            );
227        });
228        Ok(defaults)
229    }
230
231    /// Get learning rate
232    #[getter]
233    fn lr(&self) -> f32 {
234        self.lr
235    }
236
237    /// Set learning rate
238    #[setter]
239    fn set_lr(&mut self, lr: f32) {
240        self.lr = lr;
241        Python::attach(|py| {
242            for param_group in &mut self.param_groups {
243                param_group.insert(
244                    "lr".to_string(),
245                    lr.into_pyobject(py)
246                        .expect("Python object conversion should succeed")
247                        .into_any()
248                        .unbind(),
249                );
250            }
251        });
252    }
253
254    /// Get betas
255    #[getter]
256    fn betas(&self) -> (f32, f32) {
257        self.betas
258    }
259
260    /// Get eps
261    #[getter]
262    fn eps(&self) -> f32 {
263        self.eps
264    }
265
266    /// Get weight decay
267    #[getter]
268    fn weight_decay(&self) -> f32 {
269        self.weight_decay
270    }
271
272    /// Get amsgrad flag
273    #[getter]
274    fn amsgrad(&self) -> bool {
275        self.amsgrad
276    }
277}
278
279/// AdamW optimizer - Adam with decoupled weight decay
280#[pyclass(name = "AdamW", extends = PyOptimizer)]
281pub struct PyAdamW {
282    adamw: AdamW,
283    param_groups: Vec<HashMap<String, Py<PyAny>>>,
284    lr: f32,
285    betas: (f32, f32),
286    eps: f32,
287    weight_decay: f32,
288    amsgrad: bool,
289}
290
291#[pymethods]
292impl PyAdamW {
293    #[new]
294    fn new(
295        params: Vec<PyTensor>,
296        lr: Option<f32>,
297        betas: Option<(f32, f32)>,
298        eps: Option<f32>,
299        weight_decay: Option<f32>,
300        amsgrad: Option<bool>,
301    ) -> (Self, PyOptimizer) {
302        let lr = lr.unwrap_or(0.001);
303        let betas = betas.unwrap_or((0.9, 0.999));
304        let eps = eps.unwrap_or(1e-8);
305        let weight_decay = weight_decay.unwrap_or(0.01);
306        let amsgrad = amsgrad.unwrap_or(false);
307
308        // Extract tensor parameters and wrap in Arc<RwLock>
309        let tensor_params =
310            extract_parameters(params.clone()).expect("parameter extraction should succeed");
311        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
312            .into_iter()
313            .map(|tensor| Arc::new(RwLock::new(tensor)))
314            .collect();
315        let adamw = AdamW::new(
316            wrapped_params,
317            Some(lr),
318            Some(betas),
319            Some(eps),
320            Some(weight_decay),
321            amsgrad,
322        );
323
324        // Create parameter groups
325        let mut param_group_data = HashMap::new();
326        Python::attach(|py| {
327            param_group_data.insert(
328                "betas".to_string(),
329                betas
330                    .into_pyobject(py)
331                    .expect("Python object conversion should succeed")
332                    .into_any()
333                    .unbind(),
334            );
335            param_group_data.insert(
336                "eps".to_string(),
337                eps.into_pyobject(py)
338                    .expect("Python object conversion should succeed")
339                    .into_any()
340                    .unbind(),
341            );
342            param_group_data.insert(
343                "weight_decay".to_string(),
344                weight_decay
345                    .into_pyobject(py)
346                    .expect("Python object conversion should succeed")
347                    .into_any()
348                    .unbind(),
349            );
350            param_group_data.insert(
351                "amsgrad".to_string(),
352                PyBool::new(py, amsgrad).to_owned().into(),
353            );
354        });
355
356        let param_groups = vec![create_param_group(params, lr, param_group_data)
357            .expect("param group creation should succeed")];
358
359        (
360            Self {
361                adamw,
362                param_groups,
363                lr,
364                betas,
365                eps,
366                weight_decay,
367                amsgrad,
368            },
369            PyOptimizer {},
370        )
371    }
372
373    /// Perform a single optimization step
374    fn step(&mut self) -> PyResult<()> {
375        self.adamw.step().map_err(|e| {
376            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
377                "Optimizer step failed: {}",
378                e
379            ))
380        })?;
381        Ok(())
382    }
383
384    /// Zero out gradients of all parameters
385    fn zero_grad(&mut self, set_to_none: Option<bool>) {
386        let _set_to_none = set_to_none.unwrap_or(false);
387        self.adamw.zero_grad();
388    }
389
390    /// Get parameter groups
391    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
392        // Manual clone since Py<PyAny> doesn't implement Clone
393        Python::attach(|py| {
394            let cloned_groups = self
395                .param_groups
396                .iter()
397                .map(|group| {
398                    group
399                        .iter()
400                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
401                        .collect()
402                })
403                .collect();
404            Ok(cloned_groups)
405        })
406    }
407
408    /// Get current state
409    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
410        let mut state = HashMap::new();
411        Python::attach(|py| {
412            state.insert(
413                "step".to_string(),
414                0i64.into_pyobject(py)
415                    .expect("Python object conversion should succeed")
416                    .into_any()
417                    .unbind(),
418            );
419            state.insert(
420                "exp_avg".to_string(),
421                "{}".into_pyobject(py)
422                    .expect("Python object conversion should succeed")
423                    .into_any()
424                    .unbind(),
425            );
426            state.insert(
427                "exp_avg_sq".to_string(),
428                "{}".into_pyobject(py)
429                    .expect("Python object conversion should succeed")
430                    .into_any()
431                    .unbind(),
432            );
433            if self.amsgrad {
434                state.insert(
435                    "max_exp_avg_sq".to_string(),
436                    "{}".into_pyobject(py)
437                        .expect("Python object conversion should succeed")
438                        .into_any()
439                        .unbind(),
440                );
441            }
442        });
443        Ok(state)
444    }
445
446    /// String representation
447    fn __repr__(&self) -> String {
448        format!(
449            "AdamW(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
450            self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
451        )
452    }
453
454    /// Get defaults
455    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
456        let mut defaults = HashMap::new();
457        Python::attach(|py| {
458            defaults.insert(
459                "lr".to_string(),
460                self.lr
461                    .into_pyobject(py)
462                    .expect("Python object conversion should succeed")
463                    .into_any()
464                    .unbind(),
465            );
466            defaults.insert(
467                "betas".to_string(),
468                self.betas
469                    .into_pyobject(py)
470                    .expect("Python object conversion should succeed")
471                    .into_any()
472                    .unbind(),
473            );
474            defaults.insert(
475                "eps".to_string(),
476                self.eps
477                    .into_pyobject(py)
478                    .expect("Python object conversion should succeed")
479                    .into_any()
480                    .unbind(),
481            );
482            defaults.insert(
483                "weight_decay".to_string(),
484                self.weight_decay
485                    .into_pyobject(py)
486                    .expect("Python object conversion should succeed")
487                    .into_any()
488                    .unbind(),
489            );
490            defaults.insert(
491                "amsgrad".to_string(),
492                PyBool::new(py, self.amsgrad).to_owned().into(),
493            );
494        });
495        Ok(defaults)
496    }
497
498    /// Get learning rate
499    #[getter]
500    fn lr(&self) -> f32 {
501        self.lr
502    }
503
504    /// Set learning rate
505    #[setter]
506    fn set_lr(&mut self, lr: f32) {
507        self.lr = lr;
508        Python::attach(|py| {
509            for param_group in &mut self.param_groups {
510                param_group.insert(
511                    "lr".to_string(),
512                    lr.into_pyobject(py)
513                        .expect("Python object conversion should succeed")
514                        .into_any()
515                        .unbind(),
516                );
517            }
518        });
519    }
520
521    /// Get betas
522    #[getter]
523    fn betas(&self) -> (f32, f32) {
524        self.betas
525    }
526
527    /// Get eps
528    #[getter]
529    fn eps(&self) -> f32 {
530        self.eps
531    }
532
533    /// Get weight decay
534    #[getter]
535    fn weight_decay(&self) -> f32 {
536        self.weight_decay
537    }
538
539    /// Get amsgrad flag
540    #[getter]
541    fn amsgrad(&self) -> bool {
542        self.amsgrad
543    }
544}