1use crate::{Adam, AdamW};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use trustformers_core::errors::{Result, TrustformersError};
12use trustformers_core::traits::Optimizer;
13use trustformers_core::Tensor;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TensorFlowOptimizerConfig {
18 pub optimizer_type: String,
19 pub learning_rate: f64,
20 pub beta_1: Option<f64>,
21 pub beta_2: Option<f64>,
22 pub epsilon: Option<f64>,
23 pub weight_decay: Option<f64>,
24 pub clipnorm: Option<f64>,
25 pub clipvalue: Option<f64>,
26 pub global_clipnorm: Option<f64>,
27 pub use_ema: Option<bool>,
28 pub ema_momentum: Option<f64>,
29 pub ema_overwrite_frequency: Option<i32>,
30 pub jit_compile: Option<bool>,
31 pub name: Option<String>,
32 pub parameters: HashMap<String, serde_json::Value>,
33}
34
35impl Default for TensorFlowOptimizerConfig {
36 fn default() -> Self {
37 Self {
38 optimizer_type: "Adam".to_string(),
39 learning_rate: 0.001,
40 beta_1: Some(0.9),
41 beta_2: Some(0.999),
42 epsilon: Some(1e-7),
43 weight_decay: None,
44 clipnorm: None,
45 clipvalue: None,
46 global_clipnorm: None,
47 use_ema: Some(false),
48 ema_momentum: Some(0.99),
49 ema_overwrite_frequency: None,
50 jit_compile: Some(true),
51 name: None,
52 parameters: HashMap::new(),
53 }
54 }
55}
56
57pub trait TensorFlowLearningRateSchedule: Send + Sync {
59 fn get_lr(&self, step: i64) -> f64;
61
62 fn get_config(&self) -> serde_json::Value;
64}
65
66#[derive(Debug, Clone)]
68pub struct TensorFlowExponentialDecay {
69 initial_learning_rate: f64,
70 decay_steps: i64,
71 decay_rate: f64,
72 staircase: bool,
73}
74
75impl TensorFlowExponentialDecay {
76 pub fn new(
77 initial_learning_rate: f64,
78 decay_steps: i64,
79 decay_rate: f64,
80 staircase: bool,
81 ) -> Self {
82 Self {
83 initial_learning_rate,
84 decay_steps,
85 decay_rate,
86 staircase,
87 }
88 }
89}
90
91impl TensorFlowLearningRateSchedule for TensorFlowExponentialDecay {
92 fn get_lr(&self, step: i64) -> f64 {
93 let decay_factor = if self.staircase {
94 (step / self.decay_steps) as f64
95 } else {
96 step as f64 / self.decay_steps as f64
97 };
98
99 self.initial_learning_rate * self.decay_rate.powf(decay_factor)
100 }
101
102 fn get_config(&self) -> serde_json::Value {
103 serde_json::json!({
104 "initial_learning_rate": self.initial_learning_rate,
105 "decay_steps": self.decay_steps,
106 "decay_rate": self.decay_rate,
107 "staircase": self.staircase,
108 })
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct TensorFlowCosineDecay {
115 initial_learning_rate: f64,
116 decay_steps: i64,
117 alpha: f64,
118}
119
120impl TensorFlowCosineDecay {
121 pub fn new(initial_learning_rate: f64, decay_steps: i64, alpha: f64) -> Self {
122 Self {
123 initial_learning_rate,
124 decay_steps,
125 alpha,
126 }
127 }
128}
129
130impl TensorFlowLearningRateSchedule for TensorFlowCosineDecay {
131 fn get_lr(&self, step: i64) -> f64 {
132 let completed_fraction = (step.min(self.decay_steps) as f64) / (self.decay_steps as f64);
133 let cosine_decayed = 0.5 * (1.0 + (std::f64::consts::PI * completed_fraction).cos());
134 let decayed = (1.0 - self.alpha) * cosine_decayed + self.alpha;
135
136 self.initial_learning_rate * decayed
137 }
138
139 fn get_config(&self) -> serde_json::Value {
140 serde_json::json!({
141 "initial_learning_rate": self.initial_learning_rate,
142 "decay_steps": self.decay_steps,
143 "alpha": self.alpha,
144 })
145 }
146}
147
148pub trait TensorFlowOptimizer: Send + Sync {
150 fn apply_gradients(
152 &mut self,
153 grads_and_vars: &[(Tensor, String)],
154 global_step: Option<i64>,
155 ) -> Result<()>;
156
157 fn minimize(
159 &mut self,
160 loss_fn: Box<dyn Fn() -> Result<Tensor>>,
161 var_list: &[String],
162 global_step: Option<i64>,
163 ) -> Result<Tensor>;
164
165 fn get_config(&self) -> TensorFlowOptimizerConfig;
167
168 fn variables(&self) -> Vec<String>;
170
171 fn get_weights(&self) -> Vec<Tensor>;
173
174 fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()>;
176
177 fn get_learning_rate(&self) -> f64;
179
180 fn set_learning_rate(&mut self, lr: f64) -> Result<()>;
182
183 fn get_name(&self) -> &str;
185}
186
187pub struct TensorFlowAdam {
189 inner: Adam,
190 config: TensorFlowOptimizerConfig,
191 variables: Arc<Mutex<HashMap<String, Tensor>>>,
192 lr_schedule: Option<Box<dyn TensorFlowLearningRateSchedule>>,
193 global_step: i64,
194}
195
196impl TensorFlowAdam {
197 pub fn new(
199 learning_rate: f64,
200 beta_1: f64,
201 beta_2: f64,
202 epsilon: f64,
203 weight_decay: Option<f64>,
204 clipnorm: Option<f64>,
205 clipvalue: Option<f64>,
206 global_clipnorm: Option<f64>,
207 use_ema: bool,
208 ema_momentum: f64,
209 jit_compile: bool,
210 name: Option<String>,
211 ) -> Result<Self> {
212 let config = TensorFlowOptimizerConfig {
213 optimizer_type: "Adam".to_string(),
214 learning_rate,
215 beta_1: Some(beta_1),
216 beta_2: Some(beta_2),
217 epsilon: Some(epsilon),
218 weight_decay,
219 clipnorm,
220 clipvalue,
221 global_clipnorm,
222 use_ema: Some(use_ema),
223 ema_momentum: Some(ema_momentum),
224 ema_overwrite_frequency: None,
225 jit_compile: Some(jit_compile),
226 name,
227 parameters: HashMap::new(),
228 };
229
230 let inner = Adam::new(
233 learning_rate as f32,
234 (beta_1 as f32, beta_2 as f32),
235 epsilon as f32,
236 weight_decay.unwrap_or(0.0) as f32,
237 );
238
239 Ok(Self {
240 inner,
241 config,
242 variables: Arc::new(Mutex::new(HashMap::new())),
243 lr_schedule: None,
244 global_step: 0,
245 })
246 }
247
248 pub fn with_defaults() -> Result<Self> {
250 Self::new(
251 0.001,
252 0.9,
253 0.999,
254 1e-7,
255 None,
256 None,
257 None,
258 None,
259 false,
260 0.99,
261 true,
262 Some("Adam".to_string()),
263 )
264 }
265
266 pub fn from_config(config: TensorFlowOptimizerConfig) -> Result<Self> {
268 Self::new(
269 config.learning_rate,
270 config.beta_1.unwrap_or(0.9),
271 config.beta_2.unwrap_or(0.999),
272 config.epsilon.unwrap_or(1e-7),
273 config.weight_decay,
274 config.clipnorm,
275 config.clipvalue,
276 config.global_clipnorm,
277 config.use_ema.unwrap_or(false),
278 config.ema_momentum.unwrap_or(0.99),
279 config.jit_compile.unwrap_or(true),
280 config.name,
281 )
282 }
283
284 pub fn with_schedule(
286 schedule: Box<dyn TensorFlowLearningRateSchedule>,
287 beta_1: f64,
288 beta_2: f64,
289 epsilon: f64,
290 weight_decay: Option<f64>,
291 clipnorm: Option<f64>,
292 clipvalue: Option<f64>,
293 global_clipnorm: Option<f64>,
294 use_ema: bool,
295 ema_momentum: f64,
296 jit_compile: bool,
297 name: Option<String>,
298 ) -> Result<Self> {
299 let mut optimizer = Self::new(
300 schedule.get_lr(0),
301 beta_1,
302 beta_2,
303 epsilon,
304 weight_decay,
305 clipnorm,
306 clipvalue,
307 global_clipnorm,
308 use_ema,
309 ema_momentum,
310 jit_compile,
311 name,
312 )?;
313
314 optimizer.lr_schedule = Some(schedule);
315 Ok(optimizer)
316 }
317
318 pub fn add_variable(&mut self, name: String, var: Tensor) -> Result<()> {
320 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
321 variables.insert(name, var);
322 Ok(())
323 }
324
325 fn update_learning_rate(&mut self) -> Result<()> {
327 if let Some(ref schedule) = self.lr_schedule {
328 let new_lr = schedule.get_lr(self.global_step);
329 self.config.learning_rate = new_lr;
330
331 self.inner.set_lr(new_lr as f32);
333 }
334 Ok(())
335 }
336
337 fn clip_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
339 if let Some(clipnorm) = self.config.clipnorm {
340 for grad in gradients.iter_mut() {
342 let norm = grad.norm()?;
343 if norm > clipnorm as f32 {
344 grad.mul_scalar((clipnorm as f32) / norm)?;
345 }
346 }
347 }
348
349 if let Some(clipvalue) = self.config.clipvalue {
350 for grad in gradients.iter_mut() {
352 grad.clamp(-clipvalue as f32, clipvalue as f32)?;
353 }
354 }
355
356 if let Some(global_clipnorm) = self.config.global_clipnorm {
357 let global_norm: f64 = gradients
359 .iter()
360 .map(|g| g.norm().unwrap_or(0.0).powi(2) as f64)
361 .sum::<f64>()
362 .sqrt();
363
364 if global_norm > global_clipnorm {
365 let scale = global_clipnorm / global_norm;
366 for grad in gradients.iter_mut() {
367 grad.mul_scalar(scale as f32)?;
368 }
369 }
370 }
371
372 Ok(())
373 }
374}
375
376impl TensorFlowOptimizer for TensorFlowAdam {
377 fn apply_gradients(
378 &mut self,
379 grads_and_vars: &[(Tensor, String)],
380 global_step: Option<i64>,
381 ) -> Result<()> {
382 if let Some(step) = global_step {
383 self.global_step = step;
384 } else {
385 self.global_step += 1;
386 }
387
388 self.update_learning_rate()?;
390
391 let mut gradients: Vec<Tensor> = grads_and_vars.iter().map(|(g, _)| g.clone()).collect();
392
393 self.clip_gradients(&mut gradients)?;
395
396 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
398 for (grad, var_name) in grads_and_vars {
399 if let Some(var) = variables.get_mut(var_name) {
400 self.inner.update(var, grad)?;
401 }
402 }
403 self.inner.step();
404
405 Ok(())
406 }
407
408 fn minimize(
409 &mut self,
410 loss_fn: Box<dyn Fn() -> Result<Tensor>>,
411 var_list: &[String],
412 global_step: Option<i64>,
413 ) -> Result<Tensor> {
414 let loss = loss_fn()?;
415
416 let mut grads_and_vars = Vec::new();
418 {
419 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
420
421 for var_name in var_list {
422 if let Some(var) = variables.get_mut(var_name) {
423 let grad = self.compute_numerical_gradient(loss_fn.as_ref(), var, var_name)?;
425 grads_and_vars.push((grad, var_name.clone()));
426 }
427 }
428 } self.apply_gradients(&grads_and_vars, global_step)?;
431 Ok(loss)
432 }
433
434 fn get_config(&self) -> TensorFlowOptimizerConfig {
435 self.config.clone()
436 }
437
438 fn variables(&self) -> Vec<String> {
439 let variables = self.variables.lock().expect("Mutex lock poisoned");
440 variables.keys().cloned().collect()
441 }
442
443 fn get_weights(&self) -> Vec<Tensor> {
444 let variables = self.variables.lock().expect("Mutex lock poisoned");
445 variables.values().cloned().collect()
446 }
447
448 fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()> {
449 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
450 let var_names: Vec<String> = variables.keys().cloned().collect();
451
452 if weights.len() != var_names.len() {
453 return Err(TrustformersError::invalid_argument(
454 "Number of weights must match number of variables".to_string(),
455 ));
456 }
457
458 for (weight, var_name) in weights.into_iter().zip(var_names) {
459 variables.insert(var_name, weight);
460 }
461
462 Ok(())
463 }
464
465 fn get_learning_rate(&self) -> f64 {
466 self.config.learning_rate
467 }
468
469 fn set_learning_rate(&mut self, lr: f64) -> Result<()> {
470 self.config.learning_rate = lr;
471
472 self.inner.set_lr(lr as f32);
474
475 Ok(())
476 }
477
478 fn get_name(&self) -> &str {
479 self.config.name.as_deref().unwrap_or("Adam")
480 }
481}
482
483impl TensorFlowAdam {
484 fn compute_numerical_gradient(
486 &self,
487 loss_fn: &dyn Fn() -> Result<Tensor>,
488 var: &mut Tensor,
489 _var_name: &str,
490 ) -> Result<Tensor> {
491 const EPSILON: f32 = 1e-4;
492
493 let original_loss = loss_fn()?;
494 #[allow(unused_assignments)]
495 let mut grad = Tensor::zeros(&var.shape())?;
496
497 let var_data = var.data()?;
499 let mut grad_data = vec![0.0; var_data.len()];
500
501 for i in 0..var_data.len() {
502 let mut var_plus = var_data.clone();
504 var_plus[i] += EPSILON;
505 *var = Tensor::from_vec(var_plus, &var.shape())?;
506
507 let loss_plus = loss_fn()?;
508 let loss_plus_scalar = loss_plus.data()?[0];
509 let original_loss_scalar = original_loss.data()?[0];
510
511 grad_data[i] = (loss_plus_scalar - original_loss_scalar) / EPSILON;
512
513 let var_original = var_data.clone();
515 *var = Tensor::from_vec(var_original, &var.shape())?;
516 }
517
518 grad = Tensor::from_vec(grad_data, &var.shape())?;
519 Ok(grad)
520 }
521}
522
523pub struct TensorFlowAdamW {
525 inner: AdamW,
526 config: TensorFlowOptimizerConfig,
527 variables: Arc<Mutex<HashMap<String, Tensor>>>,
528 lr_schedule: Option<Box<dyn TensorFlowLearningRateSchedule>>,
529 global_step: i64,
530}
531
532impl TensorFlowAdamW {
533 pub fn new(
535 learning_rate: f64,
536 beta_1: f64,
537 beta_2: f64,
538 epsilon: f64,
539 weight_decay: f64,
540 clipnorm: Option<f64>,
541 clipvalue: Option<f64>,
542 global_clipnorm: Option<f64>,
543 use_ema: bool,
544 ema_momentum: f64,
545 jit_compile: bool,
546 name: Option<String>,
547 ) -> Result<Self> {
548 let config = TensorFlowOptimizerConfig {
549 optimizer_type: "AdamW".to_string(),
550 learning_rate,
551 beta_1: Some(beta_1),
552 beta_2: Some(beta_2),
553 epsilon: Some(epsilon),
554 weight_decay: Some(weight_decay),
555 clipnorm,
556 clipvalue,
557 global_clipnorm,
558 use_ema: Some(use_ema),
559 ema_momentum: Some(ema_momentum),
560 ema_overwrite_frequency: None,
561 jit_compile: Some(jit_compile),
562 name,
563 parameters: HashMap::new(),
564 };
565
566 let _optimizer_config = TensorFlowOptimizerConfig {
567 learning_rate,
568 beta_1: Some(beta_1),
569 beta_2: Some(beta_2),
570 epsilon: Some(epsilon),
571 weight_decay: Some(weight_decay),
572 ..Default::default()
573 };
574
575 let inner = AdamW::new(
576 learning_rate as f32,
577 (beta_1 as f32, beta_2 as f32),
578 epsilon as f32,
579 weight_decay as f32,
580 );
581
582 Ok(Self {
583 inner,
584 config,
585 variables: Arc::new(Mutex::new(HashMap::new())),
586 lr_schedule: None,
587 global_step: 0,
588 })
589 }
590
591 pub fn with_defaults() -> Result<Self> {
593 Self::new(
594 0.001,
595 0.9,
596 0.999,
597 1e-7,
598 0.01,
599 None,
600 None,
601 None,
602 false,
603 0.99,
604 true,
605 Some("AdamW".to_string()),
606 )
607 }
608
609 pub fn with_schedule(
611 schedule: Box<dyn TensorFlowLearningRateSchedule>,
612 beta_1: f64,
613 beta_2: f64,
614 epsilon: f64,
615 weight_decay: f64,
616 clipnorm: Option<f64>,
617 clipvalue: Option<f64>,
618 global_clipnorm: Option<f64>,
619 use_ema: bool,
620 ema_momentum: f64,
621 jit_compile: bool,
622 name: Option<String>,
623 ) -> Result<Self> {
624 let mut optimizer = Self::new(
625 schedule.get_lr(0),
626 beta_1,
627 beta_2,
628 epsilon,
629 weight_decay,
630 clipnorm,
631 clipvalue,
632 global_clipnorm,
633 use_ema,
634 ema_momentum,
635 jit_compile,
636 name,
637 )?;
638
639 optimizer.lr_schedule = Some(schedule);
640 Ok(optimizer)
641 }
642
643 pub fn add_variable(&mut self, name: String, var: Tensor) -> Result<()> {
645 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
646 variables.insert(name, var);
647 Ok(())
648 }
649
650 fn update_learning_rate(&mut self) -> Result<()> {
652 if let Some(ref schedule) = self.lr_schedule {
653 let new_lr = schedule.get_lr(self.global_step);
654 self.config.learning_rate = new_lr;
655
656 self.inner.set_lr(new_lr as f32);
658 }
659 Ok(())
660 }
661
662 fn clip_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
664 if let Some(clipnorm) = self.config.clipnorm {
665 for grad in gradients.iter_mut() {
667 let norm = grad.norm()?;
668 if norm > clipnorm as f32 {
669 grad.mul_scalar((clipnorm as f32) / norm)?;
670 }
671 }
672 }
673
674 if let Some(clipvalue) = self.config.clipvalue {
675 for grad in gradients.iter_mut() {
677 grad.clamp(-clipvalue as f32, clipvalue as f32)?;
678 }
679 }
680
681 if let Some(global_clipnorm) = self.config.global_clipnorm {
682 let global_norm: f64 = gradients
684 .iter()
685 .map(|g| g.norm().unwrap_or(0.0).powi(2) as f64)
686 .sum::<f64>()
687 .sqrt();
688
689 if global_norm > global_clipnorm {
690 let scale = global_clipnorm / global_norm;
691 for grad in gradients.iter_mut() {
692 grad.mul_scalar(scale as f32)?;
693 }
694 }
695 }
696
697 Ok(())
698 }
699}
700
701impl TensorFlowOptimizer for TensorFlowAdamW {
702 fn apply_gradients(
703 &mut self,
704 grads_and_vars: &[(Tensor, String)],
705 global_step: Option<i64>,
706 ) -> Result<()> {
707 if let Some(step) = global_step {
708 self.global_step = step;
709 } else {
710 self.global_step += 1;
711 }
712
713 self.update_learning_rate()?;
715
716 let mut gradients: Vec<Tensor> = grads_and_vars.iter().map(|(g, _)| g.clone()).collect();
717
718 self.clip_gradients(&mut gradients)?;
720
721 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
723 for (grad, var_name) in grads_and_vars {
724 if let Some(var) = variables.get_mut(var_name) {
725 self.inner.update(var, grad)?;
726 }
727 }
728 self.inner.step();
729
730 Ok(())
731 }
732
733 fn minimize(
734 &mut self,
735 loss_fn: Box<dyn Fn() -> Result<Tensor>>,
736 var_list: &[String],
737 global_step: Option<i64>,
738 ) -> Result<Tensor> {
739 let loss = loss_fn()?;
740
741 let mut grads_and_vars = Vec::new();
743 {
744 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
745
746 for var_name in var_list {
747 if let Some(var) = variables.get_mut(var_name) {
748 let grad = self.compute_numerical_gradient(loss_fn.as_ref(), var, var_name)?;
750 grads_and_vars.push((grad, var_name.clone()));
751 }
752 }
753 } self.apply_gradients(&grads_and_vars, global_step)?;
756 Ok(loss)
757 }
758
759 fn get_config(&self) -> TensorFlowOptimizerConfig {
760 self.config.clone()
761 }
762
763 fn variables(&self) -> Vec<String> {
764 let variables = self.variables.lock().expect("Mutex lock poisoned");
765 variables.keys().cloned().collect()
766 }
767
768 fn get_weights(&self) -> Vec<Tensor> {
769 let variables = self.variables.lock().expect("Mutex lock poisoned");
770 variables.values().cloned().collect()
771 }
772
773 fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()> {
774 let mut variables = self.variables.lock().expect("Mutex lock poisoned");
775 let var_names: Vec<String> = variables.keys().cloned().collect();
776
777 if weights.len() != var_names.len() {
778 return Err(TrustformersError::invalid_argument(
779 "Number of weights must match number of variables".to_string(),
780 ));
781 }
782
783 for (weight, var_name) in weights.into_iter().zip(var_names) {
784 variables.insert(var_name, weight);
785 }
786
787 Ok(())
788 }
789
790 fn get_learning_rate(&self) -> f64 {
791 self.config.learning_rate
792 }
793
794 fn set_learning_rate(&mut self, lr: f64) -> Result<()> {
795 self.config.learning_rate = lr;
796
797 self.inner.set_lr(lr as f32);
799
800 Ok(())
801 }
802
803 fn get_name(&self) -> &str {
804 self.config.name.as_deref().unwrap_or("AdamW")
805 }
806}
807
808impl TensorFlowAdamW {
809 fn compute_numerical_gradient(
811 &self,
812 loss_fn: &dyn Fn() -> Result<Tensor>,
813 var: &mut Tensor,
814 _var_name: &str,
815 ) -> Result<Tensor> {
816 const EPSILON: f32 = 1e-4;
817
818 let original_loss = loss_fn()?;
819 #[allow(unused_assignments)]
820 let mut grad = Tensor::zeros(&var.shape())?;
821
822 let var_data = var.data()?;
824 let mut grad_data = vec![0.0; var_data.len()];
825
826 for i in 0..var_data.len() {
827 let mut var_plus = var_data.clone();
829 var_plus[i] += EPSILON;
830 *var = Tensor::from_vec(var_plus, &var.shape())?;
831
832 let loss_plus = loss_fn()?;
833 let loss_plus_scalar = loss_plus.data()?[0];
834 let original_loss_scalar = original_loss.data()?[0];
835
836 grad_data[i] = (loss_plus_scalar - original_loss_scalar) / EPSILON;
837
838 let var_original = var_data.clone();
840 *var = Tensor::from_vec(var_original, &var.shape())?;
841 }
842
843 grad = Tensor::from_vec(grad_data, &var.shape())?;
844 Ok(grad)
845 }
846}
847
848pub struct TensorFlowOptimizerFactory;
850
851impl TensorFlowOptimizerFactory {
852 pub fn adam(
854 learning_rate: f64,
855 beta_1: f64,
856 beta_2: f64,
857 epsilon: f64,
858 weight_decay: Option<f64>,
859 clipnorm: Option<f64>,
860 clipvalue: Option<f64>,
861 global_clipnorm: Option<f64>,
862 use_ema: bool,
863 ema_momentum: f64,
864 jit_compile: bool,
865 name: Option<String>,
866 ) -> Result<TensorFlowAdam> {
867 TensorFlowAdam::new(
868 learning_rate,
869 beta_1,
870 beta_2,
871 epsilon,
872 weight_decay,
873 clipnorm,
874 clipvalue,
875 global_clipnorm,
876 use_ema,
877 ema_momentum,
878 jit_compile,
879 name,
880 )
881 }
882
883 pub fn adamw(
885 learning_rate: f64,
886 beta_1: f64,
887 beta_2: f64,
888 epsilon: f64,
889 weight_decay: f64,
890 clipnorm: Option<f64>,
891 clipvalue: Option<f64>,
892 global_clipnorm: Option<f64>,
893 use_ema: bool,
894 ema_momentum: f64,
895 jit_compile: bool,
896 name: Option<String>,
897 ) -> Result<TensorFlowAdamW> {
898 TensorFlowAdamW::new(
899 learning_rate,
900 beta_1,
901 beta_2,
902 epsilon,
903 weight_decay,
904 clipnorm,
905 clipvalue,
906 global_clipnorm,
907 use_ema,
908 ema_momentum,
909 jit_compile,
910 name,
911 )
912 }
913
914 pub fn exponential_decay(
916 initial_learning_rate: f64,
917 decay_steps: i64,
918 decay_rate: f64,
919 staircase: bool,
920 ) -> TensorFlowExponentialDecay {
921 TensorFlowExponentialDecay::new(initial_learning_rate, decay_steps, decay_rate, staircase)
922 }
923
924 pub fn cosine_decay(
926 initial_learning_rate: f64,
927 decay_steps: i64,
928 alpha: f64,
929 ) -> TensorFlowCosineDecay {
930 TensorFlowCosineDecay::new(initial_learning_rate, decay_steps, alpha)
931 }
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937 use trustformers_core::Tensor;
938
939 #[test]
940 fn test_tensorflow_adam_creation() {
941 let optimizer = TensorFlowAdam::with_defaults().unwrap();
942 assert_eq!(optimizer.get_learning_rate(), 0.001);
943 assert_eq!(optimizer.get_name(), "Adam");
944 }
945
946 #[test]
947 fn test_tensorflow_adamw_creation() {
948 let optimizer = TensorFlowAdamW::with_defaults().unwrap();
949 assert_eq!(optimizer.get_learning_rate(), 0.001);
950 assert_eq!(optimizer.get_name(), "AdamW");
951 }
952
953 #[test]
954 fn test_tensorflow_exponential_decay() {
955 let schedule = TensorFlowExponentialDecay::new(0.1, 100, 0.96, false);
956 assert_eq!(schedule.get_lr(0), 0.1);
957 assert!(schedule.get_lr(100) < 0.1);
958 }
959
960 #[test]
961 fn test_tensorflow_cosine_decay() {
962 let schedule = TensorFlowCosineDecay::new(0.1, 100, 0.0);
963 assert_eq!(schedule.get_lr(0), 0.1);
964 assert!(schedule.get_lr(50) < 0.1);
965 assert!(schedule.get_lr(100) < 0.1);
966 }
967
968 #[test]
969 fn test_tensorflow_optimizer_factory() {
970 let adam = TensorFlowOptimizerFactory::adam(
971 0.001,
972 0.9,
973 0.999,
974 1e-7,
975 None,
976 None,
977 None,
978 None,
979 false,
980 0.99,
981 true,
982 Some("TestAdam".to_string()),
983 )
984 .unwrap();
985 assert_eq!(adam.get_name(), "TestAdam");
986
987 let adamw = TensorFlowOptimizerFactory::adamw(
988 0.001,
989 0.9,
990 0.999,
991 1e-7,
992 0.01,
993 None,
994 None,
995 None,
996 false,
997 0.99,
998 true,
999 Some("TestAdamW".to_string()),
1000 )
1001 .unwrap();
1002 assert_eq!(adamw.get_name(), "TestAdamW");
1003 }
1004
1005 #[test]
1006 fn test_learning_rate_schedule_with_optimizer() {
1007 let schedule = Box::new(TensorFlowExponentialDecay::new(0.1, 100, 0.96, false));
1008 let optimizer = TensorFlowAdam::with_schedule(
1009 schedule,
1010 0.9,
1011 0.999,
1012 1e-7,
1013 None,
1014 None,
1015 None,
1016 None,
1017 false,
1018 0.99,
1019 true,
1020 Some("ScheduledAdam".to_string()),
1021 )
1022 .unwrap();
1023
1024 assert_eq!(optimizer.get_learning_rate(), 0.1);
1025 }
1026
1027 #[test]
1028 fn test_variable_management() {
1029 let mut optimizer = TensorFlowAdam::with_defaults().unwrap();
1030
1031 let var1 = Tensor::zeros(&[10, 10]).unwrap();
1032 let var2 = Tensor::zeros(&[5, 5]).unwrap();
1033
1034 optimizer.add_variable("var1".to_string(), var1).unwrap();
1035 optimizer.add_variable("var2".to_string(), var2).unwrap();
1036
1037 let variables = optimizer.variables();
1038 assert_eq!(variables.len(), 2);
1039 assert!(variables.contains(&"var1".to_string()));
1040 assert!(variables.contains(&"var2".to_string()));
1041 }
1042
1043 #[test]
1044 fn test_learning_rate_updates() {
1045 let mut optimizer = TensorFlowAdam::with_defaults().unwrap();
1046 assert_eq!(optimizer.get_learning_rate(), 0.001);
1047
1048 optimizer.set_learning_rate(0.01).unwrap();
1049 assert_eq!(optimizer.get_learning_rate(), 0.01);
1050 }
1051
1052 #[test]
1053 fn test_config_serialization() {
1054 let optimizer = TensorFlowAdam::with_defaults().unwrap();
1055 let config = optimizer.get_config();
1056
1057 assert_eq!(config.learning_rate, 0.001);
1058 assert_eq!(config.beta_1, Some(0.9));
1059 assert_eq!(config.beta_2, Some(0.999));
1060 assert_eq!(config.epsilon, Some(1e-7));
1061 }
1062}