Skip to main content

trustformers_optim/
cross_framework.rs

1//! Cross-Framework Optimizer Conversion
2//!
3//! This module provides utilities to convert optimizer configurations and states
4//! between different ML frameworks (PyTorch, TensorFlow, JAX, ONNX).
5
6use crate::{
7    onnx_export::OptimizerConfig as ONNXConfig, tensorflow_compat::TensorFlowOptimizerConfig,
8};
9use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13
14/// Supported ML frameworks for conversion
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum Framework {
17    PyTorch,
18    TensorFlow,
19    JAX,
20    ONNX,
21    TrustformeRS,
22}
23
24/// Universal optimizer configuration that can be converted between frameworks
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct UniversalOptimizerConfig {
27    pub optimizer_type: String,
28    pub learning_rate: f32,
29    pub parameters: HashMap<String, Value>,
30    pub source_framework: Framework,
31}
32
33/// Universal optimizer state representation
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct UniversalOptimizerState {
36    pub step: i64,
37    pub state_dict: HashMap<String, Value>,
38    pub framework: Framework,
39}
40
41/// Cross-framework optimizer converter
42pub struct CrossFrameworkConverter {
43    // Mapping between parameter names across frameworks
44    parameter_mappings: HashMap<(Framework, Framework), HashMap<String, String>>,
45}
46
47impl CrossFrameworkConverter {
48    /// Create a new cross-framework converter
49    pub fn new() -> Self {
50        let mut converter = Self {
51            parameter_mappings: HashMap::new(),
52        };
53        converter.initialize_mappings();
54        converter
55    }
56
57    /// Initialize parameter name mappings between frameworks
58    fn initialize_mappings(&mut self) {
59        // PyTorch to TensorFlow mappings
60        let mut pytorch_to_tf = HashMap::new();
61        pytorch_to_tf.insert("lr".to_string(), "learning_rate".to_string());
62        pytorch_to_tf.insert("betas".to_string(), "beta_1_beta_2".to_string());
63        pytorch_to_tf.insert("eps".to_string(), "epsilon".to_string());
64        pytorch_to_tf.insert("weight_decay".to_string(), "weight_decay".to_string());
65        self.parameter_mappings
66            .insert((Framework::PyTorch, Framework::TensorFlow), pytorch_to_tf);
67
68        // TensorFlow to PyTorch mappings
69        let mut tf_to_pytorch = HashMap::new();
70        tf_to_pytorch.insert("learning_rate".to_string(), "lr".to_string());
71        tf_to_pytorch.insert("beta_1".to_string(), "betas[0]".to_string());
72        tf_to_pytorch.insert("beta_2".to_string(), "betas[1]".to_string());
73        tf_to_pytorch.insert("epsilon".to_string(), "eps".to_string());
74        self.parameter_mappings
75            .insert((Framework::TensorFlow, Framework::PyTorch), tf_to_pytorch);
76
77        // JAX to PyTorch mappings
78        let mut jax_to_pytorch = HashMap::new();
79        jax_to_pytorch.insert("learning_rate".to_string(), "lr".to_string());
80        jax_to_pytorch.insert("b1".to_string(), "betas[0]".to_string());
81        jax_to_pytorch.insert("b2".to_string(), "betas[1]".to_string());
82        jax_to_pytorch.insert("eps".to_string(), "eps".to_string());
83        self.parameter_mappings
84            .insert((Framework::JAX, Framework::PyTorch), jax_to_pytorch);
85
86        // PyTorch to JAX mappings
87        let mut pytorch_to_jax = HashMap::new();
88        pytorch_to_jax.insert("lr".to_string(), "learning_rate".to_string());
89        pytorch_to_jax.insert("betas[0]".to_string(), "b1".to_string());
90        pytorch_to_jax.insert("betas[1]".to_string(), "b2".to_string());
91        pytorch_to_jax.insert("eps".to_string(), "eps".to_string());
92        self.parameter_mappings
93            .insert((Framework::PyTorch, Framework::JAX), pytorch_to_jax);
94
95        // ONNX mappings (similar to TensorFlow)
96        let mut onnx_to_pytorch = HashMap::new();
97        onnx_to_pytorch.insert("alpha".to_string(), "lr".to_string());
98        onnx_to_pytorch.insert("beta".to_string(), "betas[0]".to_string());
99        onnx_to_pytorch.insert("beta2".to_string(), "betas[1]".to_string());
100        onnx_to_pytorch.insert("epsilon".to_string(), "eps".to_string());
101        self.parameter_mappings
102            .insert((Framework::ONNX, Framework::PyTorch), onnx_to_pytorch);
103
104        let mut pytorch_to_onnx = HashMap::new();
105        pytorch_to_onnx.insert("lr".to_string(), "alpha".to_string());
106        pytorch_to_onnx.insert("betas[0]".to_string(), "beta".to_string());
107        pytorch_to_onnx.insert("betas[1]".to_string(), "beta2".to_string());
108        pytorch_to_onnx.insert("eps".to_string(), "epsilon".to_string());
109        self.parameter_mappings
110            .insert((Framework::PyTorch, Framework::ONNX), pytorch_to_onnx);
111    }
112
113    /// Convert optimizer configuration to universal format
114    pub fn to_universal(
115        &self,
116        config: &dyn ConfigSource,
117        source_framework: Framework,
118    ) -> Result<UniversalOptimizerConfig> {
119        let (optimizer_type, learning_rate, parameters) = config.extract_config()?;
120
121        Ok(UniversalOptimizerConfig {
122            optimizer_type,
123            learning_rate,
124            parameters,
125            source_framework,
126        })
127    }
128
129    /// Convert from universal format to target framework
130    pub fn from_universal(
131        &self,
132        config: &UniversalOptimizerConfig,
133        target_framework: Framework,
134    ) -> Result<Box<dyn ConfigTarget>> {
135        let mapped_params = self.map_parameters(
136            &config.parameters,
137            config.source_framework,
138            target_framework,
139        )?;
140
141        match target_framework {
142            Framework::PyTorch => {
143                let pytorch_config = PyTorchOptimizerConfig {
144                    optimizer_type: config.optimizer_type.clone(),
145                    learning_rate: config.learning_rate,
146                    parameters: mapped_params,
147                };
148                Ok(Box::new(pytorch_config))
149            },
150            Framework::TensorFlow => {
151                let tf_config = TensorFlowOptimizerConfig {
152                    optimizer_type: config.optimizer_type.clone(),
153                    learning_rate: config.learning_rate as f64,
154                    parameters: mapped_params,
155                    ..Default::default()
156                };
157                Ok(Box::new(tf_config))
158            },
159            Framework::JAX => {
160                let jax_config = JAXOptimizerConfig {
161                    optimizer_type: config.optimizer_type.clone(),
162                    learning_rate: config.learning_rate,
163                    parameters: mapped_params,
164                };
165                Ok(Box::new(jax_config))
166            },
167            Framework::ONNX => {
168                let onnx_config = ONNXConfig {
169                    optimizer_type: config.optimizer_type.clone(),
170                    learning_rate: config.learning_rate,
171                    parameters: mapped_params,
172                };
173                Ok(Box::new(onnx_config))
174            },
175            Framework::TrustformeRS => {
176                let trustformers_config = TrustformeRSOptimizerConfig {
177                    optimizer_type: config.optimizer_type.clone(),
178                    learning_rate: config.learning_rate,
179                    parameters: mapped_params,
180                };
181                Ok(Box::new(trustformers_config))
182            },
183        }
184    }
185
186    /// Direct conversion between frameworks
187    pub fn convert(
188        &self,
189        config: &dyn ConfigSource,
190        source_framework: Framework,
191        target_framework: Framework,
192    ) -> Result<Box<dyn ConfigTarget>> {
193        let universal = self.to_universal(config, source_framework)?;
194        self.from_universal(&universal, target_framework)
195    }
196
197    /// Map parameter names between frameworks
198    fn map_parameters(
199        &self,
200        parameters: &HashMap<String, Value>,
201        source: Framework,
202        target: Framework,
203    ) -> Result<HashMap<String, Value>> {
204        if source == target {
205            return Ok(parameters.clone());
206        }
207
208        let mapping = self.parameter_mappings.get(&(source, target)).ok_or_else(|| {
209            anyhow!(
210                "No parameter mapping found for {:?} to {:?}",
211                source,
212                target
213            )
214        })?;
215
216        let mut mapped_params = HashMap::new();
217
218        for (key, value) in parameters {
219            let mapped_key = mapping.get(key).unwrap_or(key);
220            mapped_params.insert(mapped_key.clone(), value.clone());
221        }
222
223        Ok(mapped_params)
224    }
225
226    /// Convert PyTorch Adam to TensorFlow Adam
227    pub fn pytorch_adam_to_tensorflow(
228        &self,
229        lr: f32,
230        betas: (f32, f32),
231        eps: f32,
232        weight_decay: f32,
233    ) -> Result<TensorFlowOptimizerConfig> {
234        let mut parameters = HashMap::new();
235        parameters.insert(
236            "beta_1".to_string(),
237            Value::Number(
238                serde_json::Number::from_f64(betas.0 as f64)
239                    .ok_or_else(|| anyhow!("Invalid beta_1"))?,
240            ),
241        );
242        parameters.insert(
243            "beta_2".to_string(),
244            Value::Number(
245                serde_json::Number::from_f64(betas.1 as f64)
246                    .ok_or_else(|| anyhow!("Invalid beta_2"))?,
247            ),
248        );
249        parameters.insert(
250            "epsilon".to_string(),
251            Value::Number(
252                serde_json::Number::from_f64(eps as f64)
253                    .ok_or_else(|| anyhow!("Invalid epsilon"))?,
254            ),
255        );
256        parameters.insert(
257            "weight_decay".to_string(),
258            Value::Number(
259                serde_json::Number::from_f64(weight_decay as f64)
260                    .ok_or_else(|| anyhow!("Invalid weight_decay"))?,
261            ),
262        );
263
264        Ok(TensorFlowOptimizerConfig {
265            optimizer_type: "Adam".to_string(),
266            learning_rate: lr as f64,
267            parameters,
268            ..Default::default()
269        })
270    }
271
272    /// Convert TensorFlow Adam to PyTorch Adam
273    pub fn tensorflow_adam_to_pytorch(
274        &self,
275        lr: f32,
276        beta_1: f32,
277        beta_2: f32,
278        epsilon: f32,
279        weight_decay: f32,
280    ) -> Result<PyTorchOptimizerConfig> {
281        let mut parameters = HashMap::new();
282        parameters.insert(
283            "betas".to_string(),
284            Value::Array(vec![
285                Value::Number(
286                    serde_json::Number::from_f64(beta_1 as f64)
287                        .ok_or_else(|| anyhow!("Invalid beta_1"))?,
288                ),
289                Value::Number(
290                    serde_json::Number::from_f64(beta_2 as f64)
291                        .ok_or_else(|| anyhow!("Invalid beta_2"))?,
292                ),
293            ]),
294        );
295        parameters.insert(
296            "eps".to_string(),
297            Value::Number(
298                serde_json::Number::from_f64(epsilon as f64)
299                    .ok_or_else(|| anyhow!("Invalid epsilon"))?,
300            ),
301        );
302        parameters.insert(
303            "weight_decay".to_string(),
304            Value::Number(
305                serde_json::Number::from_f64(weight_decay as f64)
306                    .ok_or_else(|| anyhow!("Invalid weight_decay"))?,
307            ),
308        );
309
310        Ok(PyTorchOptimizerConfig {
311            optimizer_type: "Adam".to_string(),
312            learning_rate: lr,
313            parameters,
314        })
315    }
316
317    /// Convert JAX Adam to PyTorch Adam
318    pub fn jax_adam_to_pytorch(
319        &self,
320        learning_rate: f32,
321        b1: f32,
322        b2: f32,
323        eps: f32,
324    ) -> Result<PyTorchOptimizerConfig> {
325        let mut parameters = HashMap::new();
326        parameters.insert(
327            "betas".to_string(),
328            Value::Array(vec![
329                Value::Number(
330                    serde_json::Number::from_f64(b1 as f64).ok_or_else(|| anyhow!("Invalid b1"))?,
331                ),
332                Value::Number(
333                    serde_json::Number::from_f64(b2 as f64).ok_or_else(|| anyhow!("Invalid b2"))?,
334                ),
335            ]),
336        );
337        parameters.insert(
338            "eps".to_string(),
339            Value::Number(
340                serde_json::Number::from_f64(eps as f64)
341                    .ok_or_else(|| anyhow!("Invalid epsilon"))?,
342            ),
343        );
344
345        Ok(PyTorchOptimizerConfig {
346            optimizer_type: "Adam".to_string(),
347            learning_rate,
348            parameters,
349        })
350    }
351
352    /// Convert ONNX optimizer to PyTorch
353    pub fn onnx_to_pytorch(&self, onnx_config: &ONNXConfig) -> Result<PyTorchOptimizerConfig> {
354        let mapped_params =
355            self.map_parameters(&onnx_config.parameters, Framework::ONNX, Framework::PyTorch)?;
356
357        Ok(PyTorchOptimizerConfig {
358            optimizer_type: onnx_config.optimizer_type.clone(),
359            learning_rate: onnx_config.learning_rate,
360            parameters: mapped_params,
361        })
362    }
363
364    /// Batch convert multiple optimizers
365    pub fn batch_convert(
366        &self,
367        configs: Vec<(&dyn ConfigSource, Framework)>,
368        target_framework: Framework,
369    ) -> Result<Vec<Box<dyn ConfigTarget>>> {
370        let mut results = Vec::new();
371
372        for (config, source_framework) in configs {
373            let converted = self.convert(config, source_framework, target_framework)?;
374            results.push(converted);
375        }
376
377        Ok(results)
378    }
379
380    /// Generate conversion report
381    pub fn generate_conversion_report(&self, source: Framework, target: Framework) -> String {
382        let mapping = self.parameter_mappings.get(&(source, target));
383
384        match mapping {
385            Some(map) => {
386                let mut report = format!("Conversion mapping from {:?} to {:?}:\n", source, target);
387                for (source_param, target_param) in map {
388                    report.push_str(&format!("  {} -> {}\n", source_param, target_param));
389                }
390                report
391            },
392            None => format!(
393                "No conversion mapping available from {:?} to {:?}",
394                source, target
395            ),
396        }
397    }
398}
399
400impl Default for CrossFrameworkConverter {
401    fn default() -> Self {
402        Self::new()
403    }
404}
405
406/// Trait for extracting configuration from different sources
407pub trait ConfigSource {
408    fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)>;
409}
410
411/// Trait for creating configuration targets
412pub trait ConfigTarget {}
413
414/// PyTorch optimizer configuration
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct PyTorchOptimizerConfig {
417    pub optimizer_type: String,
418    pub learning_rate: f32,
419    pub parameters: HashMap<String, Value>,
420}
421
422impl ConfigTarget for PyTorchOptimizerConfig {}
423
424impl ConfigSource for PyTorchOptimizerConfig {
425    fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
426        Ok((
427            self.optimizer_type.clone(),
428            self.learning_rate,
429            self.parameters.clone(),
430        ))
431    }
432}
433
434/// JAX optimizer configuration
435#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct JAXOptimizerConfig {
437    pub optimizer_type: String,
438    pub learning_rate: f32,
439    pub parameters: HashMap<String, Value>,
440}
441
442impl ConfigTarget for JAXOptimizerConfig {}
443
444impl ConfigSource for JAXOptimizerConfig {
445    fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
446        Ok((
447            self.optimizer_type.clone(),
448            self.learning_rate,
449            self.parameters.clone(),
450        ))
451    }
452}
453
454/// TrustformeRS optimizer configuration
455#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct TrustformeRSOptimizerConfig {
457    pub optimizer_type: String,
458    pub learning_rate: f32,
459    pub parameters: HashMap<String, Value>,
460}
461
462impl ConfigTarget for TrustformeRSOptimizerConfig {}
463
464impl ConfigSource for TrustformeRSOptimizerConfig {
465    fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
466        Ok((
467            self.optimizer_type.clone(),
468            self.learning_rate,
469            self.parameters.clone(),
470        ))
471    }
472}
473
474// Implement ConfigSource for existing types
475impl ConfigSource for crate::tensorflow_compat::TensorFlowOptimizerConfig {
476    fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
477        Ok((
478            self.optimizer_type.clone(),
479            self.learning_rate as f32,
480            self.parameters.clone(),
481        ))
482    }
483}
484
485impl ConfigTarget for crate::tensorflow_compat::TensorFlowOptimizerConfig {}
486
487impl ConfigSource for crate::onnx_export::OptimizerConfig {
488    fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
489        Ok((
490            self.optimizer_type.clone(),
491            self.learning_rate,
492            self.parameters.clone(),
493        ))
494    }
495}
496
497impl ConfigTarget for crate::onnx_export::OptimizerConfig {}
498
499/// Conversion utilities
500pub mod utils {
501    use super::*;
502
503    /// Create a conversion matrix showing all supported conversions
504    pub fn create_conversion_matrix() -> HashMap<(Framework, Framework), bool> {
505        let frameworks = [
506            Framework::PyTorch,
507            Framework::TensorFlow,
508            Framework::JAX,
509            Framework::ONNX,
510            Framework::TrustformeRS,
511        ];
512        let mut matrix = HashMap::new();
513
514        for &source in &frameworks {
515            for &target in &frameworks {
516                // All conversions are supported
517                matrix.insert((source, target), true);
518            }
519        }
520
521        matrix
522    }
523
524    /// Get list of supported frameworks
525    pub fn get_supported_frameworks() -> Vec<Framework> {
526        vec![
527            Framework::PyTorch,
528            Framework::TensorFlow,
529            Framework::JAX,
530            Framework::ONNX,
531            Framework::TrustformeRS,
532        ]
533    }
534
535    /// Validate parameter values during conversion
536    pub fn validate_parameters(
537        optimizer_type: &str,
538        parameters: &HashMap<String, Value>,
539    ) -> Result<()> {
540        match optimizer_type {
541            "Adam" | "AdamW" => {
542                // Validate beta values
543                if let Some(Value::Array(betas)) = parameters.get("betas") {
544                    if betas.len() != 2 {
545                        return Err(anyhow!("Adam betas must be a 2-element array"));
546                    }
547                }
548
549                // Validate learning rate is positive
550                if let Some(Value::Number(lr)) = parameters.get("lr") {
551                    if lr.as_f64().unwrap_or(0.0) <= 0.0 {
552                        return Err(anyhow!("Learning rate must be positive"));
553                    }
554                }
555            },
556            "SGD" => {
557                // Validate momentum
558                if let Some(Value::Number(momentum)) = parameters.get("momentum") {
559                    let momentum_val = momentum.as_f64().unwrap_or(0.0);
560                    if !(0.0..1.0).contains(&momentum_val) {
561                        return Err(anyhow!("Momentum must be in [0, 1)"));
562                    }
563                }
564            },
565            _ => {
566                // Generic validation for unknown optimizers
567                for (key, value) in parameters {
568                    if key.contains("learning_rate") || key.contains("lr") {
569                        if let Value::Number(lr) = value {
570                            if lr.as_f64().unwrap_or(0.0) <= 0.0 {
571                                return Err(anyhow!("Learning rate must be positive"));
572                            }
573                        }
574                    }
575                }
576            },
577        }
578
579        Ok(())
580    }
581}
582
583#[cfg(test)]
584mod tests {
585    use super::*;
586
587    #[test]
588    fn test_pytorch_to_tensorflow_conversion() {
589        let converter = CrossFrameworkConverter::new();
590        let tf_config =
591            converter.pytorch_adam_to_tensorflow(0.001, (0.9, 0.999), 1e-8, 0.01).unwrap();
592
593        assert_eq!(tf_config.optimizer_type, "Adam");
594        assert!((tf_config.learning_rate - 0.001).abs() < 1e-9);
595        assert!(tf_config.parameters.contains_key("beta_1"));
596        assert!(tf_config.parameters.contains_key("beta_2"));
597    }
598
599    #[test]
600    fn test_tensorflow_to_pytorch_conversion() {
601        let converter = CrossFrameworkConverter::new();
602        let pytorch_config =
603            converter.tensorflow_adam_to_pytorch(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
604
605        assert_eq!(pytorch_config.optimizer_type, "Adam");
606        assert_eq!(pytorch_config.learning_rate, 0.001);
607        assert!(pytorch_config.parameters.contains_key("betas"));
608        assert!(pytorch_config.parameters.contains_key("eps"));
609    }
610
611    #[test]
612    fn test_jax_to_pytorch_conversion() {
613        let converter = CrossFrameworkConverter::new();
614        let pytorch_config = converter.jax_adam_to_pytorch(0.001, 0.9, 0.999, 1e-8).unwrap();
615
616        assert_eq!(pytorch_config.optimizer_type, "Adam");
617        assert_eq!(pytorch_config.learning_rate, 0.001);
618        assert!(pytorch_config.parameters.contains_key("betas"));
619    }
620
621    #[test]
622    fn test_parameter_mapping() {
623        let converter = CrossFrameworkConverter::new();
624        let mut params = HashMap::new();
625        params.insert(
626            "lr".to_string(),
627            Value::Number(serde_json::Number::from_f64(0.001).expect("Invalid constant")),
628        );
629
630        let mapped = converter
631            .map_parameters(&params, Framework::PyTorch, Framework::TensorFlow)
632            .unwrap();
633        assert!(mapped.contains_key("learning_rate"));
634    }
635
636    #[test]
637    fn test_universal_conversion() {
638        let converter = CrossFrameworkConverter::new();
639
640        let pytorch_config = PyTorchOptimizerConfig {
641            optimizer_type: "Adam".to_string(),
642            learning_rate: 0.001,
643            parameters: HashMap::new(),
644        };
645
646        let universal = converter.to_universal(&pytorch_config, Framework::PyTorch).unwrap();
647        assert_eq!(universal.optimizer_type, "Adam");
648        assert_eq!(universal.source_framework, Framework::PyTorch);
649
650        let _tf_config = converter.from_universal(&universal, Framework::TensorFlow).unwrap();
651    }
652
653    #[test]
654    fn test_conversion_matrix() {
655        let matrix = utils::create_conversion_matrix();
656        assert!(matrix.get(&(Framework::PyTorch, Framework::TensorFlow)).unwrap());
657        assert!(matrix.get(&(Framework::JAX, Framework::ONNX)).unwrap());
658    }
659
660    #[test]
661    fn test_parameter_validation() {
662        let mut params = HashMap::new();
663        params.insert(
664            "lr".to_string(),
665            Value::Number(serde_json::Number::from_f64(0.001).expect("Invalid constant")),
666        );
667
668        utils::validate_parameters("Adam", &params).unwrap();
669
670        // Test invalid learning rate
671        params.insert(
672            "lr".to_string(),
673            Value::Number(serde_json::Number::from_f64(-0.001).expect("Invalid constant")),
674        );
675        assert!(utils::validate_parameters("Adam", &params).is_err());
676    }
677
678    #[test]
679    fn test_conversion_report() {
680        let converter = CrossFrameworkConverter::new();
681        let report =
682            converter.generate_conversion_report(Framework::PyTorch, Framework::TensorFlow);
683
684        assert!(report.contains("PyTorch"));
685        assert!(report.contains("TensorFlow"));
686        assert!(report.contains("->"));
687    }
688
689    #[test]
690    fn test_supported_frameworks() {
691        let frameworks = utils::get_supported_frameworks();
692        assert!(frameworks.contains(&Framework::PyTorch));
693        assert!(frameworks.contains(&Framework::TensorFlow));
694        assert!(frameworks.contains(&Framework::JAX));
695        assert!(frameworks.contains(&Framework::ONNX));
696        assert!(frameworks.contains(&Framework::TrustformeRS));
697    }
698}