torsh_python/optim/
adagrad.rs

1//! AdaGrad optimizer
2
3use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::PyAny;
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{adagrad::AdaGrad, Optimizer};
11
12/// AdaGrad optimizer - Adaptive Gradient Algorithm
13#[pyclass(name = "Adagrad", extends = PyOptimizer)]
14pub struct PyAdaGrad {
15    adagrad: AdaGrad,
16    param_groups: Vec<HashMap<String, Py<PyAny>>>,
17    lr: f32,
18    lr_decay: f32,
19    weight_decay: f32,
20    eps: f32,
21}
22
23#[pymethods]
24impl PyAdaGrad {
25    #[new]
26    fn new(
27        params: Vec<PyTensor>,
28        lr: Option<f32>,
29        lr_decay: Option<f32>,
30        weight_decay: Option<f32>,
31        eps: Option<f32>,
32    ) -> (Self, PyOptimizer) {
33        let lr = lr.unwrap_or(0.01);
34        let lr_decay = lr_decay.unwrap_or(0.0);
35        let weight_decay = weight_decay.unwrap_or(0.0);
36        let eps = eps.unwrap_or(1e-10);
37
38        // Extract tensor parameters and wrap in Arc<RwLock>
39        let tensor_params = extract_parameters(params.clone()).unwrap();
40        let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
41            .into_iter()
42            .map(|tensor| Arc::new(RwLock::new(tensor)))
43            .collect();
44        let adagrad = AdaGrad::new(
45            wrapped_params,
46            Some(lr),
47            Some(lr_decay),
48            Some(weight_decay),
49            Some(0.0),
50            Some(eps),
51        );
52
53        // Create parameter groups
54        let mut param_group_data = HashMap::new();
55        Python::attach(|py| {
56            param_group_data.insert(
57                "lr_decay".to_string(),
58                lr_decay.into_pyobject(py).unwrap().into_any().unbind(),
59            );
60            param_group_data.insert(
61                "weight_decay".to_string(),
62                weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
63            );
64            param_group_data.insert(
65                "eps".to_string(),
66                eps.into_pyobject(py).unwrap().into_any().unbind(),
67            );
68        });
69
70        let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
71
72        (
73            Self {
74                adagrad,
75                param_groups,
76                lr,
77                lr_decay,
78                weight_decay,
79                eps,
80            },
81            PyOptimizer {},
82        )
83    }
84
85    /// Perform a single optimization step
86    fn step(&mut self) -> PyResult<()> {
87        self.adagrad.step().map_err(|e| {
88            PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
89                "Optimizer step failed: {}",
90                e
91            ))
92        })?;
93        Ok(())
94    }
95
96    /// Zero out gradients of all parameters
97    fn zero_grad(&mut self, set_to_none: Option<bool>) {
98        let _set_to_none = set_to_none.unwrap_or(false);
99        self.adagrad.zero_grad();
100    }
101
102    /// Get parameter groups
103    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
104        // Manual clone since Py<PyAny> doesn't implement Clone
105        Python::attach(|py| {
106            let cloned_groups = self
107                .param_groups
108                .iter()
109                .map(|group| {
110                    group
111                        .iter()
112                        .map(|(k, v)| (k.clone(), v.clone_ref(py)))
113                        .collect()
114                })
115                .collect();
116            Ok(cloned_groups)
117        })
118    }
119
120    /// Get current state
121    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
122        let mut state = HashMap::new();
123        Python::attach(|py| {
124            state.insert(
125                "step".to_string(),
126                0i64.into_pyobject(py).unwrap().into_any().unbind(),
127            );
128            state.insert(
129                "sum".to_string(),
130                "{}".into_pyobject(py).unwrap().into_any().unbind(),
131            );
132        });
133        Ok(state)
134    }
135
136    /// String representation
137    fn __repr__(&self) -> String {
138        format!(
139            "Adagrad(lr={}, lr_decay={}, eps={}, weight_decay={})",
140            self.lr, self.lr_decay, self.eps, self.weight_decay
141        )
142    }
143
144    /// Get defaults
145    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
146        let mut defaults = HashMap::new();
147        Python::attach(|py| {
148            defaults.insert(
149                "lr".to_string(),
150                self.lr.into_pyobject(py).unwrap().into_any().unbind(),
151            );
152            defaults.insert(
153                "lr_decay".to_string(),
154                self.lr_decay.into_pyobject(py).unwrap().into_any().unbind(),
155            );
156            defaults.insert(
157                "weight_decay".to_string(),
158                self.weight_decay
159                    .into_pyobject(py)
160                    .unwrap()
161                    .into_any()
162                    .unbind(),
163            );
164            defaults.insert(
165                "eps".to_string(),
166                self.eps.into_pyobject(py).unwrap().into_any().unbind(),
167            );
168        });
169        Ok(defaults)
170    }
171
172    /// Get learning rate
173    #[getter]
174    fn lr(&self) -> f32 {
175        self.lr
176    }
177
178    /// Set learning rate
179    #[setter]
180    fn set_lr(&mut self, lr: f32) {
181        self.lr = lr;
182        Python::attach(|py| {
183            for param_group in &mut self.param_groups {
184                param_group.insert(
185                    "lr".to_string(),
186                    lr.into_pyobject(py).unwrap().into_any().unbind(),
187                );
188            }
189        });
190    }
191
192    /// Get learning rate decay
193    #[getter]
194    fn lr_decay(&self) -> f32 {
195        self.lr_decay
196    }
197
198    /// Get weight decay
199    #[getter]
200    fn weight_decay(&self) -> f32 {
201        self.weight_decay
202    }
203
204    /// Get epsilon
205    #[getter]
206    fn eps(&self) -> f32 {
207        self.eps
208    }
209}