Skip to main content

trustformers_optim/
tensorflow_compat.rs

1//! TensorFlow Optimizer API Compatibility Layer
2//!
3//! This module provides TensorFlow-compatible optimizer interfaces for seamless
4//! integration with TensorFlow-based training workflows. It wraps our native
5//! optimizers to provide the familiar TensorFlow API while maintaining high performance.
6
7use crate::{Adam, AdamW};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use trustformers_core::errors::{Result, TrustformersError};
12use trustformers_core::traits::Optimizer;
13use trustformers_core::Tensor;
14
15/// TensorFlow-compatible optimizer configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct TensorFlowOptimizerConfig {
18    pub optimizer_type: String,
19    pub learning_rate: f64,
20    pub beta_1: Option<f64>,
21    pub beta_2: Option<f64>,
22    pub epsilon: Option<f64>,
23    pub weight_decay: Option<f64>,
24    pub clipnorm: Option<f64>,
25    pub clipvalue: Option<f64>,
26    pub global_clipnorm: Option<f64>,
27    pub use_ema: Option<bool>,
28    pub ema_momentum: Option<f64>,
29    pub ema_overwrite_frequency: Option<i32>,
30    pub jit_compile: Option<bool>,
31    pub name: Option<String>,
32    pub parameters: HashMap<String, serde_json::Value>,
33}
34
35impl Default for TensorFlowOptimizerConfig {
36    fn default() -> Self {
37        Self {
38            optimizer_type: "Adam".to_string(),
39            learning_rate: 0.001,
40            beta_1: Some(0.9),
41            beta_2: Some(0.999),
42            epsilon: Some(1e-7),
43            weight_decay: None,
44            clipnorm: None,
45            clipvalue: None,
46            global_clipnorm: None,
47            use_ema: Some(false),
48            ema_momentum: Some(0.99),
49            ema_overwrite_frequency: None,
50            jit_compile: Some(true),
51            name: None,
52            parameters: HashMap::new(),
53        }
54    }
55}
56
57/// TensorFlow-compatible learning rate schedule
58pub trait TensorFlowLearningRateSchedule: Send + Sync {
59    /// Get learning rate at current step
60    fn get_lr(&self, step: i64) -> f64;
61
62    /// Get configuration
63    fn get_config(&self) -> serde_json::Value;
64}
65
66/// TensorFlow-compatible exponential decay schedule
67#[derive(Debug, Clone)]
68pub struct TensorFlowExponentialDecay {
69    initial_learning_rate: f64,
70    decay_steps: i64,
71    decay_rate: f64,
72    staircase: bool,
73}
74
75impl TensorFlowExponentialDecay {
76    pub fn new(
77        initial_learning_rate: f64,
78        decay_steps: i64,
79        decay_rate: f64,
80        staircase: bool,
81    ) -> Self {
82        Self {
83            initial_learning_rate,
84            decay_steps,
85            decay_rate,
86            staircase,
87        }
88    }
89}
90
91impl TensorFlowLearningRateSchedule for TensorFlowExponentialDecay {
92    fn get_lr(&self, step: i64) -> f64 {
93        let decay_factor = if self.staircase {
94            (step / self.decay_steps) as f64
95        } else {
96            step as f64 / self.decay_steps as f64
97        };
98
99        self.initial_learning_rate * self.decay_rate.powf(decay_factor)
100    }
101
102    fn get_config(&self) -> serde_json::Value {
103        serde_json::json!({
104            "initial_learning_rate": self.initial_learning_rate,
105            "decay_steps": self.decay_steps,
106            "decay_rate": self.decay_rate,
107            "staircase": self.staircase,
108        })
109    }
110}
111
112/// TensorFlow-compatible cosine decay schedule
113#[derive(Debug, Clone)]
114pub struct TensorFlowCosineDecay {
115    initial_learning_rate: f64,
116    decay_steps: i64,
117    alpha: f64,
118}
119
120impl TensorFlowCosineDecay {
121    pub fn new(initial_learning_rate: f64, decay_steps: i64, alpha: f64) -> Self {
122        Self {
123            initial_learning_rate,
124            decay_steps,
125            alpha,
126        }
127    }
128}
129
130impl TensorFlowLearningRateSchedule for TensorFlowCosineDecay {
131    fn get_lr(&self, step: i64) -> f64 {
132        let completed_fraction = (step.min(self.decay_steps) as f64) / (self.decay_steps as f64);
133        let cosine_decayed = 0.5 * (1.0 + (std::f64::consts::PI * completed_fraction).cos());
134        let decayed = (1.0 - self.alpha) * cosine_decayed + self.alpha;
135
136        self.initial_learning_rate * decayed
137    }
138
139    fn get_config(&self) -> serde_json::Value {
140        serde_json::json!({
141            "initial_learning_rate": self.initial_learning_rate,
142            "decay_steps": self.decay_steps,
143            "alpha": self.alpha,
144        })
145    }
146}
147
148/// TensorFlow-compatible optimizer interface
149pub trait TensorFlowOptimizer: Send + Sync {
150    /// Apply gradients to variables
151    fn apply_gradients(
152        &mut self,
153        grads_and_vars: &[(Tensor, String)],
154        global_step: Option<i64>,
155    ) -> Result<()>;
156
157    /// Minimize loss function
158    fn minimize(
159        &mut self,
160        loss_fn: Box<dyn Fn() -> Result<Tensor>>,
161        var_list: &[String],
162        global_step: Option<i64>,
163    ) -> Result<Tensor>;
164
165    /// Get optimizer configuration
166    fn get_config(&self) -> TensorFlowOptimizerConfig;
167
168    /// Get optimizer variables (state)
169    fn variables(&self) -> Vec<String>;
170
171    /// Get optimizer weights
172    fn get_weights(&self) -> Vec<Tensor>;
173
174    /// Set optimizer weights
175    fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()>;
176
177    /// Get learning rate
178    fn get_learning_rate(&self) -> f64;
179
180    /// Set learning rate
181    fn set_learning_rate(&mut self, lr: f64) -> Result<()>;
182
183    /// Get optimizer name
184    fn get_name(&self) -> &str;
185}
186
187/// TensorFlow-compatible Adam optimizer
188pub struct TensorFlowAdam {
189    inner: Adam,
190    config: TensorFlowOptimizerConfig,
191    variables: Arc<Mutex<HashMap<String, Tensor>>>,
192    lr_schedule: Option<Box<dyn TensorFlowLearningRateSchedule>>,
193    global_step: i64,
194}
195
196impl TensorFlowAdam {
197    /// Create new TensorFlow-compatible Adam optimizer
198    pub fn new(
199        learning_rate: f64,
200        beta_1: f64,
201        beta_2: f64,
202        epsilon: f64,
203        weight_decay: Option<f64>,
204        clipnorm: Option<f64>,
205        clipvalue: Option<f64>,
206        global_clipnorm: Option<f64>,
207        use_ema: bool,
208        ema_momentum: f64,
209        jit_compile: bool,
210        name: Option<String>,
211    ) -> Result<Self> {
212        let config = TensorFlowOptimizerConfig {
213            optimizer_type: "Adam".to_string(),
214            learning_rate,
215            beta_1: Some(beta_1),
216            beta_2: Some(beta_2),
217            epsilon: Some(epsilon),
218            weight_decay,
219            clipnorm,
220            clipvalue,
221            global_clipnorm,
222            use_ema: Some(use_ema),
223            ema_momentum: Some(ema_momentum),
224            ema_overwrite_frequency: None,
225            jit_compile: Some(jit_compile),
226            name,
227            parameters: HashMap::new(),
228        };
229
230        // optimizer_config is redundant - using config above
231
232        let inner = Adam::new(
233            learning_rate as f32,
234            (beta_1 as f32, beta_2 as f32),
235            epsilon as f32,
236            weight_decay.unwrap_or(0.0) as f32,
237        );
238
239        Ok(Self {
240            inner,
241            config,
242            variables: Arc::new(Mutex::new(HashMap::new())),
243            lr_schedule: None,
244            global_step: 0,
245        })
246    }
247
248    /// Create with default parameters
249    pub fn with_defaults() -> Result<Self> {
250        Self::new(
251            0.001,
252            0.9,
253            0.999,
254            1e-7,
255            None,
256            None,
257            None,
258            None,
259            false,
260            0.99,
261            true,
262            Some("Adam".to_string()),
263        )
264    }
265
266    /// Create TensorFlow Adam optimizer from configuration
267    pub fn from_config(config: TensorFlowOptimizerConfig) -> Result<Self> {
268        Self::new(
269            config.learning_rate,
270            config.beta_1.unwrap_or(0.9),
271            config.beta_2.unwrap_or(0.999),
272            config.epsilon.unwrap_or(1e-7),
273            config.weight_decay,
274            config.clipnorm,
275            config.clipvalue,
276            config.global_clipnorm,
277            config.use_ema.unwrap_or(false),
278            config.ema_momentum.unwrap_or(0.99),
279            config.jit_compile.unwrap_or(true),
280            config.name,
281        )
282    }
283
284    /// Create with learning rate schedule
285    pub fn with_schedule(
286        schedule: Box<dyn TensorFlowLearningRateSchedule>,
287        beta_1: f64,
288        beta_2: f64,
289        epsilon: f64,
290        weight_decay: Option<f64>,
291        clipnorm: Option<f64>,
292        clipvalue: Option<f64>,
293        global_clipnorm: Option<f64>,
294        use_ema: bool,
295        ema_momentum: f64,
296        jit_compile: bool,
297        name: Option<String>,
298    ) -> Result<Self> {
299        let mut optimizer = Self::new(
300            schedule.get_lr(0),
301            beta_1,
302            beta_2,
303            epsilon,
304            weight_decay,
305            clipnorm,
306            clipvalue,
307            global_clipnorm,
308            use_ema,
309            ema_momentum,
310            jit_compile,
311            name,
312        )?;
313
314        optimizer.lr_schedule = Some(schedule);
315        Ok(optimizer)
316    }
317
318    /// Add variable to optimizer
319    pub fn add_variable(&mut self, name: String, var: Tensor) -> Result<()> {
320        let mut variables = self.variables.lock().expect("Mutex lock poisoned");
321        variables.insert(name, var);
322        Ok(())
323    }
324
325    /// Update learning rate based on schedule
326    fn update_learning_rate(&mut self) -> Result<()> {
327        if let Some(ref schedule) = self.lr_schedule {
328            let new_lr = schedule.get_lr(self.global_step);
329            self.config.learning_rate = new_lr;
330
331            // Update inner optimizer learning rate
332            self.inner.set_lr(new_lr as f32);
333        }
334        Ok(())
335    }
336
337    /// Apply gradient clipping
338    fn clip_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
339        if let Some(clipnorm) = self.config.clipnorm {
340            // Clip by norm
341            for grad in gradients.iter_mut() {
342                let norm = grad.norm()?;
343                if norm > clipnorm as f32 {
344                    grad.mul_scalar((clipnorm as f32) / norm)?;
345                }
346            }
347        }
348
349        if let Some(clipvalue) = self.config.clipvalue {
350            // Clip by value
351            for grad in gradients.iter_mut() {
352                grad.clamp(-clipvalue as f32, clipvalue as f32)?;
353            }
354        }
355
356        if let Some(global_clipnorm) = self.config.global_clipnorm {
357            // Global gradient clipping
358            let global_norm: f64 = gradients
359                .iter()
360                .map(|g| g.norm().unwrap_or(0.0).powi(2) as f64)
361                .sum::<f64>()
362                .sqrt();
363
364            if global_norm > global_clipnorm {
365                let scale = global_clipnorm / global_norm;
366                for grad in gradients.iter_mut() {
367                    grad.mul_scalar(scale as f32)?;
368                }
369            }
370        }
371
372        Ok(())
373    }
374}
375
376impl TensorFlowOptimizer for TensorFlowAdam {
377    fn apply_gradients(
378        &mut self,
379        grads_and_vars: &[(Tensor, String)],
380        global_step: Option<i64>,
381    ) -> Result<()> {
382        if let Some(step) = global_step {
383            self.global_step = step;
384        } else {
385            self.global_step += 1;
386        }
387
388        // Update learning rate if schedule is set
389        self.update_learning_rate()?;
390
391        let mut gradients: Vec<Tensor> = grads_and_vars.iter().map(|(g, _)| g.clone()).collect();
392
393        // Apply gradient clipping
394        self.clip_gradients(&mut gradients)?;
395
396        // Apply gradients using inner optimizer
397        let mut variables = self.variables.lock().expect("Mutex lock poisoned");
398        for (grad, var_name) in grads_and_vars {
399            if let Some(var) = variables.get_mut(var_name) {
400                self.inner.update(var, grad)?;
401            }
402        }
403        self.inner.step();
404
405        Ok(())
406    }
407
408    fn minimize(
409        &mut self,
410        loss_fn: Box<dyn Fn() -> Result<Tensor>>,
411        var_list: &[String],
412        global_step: Option<i64>,
413    ) -> Result<Tensor> {
414        let loss = loss_fn()?;
415
416        // Compute gradients (this would normally be done by automatic differentiation)
417        let mut grads_and_vars = Vec::new();
418        {
419            let mut variables = self.variables.lock().expect("Mutex lock poisoned");
420
421            for var_name in var_list {
422                if let Some(var) = variables.get_mut(var_name) {
423                    // Compute numerical gradient using finite differences
424                    let grad = self.compute_numerical_gradient(loss_fn.as_ref(), var, var_name)?;
425                    grads_and_vars.push((grad, var_name.clone()));
426                }
427            }
428        } // variables lock is dropped here
429
430        self.apply_gradients(&grads_and_vars, global_step)?;
431        Ok(loss)
432    }
433
434    fn get_config(&self) -> TensorFlowOptimizerConfig {
435        self.config.clone()
436    }
437
438    fn variables(&self) -> Vec<String> {
439        let variables = self.variables.lock().expect("Mutex lock poisoned");
440        variables.keys().cloned().collect()
441    }
442
443    fn get_weights(&self) -> Vec<Tensor> {
444        let variables = self.variables.lock().expect("Mutex lock poisoned");
445        variables.values().cloned().collect()
446    }
447
448    fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()> {
449        let mut variables = self.variables.lock().expect("Mutex lock poisoned");
450        let var_names: Vec<String> = variables.keys().cloned().collect();
451
452        if weights.len() != var_names.len() {
453            return Err(TrustformersError::invalid_argument(
454                "Number of weights must match number of variables".to_string(),
455            ));
456        }
457
458        for (weight, var_name) in weights.into_iter().zip(var_names) {
459            variables.insert(var_name, weight);
460        }
461
462        Ok(())
463    }
464
465    fn get_learning_rate(&self) -> f64 {
466        self.config.learning_rate
467    }
468
469    fn set_learning_rate(&mut self, lr: f64) -> Result<()> {
470        self.config.learning_rate = lr;
471
472        // Update inner optimizer
473        self.inner.set_lr(lr as f32);
474
475        Ok(())
476    }
477
478    fn get_name(&self) -> &str {
479        self.config.name.as_deref().unwrap_or("Adam")
480    }
481}
482
483impl TensorFlowAdam {
484    /// Compute numerical gradient using finite differences
485    fn compute_numerical_gradient(
486        &self,
487        loss_fn: &dyn Fn() -> Result<Tensor>,
488        var: &mut Tensor,
489        _var_name: &str,
490    ) -> Result<Tensor> {
491        const EPSILON: f32 = 1e-4;
492
493        let original_loss = loss_fn()?;
494        #[allow(unused_assignments)]
495        let mut grad = Tensor::zeros(&var.shape())?;
496
497        // Compute gradient for each element using finite differences
498        let var_data = var.data()?;
499        let mut grad_data = vec![0.0; var_data.len()];
500
501        for i in 0..var_data.len() {
502            // Forward difference: f(x + h) - f(x) / h
503            let mut var_plus = var_data.clone();
504            var_plus[i] += EPSILON;
505            *var = Tensor::from_vec(var_plus, &var.shape())?;
506
507            let loss_plus = loss_fn()?;
508            let loss_plus_scalar = loss_plus.data()?[0];
509            let original_loss_scalar = original_loss.data()?[0];
510
511            grad_data[i] = (loss_plus_scalar - original_loss_scalar) / EPSILON;
512
513            // Restore original value
514            let var_original = var_data.clone();
515            *var = Tensor::from_vec(var_original, &var.shape())?;
516        }
517
518        grad = Tensor::from_vec(grad_data, &var.shape())?;
519        Ok(grad)
520    }
521}
522
523/// TensorFlow-compatible AdamW optimizer
524pub struct TensorFlowAdamW {
525    inner: AdamW,
526    config: TensorFlowOptimizerConfig,
527    variables: Arc<Mutex<HashMap<String, Tensor>>>,
528    lr_schedule: Option<Box<dyn TensorFlowLearningRateSchedule>>,
529    global_step: i64,
530}
531
532impl TensorFlowAdamW {
533    /// Create new TensorFlow-compatible AdamW optimizer
534    pub fn new(
535        learning_rate: f64,
536        beta_1: f64,
537        beta_2: f64,
538        epsilon: f64,
539        weight_decay: f64,
540        clipnorm: Option<f64>,
541        clipvalue: Option<f64>,
542        global_clipnorm: Option<f64>,
543        use_ema: bool,
544        ema_momentum: f64,
545        jit_compile: bool,
546        name: Option<String>,
547    ) -> Result<Self> {
548        let config = TensorFlowOptimizerConfig {
549            optimizer_type: "AdamW".to_string(),
550            learning_rate,
551            beta_1: Some(beta_1),
552            beta_2: Some(beta_2),
553            epsilon: Some(epsilon),
554            weight_decay: Some(weight_decay),
555            clipnorm,
556            clipvalue,
557            global_clipnorm,
558            use_ema: Some(use_ema),
559            ema_momentum: Some(ema_momentum),
560            ema_overwrite_frequency: None,
561            jit_compile: Some(jit_compile),
562            name,
563            parameters: HashMap::new(),
564        };
565
566        let _optimizer_config = TensorFlowOptimizerConfig {
567            learning_rate,
568            beta_1: Some(beta_1),
569            beta_2: Some(beta_2),
570            epsilon: Some(epsilon),
571            weight_decay: Some(weight_decay),
572            ..Default::default()
573        };
574
575        let inner = AdamW::new(
576            learning_rate as f32,
577            (beta_1 as f32, beta_2 as f32),
578            epsilon as f32,
579            weight_decay as f32,
580        );
581
582        Ok(Self {
583            inner,
584            config,
585            variables: Arc::new(Mutex::new(HashMap::new())),
586            lr_schedule: None,
587            global_step: 0,
588        })
589    }
590
591    /// Create with default parameters
592    pub fn with_defaults() -> Result<Self> {
593        Self::new(
594            0.001,
595            0.9,
596            0.999,
597            1e-7,
598            0.01,
599            None,
600            None,
601            None,
602            false,
603            0.99,
604            true,
605            Some("AdamW".to_string()),
606        )
607    }
608
609    /// Create with learning rate schedule
610    pub fn with_schedule(
611        schedule: Box<dyn TensorFlowLearningRateSchedule>,
612        beta_1: f64,
613        beta_2: f64,
614        epsilon: f64,
615        weight_decay: f64,
616        clipnorm: Option<f64>,
617        clipvalue: Option<f64>,
618        global_clipnorm: Option<f64>,
619        use_ema: bool,
620        ema_momentum: f64,
621        jit_compile: bool,
622        name: Option<String>,
623    ) -> Result<Self> {
624        let mut optimizer = Self::new(
625            schedule.get_lr(0),
626            beta_1,
627            beta_2,
628            epsilon,
629            weight_decay,
630            clipnorm,
631            clipvalue,
632            global_clipnorm,
633            use_ema,
634            ema_momentum,
635            jit_compile,
636            name,
637        )?;
638
639        optimizer.lr_schedule = Some(schedule);
640        Ok(optimizer)
641    }
642
643    /// Add variable to optimizer
644    pub fn add_variable(&mut self, name: String, var: Tensor) -> Result<()> {
645        let mut variables = self.variables.lock().expect("Mutex lock poisoned");
646        variables.insert(name, var);
647        Ok(())
648    }
649
650    /// Update learning rate based on schedule
651    fn update_learning_rate(&mut self) -> Result<()> {
652        if let Some(ref schedule) = self.lr_schedule {
653            let new_lr = schedule.get_lr(self.global_step);
654            self.config.learning_rate = new_lr;
655
656            // Update inner optimizer learning rate
657            self.inner.set_lr(new_lr as f32);
658        }
659        Ok(())
660    }
661
662    /// Apply gradient clipping
663    fn clip_gradients(&self, gradients: &mut [Tensor]) -> Result<()> {
664        if let Some(clipnorm) = self.config.clipnorm {
665            // Clip by norm
666            for grad in gradients.iter_mut() {
667                let norm = grad.norm()?;
668                if norm > clipnorm as f32 {
669                    grad.mul_scalar((clipnorm as f32) / norm)?;
670                }
671            }
672        }
673
674        if let Some(clipvalue) = self.config.clipvalue {
675            // Clip by value
676            for grad in gradients.iter_mut() {
677                grad.clamp(-clipvalue as f32, clipvalue as f32)?;
678            }
679        }
680
681        if let Some(global_clipnorm) = self.config.global_clipnorm {
682            // Global gradient clipping
683            let global_norm: f64 = gradients
684                .iter()
685                .map(|g| g.norm().unwrap_or(0.0).powi(2) as f64)
686                .sum::<f64>()
687                .sqrt();
688
689            if global_norm > global_clipnorm {
690                let scale = global_clipnorm / global_norm;
691                for grad in gradients.iter_mut() {
692                    grad.mul_scalar(scale as f32)?;
693                }
694            }
695        }
696
697        Ok(())
698    }
699}
700
701impl TensorFlowOptimizer for TensorFlowAdamW {
702    fn apply_gradients(
703        &mut self,
704        grads_and_vars: &[(Tensor, String)],
705        global_step: Option<i64>,
706    ) -> Result<()> {
707        if let Some(step) = global_step {
708            self.global_step = step;
709        } else {
710            self.global_step += 1;
711        }
712
713        // Update learning rate if schedule is set
714        self.update_learning_rate()?;
715
716        let mut gradients: Vec<Tensor> = grads_and_vars.iter().map(|(g, _)| g.clone()).collect();
717
718        // Apply gradient clipping
719        self.clip_gradients(&mut gradients)?;
720
721        // Apply gradients using inner optimizer
722        let mut variables = self.variables.lock().expect("Mutex lock poisoned");
723        for (grad, var_name) in grads_and_vars {
724            if let Some(var) = variables.get_mut(var_name) {
725                self.inner.update(var, grad)?;
726            }
727        }
728        self.inner.step();
729
730        Ok(())
731    }
732
733    fn minimize(
734        &mut self,
735        loss_fn: Box<dyn Fn() -> Result<Tensor>>,
736        var_list: &[String],
737        global_step: Option<i64>,
738    ) -> Result<Tensor> {
739        let loss = loss_fn()?;
740
741        // Compute gradients (this would normally be done by automatic differentiation)
742        let mut grads_and_vars = Vec::new();
743        {
744            let mut variables = self.variables.lock().expect("Mutex lock poisoned");
745
746            for var_name in var_list {
747                if let Some(var) = variables.get_mut(var_name) {
748                    // Compute numerical gradient using finite differences
749                    let grad = self.compute_numerical_gradient(loss_fn.as_ref(), var, var_name)?;
750                    grads_and_vars.push((grad, var_name.clone()));
751                }
752            }
753        } // variables lock is dropped here
754
755        self.apply_gradients(&grads_and_vars, global_step)?;
756        Ok(loss)
757    }
758
759    fn get_config(&self) -> TensorFlowOptimizerConfig {
760        self.config.clone()
761    }
762
763    fn variables(&self) -> Vec<String> {
764        let variables = self.variables.lock().expect("Mutex lock poisoned");
765        variables.keys().cloned().collect()
766    }
767
768    fn get_weights(&self) -> Vec<Tensor> {
769        let variables = self.variables.lock().expect("Mutex lock poisoned");
770        variables.values().cloned().collect()
771    }
772
773    fn set_weights(&mut self, weights: Vec<Tensor>) -> Result<()> {
774        let mut variables = self.variables.lock().expect("Mutex lock poisoned");
775        let var_names: Vec<String> = variables.keys().cloned().collect();
776
777        if weights.len() != var_names.len() {
778            return Err(TrustformersError::invalid_argument(
779                "Number of weights must match number of variables".to_string(),
780            ));
781        }
782
783        for (weight, var_name) in weights.into_iter().zip(var_names) {
784            variables.insert(var_name, weight);
785        }
786
787        Ok(())
788    }
789
790    fn get_learning_rate(&self) -> f64 {
791        self.config.learning_rate
792    }
793
794    fn set_learning_rate(&mut self, lr: f64) -> Result<()> {
795        self.config.learning_rate = lr;
796
797        // Update inner optimizer
798        self.inner.set_lr(lr as f32);
799
800        Ok(())
801    }
802
803    fn get_name(&self) -> &str {
804        self.config.name.as_deref().unwrap_or("AdamW")
805    }
806}
807
808impl TensorFlowAdamW {
809    /// Compute numerical gradient using finite differences
810    fn compute_numerical_gradient(
811        &self,
812        loss_fn: &dyn Fn() -> Result<Tensor>,
813        var: &mut Tensor,
814        _var_name: &str,
815    ) -> Result<Tensor> {
816        const EPSILON: f32 = 1e-4;
817
818        let original_loss = loss_fn()?;
819        #[allow(unused_assignments)]
820        let mut grad = Tensor::zeros(&var.shape())?;
821
822        // Compute gradient for each element using finite differences
823        let var_data = var.data()?;
824        let mut grad_data = vec![0.0; var_data.len()];
825
826        for i in 0..var_data.len() {
827            // Forward difference: f(x + h) - f(x) / h
828            let mut var_plus = var_data.clone();
829            var_plus[i] += EPSILON;
830            *var = Tensor::from_vec(var_plus, &var.shape())?;
831
832            let loss_plus = loss_fn()?;
833            let loss_plus_scalar = loss_plus.data()?[0];
834            let original_loss_scalar = original_loss.data()?[0];
835
836            grad_data[i] = (loss_plus_scalar - original_loss_scalar) / EPSILON;
837
838            // Restore original value
839            let var_original = var_data.clone();
840            *var = Tensor::from_vec(var_original, &var.shape())?;
841        }
842
843        grad = Tensor::from_vec(grad_data, &var.shape())?;
844        Ok(grad)
845    }
846}
847
848/// TensorFlow optimizer factory
849pub struct TensorFlowOptimizerFactory;
850
851impl TensorFlowOptimizerFactory {
852    /// Create Adam optimizer
853    pub fn adam(
854        learning_rate: f64,
855        beta_1: f64,
856        beta_2: f64,
857        epsilon: f64,
858        weight_decay: Option<f64>,
859        clipnorm: Option<f64>,
860        clipvalue: Option<f64>,
861        global_clipnorm: Option<f64>,
862        use_ema: bool,
863        ema_momentum: f64,
864        jit_compile: bool,
865        name: Option<String>,
866    ) -> Result<TensorFlowAdam> {
867        TensorFlowAdam::new(
868            learning_rate,
869            beta_1,
870            beta_2,
871            epsilon,
872            weight_decay,
873            clipnorm,
874            clipvalue,
875            global_clipnorm,
876            use_ema,
877            ema_momentum,
878            jit_compile,
879            name,
880        )
881    }
882
883    /// Create AdamW optimizer
884    pub fn adamw(
885        learning_rate: f64,
886        beta_1: f64,
887        beta_2: f64,
888        epsilon: f64,
889        weight_decay: f64,
890        clipnorm: Option<f64>,
891        clipvalue: Option<f64>,
892        global_clipnorm: Option<f64>,
893        use_ema: bool,
894        ema_momentum: f64,
895        jit_compile: bool,
896        name: Option<String>,
897    ) -> Result<TensorFlowAdamW> {
898        TensorFlowAdamW::new(
899            learning_rate,
900            beta_1,
901            beta_2,
902            epsilon,
903            weight_decay,
904            clipnorm,
905            clipvalue,
906            global_clipnorm,
907            use_ema,
908            ema_momentum,
909            jit_compile,
910            name,
911        )
912    }
913
914    /// Create exponential decay schedule
915    pub fn exponential_decay(
916        initial_learning_rate: f64,
917        decay_steps: i64,
918        decay_rate: f64,
919        staircase: bool,
920    ) -> TensorFlowExponentialDecay {
921        TensorFlowExponentialDecay::new(initial_learning_rate, decay_steps, decay_rate, staircase)
922    }
923
924    /// Create cosine decay schedule
925    pub fn cosine_decay(
926        initial_learning_rate: f64,
927        decay_steps: i64,
928        alpha: f64,
929    ) -> TensorFlowCosineDecay {
930        TensorFlowCosineDecay::new(initial_learning_rate, decay_steps, alpha)
931    }
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937    use trustformers_core::Tensor;
938
939    #[test]
940    fn test_tensorflow_adam_creation() {
941        let optimizer = TensorFlowAdam::with_defaults().unwrap();
942        assert_eq!(optimizer.get_learning_rate(), 0.001);
943        assert_eq!(optimizer.get_name(), "Adam");
944    }
945
946    #[test]
947    fn test_tensorflow_adamw_creation() {
948        let optimizer = TensorFlowAdamW::with_defaults().unwrap();
949        assert_eq!(optimizer.get_learning_rate(), 0.001);
950        assert_eq!(optimizer.get_name(), "AdamW");
951    }
952
953    #[test]
954    fn test_tensorflow_exponential_decay() {
955        let schedule = TensorFlowExponentialDecay::new(0.1, 100, 0.96, false);
956        assert_eq!(schedule.get_lr(0), 0.1);
957        assert!(schedule.get_lr(100) < 0.1);
958    }
959
960    #[test]
961    fn test_tensorflow_cosine_decay() {
962        let schedule = TensorFlowCosineDecay::new(0.1, 100, 0.0);
963        assert_eq!(schedule.get_lr(0), 0.1);
964        assert!(schedule.get_lr(50) < 0.1);
965        assert!(schedule.get_lr(100) < 0.1);
966    }
967
968    #[test]
969    fn test_tensorflow_optimizer_factory() {
970        let adam = TensorFlowOptimizerFactory::adam(
971            0.001,
972            0.9,
973            0.999,
974            1e-7,
975            None,
976            None,
977            None,
978            None,
979            false,
980            0.99,
981            true,
982            Some("TestAdam".to_string()),
983        )
984        .unwrap();
985        assert_eq!(adam.get_name(), "TestAdam");
986
987        let adamw = TensorFlowOptimizerFactory::adamw(
988            0.001,
989            0.9,
990            0.999,
991            1e-7,
992            0.01,
993            None,
994            None,
995            None,
996            false,
997            0.99,
998            true,
999            Some("TestAdamW".to_string()),
1000        )
1001        .unwrap();
1002        assert_eq!(adamw.get_name(), "TestAdamW");
1003    }
1004
1005    #[test]
1006    fn test_learning_rate_schedule_with_optimizer() {
1007        let schedule = Box::new(TensorFlowExponentialDecay::new(0.1, 100, 0.96, false));
1008        let optimizer = TensorFlowAdam::with_schedule(
1009            schedule,
1010            0.9,
1011            0.999,
1012            1e-7,
1013            None,
1014            None,
1015            None,
1016            None,
1017            false,
1018            0.99,
1019            true,
1020            Some("ScheduledAdam".to_string()),
1021        )
1022        .unwrap();
1023
1024        assert_eq!(optimizer.get_learning_rate(), 0.1);
1025    }
1026
1027    #[test]
1028    fn test_variable_management() {
1029        let mut optimizer = TensorFlowAdam::with_defaults().unwrap();
1030
1031        let var1 = Tensor::zeros(&[10, 10]).unwrap();
1032        let var2 = Tensor::zeros(&[5, 5]).unwrap();
1033
1034        optimizer.add_variable("var1".to_string(), var1).unwrap();
1035        optimizer.add_variable("var2".to_string(), var2).unwrap();
1036
1037        let variables = optimizer.variables();
1038        assert_eq!(variables.len(), 2);
1039        assert!(variables.contains(&"var1".to_string()));
1040        assert!(variables.contains(&"var2".to_string()));
1041    }
1042
1043    #[test]
1044    fn test_learning_rate_updates() {
1045        let mut optimizer = TensorFlowAdam::with_defaults().unwrap();
1046        assert_eq!(optimizer.get_learning_rate(), 0.001);
1047
1048        optimizer.set_learning_rate(0.01).unwrap();
1049        assert_eq!(optimizer.get_learning_rate(), 0.01);
1050    }
1051
1052    #[test]
1053    fn test_config_serialization() {
1054        let optimizer = TensorFlowAdam::with_defaults().unwrap();
1055        let config = optimizer.get_config();
1056
1057        assert_eq!(config.learning_rate, 0.001);
1058        assert_eq!(config.beta_1, Some(0.9));
1059        assert_eq!(config.beta_2, Some(0.999));
1060        assert_eq!(config.epsilon, Some(1e-7));
1061    }
1062}