1use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{Adam, AdamW, Optimizer};
11
12#[pyclass(name = "Adam", extends = PyOptimizer)]
14pub struct PyAdam {
15 adam: Adam,
16 param_groups: Vec<HashMap<String, Py<PyAny>>>,
17 lr: f32,
18 betas: (f32, f32),
19 eps: f32,
20 weight_decay: f32,
21 amsgrad: bool,
22}
23
24#[pymethods]
25impl PyAdam {
26 #[new]
27 fn new(
28 params: Vec<PyTensor>,
29 lr: Option<f32>,
30 betas: Option<(f32, f32)>,
31 eps: Option<f32>,
32 weight_decay: Option<f32>,
33 amsgrad: Option<bool>,
34 ) -> (Self, PyOptimizer) {
35 let lr = lr.unwrap_or(0.001);
36 let betas = betas.unwrap_or((0.9, 0.999));
37 let eps = eps.unwrap_or(1e-8);
38 let weight_decay = weight_decay.unwrap_or(0.0);
39 let amsgrad = amsgrad.unwrap_or(false);
40
41 let tensor_params = extract_parameters(params.clone()).unwrap();
43 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
44 .into_iter()
45 .map(|tensor| Arc::new(RwLock::new(tensor)))
46 .collect();
47 let adam = Adam::new(
48 wrapped_params,
49 Some(lr),
50 Some(betas),
51 Some(eps),
52 Some(weight_decay),
53 amsgrad,
54 );
55
56 let mut param_group_data = HashMap::new();
58 Python::attach(|py| {
59 param_group_data.insert(
60 "betas".to_string(),
61 betas.into_pyobject(py).unwrap().into_any().unbind(),
62 );
63 param_group_data.insert(
64 "eps".to_string(),
65 eps.into_pyobject(py).unwrap().into_any().unbind(),
66 );
67 param_group_data.insert(
68 "weight_decay".to_string(),
69 weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
70 );
71 param_group_data.insert(
72 "amsgrad".to_string(),
73 PyBool::new(py, amsgrad).to_owned().into(),
74 );
75 });
76
77 let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
78
79 (
80 Self {
81 adam,
82 param_groups,
83 lr,
84 betas,
85 eps,
86 weight_decay,
87 amsgrad,
88 },
89 PyOptimizer {},
90 )
91 }
92
93 fn step(&mut self) -> PyResult<()> {
95 self.adam.step().map_err(|e| {
96 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
97 "Optimizer step failed: {}",
98 e
99 ))
100 })?;
101 Ok(())
102 }
103
104 fn zero_grad(&mut self, set_to_none: Option<bool>) {
106 let _set_to_none = set_to_none.unwrap_or(false);
107 self.adam.zero_grad();
108 }
109
110 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
112 Python::attach(|py| {
114 let cloned_groups = self
115 .param_groups
116 .iter()
117 .map(|group| {
118 group
119 .iter()
120 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
121 .collect()
122 })
123 .collect();
124 Ok(cloned_groups)
125 })
126 }
127
128 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
130 let mut state = HashMap::new();
131 Python::attach(|py| {
132 state.insert(
133 "step".to_string(),
134 0i64.into_pyobject(py).unwrap().into_any().unbind(),
135 );
136 state.insert(
137 "exp_avg".to_string(),
138 "{}".into_pyobject(py).unwrap().into_any().unbind(),
139 );
140 state.insert(
141 "exp_avg_sq".to_string(),
142 "{}".into_pyobject(py).unwrap().into_any().unbind(),
143 );
144 if self.amsgrad {
145 state.insert(
146 "max_exp_avg_sq".to_string(),
147 "{}".into_pyobject(py).unwrap().into_any().unbind(),
148 );
149 }
150 });
151 Ok(state)
152 }
153
154 fn __repr__(&self) -> String {
156 format!(
157 "Adam(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
158 self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
159 )
160 }
161
162 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
164 let mut defaults = HashMap::new();
165 Python::attach(|py| {
166 defaults.insert(
167 "lr".to_string(),
168 self.lr.into_pyobject(py).unwrap().into_any().unbind(),
169 );
170 defaults.insert(
171 "betas".to_string(),
172 self.betas.into_pyobject(py).unwrap().into_any().unbind(),
173 );
174 defaults.insert(
175 "eps".to_string(),
176 self.eps.into_pyobject(py).unwrap().into_any().unbind(),
177 );
178 defaults.insert(
179 "weight_decay".to_string(),
180 self.weight_decay
181 .into_pyobject(py)
182 .unwrap()
183 .into_any()
184 .unbind(),
185 );
186 defaults.insert(
187 "amsgrad".to_string(),
188 PyBool::new(py, self.amsgrad).to_owned().into(),
189 );
190 });
191 Ok(defaults)
192 }
193
194 #[getter]
196 fn lr(&self) -> f32 {
197 self.lr
198 }
199
200 #[setter]
202 fn set_lr(&mut self, lr: f32) {
203 self.lr = lr;
204 Python::attach(|py| {
205 for param_group in &mut self.param_groups {
206 param_group.insert(
207 "lr".to_string(),
208 lr.into_pyobject(py).unwrap().into_any().unbind(),
209 );
210 }
211 });
212 }
213
214 #[getter]
216 fn betas(&self) -> (f32, f32) {
217 self.betas
218 }
219
220 #[getter]
222 fn eps(&self) -> f32 {
223 self.eps
224 }
225
226 #[getter]
228 fn weight_decay(&self) -> f32 {
229 self.weight_decay
230 }
231
232 #[getter]
234 fn amsgrad(&self) -> bool {
235 self.amsgrad
236 }
237}
238
239#[pyclass(name = "AdamW", extends = PyOptimizer)]
241pub struct PyAdamW {
242 adamw: AdamW,
243 param_groups: Vec<HashMap<String, Py<PyAny>>>,
244 lr: f32,
245 betas: (f32, f32),
246 eps: f32,
247 weight_decay: f32,
248 amsgrad: bool,
249}
250
251#[pymethods]
252impl PyAdamW {
253 #[new]
254 fn new(
255 params: Vec<PyTensor>,
256 lr: Option<f32>,
257 betas: Option<(f32, f32)>,
258 eps: Option<f32>,
259 weight_decay: Option<f32>,
260 amsgrad: Option<bool>,
261 ) -> (Self, PyOptimizer) {
262 let lr = lr.unwrap_or(0.001);
263 let betas = betas.unwrap_or((0.9, 0.999));
264 let eps = eps.unwrap_or(1e-8);
265 let weight_decay = weight_decay.unwrap_or(0.01);
266 let amsgrad = amsgrad.unwrap_or(false);
267
268 let tensor_params = extract_parameters(params.clone()).unwrap();
270 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
271 .into_iter()
272 .map(|tensor| Arc::new(RwLock::new(tensor)))
273 .collect();
274 let adamw = AdamW::new(
275 wrapped_params,
276 Some(lr),
277 Some(betas),
278 Some(eps),
279 Some(weight_decay),
280 amsgrad,
281 );
282
283 let mut param_group_data = HashMap::new();
285 Python::attach(|py| {
286 param_group_data.insert(
287 "betas".to_string(),
288 betas.into_pyobject(py).unwrap().into_any().unbind(),
289 );
290 param_group_data.insert(
291 "eps".to_string(),
292 eps.into_pyobject(py).unwrap().into_any().unbind(),
293 );
294 param_group_data.insert(
295 "weight_decay".to_string(),
296 weight_decay.into_pyobject(py).unwrap().into_any().unbind(),
297 );
298 param_group_data.insert(
299 "amsgrad".to_string(),
300 PyBool::new(py, amsgrad).to_owned().into(),
301 );
302 });
303
304 let param_groups = vec![create_param_group(params, lr, param_group_data).unwrap()];
305
306 (
307 Self {
308 adamw,
309 param_groups,
310 lr,
311 betas,
312 eps,
313 weight_decay,
314 amsgrad,
315 },
316 PyOptimizer {},
317 )
318 }
319
320 fn step(&mut self) -> PyResult<()> {
322 self.adamw.step().map_err(|e| {
323 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
324 "Optimizer step failed: {}",
325 e
326 ))
327 })?;
328 Ok(())
329 }
330
331 fn zero_grad(&mut self, set_to_none: Option<bool>) {
333 let _set_to_none = set_to_none.unwrap_or(false);
334 self.adamw.zero_grad();
335 }
336
337 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
339 Python::attach(|py| {
341 let cloned_groups = self
342 .param_groups
343 .iter()
344 .map(|group| {
345 group
346 .iter()
347 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
348 .collect()
349 })
350 .collect();
351 Ok(cloned_groups)
352 })
353 }
354
355 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
357 let mut state = HashMap::new();
358 Python::attach(|py| {
359 state.insert(
360 "step".to_string(),
361 0i64.into_pyobject(py).unwrap().into_any().unbind(),
362 );
363 state.insert(
364 "exp_avg".to_string(),
365 "{}".into_pyobject(py).unwrap().into_any().unbind(),
366 );
367 state.insert(
368 "exp_avg_sq".to_string(),
369 "{}".into_pyobject(py).unwrap().into_any().unbind(),
370 );
371 if self.amsgrad {
372 state.insert(
373 "max_exp_avg_sq".to_string(),
374 "{}".into_pyobject(py).unwrap().into_any().unbind(),
375 );
376 }
377 });
378 Ok(state)
379 }
380
381 fn __repr__(&self) -> String {
383 format!(
384 "AdamW(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
385 self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
386 )
387 }
388
389 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
391 let mut defaults = HashMap::new();
392 Python::attach(|py| {
393 defaults.insert(
394 "lr".to_string(),
395 self.lr.into_pyobject(py).unwrap().into_any().unbind(),
396 );
397 defaults.insert(
398 "betas".to_string(),
399 self.betas.into_pyobject(py).unwrap().into_any().unbind(),
400 );
401 defaults.insert(
402 "eps".to_string(),
403 self.eps.into_pyobject(py).unwrap().into_any().unbind(),
404 );
405 defaults.insert(
406 "weight_decay".to_string(),
407 self.weight_decay
408 .into_pyobject(py)
409 .unwrap()
410 .into_any()
411 .unbind(),
412 );
413 defaults.insert(
414 "amsgrad".to_string(),
415 PyBool::new(py, self.amsgrad).to_owned().into(),
416 );
417 });
418 Ok(defaults)
419 }
420
421 #[getter]
423 fn lr(&self) -> f32 {
424 self.lr
425 }
426
427 #[setter]
429 fn set_lr(&mut self, lr: f32) {
430 self.lr = lr;
431 Python::attach(|py| {
432 for param_group in &mut self.param_groups {
433 param_group.insert(
434 "lr".to_string(),
435 lr.into_pyobject(py).unwrap().into_any().unbind(),
436 );
437 }
438 });
439 }
440
441 #[getter]
443 fn betas(&self) -> (f32, f32) {
444 self.betas
445 }
446
447 #[getter]
449 fn eps(&self) -> f32 {
450 self.eps
451 }
452
453 #[getter]
455 fn weight_decay(&self) -> f32 {
456 self.weight_decay
457 }
458
459 #[getter]
461 fn amsgrad(&self) -> bool {
462 self.amsgrad
463 }
464}