1use crate::{
7 onnx_export::OptimizerConfig as ONNXConfig, tensorflow_compat::TensorFlowOptimizerConfig,
8};
9use anyhow::{anyhow, Result};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum Framework {
17 PyTorch,
18 TensorFlow,
19 JAX,
20 ONNX,
21 TrustformeRS,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct UniversalOptimizerConfig {
27 pub optimizer_type: String,
28 pub learning_rate: f32,
29 pub parameters: HashMap<String, Value>,
30 pub source_framework: Framework,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct UniversalOptimizerState {
36 pub step: i64,
37 pub state_dict: HashMap<String, Value>,
38 pub framework: Framework,
39}
40
41pub struct CrossFrameworkConverter {
43 parameter_mappings: HashMap<(Framework, Framework), HashMap<String, String>>,
45}
46
47impl CrossFrameworkConverter {
48 pub fn new() -> Self {
50 let mut converter = Self {
51 parameter_mappings: HashMap::new(),
52 };
53 converter.initialize_mappings();
54 converter
55 }
56
57 fn initialize_mappings(&mut self) {
59 let mut pytorch_to_tf = HashMap::new();
61 pytorch_to_tf.insert("lr".to_string(), "learning_rate".to_string());
62 pytorch_to_tf.insert("betas".to_string(), "beta_1_beta_2".to_string());
63 pytorch_to_tf.insert("eps".to_string(), "epsilon".to_string());
64 pytorch_to_tf.insert("weight_decay".to_string(), "weight_decay".to_string());
65 self.parameter_mappings
66 .insert((Framework::PyTorch, Framework::TensorFlow), pytorch_to_tf);
67
68 let mut tf_to_pytorch = HashMap::new();
70 tf_to_pytorch.insert("learning_rate".to_string(), "lr".to_string());
71 tf_to_pytorch.insert("beta_1".to_string(), "betas[0]".to_string());
72 tf_to_pytorch.insert("beta_2".to_string(), "betas[1]".to_string());
73 tf_to_pytorch.insert("epsilon".to_string(), "eps".to_string());
74 self.parameter_mappings
75 .insert((Framework::TensorFlow, Framework::PyTorch), tf_to_pytorch);
76
77 let mut jax_to_pytorch = HashMap::new();
79 jax_to_pytorch.insert("learning_rate".to_string(), "lr".to_string());
80 jax_to_pytorch.insert("b1".to_string(), "betas[0]".to_string());
81 jax_to_pytorch.insert("b2".to_string(), "betas[1]".to_string());
82 jax_to_pytorch.insert("eps".to_string(), "eps".to_string());
83 self.parameter_mappings
84 .insert((Framework::JAX, Framework::PyTorch), jax_to_pytorch);
85
86 let mut pytorch_to_jax = HashMap::new();
88 pytorch_to_jax.insert("lr".to_string(), "learning_rate".to_string());
89 pytorch_to_jax.insert("betas[0]".to_string(), "b1".to_string());
90 pytorch_to_jax.insert("betas[1]".to_string(), "b2".to_string());
91 pytorch_to_jax.insert("eps".to_string(), "eps".to_string());
92 self.parameter_mappings
93 .insert((Framework::PyTorch, Framework::JAX), pytorch_to_jax);
94
95 let mut onnx_to_pytorch = HashMap::new();
97 onnx_to_pytorch.insert("alpha".to_string(), "lr".to_string());
98 onnx_to_pytorch.insert("beta".to_string(), "betas[0]".to_string());
99 onnx_to_pytorch.insert("beta2".to_string(), "betas[1]".to_string());
100 onnx_to_pytorch.insert("epsilon".to_string(), "eps".to_string());
101 self.parameter_mappings
102 .insert((Framework::ONNX, Framework::PyTorch), onnx_to_pytorch);
103
104 let mut pytorch_to_onnx = HashMap::new();
105 pytorch_to_onnx.insert("lr".to_string(), "alpha".to_string());
106 pytorch_to_onnx.insert("betas[0]".to_string(), "beta".to_string());
107 pytorch_to_onnx.insert("betas[1]".to_string(), "beta2".to_string());
108 pytorch_to_onnx.insert("eps".to_string(), "epsilon".to_string());
109 self.parameter_mappings
110 .insert((Framework::PyTorch, Framework::ONNX), pytorch_to_onnx);
111 }
112
113 pub fn to_universal(
115 &self,
116 config: &dyn ConfigSource,
117 source_framework: Framework,
118 ) -> Result<UniversalOptimizerConfig> {
119 let (optimizer_type, learning_rate, parameters) = config.extract_config()?;
120
121 Ok(UniversalOptimizerConfig {
122 optimizer_type,
123 learning_rate,
124 parameters,
125 source_framework,
126 })
127 }
128
129 pub fn from_universal(
131 &self,
132 config: &UniversalOptimizerConfig,
133 target_framework: Framework,
134 ) -> Result<Box<dyn ConfigTarget>> {
135 let mapped_params = self.map_parameters(
136 &config.parameters,
137 config.source_framework,
138 target_framework,
139 )?;
140
141 match target_framework {
142 Framework::PyTorch => {
143 let pytorch_config = PyTorchOptimizerConfig {
144 optimizer_type: config.optimizer_type.clone(),
145 learning_rate: config.learning_rate,
146 parameters: mapped_params,
147 };
148 Ok(Box::new(pytorch_config))
149 },
150 Framework::TensorFlow => {
151 let tf_config = TensorFlowOptimizerConfig {
152 optimizer_type: config.optimizer_type.clone(),
153 learning_rate: config.learning_rate as f64,
154 parameters: mapped_params,
155 ..Default::default()
156 };
157 Ok(Box::new(tf_config))
158 },
159 Framework::JAX => {
160 let jax_config = JAXOptimizerConfig {
161 optimizer_type: config.optimizer_type.clone(),
162 learning_rate: config.learning_rate,
163 parameters: mapped_params,
164 };
165 Ok(Box::new(jax_config))
166 },
167 Framework::ONNX => {
168 let onnx_config = ONNXConfig {
169 optimizer_type: config.optimizer_type.clone(),
170 learning_rate: config.learning_rate,
171 parameters: mapped_params,
172 };
173 Ok(Box::new(onnx_config))
174 },
175 Framework::TrustformeRS => {
176 let trustformers_config = TrustformeRSOptimizerConfig {
177 optimizer_type: config.optimizer_type.clone(),
178 learning_rate: config.learning_rate,
179 parameters: mapped_params,
180 };
181 Ok(Box::new(trustformers_config))
182 },
183 }
184 }
185
186 pub fn convert(
188 &self,
189 config: &dyn ConfigSource,
190 source_framework: Framework,
191 target_framework: Framework,
192 ) -> Result<Box<dyn ConfigTarget>> {
193 let universal = self.to_universal(config, source_framework)?;
194 self.from_universal(&universal, target_framework)
195 }
196
197 fn map_parameters(
199 &self,
200 parameters: &HashMap<String, Value>,
201 source: Framework,
202 target: Framework,
203 ) -> Result<HashMap<String, Value>> {
204 if source == target {
205 return Ok(parameters.clone());
206 }
207
208 let mapping = self.parameter_mappings.get(&(source, target)).ok_or_else(|| {
209 anyhow!(
210 "No parameter mapping found for {:?} to {:?}",
211 source,
212 target
213 )
214 })?;
215
216 let mut mapped_params = HashMap::new();
217
218 for (key, value) in parameters {
219 let mapped_key = mapping.get(key).unwrap_or(key);
220 mapped_params.insert(mapped_key.clone(), value.clone());
221 }
222
223 Ok(mapped_params)
224 }
225
226 pub fn pytorch_adam_to_tensorflow(
228 &self,
229 lr: f32,
230 betas: (f32, f32),
231 eps: f32,
232 weight_decay: f32,
233 ) -> Result<TensorFlowOptimizerConfig> {
234 let mut parameters = HashMap::new();
235 parameters.insert(
236 "beta_1".to_string(),
237 Value::Number(
238 serde_json::Number::from_f64(betas.0 as f64)
239 .ok_or_else(|| anyhow!("Invalid beta_1"))?,
240 ),
241 );
242 parameters.insert(
243 "beta_2".to_string(),
244 Value::Number(
245 serde_json::Number::from_f64(betas.1 as f64)
246 .ok_or_else(|| anyhow!("Invalid beta_2"))?,
247 ),
248 );
249 parameters.insert(
250 "epsilon".to_string(),
251 Value::Number(
252 serde_json::Number::from_f64(eps as f64)
253 .ok_or_else(|| anyhow!("Invalid epsilon"))?,
254 ),
255 );
256 parameters.insert(
257 "weight_decay".to_string(),
258 Value::Number(
259 serde_json::Number::from_f64(weight_decay as f64)
260 .ok_or_else(|| anyhow!("Invalid weight_decay"))?,
261 ),
262 );
263
264 Ok(TensorFlowOptimizerConfig {
265 optimizer_type: "Adam".to_string(),
266 learning_rate: lr as f64,
267 parameters,
268 ..Default::default()
269 })
270 }
271
272 pub fn tensorflow_adam_to_pytorch(
274 &self,
275 lr: f32,
276 beta_1: f32,
277 beta_2: f32,
278 epsilon: f32,
279 weight_decay: f32,
280 ) -> Result<PyTorchOptimizerConfig> {
281 let mut parameters = HashMap::new();
282 parameters.insert(
283 "betas".to_string(),
284 Value::Array(vec![
285 Value::Number(
286 serde_json::Number::from_f64(beta_1 as f64)
287 .ok_or_else(|| anyhow!("Invalid beta_1"))?,
288 ),
289 Value::Number(
290 serde_json::Number::from_f64(beta_2 as f64)
291 .ok_or_else(|| anyhow!("Invalid beta_2"))?,
292 ),
293 ]),
294 );
295 parameters.insert(
296 "eps".to_string(),
297 Value::Number(
298 serde_json::Number::from_f64(epsilon as f64)
299 .ok_or_else(|| anyhow!("Invalid epsilon"))?,
300 ),
301 );
302 parameters.insert(
303 "weight_decay".to_string(),
304 Value::Number(
305 serde_json::Number::from_f64(weight_decay as f64)
306 .ok_or_else(|| anyhow!("Invalid weight_decay"))?,
307 ),
308 );
309
310 Ok(PyTorchOptimizerConfig {
311 optimizer_type: "Adam".to_string(),
312 learning_rate: lr,
313 parameters,
314 })
315 }
316
317 pub fn jax_adam_to_pytorch(
319 &self,
320 learning_rate: f32,
321 b1: f32,
322 b2: f32,
323 eps: f32,
324 ) -> Result<PyTorchOptimizerConfig> {
325 let mut parameters = HashMap::new();
326 parameters.insert(
327 "betas".to_string(),
328 Value::Array(vec![
329 Value::Number(
330 serde_json::Number::from_f64(b1 as f64).ok_or_else(|| anyhow!("Invalid b1"))?,
331 ),
332 Value::Number(
333 serde_json::Number::from_f64(b2 as f64).ok_or_else(|| anyhow!("Invalid b2"))?,
334 ),
335 ]),
336 );
337 parameters.insert(
338 "eps".to_string(),
339 Value::Number(
340 serde_json::Number::from_f64(eps as f64)
341 .ok_or_else(|| anyhow!("Invalid epsilon"))?,
342 ),
343 );
344
345 Ok(PyTorchOptimizerConfig {
346 optimizer_type: "Adam".to_string(),
347 learning_rate,
348 parameters,
349 })
350 }
351
352 pub fn onnx_to_pytorch(&self, onnx_config: &ONNXConfig) -> Result<PyTorchOptimizerConfig> {
354 let mapped_params =
355 self.map_parameters(&onnx_config.parameters, Framework::ONNX, Framework::PyTorch)?;
356
357 Ok(PyTorchOptimizerConfig {
358 optimizer_type: onnx_config.optimizer_type.clone(),
359 learning_rate: onnx_config.learning_rate,
360 parameters: mapped_params,
361 })
362 }
363
364 pub fn batch_convert(
366 &self,
367 configs: Vec<(&dyn ConfigSource, Framework)>,
368 target_framework: Framework,
369 ) -> Result<Vec<Box<dyn ConfigTarget>>> {
370 let mut results = Vec::new();
371
372 for (config, source_framework) in configs {
373 let converted = self.convert(config, source_framework, target_framework)?;
374 results.push(converted);
375 }
376
377 Ok(results)
378 }
379
380 pub fn generate_conversion_report(&self, source: Framework, target: Framework) -> String {
382 let mapping = self.parameter_mappings.get(&(source, target));
383
384 match mapping {
385 Some(map) => {
386 let mut report = format!("Conversion mapping from {:?} to {:?}:\n", source, target);
387 for (source_param, target_param) in map {
388 report.push_str(&format!(" {} -> {}\n", source_param, target_param));
389 }
390 report
391 },
392 None => format!(
393 "No conversion mapping available from {:?} to {:?}",
394 source, target
395 ),
396 }
397 }
398}
399
400impl Default for CrossFrameworkConverter {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406pub trait ConfigSource {
408 fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)>;
409}
410
411pub trait ConfigTarget {}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct PyTorchOptimizerConfig {
417 pub optimizer_type: String,
418 pub learning_rate: f32,
419 pub parameters: HashMap<String, Value>,
420}
421
422impl ConfigTarget for PyTorchOptimizerConfig {}
423
424impl ConfigSource for PyTorchOptimizerConfig {
425 fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
426 Ok((
427 self.optimizer_type.clone(),
428 self.learning_rate,
429 self.parameters.clone(),
430 ))
431 }
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
436pub struct JAXOptimizerConfig {
437 pub optimizer_type: String,
438 pub learning_rate: f32,
439 pub parameters: HashMap<String, Value>,
440}
441
442impl ConfigTarget for JAXOptimizerConfig {}
443
444impl ConfigSource for JAXOptimizerConfig {
445 fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
446 Ok((
447 self.optimizer_type.clone(),
448 self.learning_rate,
449 self.parameters.clone(),
450 ))
451 }
452}
453
454#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct TrustformeRSOptimizerConfig {
457 pub optimizer_type: String,
458 pub learning_rate: f32,
459 pub parameters: HashMap<String, Value>,
460}
461
462impl ConfigTarget for TrustformeRSOptimizerConfig {}
463
464impl ConfigSource for TrustformeRSOptimizerConfig {
465 fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
466 Ok((
467 self.optimizer_type.clone(),
468 self.learning_rate,
469 self.parameters.clone(),
470 ))
471 }
472}
473
474impl ConfigSource for crate::tensorflow_compat::TensorFlowOptimizerConfig {
476 fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
477 Ok((
478 self.optimizer_type.clone(),
479 self.learning_rate as f32,
480 self.parameters.clone(),
481 ))
482 }
483}
484
485impl ConfigTarget for crate::tensorflow_compat::TensorFlowOptimizerConfig {}
486
487impl ConfigSource for crate::onnx_export::OptimizerConfig {
488 fn extract_config(&self) -> Result<(String, f32, HashMap<String, Value>)> {
489 Ok((
490 self.optimizer_type.clone(),
491 self.learning_rate,
492 self.parameters.clone(),
493 ))
494 }
495}
496
497impl ConfigTarget for crate::onnx_export::OptimizerConfig {}
498
499pub mod utils {
501 use super::*;
502
503 pub fn create_conversion_matrix() -> HashMap<(Framework, Framework), bool> {
505 let frameworks = [
506 Framework::PyTorch,
507 Framework::TensorFlow,
508 Framework::JAX,
509 Framework::ONNX,
510 Framework::TrustformeRS,
511 ];
512 let mut matrix = HashMap::new();
513
514 for &source in &frameworks {
515 for &target in &frameworks {
516 matrix.insert((source, target), true);
518 }
519 }
520
521 matrix
522 }
523
524 pub fn get_supported_frameworks() -> Vec<Framework> {
526 vec![
527 Framework::PyTorch,
528 Framework::TensorFlow,
529 Framework::JAX,
530 Framework::ONNX,
531 Framework::TrustformeRS,
532 ]
533 }
534
535 pub fn validate_parameters(
537 optimizer_type: &str,
538 parameters: &HashMap<String, Value>,
539 ) -> Result<()> {
540 match optimizer_type {
541 "Adam" | "AdamW" => {
542 if let Some(Value::Array(betas)) = parameters.get("betas") {
544 if betas.len() != 2 {
545 return Err(anyhow!("Adam betas must be a 2-element array"));
546 }
547 }
548
549 if let Some(Value::Number(lr)) = parameters.get("lr") {
551 if lr.as_f64().unwrap_or(0.0) <= 0.0 {
552 return Err(anyhow!("Learning rate must be positive"));
553 }
554 }
555 },
556 "SGD" => {
557 if let Some(Value::Number(momentum)) = parameters.get("momentum") {
559 let momentum_val = momentum.as_f64().unwrap_or(0.0);
560 if !(0.0..1.0).contains(&momentum_val) {
561 return Err(anyhow!("Momentum must be in [0, 1)"));
562 }
563 }
564 },
565 _ => {
566 for (key, value) in parameters {
568 if key.contains("learning_rate") || key.contains("lr") {
569 if let Value::Number(lr) = value {
570 if lr.as_f64().unwrap_or(0.0) <= 0.0 {
571 return Err(anyhow!("Learning rate must be positive"));
572 }
573 }
574 }
575 }
576 },
577 }
578
579 Ok(())
580 }
581}
582
583#[cfg(test)]
584mod tests {
585 use super::*;
586
587 #[test]
588 fn test_pytorch_to_tensorflow_conversion() {
589 let converter = CrossFrameworkConverter::new();
590 let tf_config =
591 converter.pytorch_adam_to_tensorflow(0.001, (0.9, 0.999), 1e-8, 0.01).unwrap();
592
593 assert_eq!(tf_config.optimizer_type, "Adam");
594 assert!((tf_config.learning_rate - 0.001).abs() < 1e-9);
595 assert!(tf_config.parameters.contains_key("beta_1"));
596 assert!(tf_config.parameters.contains_key("beta_2"));
597 }
598
599 #[test]
600 fn test_tensorflow_to_pytorch_conversion() {
601 let converter = CrossFrameworkConverter::new();
602 let pytorch_config =
603 converter.tensorflow_adam_to_pytorch(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
604
605 assert_eq!(pytorch_config.optimizer_type, "Adam");
606 assert_eq!(pytorch_config.learning_rate, 0.001);
607 assert!(pytorch_config.parameters.contains_key("betas"));
608 assert!(pytorch_config.parameters.contains_key("eps"));
609 }
610
611 #[test]
612 fn test_jax_to_pytorch_conversion() {
613 let converter = CrossFrameworkConverter::new();
614 let pytorch_config = converter.jax_adam_to_pytorch(0.001, 0.9, 0.999, 1e-8).unwrap();
615
616 assert_eq!(pytorch_config.optimizer_type, "Adam");
617 assert_eq!(pytorch_config.learning_rate, 0.001);
618 assert!(pytorch_config.parameters.contains_key("betas"));
619 }
620
621 #[test]
622 fn test_parameter_mapping() {
623 let converter = CrossFrameworkConverter::new();
624 let mut params = HashMap::new();
625 params.insert(
626 "lr".to_string(),
627 Value::Number(serde_json::Number::from_f64(0.001).expect("Invalid constant")),
628 );
629
630 let mapped = converter
631 .map_parameters(¶ms, Framework::PyTorch, Framework::TensorFlow)
632 .unwrap();
633 assert!(mapped.contains_key("learning_rate"));
634 }
635
636 #[test]
637 fn test_universal_conversion() {
638 let converter = CrossFrameworkConverter::new();
639
640 let pytorch_config = PyTorchOptimizerConfig {
641 optimizer_type: "Adam".to_string(),
642 learning_rate: 0.001,
643 parameters: HashMap::new(),
644 };
645
646 let universal = converter.to_universal(&pytorch_config, Framework::PyTorch).unwrap();
647 assert_eq!(universal.optimizer_type, "Adam");
648 assert_eq!(universal.source_framework, Framework::PyTorch);
649
650 let _tf_config = converter.from_universal(&universal, Framework::TensorFlow).unwrap();
651 }
652
653 #[test]
654 fn test_conversion_matrix() {
655 let matrix = utils::create_conversion_matrix();
656 assert!(matrix.get(&(Framework::PyTorch, Framework::TensorFlow)).unwrap());
657 assert!(matrix.get(&(Framework::JAX, Framework::ONNX)).unwrap());
658 }
659
660 #[test]
661 fn test_parameter_validation() {
662 let mut params = HashMap::new();
663 params.insert(
664 "lr".to_string(),
665 Value::Number(serde_json::Number::from_f64(0.001).expect("Invalid constant")),
666 );
667
668 utils::validate_parameters("Adam", ¶ms).unwrap();
669
670 params.insert(
672 "lr".to_string(),
673 Value::Number(serde_json::Number::from_f64(-0.001).expect("Invalid constant")),
674 );
675 assert!(utils::validate_parameters("Adam", ¶ms).is_err());
676 }
677
678 #[test]
679 fn test_conversion_report() {
680 let converter = CrossFrameworkConverter::new();
681 let report =
682 converter.generate_conversion_report(Framework::PyTorch, Framework::TensorFlow);
683
684 assert!(report.contains("PyTorch"));
685 assert!(report.contains("TensorFlow"));
686 assert!(report.contains("->"));
687 }
688
689 #[test]
690 fn test_supported_frameworks() {
691 let frameworks = utils::get_supported_frameworks();
692 assert!(frameworks.contains(&Framework::PyTorch));
693 assert!(frameworks.contains(&Framework::TensorFlow));
694 assert!(frameworks.contains(&Framework::JAX));
695 assert!(frameworks.contains(&Framework::ONNX));
696 assert!(frameworks.contains(&Framework::TrustformeRS));
697 }
698}