torsh_python/optim/
adagrad.rs1use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::PyAny;
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{adagrad::AdaGrad, Optimizer};
11
12#[pyclass(name = "Adagrad", extends = PyOptimizer)]
14pub struct PyAdaGrad {
15 adagrad: AdaGrad,
16 param_groups: Vec<HashMap<String, Py<PyAny>>>,
17 lr: f32,
18 lr_decay: f32,
19 weight_decay: f32,
20 eps: f32,
21}
22
23#[pymethods]
24impl PyAdaGrad {
25 #[new]
26 fn new(
27 params: Vec<PyTensor>,
28 lr: Option<f32>,
29 lr_decay: Option<f32>,
30 weight_decay: Option<f32>,
31 eps: Option<f32>,
32 ) -> (Self, PyOptimizer) {
33 let lr = lr.unwrap_or(0.01);
34 let lr_decay = lr_decay.unwrap_or(0.0);
35 let weight_decay = weight_decay.unwrap_or(0.0);
36 let eps = eps.unwrap_or(1e-10);
37
38 let tensor_params =
40 extract_parameters(params.clone()).expect("parameter extraction should succeed");
41 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
42 .into_iter()
43 .map(|tensor| Arc::new(RwLock::new(tensor)))
44 .collect();
45 let adagrad = AdaGrad::new(
46 wrapped_params,
47 Some(lr),
48 Some(lr_decay),
49 Some(weight_decay),
50 Some(0.0),
51 Some(eps),
52 );
53
54 let mut param_group_data = HashMap::new();
56 Python::attach(|py| {
57 param_group_data.insert(
58 "lr_decay".to_string(),
59 lr_decay
60 .into_pyobject(py)
61 .expect("Python object conversion should succeed")
62 .into_any()
63 .unbind(),
64 );
65 param_group_data.insert(
66 "weight_decay".to_string(),
67 weight_decay
68 .into_pyobject(py)
69 .expect("Python object conversion should succeed")
70 .into_any()
71 .unbind(),
72 );
73 param_group_data.insert(
74 "eps".to_string(),
75 eps.into_pyobject(py)
76 .expect("Python object conversion should succeed")
77 .into_any()
78 .unbind(),
79 );
80 });
81
82 let param_groups = vec![create_param_group(params, lr, param_group_data)
83 .expect("param group creation should succeed")];
84
85 (
86 Self {
87 adagrad,
88 param_groups,
89 lr,
90 lr_decay,
91 weight_decay,
92 eps,
93 },
94 PyOptimizer {},
95 )
96 }
97
98 fn step(&mut self) -> PyResult<()> {
100 self.adagrad.step().map_err(|e| {
101 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
102 "Optimizer step failed: {}",
103 e
104 ))
105 })?;
106 Ok(())
107 }
108
109 fn zero_grad(&mut self, set_to_none: Option<bool>) {
111 let _set_to_none = set_to_none.unwrap_or(false);
112 self.adagrad.zero_grad();
113 }
114
115 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
117 Python::attach(|py| {
119 let cloned_groups = self
120 .param_groups
121 .iter()
122 .map(|group| {
123 group
124 .iter()
125 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
126 .collect()
127 })
128 .collect();
129 Ok(cloned_groups)
130 })
131 }
132
133 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
135 let mut state = HashMap::new();
136 Python::attach(|py| {
137 state.insert(
138 "step".to_string(),
139 0i64.into_pyobject(py)
140 .expect("Python object conversion should succeed")
141 .into_any()
142 .unbind(),
143 );
144 state.insert(
145 "sum".to_string(),
146 "{}".into_pyobject(py)
147 .expect("Python object conversion should succeed")
148 .into_any()
149 .unbind(),
150 );
151 });
152 Ok(state)
153 }
154
155 fn __repr__(&self) -> String {
157 format!(
158 "Adagrad(lr={}, lr_decay={}, eps={}, weight_decay={})",
159 self.lr, self.lr_decay, self.eps, self.weight_decay
160 )
161 }
162
163 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
165 let mut defaults = HashMap::new();
166 Python::attach(|py| {
167 defaults.insert(
168 "lr".to_string(),
169 self.lr
170 .into_pyobject(py)
171 .expect("Python object conversion should succeed")
172 .into_any()
173 .unbind(),
174 );
175 defaults.insert(
176 "lr_decay".to_string(),
177 self.lr_decay
178 .into_pyobject(py)
179 .expect("Python object conversion should succeed")
180 .into_any()
181 .unbind(),
182 );
183 defaults.insert(
184 "weight_decay".to_string(),
185 self.weight_decay
186 .into_pyobject(py)
187 .expect("Python object conversion should succeed")
188 .into_any()
189 .unbind(),
190 );
191 defaults.insert(
192 "eps".to_string(),
193 self.eps
194 .into_pyobject(py)
195 .expect("Python object conversion should succeed")
196 .into_any()
197 .unbind(),
198 );
199 });
200 Ok(defaults)
201 }
202
203 #[getter]
205 fn lr(&self) -> f32 {
206 self.lr
207 }
208
209 #[setter]
211 fn set_lr(&mut self, lr: f32) {
212 self.lr = lr;
213 Python::attach(|py| {
214 for param_group in &mut self.param_groups {
215 param_group.insert(
216 "lr".to_string(),
217 lr.into_pyobject(py)
218 .expect("Python object conversion should succeed")
219 .into_any()
220 .unbind(),
221 );
222 }
223 });
224 }
225
226 #[getter]
228 fn lr_decay(&self) -> f32 {
229 self.lr_decay
230 }
231
232 #[getter]
234 fn weight_decay(&self) -> f32 {
235 self.weight_decay
236 }
237
238 #[getter]
240 fn eps(&self) -> f32 {
241 self.eps
242 }
243}