1use crate::traits::StatefulOptimizer;
8use crate::{Adam, AdamW, LRScheduler, OptimizerState, SGD};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use trustformers_core::errors::{Result, TrustformersError};
13use trustformers_core::traits::Optimizer;
14use trustformers_core::Tensor;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PyTorchParamGroup {
19 pub params: Vec<String>, pub lr: f64,
21 pub weight_decay: f64,
22 pub momentum: Option<f64>,
23 pub dampening: Option<f64>,
24 pub eps: Option<f64>,
25 pub betas: Option<(f64, f64)>,
26 pub alpha: Option<f64>,
27 pub amsgrad: Option<bool>,
28 pub maximize: Option<bool>,
29 pub foreach: Option<bool>,
30 pub differentiable: Option<bool>,
31}
32
33impl Default for PyTorchParamGroup {
34 fn default() -> Self {
35 Self {
36 params: Vec::new(),
37 lr: 0.001,
38 weight_decay: 0.0,
39 momentum: None,
40 dampening: None,
41 eps: Some(1e-8),
42 betas: Some((0.9, 0.999)),
43 alpha: None,
44 amsgrad: Some(false),
45 maximize: Some(false),
46 foreach: None,
47 differentiable: Some(false),
48 }
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct PyTorchOptimizerState {
55 pub state: HashMap<String, serde_json::Value>,
56 pub param_groups: Vec<PyTorchParamGroup>,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct PyTorchOptimizerConfig {
62 pub optimizer_type: String,
63 pub learning_rate: f64,
64 pub betas: (f64, f64),
65 pub epsilon: f64,
66 pub weight_decay: f64,
67 pub amsgrad: bool,
68 pub maximize: bool,
69 pub parameters: HashMap<String, serde_json::Value>,
70}
71
72impl Default for PyTorchOptimizerConfig {
73 fn default() -> Self {
74 Self {
75 optimizer_type: "Adam".to_string(),
76 learning_rate: 1e-3,
77 betas: (0.9, 0.999),
78 epsilon: 1e-8,
79 weight_decay: 0.0,
80 amsgrad: false,
81 maximize: false,
82 parameters: HashMap::new(),
83 }
84 }
85}
86
87pub trait PyTorchOptimizer: Send + Sync {
89 fn param_groups(&self) -> &[PyTorchParamGroup];
91
92 fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup];
94
95 fn state_dict(&self) -> PyTorchOptimizerState;
97
98 fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()>;
100
101 fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>>;
103
104 fn zero_grad(&mut self, set_to_none: bool) -> Result<()>;
106
107 fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()>;
109
110 fn defaults(&self) -> PyTorchParamGroup;
112}
113
114#[derive(Debug)]
116pub struct PyTorchAdam {
117 inner: Adam,
118 param_groups: Vec<PyTorchParamGroup>,
119 parameters: Arc<Mutex<HashMap<String, Tensor>>>,
120 gradients: Arc<Mutex<HashMap<String, Tensor>>>,
121}
122
123impl PyTorchAdam {
124 pub fn new(
126 params: Vec<PyTorchParamGroup>,
127 lr: f64,
128 betas: (f64, f64),
129 eps: f64,
130 weight_decay: f64,
131 _amsgrad: bool,
132 ) -> Result<Self> {
133 let inner = Adam::new(
134 lr as f32,
135 (betas.0 as f32, betas.1 as f32),
136 eps as f32,
137 weight_decay as f32,
138 );
139
140 Ok(Self {
141 inner,
142 param_groups: params,
143 parameters: Arc::new(Mutex::new(HashMap::new())),
144 gradients: Arc::new(Mutex::new(HashMap::new())),
145 })
146 }
147
148 pub fn from_params(params: impl IntoIterator<Item = (String, Tensor)>) -> Result<Self> {
150 let param_group = PyTorchParamGroup {
151 params: params.into_iter().map(|(name, _)| name).collect(),
152 ..Default::default()
153 };
154
155 Self::new(vec![param_group], 0.001, (0.9, 0.999), 1e-8, 0.0, false)
156 }
157
158 pub fn from_config(config: PyTorchOptimizerConfig) -> Result<Self> {
160 let param_group = PyTorchParamGroup {
162 params: config.parameters.keys().cloned().collect(),
163 lr: config.learning_rate,
164 weight_decay: config.weight_decay,
165 eps: Some(config.epsilon),
166 betas: Some(config.betas),
167 amsgrad: Some(config.amsgrad),
168 maximize: Some(config.maximize),
169 ..Default::default()
170 };
171
172 Self::new(
173 vec![param_group],
174 config.learning_rate,
175 config.betas,
176 config.epsilon,
177 config.weight_decay,
178 config.amsgrad,
179 )
180 }
181
182 pub fn from_cross_framework_config(
184 config: crate::cross_framework::PyTorchOptimizerConfig,
185 ) -> Result<Self> {
186 let betas = if let Some(betas_val) = config.parameters.get("betas") {
188 if let Some(arr) = betas_val.as_array() {
189 (
190 arr[0].as_f64().unwrap_or(0.9),
191 arr[1].as_f64().unwrap_or(0.999),
192 )
193 } else {
194 (0.9, 0.999)
195 }
196 } else {
197 (0.9, 0.999)
198 };
199
200 let epsilon = config.parameters.get("epsilon").and_then(|v| v.as_f64()).unwrap_or(1e-8);
201
202 let weight_decay =
203 config.parameters.get("weight_decay").and_then(|v| v.as_f64()).unwrap_or(0.0);
204
205 let amsgrad = config.parameters.get("amsgrad").and_then(|v| v.as_bool()).unwrap_or(false);
206
207 let param_group = PyTorchParamGroup {
209 params: Vec::new(),
210 lr: config.learning_rate as f64,
211 weight_decay,
212 eps: Some(epsilon),
213 betas: Some(betas),
214 amsgrad: Some(amsgrad),
215 maximize: Some(false),
216 ..Default::default()
217 };
218
219 Self::new(
220 vec![param_group],
221 config.learning_rate as f64,
222 betas,
223 epsilon,
224 weight_decay,
225 amsgrad,
226 )
227 }
228
229 pub fn register_param(&mut self, name: String, param: Tensor) -> Result<()> {
231 let mut params = self
232 .parameters
233 .lock()
234 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
235 params.insert(name, param);
236 Ok(())
237 }
238
239 pub fn set_grad(&mut self, name: String, grad: Tensor) -> Result<()> {
241 let mut grads = self
242 .gradients
243 .lock()
244 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
245 grads.insert(name, grad);
246 Ok(())
247 }
248
249 fn load_optimizer_state(&mut self, optimizer_state: OptimizerState) -> Result<()> {
253 for (param_name, momentum_data) in optimizer_state.momentum {
255 let momentum_tensor = Tensor::new(momentum_data)?;
256 let mut params = self
263 .parameters
264 .lock()
265 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
266 if !params.contains_key(¶m_name) {
267 params.insert(param_name.clone(), momentum_tensor.clone());
269 }
270 }
271
272 for (param_name, variance_data) in optimizer_state.variance {
274 let variance_tensor = Tensor::new(variance_data)?;
275 let mut params = self
280 .parameters
281 .lock()
282 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
283 if !params.contains_key(¶m_name) {
284 params.insert(param_name.clone(), variance_tensor.clone());
285 }
286 }
287
288 Ok(())
293 }
294}
295
296impl PyTorchOptimizer for PyTorchAdam {
297 fn param_groups(&self) -> &[PyTorchParamGroup] {
298 &self.param_groups
299 }
300
301 fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup] {
302 &mut self.param_groups
303 }
304
305 fn state_dict(&self) -> PyTorchOptimizerState {
306 let state = self.inner.state();
307 let state_json = serde_json::to_value(state).unwrap_or_default();
308
309 PyTorchOptimizerState {
310 state: [(String::from("adam_state"), state_json)].into(),
311 param_groups: self.param_groups.clone(),
312 }
313 }
314
315 fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()> {
316 self.param_groups = state.param_groups;
317
318 if let Some(adam_state) = state.state.get("adam_state") {
319 if let Ok(optimizer_state) =
320 serde_json::from_value::<OptimizerState>(adam_state.clone())
321 {
322 self.load_optimizer_state(optimizer_state)?;
324 }
325 }
326
327 Ok(())
328 }
329
330 fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>> {
331 let loss = closure.map(|closure_fn| closure_fn());
332
333 for group in &self.param_groups {
335 for param_name in &group.params {
336 let param_copy = {
338 let params = self.parameters.lock().map_err(|_| {
339 TrustformersError::runtime_error("Mutex lock poisoned".into())
340 })?;
341 params.get(param_name).cloned()
342 };
343 let grad_copy = {
344 let grads = self.gradients.lock().map_err(|_| {
345 TrustformersError::runtime_error("Mutex lock poisoned".into())
346 })?;
347 grads.get(param_name).cloned()
348 };
349
350 if let (Some(mut param), Some(grad)) = (param_copy, grad_copy) {
351 self.inner.update(&mut param, &grad)?;
353
354 let mut params = self.parameters.lock().map_err(|_| {
356 TrustformersError::runtime_error("Mutex lock poisoned".into())
357 })?;
358 params.insert(param_name.clone(), param);
359 }
360 }
361 }
362
363 Ok(loss)
364 }
365
366 fn zero_grad(&mut self, _set_to_none: bool) -> Result<()> {
367 let mut grads = self
368 .gradients
369 .lock()
370 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
371 grads.clear();
372 Ok(())
373 }
374
375 fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()> {
376 self.param_groups.push(param_group);
377 Ok(())
378 }
379
380 fn defaults(&self) -> PyTorchParamGroup {
381 PyTorchParamGroup {
382 lr: 0.001,
383 betas: Some((0.9, 0.999)),
384 eps: Some(1e-8),
385 weight_decay: 0.0,
386 amsgrad: Some(false),
387 ..Default::default()
388 }
389 }
390}
391
392#[derive(Debug)]
394pub struct PyTorchAdamW {
395 inner: AdamW,
396 param_groups: Vec<PyTorchParamGroup>,
397 parameters: Arc<Mutex<HashMap<String, Tensor>>>,
398 gradients: Arc<Mutex<HashMap<String, Tensor>>>,
399}
400
401impl PyTorchAdamW {
402 pub fn new(
404 params: Vec<PyTorchParamGroup>,
405 lr: f64,
406 betas: (f64, f64),
407 eps: f64,
408 weight_decay: f64,
409 _amsgrad: bool,
410 ) -> Result<Self> {
411 let inner = AdamW::new(
412 lr as f32,
413 (betas.0 as f32, betas.1 as f32),
414 eps as f32,
415 weight_decay as f32,
416 );
417
418 Ok(Self {
419 inner,
420 param_groups: params,
421 parameters: Arc::new(Mutex::new(HashMap::new())),
422 gradients: Arc::new(Mutex::new(HashMap::new())),
423 })
424 }
425
426 pub fn from_params(params: impl IntoIterator<Item = (String, Tensor)>) -> Result<Self> {
428 let param_group = PyTorchParamGroup {
429 params: params.into_iter().map(|(name, _)| name).collect(),
430 ..Default::default()
431 };
432
433 Self::new(vec![param_group], 0.001, (0.9, 0.999), 1e-8, 0.01, false)
434 }
435
436 pub fn register_param(&mut self, name: String, param: Tensor) -> Result<()> {
438 let mut params = self
439 .parameters
440 .lock()
441 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
442 params.insert(name, param);
443 Ok(())
444 }
445
446 pub fn set_grad(&mut self, name: String, grad: Tensor) -> Result<()> {
448 let mut grads = self
449 .gradients
450 .lock()
451 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
452 grads.insert(name, grad);
453 Ok(())
454 }
455
456 fn load_optimizer_state(&mut self, optimizer_state: OptimizerState) -> Result<()> {
460 for (param_name, momentum_data) in optimizer_state.momentum {
462 let momentum_tensor = Tensor::new(momentum_data)?;
463 let mut params = self
464 .parameters
465 .lock()
466 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
467 if !params.contains_key(¶m_name) {
468 params.insert(param_name.clone(), momentum_tensor.clone());
469 }
470 }
471
472 for (param_name, variance_data) in optimizer_state.variance {
474 let variance_tensor = Tensor::new(variance_data)?;
475 let mut params = self
476 .parameters
477 .lock()
478 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
479 if !params.contains_key(¶m_name) {
480 params.insert(param_name.clone(), variance_tensor.clone());
481 }
482 }
483
484 Ok(())
485 }
486}
487
488impl PyTorchOptimizer for PyTorchAdamW {
489 fn param_groups(&self) -> &[PyTorchParamGroup] {
490 &self.param_groups
491 }
492
493 fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup] {
494 &mut self.param_groups
495 }
496
497 fn state_dict(&self) -> PyTorchOptimizerState {
498 let state = self.inner.state();
499 let state_json = serde_json::to_value(state).unwrap_or_default();
500
501 PyTorchOptimizerState {
502 state: [(String::from("adamw_state"), state_json)].into(),
503 param_groups: self.param_groups.clone(),
504 }
505 }
506
507 fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()> {
508 self.param_groups = state.param_groups;
509
510 if let Some(adamw_state) = state.state.get("adamw_state") {
511 if let Ok(optimizer_state) =
512 serde_json::from_value::<OptimizerState>(adamw_state.clone())
513 {
514 self.load_optimizer_state(optimizer_state)?;
516 }
517 }
518
519 Ok(())
520 }
521
522 fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>> {
523 let loss = closure.map(|closure_fn| closure_fn());
524
525 for group in &self.param_groups {
526 for param_name in &group.params {
527 let param_copy = {
529 let params = self.parameters.lock().map_err(|_| {
530 TrustformersError::runtime_error("Mutex lock poisoned".into())
531 })?;
532 params.get(param_name).cloned()
533 };
534 let grad_copy = {
535 let grads = self.gradients.lock().map_err(|_| {
536 TrustformersError::runtime_error("Mutex lock poisoned".into())
537 })?;
538 grads.get(param_name).cloned()
539 };
540
541 if let (Some(mut param), Some(grad)) = (param_copy, grad_copy) {
542 self.inner.update(&mut param, &grad)?;
544
545 let mut params = self.parameters.lock().map_err(|_| {
547 TrustformersError::runtime_error("Mutex lock poisoned".into())
548 })?;
549 params.insert(param_name.clone(), param);
550 }
551 }
552 }
553
554 Ok(loss)
555 }
556
557 fn zero_grad(&mut self, _set_to_none: bool) -> Result<()> {
558 let mut grads = self
559 .gradients
560 .lock()
561 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
562 grads.clear();
563 Ok(())
564 }
565
566 fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()> {
567 self.param_groups.push(param_group);
568 Ok(())
569 }
570
571 fn defaults(&self) -> PyTorchParamGroup {
572 PyTorchParamGroup {
573 lr: 0.001,
574 betas: Some((0.9, 0.999)),
575 eps: Some(1e-8),
576 weight_decay: 0.01,
577 amsgrad: Some(false),
578 ..Default::default()
579 }
580 }
581}
582
583#[derive(Debug)]
585pub struct PyTorchSGD {
586 inner: SGD,
587 param_groups: Vec<PyTorchParamGroup>,
588 parameters: Arc<Mutex<HashMap<String, Tensor>>>,
589 gradients: Arc<Mutex<HashMap<String, Tensor>>>,
590}
591
592impl PyTorchSGD {
593 pub fn new(
595 params: Vec<PyTorchParamGroup>,
596 lr: f64,
597 momentum: f64,
598 dampening: f64,
599 weight_decay: f64,
600 nesterov: bool,
601 ) -> Result<Self> {
602 let config = crate::sgd::SGDConfig {
603 lr: lr as f32,
604 momentum: momentum as f32,
605 dampening: dampening as f32,
606 weight_decay: weight_decay as f32,
607 nesterov,
608 };
609
610 let inner = SGD::from_config(config);
611
612 Ok(Self {
613 inner,
614 param_groups: params,
615 parameters: Arc::new(Mutex::new(HashMap::new())),
616 gradients: Arc::new(Mutex::new(HashMap::new())),
617 })
618 }
619
620 pub fn from_params(params: impl IntoIterator<Item = (String, Tensor)>) -> Result<Self> {
622 let param_group = PyTorchParamGroup {
623 params: params.into_iter().map(|(name, _)| name).collect(),
624 lr: 0.01,
625 momentum: Some(0.0),
626 dampening: Some(0.0),
627 weight_decay: 0.0,
628 ..Default::default()
629 };
630
631 Self::new(vec![param_group], 0.01, 0.0, 0.0, 0.0, false)
632 }
633
634 pub fn register_param(&mut self, name: String, param: Tensor) -> Result<()> {
636 let mut params = self
637 .parameters
638 .lock()
639 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
640 params.insert(name, param);
641 Ok(())
642 }
643
644 pub fn set_grad(&mut self, name: String, grad: Tensor) -> Result<()> {
646 let mut grads = self
647 .gradients
648 .lock()
649 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
650 grads.insert(name, grad);
651 Ok(())
652 }
653
654 fn load_optimizer_state(&mut self, optimizer_state: OptimizerState) -> Result<()> {
658 for (param_name, momentum_data) in optimizer_state.momentum {
660 let momentum_tensor = Tensor::new(momentum_data)?;
661 let mut params = self
662 .parameters
663 .lock()
664 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
665 if !params.contains_key(¶m_name) {
666 params.insert(param_name.clone(), momentum_tensor.clone());
667 }
668 }
669
670 for (param_name, variance_data) in optimizer_state.variance {
672 let variance_tensor = Tensor::new(variance_data)?;
673 let mut params = self
674 .parameters
675 .lock()
676 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
677 if !params.contains_key(¶m_name) {
678 params.insert(param_name.clone(), variance_tensor.clone());
679 }
680 }
681
682 Ok(())
683 }
684}
685
686impl PyTorchOptimizer for PyTorchSGD {
687 fn param_groups(&self) -> &[PyTorchParamGroup] {
688 &self.param_groups
689 }
690
691 fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup] {
692 &mut self.param_groups
693 }
694
695 fn state_dict(&self) -> PyTorchOptimizerState {
696 let state = self.inner.state();
697 let state_json = serde_json::to_value(state).unwrap_or_default();
698
699 PyTorchOptimizerState {
700 state: [(String::from("sgd_state"), state_json)].into(),
701 param_groups: self.param_groups.clone(),
702 }
703 }
704
705 fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()> {
706 self.param_groups = state.param_groups;
707
708 if let Some(sgd_state) = state.state.get("sgd_state") {
709 if let Ok(optimizer_state) = serde_json::from_value::<OptimizerState>(sgd_state.clone())
710 {
711 self.load_optimizer_state(optimizer_state)?;
713 }
714 }
715
716 Ok(())
717 }
718
719 fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>> {
720 let loss = closure.map(|closure_fn| closure_fn());
721
722 for group in &self.param_groups {
723 for param_name in &group.params {
724 let param_copy = {
726 let params = self.parameters.lock().map_err(|_| {
727 TrustformersError::runtime_error("Mutex lock poisoned".into())
728 })?;
729 params.get(param_name).cloned()
730 };
731 let grad_copy = {
732 let grads = self.gradients.lock().map_err(|_| {
733 TrustformersError::runtime_error("Mutex lock poisoned".into())
734 })?;
735 grads.get(param_name).cloned()
736 };
737
738 if let (Some(mut param), Some(grad)) = (param_copy, grad_copy) {
739 self.inner.update(&mut param, &grad)?;
741
742 let mut params = self.parameters.lock().map_err(|_| {
744 TrustformersError::runtime_error("Mutex lock poisoned".into())
745 })?;
746 params.insert(param_name.clone(), param);
747 }
748 }
749 }
750
751 Ok(loss)
752 }
753
754 fn zero_grad(&mut self, _set_to_none: bool) -> Result<()> {
755 let mut grads = self
756 .gradients
757 .lock()
758 .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
759 grads.clear();
760 Ok(())
761 }
762
763 fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()> {
764 self.param_groups.push(param_group);
765 Ok(())
766 }
767
768 fn defaults(&self) -> PyTorchParamGroup {
769 PyTorchParamGroup {
770 lr: 0.01,
771 momentum: Some(0.0),
772 dampening: Some(0.0),
773 weight_decay: 0.0,
774 ..Default::default()
775 }
776 }
777}
778
779pub struct PyTorchOptimizerFactory;
781
782impl PyTorchOptimizerFactory {
783 pub fn adam(
785 params: impl IntoIterator<Item = (String, Tensor)>,
786 lr: f64,
787 betas: (f64, f64),
788 eps: f64,
789 weight_decay: f64,
790 amsgrad: bool,
791 ) -> Result<PyTorchAdam> {
792 let param_group = PyTorchParamGroup {
793 params: params.into_iter().map(|(name, _)| name).collect(),
794 lr,
795 betas: Some(betas),
796 eps: Some(eps),
797 weight_decay,
798 amsgrad: Some(amsgrad),
799 ..Default::default()
800 };
801
802 PyTorchAdam::new(vec![param_group], lr, betas, eps, weight_decay, amsgrad)
803 }
804
805 pub fn adamw(
807 params: impl IntoIterator<Item = (String, Tensor)>,
808 lr: f64,
809 betas: (f64, f64),
810 eps: f64,
811 weight_decay: f64,
812 amsgrad: bool,
813 ) -> Result<PyTorchAdamW> {
814 let param_group = PyTorchParamGroup {
815 params: params.into_iter().map(|(name, _)| name).collect(),
816 lr,
817 betas: Some(betas),
818 eps: Some(eps),
819 weight_decay,
820 amsgrad: Some(amsgrad),
821 ..Default::default()
822 };
823
824 PyTorchAdamW::new(vec![param_group], lr, betas, eps, weight_decay, amsgrad)
825 }
826
827 pub fn sgd(
829 params: impl IntoIterator<Item = (String, Tensor)>,
830 lr: f64,
831 momentum: f64,
832 dampening: f64,
833 weight_decay: f64,
834 nesterov: bool,
835 ) -> Result<PyTorchSGD> {
836 let param_group = PyTorchParamGroup {
837 params: params.into_iter().map(|(name, _)| name).collect(),
838 lr,
839 momentum: Some(momentum),
840 dampening: Some(dampening),
841 weight_decay,
842 ..Default::default()
843 };
844
845 PyTorchSGD::new(
846 vec![param_group],
847 lr,
848 momentum,
849 dampening,
850 weight_decay,
851 nesterov,
852 )
853 }
854}
855
856pub struct PyTorchLRScheduler {
858 inner_scheduler: Box<dyn LRScheduler>,
859 optimizer: Box<dyn PyTorchOptimizer>,
860 last_epoch: i64,
861}
862
863impl PyTorchLRScheduler {
864 pub fn new(optimizer: Box<dyn PyTorchOptimizer>, scheduler: Box<dyn LRScheduler>) -> Self {
866 Self {
867 inner_scheduler: scheduler,
868 optimizer,
869 last_epoch: -1,
870 }
871 }
872
873 pub fn step(&mut self, epoch: Option<i64>) -> Result<()> {
875 let current_epoch = epoch.unwrap_or(self.last_epoch + 1);
876 self.last_epoch = current_epoch;
877
878 let new_lr = self.inner_scheduler.get_lr(current_epoch as usize);
879
880 for group in self.optimizer.param_groups_mut() {
882 group.lr = new_lr as f64;
883 }
884
885 Ok(())
886 }
887
888 pub fn get_last_lr(&self) -> f64 {
890 self.inner_scheduler.get_lr(self.last_epoch.max(0) as usize) as f64
891 }
892
893 pub fn state_dict(&self) -> serde_json::Value {
895 serde_json::json!({
896 "last_epoch": self.last_epoch,
897 "scheduler_state": "serialized_state" })
899 }
900
901 pub fn load_state_dict(&mut self, state: serde_json::Value) -> Result<()> {
903 if let Some(epoch) = state.get("last_epoch").and_then(|e| e.as_i64()) {
904 self.last_epoch = epoch;
905 }
906 Ok(())
907 }
908}
909
910#[cfg(test)]
911mod tests {
912 use super::*;
913 use trustformers_core::Tensor;
914
915 #[test]
916 fn test_pytorch_adam_creation() {
917 let params = vec![
918 ("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap()),
919 ("param2".to_string(), Tensor::zeros(&[5, 5]).unwrap()),
920 ];
921
922 let optimizer = PyTorchAdam::from_params(params).unwrap();
923 assert_eq!(optimizer.param_groups().len(), 1);
924 assert_eq!(optimizer.param_groups()[0].params.len(), 2);
925 }
926
927 #[test]
928 fn test_pytorch_adamw_creation() {
929 let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
930
931 let optimizer = PyTorchAdamW::from_params(params).unwrap();
932 assert_eq!(optimizer.param_groups().len(), 1);
933 assert_eq!(optimizer.defaults().weight_decay, 0.01);
934 }
935
936 #[test]
937 fn test_pytorch_sgd_creation() {
938 let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
939
940 let optimizer = PyTorchSGD::from_params(params).unwrap();
941 assert_eq!(optimizer.param_groups().len(), 1);
942 assert_eq!(optimizer.defaults().lr, 0.01);
943 }
944
945 #[test]
946 fn test_pytorch_optimizer_factory() {
947 let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
948
949 let adam =
950 PyTorchOptimizerFactory::adam(params.clone(), 0.001, (0.9, 0.999), 1e-8, 0.0, false)
951 .unwrap();
952 assert_eq!(adam.param_groups()[0].lr, 0.001);
953
954 let adamw =
955 PyTorchOptimizerFactory::adamw(params.clone(), 0.001, (0.9, 0.999), 1e-8, 0.01, false)
956 .unwrap();
957 assert_eq!(adamw.param_groups()[0].weight_decay, 0.01);
958
959 let sgd = PyTorchOptimizerFactory::sgd(params, 0.01, 0.9, 0.0, 0.0, false).unwrap();
960 assert_eq!(sgd.param_groups()[0].momentum, Some(0.9));
961 }
962
963 #[test]
964 fn test_param_group_operations() {
965 let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
966
967 let mut optimizer = PyTorchAdam::from_params(params).unwrap();
968
969 let new_group = PyTorchParamGroup {
970 params: vec!["param2".to_string()],
971 lr: 0.002,
972 ..Default::default()
973 };
974
975 optimizer.add_param_group(new_group).unwrap();
976 assert_eq!(optimizer.param_groups().len(), 2);
977 assert_eq!(optimizer.param_groups()[1].lr, 0.002);
978 }
979
980 #[test]
981 fn test_state_dict_operations() {
982 let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
983
984 let optimizer = PyTorchAdam::from_params(params).unwrap();
985 let state_dict = optimizer.state_dict();
986
987 assert_eq!(state_dict.param_groups.len(), 1);
988 assert!(state_dict.state.contains_key("adam_state"));
989 }
990
991 #[test]
992 fn test_zero_grad() {
993 let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
994
995 let mut optimizer = PyTorchAdam::from_params(params).unwrap();
996 optimizer
997 .set_grad("param1".to_string(), Tensor::ones(&[10, 10]).unwrap())
998 .unwrap();
999
1000 assert_eq!(
1002 optimizer.gradients.lock().expect("Mutex lock poisoned").len(),
1003 1
1004 );
1005
1006 optimizer.zero_grad(false).unwrap();
1008 assert_eq!(
1009 optimizer.gradients.lock().expect("Mutex lock poisoned").len(),
1010 0
1011 );
1012 }
1013}