1use super::base::{create_param_group, extract_parameters, PyOptimizer};
4use crate::{error::PyResult, tensor::PyTensor};
5use parking_lot::RwLock;
6use pyo3::prelude::*;
7use pyo3::types::{PyAny, PyBool};
8use std::collections::HashMap;
9use std::sync::Arc;
10use torsh_optim::{Adam, AdamW, Optimizer};
11
12#[pyclass(name = "Adam", extends = PyOptimizer)]
14pub struct PyAdam {
15 adam: Adam,
16 param_groups: Vec<HashMap<String, Py<PyAny>>>,
17 lr: f32,
18 betas: (f32, f32),
19 eps: f32,
20 weight_decay: f32,
21 amsgrad: bool,
22}
23
24#[pymethods]
25impl PyAdam {
26 #[new]
27 fn new(
28 params: Vec<PyTensor>,
29 lr: Option<f32>,
30 betas: Option<(f32, f32)>,
31 eps: Option<f32>,
32 weight_decay: Option<f32>,
33 amsgrad: Option<bool>,
34 ) -> (Self, PyOptimizer) {
35 let lr = lr.unwrap_or(0.001);
36 let betas = betas.unwrap_or((0.9, 0.999));
37 let eps = eps.unwrap_or(1e-8);
38 let weight_decay = weight_decay.unwrap_or(0.0);
39 let amsgrad = amsgrad.unwrap_or(false);
40
41 let tensor_params =
43 extract_parameters(params.clone()).expect("parameter extraction should succeed");
44 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
45 .into_iter()
46 .map(|tensor| Arc::new(RwLock::new(tensor)))
47 .collect();
48 let adam = Adam::new(
49 wrapped_params,
50 Some(lr),
51 Some(betas),
52 Some(eps),
53 Some(weight_decay),
54 amsgrad,
55 );
56
57 let mut param_group_data = HashMap::new();
59 Python::attach(|py| {
60 param_group_data.insert(
61 "betas".to_string(),
62 betas
63 .into_pyobject(py)
64 .expect("Python object conversion should succeed")
65 .into_any()
66 .unbind(),
67 );
68 param_group_data.insert(
69 "eps".to_string(),
70 eps.into_pyobject(py)
71 .expect("Python object conversion should succeed")
72 .into_any()
73 .unbind(),
74 );
75 param_group_data.insert(
76 "weight_decay".to_string(),
77 weight_decay
78 .into_pyobject(py)
79 .expect("Python object conversion should succeed")
80 .into_any()
81 .unbind(),
82 );
83 param_group_data.insert(
84 "amsgrad".to_string(),
85 PyBool::new(py, amsgrad).to_owned().into(),
86 );
87 });
88
89 let param_groups = vec![create_param_group(params, lr, param_group_data)
90 .expect("param group creation should succeed")];
91
92 (
93 Self {
94 adam,
95 param_groups,
96 lr,
97 betas,
98 eps,
99 weight_decay,
100 amsgrad,
101 },
102 PyOptimizer {},
103 )
104 }
105
106 fn step(&mut self) -> PyResult<()> {
108 self.adam.step().map_err(|e| {
109 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
110 "Optimizer step failed: {}",
111 e
112 ))
113 })?;
114 Ok(())
115 }
116
117 fn zero_grad(&mut self, set_to_none: Option<bool>) {
119 let _set_to_none = set_to_none.unwrap_or(false);
120 self.adam.zero_grad();
121 }
122
123 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
125 Python::attach(|py| {
127 let cloned_groups = self
128 .param_groups
129 .iter()
130 .map(|group| {
131 group
132 .iter()
133 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
134 .collect()
135 })
136 .collect();
137 Ok(cloned_groups)
138 })
139 }
140
141 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
143 let mut state = HashMap::new();
144 Python::attach(|py| {
145 state.insert(
146 "step".to_string(),
147 0i64.into_pyobject(py)
148 .expect("Python object conversion should succeed")
149 .into_any()
150 .unbind(),
151 );
152 state.insert(
153 "exp_avg".to_string(),
154 "{}".into_pyobject(py)
155 .expect("Python object conversion should succeed")
156 .into_any()
157 .unbind(),
158 );
159 state.insert(
160 "exp_avg_sq".to_string(),
161 "{}".into_pyobject(py)
162 .expect("Python object conversion should succeed")
163 .into_any()
164 .unbind(),
165 );
166 if self.amsgrad {
167 state.insert(
168 "max_exp_avg_sq".to_string(),
169 "{}".into_pyobject(py)
170 .expect("Python object conversion should succeed")
171 .into_any()
172 .unbind(),
173 );
174 }
175 });
176 Ok(state)
177 }
178
179 fn __repr__(&self) -> String {
181 format!(
182 "Adam(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
183 self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
184 )
185 }
186
187 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
189 let mut defaults = HashMap::new();
190 Python::attach(|py| {
191 defaults.insert(
192 "lr".to_string(),
193 self.lr
194 .into_pyobject(py)
195 .expect("Python object conversion should succeed")
196 .into_any()
197 .unbind(),
198 );
199 defaults.insert(
200 "betas".to_string(),
201 self.betas
202 .into_pyobject(py)
203 .expect("Python object conversion should succeed")
204 .into_any()
205 .unbind(),
206 );
207 defaults.insert(
208 "eps".to_string(),
209 self.eps
210 .into_pyobject(py)
211 .expect("Python object conversion should succeed")
212 .into_any()
213 .unbind(),
214 );
215 defaults.insert(
216 "weight_decay".to_string(),
217 self.weight_decay
218 .into_pyobject(py)
219 .expect("Python object conversion should succeed")
220 .into_any()
221 .unbind(),
222 );
223 defaults.insert(
224 "amsgrad".to_string(),
225 PyBool::new(py, self.amsgrad).to_owned().into(),
226 );
227 });
228 Ok(defaults)
229 }
230
231 #[getter]
233 fn lr(&self) -> f32 {
234 self.lr
235 }
236
237 #[setter]
239 fn set_lr(&mut self, lr: f32) {
240 self.lr = lr;
241 Python::attach(|py| {
242 for param_group in &mut self.param_groups {
243 param_group.insert(
244 "lr".to_string(),
245 lr.into_pyobject(py)
246 .expect("Python object conversion should succeed")
247 .into_any()
248 .unbind(),
249 );
250 }
251 });
252 }
253
254 #[getter]
256 fn betas(&self) -> (f32, f32) {
257 self.betas
258 }
259
260 #[getter]
262 fn eps(&self) -> f32 {
263 self.eps
264 }
265
266 #[getter]
268 fn weight_decay(&self) -> f32 {
269 self.weight_decay
270 }
271
272 #[getter]
274 fn amsgrad(&self) -> bool {
275 self.amsgrad
276 }
277}
278
279#[pyclass(name = "AdamW", extends = PyOptimizer)]
281pub struct PyAdamW {
282 adamw: AdamW,
283 param_groups: Vec<HashMap<String, Py<PyAny>>>,
284 lr: f32,
285 betas: (f32, f32),
286 eps: f32,
287 weight_decay: f32,
288 amsgrad: bool,
289}
290
291#[pymethods]
292impl PyAdamW {
293 #[new]
294 fn new(
295 params: Vec<PyTensor>,
296 lr: Option<f32>,
297 betas: Option<(f32, f32)>,
298 eps: Option<f32>,
299 weight_decay: Option<f32>,
300 amsgrad: Option<bool>,
301 ) -> (Self, PyOptimizer) {
302 let lr = lr.unwrap_or(0.001);
303 let betas = betas.unwrap_or((0.9, 0.999));
304 let eps = eps.unwrap_or(1e-8);
305 let weight_decay = weight_decay.unwrap_or(0.01);
306 let amsgrad = amsgrad.unwrap_or(false);
307
308 let tensor_params =
310 extract_parameters(params.clone()).expect("parameter extraction should succeed");
311 let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
312 .into_iter()
313 .map(|tensor| Arc::new(RwLock::new(tensor)))
314 .collect();
315 let adamw = AdamW::new(
316 wrapped_params,
317 Some(lr),
318 Some(betas),
319 Some(eps),
320 Some(weight_decay),
321 amsgrad,
322 );
323
324 let mut param_group_data = HashMap::new();
326 Python::attach(|py| {
327 param_group_data.insert(
328 "betas".to_string(),
329 betas
330 .into_pyobject(py)
331 .expect("Python object conversion should succeed")
332 .into_any()
333 .unbind(),
334 );
335 param_group_data.insert(
336 "eps".to_string(),
337 eps.into_pyobject(py)
338 .expect("Python object conversion should succeed")
339 .into_any()
340 .unbind(),
341 );
342 param_group_data.insert(
343 "weight_decay".to_string(),
344 weight_decay
345 .into_pyobject(py)
346 .expect("Python object conversion should succeed")
347 .into_any()
348 .unbind(),
349 );
350 param_group_data.insert(
351 "amsgrad".to_string(),
352 PyBool::new(py, amsgrad).to_owned().into(),
353 );
354 });
355
356 let param_groups = vec![create_param_group(params, lr, param_group_data)
357 .expect("param group creation should succeed")];
358
359 (
360 Self {
361 adamw,
362 param_groups,
363 lr,
364 betas,
365 eps,
366 weight_decay,
367 amsgrad,
368 },
369 PyOptimizer {},
370 )
371 }
372
373 fn step(&mut self) -> PyResult<()> {
375 self.adamw.step().map_err(|e| {
376 PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
377 "Optimizer step failed: {}",
378 e
379 ))
380 })?;
381 Ok(())
382 }
383
384 fn zero_grad(&mut self, set_to_none: Option<bool>) {
386 let _set_to_none = set_to_none.unwrap_or(false);
387 self.adamw.zero_grad();
388 }
389
390 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
392 Python::attach(|py| {
394 let cloned_groups = self
395 .param_groups
396 .iter()
397 .map(|group| {
398 group
399 .iter()
400 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
401 .collect()
402 })
403 .collect();
404 Ok(cloned_groups)
405 })
406 }
407
408 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
410 let mut state = HashMap::new();
411 Python::attach(|py| {
412 state.insert(
413 "step".to_string(),
414 0i64.into_pyobject(py)
415 .expect("Python object conversion should succeed")
416 .into_any()
417 .unbind(),
418 );
419 state.insert(
420 "exp_avg".to_string(),
421 "{}".into_pyobject(py)
422 .expect("Python object conversion should succeed")
423 .into_any()
424 .unbind(),
425 );
426 state.insert(
427 "exp_avg_sq".to_string(),
428 "{}".into_pyobject(py)
429 .expect("Python object conversion should succeed")
430 .into_any()
431 .unbind(),
432 );
433 if self.amsgrad {
434 state.insert(
435 "max_exp_avg_sq".to_string(),
436 "{}".into_pyobject(py)
437 .expect("Python object conversion should succeed")
438 .into_any()
439 .unbind(),
440 );
441 }
442 });
443 Ok(state)
444 }
445
446 fn __repr__(&self) -> String {
448 format!(
449 "AdamW(lr={}, betas={:?}, eps={}, weight_decay={}, amsgrad={})",
450 self.lr, self.betas, self.eps, self.weight_decay, self.amsgrad
451 )
452 }
453
454 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
456 let mut defaults = HashMap::new();
457 Python::attach(|py| {
458 defaults.insert(
459 "lr".to_string(),
460 self.lr
461 .into_pyobject(py)
462 .expect("Python object conversion should succeed")
463 .into_any()
464 .unbind(),
465 );
466 defaults.insert(
467 "betas".to_string(),
468 self.betas
469 .into_pyobject(py)
470 .expect("Python object conversion should succeed")
471 .into_any()
472 .unbind(),
473 );
474 defaults.insert(
475 "eps".to_string(),
476 self.eps
477 .into_pyobject(py)
478 .expect("Python object conversion should succeed")
479 .into_any()
480 .unbind(),
481 );
482 defaults.insert(
483 "weight_decay".to_string(),
484 self.weight_decay
485 .into_pyobject(py)
486 .expect("Python object conversion should succeed")
487 .into_any()
488 .unbind(),
489 );
490 defaults.insert(
491 "amsgrad".to_string(),
492 PyBool::new(py, self.amsgrad).to_owned().into(),
493 );
494 });
495 Ok(defaults)
496 }
497
498 #[getter]
500 fn lr(&self) -> f32 {
501 self.lr
502 }
503
504 #[setter]
506 fn set_lr(&mut self, lr: f32) {
507 self.lr = lr;
508 Python::attach(|py| {
509 for param_group in &mut self.param_groups {
510 param_group.insert(
511 "lr".to_string(),
512 lr.into_pyobject(py)
513 .expect("Python object conversion should succeed")
514 .into_any()
515 .unbind(),
516 );
517 }
518 });
519 }
520
521 #[getter]
523 fn betas(&self) -> (f32, f32) {
524 self.betas
525 }
526
527 #[getter]
529 fn eps(&self) -> f32 {
530 self.eps
531 }
532
533 #[getter]
535 fn weight_decay(&self) -> f32 {
536 self.weight_decay
537 }
538
539 #[getter]
541 fn amsgrad(&self) -> bool {
542 self.amsgrad
543 }
544}