Skip to main content

torsh_python/optim/
adagrad.rs

1//! AdaGrad optimizer
2
3use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::PyAny;
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{adagrad::AdaGrad, Optimizer};
11
12/// AdaGrad optimizer - Adaptive Gradient Algorithm
13#[pyclass(name = "Adagrad", extends = PyOptimizer)]
14pub struct PyAdaGrad {
15    adagrad: AdaGrad,
16    param_groups: Vec<HashMap<String, Py<PyAny>>>,
17    lr: f32,
18    lr_decay: f32,
19    weight_decay: f32,
20    eps: f32,
21}
22
23#[pymethods]
24impl PyAdaGrad {
25    #[new]
26    fn new(
27        params: Vec<PyTensor>,
28        lr: Option<f32>,
29        lr_decay: Option<f32>,
30        weight_decay: Option<f32>,
31        eps: Option<f32>,
32    ) -> (Self, PyOptimizer) {
33        let lr = lr.unwrap_or(0.01);
34        let lr_decay = lr_decay.unwrap_or(0.0);
35        let weight_decay = weight_decay.unwrap_or(0.0);
36        let eps = eps.unwrap_or(1e-10);
37
38        // Extract tensor parameters and wrap in Arc<RwLock>
39        let tensor_params =
40            extract_parameters(params.clone()).expect("parameter extraction should succeed");
41        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
42            .into_iter()
43            .map(|tensor| Arc::new(RwLock::new(tensor)))
44            .collect();
45        let adagrad = AdaGrad::new(
46            wrapped_params,
47            Some(lr),
48            Some(lr_decay),
49            Some(weight_decay),
50            Some(0.0),
51            Some(eps),
52        );
53
54        // Create parameter groups
55        let mut param_group_data = HashMap::new();
56        Python::attach(|py| {
57            param_group_data.insert(
58                "lr_decay".to_string(),
59                lr_decay
60                    .into_pyobject(py)
61                    .expect("Python object conversion should succeed")
62                    .into_any()
63                    .unbind(),
64            );
65            param_group_data.insert(
66                "weight_decay".to_string(),
67                weight_decay
68                    .into_pyobject(py)
69                    .expect("Python object conversion should succeed")
70                    .into_any()
71                    .unbind(),
72            );
73            param_group_data.insert(
74                "eps".to_string(),
75                eps.into_pyobject(py)
76                    .expect("Python object conversion should succeed")
77                    .into_any()
78                    .unbind(),
79            );
80        });
81
82        let param_groups = vec![create_param_group(params, lr, param_group_data)
83            .expect("param group creation should succeed")];
84
85        (
86            Self {
87                adagrad,
88                param_groups,
89                lr,
90                lr_decay,
91                weight_decay,
92                eps,
93            },
94            PyOptimizer {},
95        )
96    }
97
98    /// Perform a single optimization step
99    fn step(&mut self) -> PyResult<()> {
100        self.adagrad.step().map_err(|e| {
101            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
102                "Optimizer step failed: {}",
103                e
104            ))
105        })?;
106        Ok(())
107    }
108
109    /// Zero out gradients of all parameters
110    fn zero_grad(&mut self, set_to_none: Option<bool>) {
111        let _set_to_none = set_to_none.unwrap_or(false);
112        self.adagrad.zero_grad();
113    }
114
115    /// Get parameter groups
116    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
117        // Manual clone since Py<PyAny> doesn't implement Clone
118        Python::attach(|py| {
119            let cloned_groups = self
120                .param_groups
121                .iter()
122                .map(|group| {
123                    group
124                        .iter()
125                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
126                        .collect()
127                })
128                .collect();
129            Ok(cloned_groups)
130        })
131    }
132
133    /// Get current state
134    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
135        let mut state = HashMap::new();
136        Python::attach(|py| {
137            state.insert(
138                "step".to_string(),
139                0i64.into_pyobject(py)
140                    .expect("Python object conversion should succeed")
141                    .into_any()
142                    .unbind(),
143            );
144            state.insert(
145                "sum".to_string(),
146                "{}".into_pyobject(py)
147                    .expect("Python object conversion should succeed")
148                    .into_any()
149                    .unbind(),
150            );
151        });
152        Ok(state)
153    }
154
155    /// String representation
156    fn __repr__(&self) -> String {
157        format!(
158            "Adagrad(lr={}, lr_decay={}, eps={}, weight_decay={})",
159            self.lr, self.lr_decay, self.eps, self.weight_decay
160        )
161    }
162
163    /// Get defaults
164    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
165        let mut defaults = HashMap::new();
166        Python::attach(|py| {
167            defaults.insert(
168                "lr".to_string(),
169                self.lr
170                    .into_pyobject(py)
171                    .expect("Python object conversion should succeed")
172                    .into_any()
173                    .unbind(),
174            );
175            defaults.insert(
176                "lr_decay".to_string(),
177                self.lr_decay
178                    .into_pyobject(py)
179                    .expect("Python object conversion should succeed")
180                    .into_any()
181                    .unbind(),
182            );
183            defaults.insert(
184                "weight_decay".to_string(),
185                self.weight_decay
186                    .into_pyobject(py)
187                    .expect("Python object conversion should succeed")
188                    .into_any()
189                    .unbind(),
190            );
191            defaults.insert(
192                "eps".to_string(),
193                self.eps
194                    .into_pyobject(py)
195                    .expect("Python object conversion should succeed")
196                    .into_any()
197                    .unbind(),
198            );
199        });
200        Ok(defaults)
201    }
202
203    /// Get learning rate
204    #[getter]
205    fn lr(&self) -> f32 {
206        self.lr
207    }
208
209    /// Set learning rate
210    #[setter]
211    fn set_lr(&mut self, lr: f32) {
212        self.lr = lr;
213        Python::attach(|py| {
214            for param_group in &mut self.param_groups {
215                param_group.insert(
216                    "lr".to_string(),
217                    lr.into_pyobject(py)
218                        .expect("Python object conversion should succeed")
219                        .into_any()
220                        .unbind(),
221                );
222            }
223        });
224    }
225
226    /// Get learning rate decay
227    #[getter]
228    fn lr_decay(&self) -> f32 {
229        self.lr_decay
230    }
231
232    /// Get weight decay
233    #[getter]
234    fn weight_decay(&self) -> f32 {
235        self.weight_decay
236    }
237
238    /// Get epsilon
239    #[getter]
240    fn eps(&self) -> f32 {
241        self.eps
242    }
243}