torsh_python/optim/
adagrad.rs1use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::PyAny;
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{adagrad::AdaGrad, Optimizer};
11
12#[pyclass(name = "Adagrad", extends = PyOptimizer)]
14pub struct PyAdaGrad {
15 adagrad: AdaGrad,
16 param_groups: Vec<HashMap<String, Py<PyAny>>>,
17 lr: f32,
18 lr_decay: f32,
19 weight_decay: f32,
20 eps: f32,
21}
22
23#[pymethods]
24impl PyAdaGrad {
25 #[new]
26 fn new(
27 params: Vec<PyTensor>,
28 lr: Option<f32>,
29 lr_decay: Option<f32>,
30 weight_decay: Option<f32>,
31 eps: Option<f32>,
32 ) -> (Self, PyOptimizer) {
33 let lr = lr.unwrap_or(0.01);
34 let lr_decay = lr_decay.unwrap_or(0.0);
35 let weight_decay = weight_decay.unwrap_or(0.0);
36 let eps = eps.unwrap_or(1e-10);
37
38 let tensor_params = extract_parameters(params.clone()).unwrap();
40 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
41 .into_iter()
42 .map(|tensor| Arc::new(RwLock::new(tensor)))
43 .collect();
44 let adagrad = AdaGrad::new(
45 wrapped_params,
46 Some(lr),
47 Some(lr_decay),
48 Some(weight_decay),
49 Some(0.0),
50 Some(eps),
51 );
52
53 let mut param_group_data = HashMap::new();
55 Python::attach(|py| {
56 param_group_data.insert(
57 "lr_decay".to_string(),
58 lr_decay.into_pyobject(py).unwrap().into_any().unbind(),
59 );
60 param_group_data.insert(
61 "weight_decay".to_string(),
62 weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
63 );
64 param_group_data.insert(
65 "eps".to_string(),
66 eps.into_pyobject(py).unwrap().into_any().unbind(),
67 );
68 });
69
70 let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
71
72 (
73 Self {
74 adagrad,
75 param_groups,
76 lr,
77 lr_decay,
78 weight_decay,
79 eps,
80 },
81 PyOptimizer {},
82 )
83 }
84
85 fn step(&mut self) -> PyResult<()> {
87 self.adagrad.step().map_err(|e| {
88 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
89 "Optimizer step failed: {}",
90 e
91 ))
92 })?;
93 Ok(())
94 }
95
96 fn zero_grad(&mut self, set_to_none: Option<bool>) {
98 let _set_to_none = set_to_none.unwrap_or(false);
99 self.adagrad.zero_grad();
100 }
101
102 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
104 Python::attach(|py| {
106 let cloned_groups = self
107 .param_groups
108 .iter()
109 .map(|group| {
110 group
111 .iter()
112 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
113 .collect()
114 })
115 .collect();
116 Ok(cloned_groups)
117 })
118 }
119
120 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
122 let mut state = HashMap::new();
123 Python::attach(|py| {
124 state.insert(
125 "step".to_string(),
126 0i64.into_pyobject(py).unwrap().into_any().unbind(),
127 );
128 state.insert(
129 "sum".to_string(),
130 "{}".into_pyobject(py).unwrap().into_any().unbind(),
131 );
132 });
133 Ok(state)
134 }
135
136 fn __repr__(&self) -> String {
138 format!(
139 "Adagrad(lr={}, lr_decay={}, eps={}, weight_decay={})",
140 self.lr, self.lr_decay, self.eps, self.weight_decay
141 )
142 }
143
144 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
146 let mut defaults = HashMap::new();
147 Python::attach(|py| {
148 defaults.insert(
149 "lr".to_string(),
150 self.lr.into_pyobject(py).unwrap().into_any().unbind(),
151 );
152 defaults.insert(
153 "lr_decay".to_string(),
154 self.lr_decay.into_pyobject(py).unwrap().into_any().unbind(),
155 );
156 defaults.insert(
157 "weight_decay".to_string(),
158 self.weight_decay
159 .into_pyobject(py)
160 .unwrap()
161 .into_any()
162 .unbind(),
163 );
164 defaults.insert(
165 "eps".to_string(),
166 self.eps.into_pyobject(py).unwrap().into_any().unbind(),
167 );
168 });
169 Ok(defaults)
170 }
171
172 #[getter]
174 fn lr(&self) -> f32 {
175 self.lr
176 }
177
178 #[setter]
180 fn set_lr(&mut self, lr: f32) {
181 self.lr = lr;
182 Python::attach(|py| {
183 for param_group in &mut self.param_groups {
184 param_group.insert(
185 "lr".to_string(),
186 lr.into_pyobject(py).unwrap().into_any().unbind(),
187 );
188 }
189 });
190 }
191
192 #[getter]
194 fn lr_decay(&self) -> f32 {
195 self.lr_decay
196 }
197
198 #[getter]
200 fn weight_decay(&self) -> f32 {
201 self.weight_decay
202 }
203
204 #[getter]
206 fn eps(&self) -> f32 {
207 self.eps
208 }
209}