Skip to main content

trustformers_optim/
onnx_export.rs

1//! ONNX Optimizer Export
2//!
3//! This module provides functionality to export optimizer configurations and states
4//! to ONNX format, enabling deployment and optimization in ONNX Runtime environments.
5
6use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11/// ONNX Graph Node representation for optimizer operations
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ONNXNode {
14    pub name: String,
15    pub op_type: String,
16    pub inputs: Vec<String>,
17    pub outputs: Vec<String>,
18    pub attributes: HashMap<String, Value>,
19}
20
21/// ONNX Graph representation for optimizer
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ONNXGraph {
24    pub name: String,
25    pub nodes: Vec<ONNXNode>,
26    pub inputs: Vec<String>,
27    pub outputs: Vec<String>,
28    pub initializers: HashMap<String, Vec<f32>>,
29}
30
31/// ONNX Model representation
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ONNXModel {
34    pub ir_version: i64,
35    pub producer_name: String,
36    pub producer_version: String,
37    pub domain: String,
38    pub model_version: i64,
39    pub graph: ONNXGraph,
40}
41
42/// Optimizer configuration for ONNX export
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct OptimizerConfig {
45    pub optimizer_type: String,
46    pub learning_rate: f32,
47    pub parameters: HashMap<String, Value>,
48}
49
50/// ONNX Export configuration
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ONNXExportConfig {
53    pub model_name: String,
54    pub opset_version: i64,
55    pub export_params: bool,
56    pub export_raw_ir: bool,
57    pub keep_initializers_as_inputs: bool,
58    pub custom_opsets: HashMap<String, i64>,
59    pub verbose: bool,
60}
61
62impl Default for ONNXExportConfig {
63    fn default() -> Self {
64        Self {
65            model_name: "TrustformeRS_Optimizer".to_string(),
66            opset_version: 17,
67            export_params: true,
68            export_raw_ir: false,
69            keep_initializers_as_inputs: false,
70            custom_opsets: HashMap::new(),
71            verbose: false,
72        }
73    }
74}
75
76/// ONNX Optimizer metadata for export
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ONNXOptimizerMetadata {
79    pub optimizer_type: String,
80    pub version: String,
81    pub hyperparameters: HashMap<String, Value>,
82    pub state_variables: Vec<String>,
83    pub export_timestamp: String,
84    pub framework_version: String,
85}
86
87impl Default for ONNXOptimizerMetadata {
88    fn default() -> Self {
89        Self {
90            optimizer_type: "Adam".to_string(),
91            version: "1.0".to_string(),
92            hyperparameters: HashMap::new(),
93            state_variables: Vec::new(),
94            export_timestamp: "2025-07-22T00:00:00Z".to_string(),
95            framework_version: "0.1.0".to_string(),
96        }
97    }
98}
99
100/// ONNX Optimizer Exporter
101pub struct ONNXOptimizerExporter {
102    producer_name: String,
103    producer_version: String,
104}
105
106impl ONNXOptimizerExporter {
107    /// Create a new ONNX optimizer exporter
108    pub fn new() -> Self {
109        Self {
110            producer_name: "TrustformeRS".to_string(),
111            producer_version: "1.0.0".to_string(),
112        }
113    }
114
115    /// Export Adam optimizer to ONNX format
116    pub fn export_adam(
117        &self,
118        learning_rate: f32,
119        beta1: f32,
120        beta2: f32,
121        epsilon: f32,
122        weight_decay: f32,
123    ) -> Result<ONNXModel> {
124        let mut nodes = Vec::new();
125        let mut initializers = HashMap::new();
126
127        // Add learning rate as initializer
128        initializers.insert("learning_rate".to_string(), vec![learning_rate]);
129        initializers.insert("beta1".to_string(), vec![beta1]);
130        initializers.insert("beta2".to_string(), vec![beta2]);
131        initializers.insert("epsilon".to_string(), vec![epsilon]);
132        initializers.insert("weight_decay".to_string(), vec![weight_decay]);
133
134        // Create Adam optimizer node
135        let mut adam_attrs = HashMap::new();
136        adam_attrs.insert(
137            "alpha".to_string(),
138            Value::Number(
139                serde_json::Number::from_f64(learning_rate as f64)
140                    .expect("Invalid learning_rate: not a finite number"),
141            ),
142        );
143        adam_attrs.insert(
144            "beta".to_string(),
145            Value::Number(
146                serde_json::Number::from_f64(beta1 as f64)
147                    .expect("Invalid beta1: not a finite number"),
148            ),
149        );
150        adam_attrs.insert(
151            "beta2".to_string(),
152            Value::Number(
153                serde_json::Number::from_f64(beta2 as f64)
154                    .expect("Invalid beta2: not a finite number"),
155            ),
156        );
157        adam_attrs.insert(
158            "epsilon".to_string(),
159            Value::Number(
160                serde_json::Number::from_f64(epsilon as f64)
161                    .expect("Invalid epsilon: not a finite number"),
162            ),
163        );
164        adam_attrs.insert(
165            "weight_decay".to_string(),
166            Value::Number(
167                serde_json::Number::from_f64(weight_decay as f64)
168                    .expect("Invalid weight_decay: not a finite number"),
169            ),
170        );
171
172        let adam_node = ONNXNode {
173            name: "adam_optimizer".to_string(),
174            op_type: "Adam".to_string(),
175            inputs: vec![
176                "gradients".to_string(),
177                "learning_rate".to_string(),
178                "beta1".to_string(),
179                "beta2".to_string(),
180                "epsilon".to_string(),
181                "weight_decay".to_string(),
182            ],
183            outputs: vec!["updated_parameters".to_string()],
184            attributes: adam_attrs,
185        };
186
187        nodes.push(adam_node);
188
189        let graph = ONNXGraph {
190            name: "adam_optimizer_graph".to_string(),
191            nodes,
192            inputs: vec!["gradients".to_string()],
193            outputs: vec!["updated_parameters".to_string()],
194            initializers,
195        };
196
197        Ok(ONNXModel {
198            ir_version: 7,
199            producer_name: self.producer_name.clone(),
200            producer_version: self.producer_version.clone(),
201            domain: "ai.onnx".to_string(),
202            model_version: 1,
203            graph,
204        })
205    }
206
207    /// Export SGD optimizer to ONNX format
208    pub fn export_sgd(
209        &self,
210        learning_rate: f32,
211        momentum: f32,
212        weight_decay: f32,
213        nesterov: bool,
214    ) -> Result<ONNXModel> {
215        let mut nodes = Vec::new();
216        let mut initializers = HashMap::new();
217
218        // Add hyperparameters as initializers
219        initializers.insert("learning_rate".to_string(), vec![learning_rate]);
220        initializers.insert("momentum".to_string(), vec![momentum]);
221        initializers.insert("weight_decay".to_string(), vec![weight_decay]);
222
223        // Create SGD optimizer node
224        let mut sgd_attrs = HashMap::new();
225        sgd_attrs.insert(
226            "learning_rate".to_string(),
227            Value::Number(
228                serde_json::Number::from_f64(learning_rate as f64)
229                    .expect("Invalid learning_rate: not a finite number"),
230            ),
231        );
232        sgd_attrs.insert(
233            "momentum".to_string(),
234            Value::Number(
235                serde_json::Number::from_f64(momentum as f64)
236                    .expect("Invalid momentum: not a finite number"),
237            ),
238        );
239        sgd_attrs.insert(
240            "weight_decay".to_string(),
241            Value::Number(
242                serde_json::Number::from_f64(weight_decay as f64)
243                    .expect("Invalid weight_decay: not a finite number"),
244            ),
245        );
246        sgd_attrs.insert("nesterov".to_string(), Value::Bool(nesterov));
247
248        let sgd_node = ONNXNode {
249            name: "sgd_optimizer".to_string(),
250            op_type: "SGD".to_string(),
251            inputs: vec![
252                "gradients".to_string(),
253                "learning_rate".to_string(),
254                "momentum".to_string(),
255                "weight_decay".to_string(),
256            ],
257            outputs: vec!["updated_parameters".to_string()],
258            attributes: sgd_attrs,
259        };
260
261        nodes.push(sgd_node);
262
263        let graph = ONNXGraph {
264            name: "sgd_optimizer_graph".to_string(),
265            nodes,
266            inputs: vec!["gradients".to_string()],
267            outputs: vec!["updated_parameters".to_string()],
268            initializers,
269        };
270
271        Ok(ONNXModel {
272            ir_version: 7,
273            producer_name: self.producer_name.clone(),
274            producer_version: self.producer_version.clone(),
275            domain: "ai.onnx".to_string(),
276            model_version: 1,
277            graph,
278        })
279    }
280
281    /// Export AdamW optimizer to ONNX format
282    pub fn export_adamw(
283        &self,
284        learning_rate: f32,
285        beta1: f32,
286        beta2: f32,
287        epsilon: f32,
288        weight_decay: f32,
289    ) -> Result<ONNXModel> {
290        let mut nodes = Vec::new();
291        let mut initializers = HashMap::new();
292
293        // Add hyperparameters as initializers
294        initializers.insert("learning_rate".to_string(), vec![learning_rate]);
295        initializers.insert("beta1".to_string(), vec![beta1]);
296        initializers.insert("beta2".to_string(), vec![beta2]);
297        initializers.insert("epsilon".to_string(), vec![epsilon]);
298        initializers.insert("weight_decay".to_string(), vec![weight_decay]);
299
300        // Create AdamW optimizer node
301        let mut adamw_attrs = HashMap::new();
302        adamw_attrs.insert(
303            "alpha".to_string(),
304            Value::Number(
305                serde_json::Number::from_f64(learning_rate as f64)
306                    .expect("Invalid learning_rate: not a finite number"),
307            ),
308        );
309        adamw_attrs.insert(
310            "beta".to_string(),
311            Value::Number(
312                serde_json::Number::from_f64(beta1 as f64)
313                    .expect("Invalid beta1: not a finite number"),
314            ),
315        );
316        adamw_attrs.insert(
317            "beta2".to_string(),
318            Value::Number(
319                serde_json::Number::from_f64(beta2 as f64)
320                    .expect("Invalid beta2: not a finite number"),
321            ),
322        );
323        adamw_attrs.insert(
324            "epsilon".to_string(),
325            Value::Number(
326                serde_json::Number::from_f64(epsilon as f64)
327                    .expect("Invalid epsilon: not a finite number"),
328            ),
329        );
330        adamw_attrs.insert(
331            "weight_decay".to_string(),
332            Value::Number(
333                serde_json::Number::from_f64(weight_decay as f64)
334                    .expect("Invalid weight_decay: not a finite number"),
335            ),
336        );
337
338        let adamw_node = ONNXNode {
339            name: "adamw_optimizer".to_string(),
340            op_type: "AdamW".to_string(),
341            inputs: vec![
342                "gradients".to_string(),
343                "learning_rate".to_string(),
344                "beta1".to_string(),
345                "beta2".to_string(),
346                "epsilon".to_string(),
347                "weight_decay".to_string(),
348            ],
349            outputs: vec!["updated_parameters".to_string()],
350            attributes: adamw_attrs,
351        };
352
353        nodes.push(adamw_node);
354
355        let graph = ONNXGraph {
356            name: "adamw_optimizer_graph".to_string(),
357            nodes,
358            inputs: vec!["gradients".to_string()],
359            outputs: vec!["updated_parameters".to_string()],
360            initializers,
361        };
362
363        Ok(ONNXModel {
364            ir_version: 7,
365            producer_name: self.producer_name.clone(),
366            producer_version: self.producer_version.clone(),
367            domain: "ai.onnx".to_string(),
368            model_version: 1,
369            graph,
370        })
371    }
372
373    /// Export optimizer configuration to JSON format for ONNX metadata
374    pub fn export_config(&self, config: &OptimizerConfig) -> Result<String> {
375        serde_json::to_string_pretty(config)
376            .map_err(|e| anyhow!("Failed to serialize optimizer config: {}", e))
377    }
378
379    /// Save ONNX model to file
380    pub fn save_model(&self, model: &ONNXModel, path: &str) -> Result<()> {
381        let json = serde_json::to_string_pretty(model)
382            .map_err(|e| anyhow!("Failed to serialize ONNX model: {}", e))?;
383
384        std::fs::write(path, json)
385            .map_err(|e| anyhow!("Failed to write ONNX model to file: {}", e))?;
386
387        Ok(())
388    }
389
390    /// Create optimizer config from common optimizers
391    pub fn create_adam_config(
392        &self,
393        learning_rate: f32,
394        beta1: f32,
395        beta2: f32,
396        epsilon: f32,
397        weight_decay: f32,
398    ) -> OptimizerConfig {
399        let mut parameters = HashMap::new();
400        parameters.insert(
401            "beta1".to_string(),
402            Value::Number(
403                serde_json::Number::from_f64(beta1 as f64)
404                    .expect("Invalid beta1: not a finite number"),
405            ),
406        );
407        parameters.insert(
408            "beta2".to_string(),
409            Value::Number(
410                serde_json::Number::from_f64(beta2 as f64)
411                    .expect("Invalid beta2: not a finite number"),
412            ),
413        );
414        parameters.insert(
415            "epsilon".to_string(),
416            Value::Number(
417                serde_json::Number::from_f64(epsilon as f64)
418                    .expect("Invalid epsilon: not a finite number"),
419            ),
420        );
421        parameters.insert(
422            "weight_decay".to_string(),
423            Value::Number(
424                serde_json::Number::from_f64(weight_decay as f64)
425                    .expect("Invalid weight_decay: not a finite number"),
426            ),
427        );
428
429        OptimizerConfig {
430            optimizer_type: "Adam".to_string(),
431            learning_rate,
432            parameters,
433        }
434    }
435
436    pub fn create_sgd_config(
437        &self,
438        learning_rate: f32,
439        momentum: f32,
440        weight_decay: f32,
441        nesterov: bool,
442    ) -> OptimizerConfig {
443        let mut parameters = HashMap::new();
444        parameters.insert(
445            "momentum".to_string(),
446            Value::Number(
447                serde_json::Number::from_f64(momentum as f64)
448                    .expect("Invalid momentum: not a finite number"),
449            ),
450        );
451        parameters.insert(
452            "weight_decay".to_string(),
453            Value::Number(
454                serde_json::Number::from_f64(weight_decay as f64)
455                    .expect("Invalid weight_decay: not a finite number"),
456            ),
457        );
458        parameters.insert("nesterov".to_string(), Value::Bool(nesterov));
459
460        OptimizerConfig {
461            optimizer_type: "SGD".to_string(),
462            learning_rate,
463            parameters,
464        }
465    }
466
467    pub fn create_adamw_config(
468        &self,
469        learning_rate: f32,
470        beta1: f32,
471        beta2: f32,
472        epsilon: f32,
473        weight_decay: f32,
474    ) -> OptimizerConfig {
475        let mut parameters = HashMap::new();
476        parameters.insert(
477            "beta1".to_string(),
478            Value::Number(
479                serde_json::Number::from_f64(beta1 as f64)
480                    .expect("Invalid beta1: not a finite number"),
481            ),
482        );
483        parameters.insert(
484            "beta2".to_string(),
485            Value::Number(
486                serde_json::Number::from_f64(beta2 as f64)
487                    .expect("Invalid beta2: not a finite number"),
488            ),
489        );
490        parameters.insert(
491            "epsilon".to_string(),
492            Value::Number(
493                serde_json::Number::from_f64(epsilon as f64)
494                    .expect("Invalid epsilon: not a finite number"),
495            ),
496        );
497        parameters.insert(
498            "weight_decay".to_string(),
499            Value::Number(
500                serde_json::Number::from_f64(weight_decay as f64)
501                    .expect("Invalid weight_decay: not a finite number"),
502            ),
503        );
504
505        OptimizerConfig {
506            optimizer_type: "AdamW".to_string(),
507            learning_rate,
508            parameters,
509        }
510    }
511}
512
513impl Default for ONNXOptimizerExporter {
514    fn default() -> Self {
515        Self::new()
516    }
517}
518
519/// Utility functions for ONNX export
520pub mod utils {
521    use super::*;
522
523    /// Validate ONNX model structure
524    pub fn validate_model(model: &ONNXModel) -> Result<()> {
525        if model.graph.nodes.is_empty() {
526            return Err(anyhow!("ONNX model must have at least one node"));
527        }
528
529        if model.graph.inputs.is_empty() {
530            return Err(anyhow!("ONNX model must have at least one input"));
531        }
532
533        if model.graph.outputs.is_empty() {
534            return Err(anyhow!("ONNX model must have at least one output"));
535        }
536
537        // Validate node connections
538        for node in &model.graph.nodes {
539            for input in &node.inputs {
540                if !model.graph.inputs.contains(input)
541                    && !model.graph.initializers.contains_key(input)
542                {
543                    // Check if input is output of another node
544                    let is_node_output =
545                        model.graph.nodes.iter().any(|n| n.outputs.contains(input));
546
547                    if !is_node_output {
548                        return Err(anyhow!("Node input '{}' is not connected", input));
549                    }
550                }
551            }
552        }
553
554        Ok(())
555    }
556
557    /// Create ONNX model with learning rate scheduler
558    pub fn create_with_scheduler(
559        optimizer_model: ONNXModel,
560        schedule_type: &str,
561        schedule_params: HashMap<String, f32>,
562    ) -> Result<ONNXModel> {
563        let mut model = optimizer_model;
564
565        // Add scheduler node
566        let mut scheduler_attrs = HashMap::new();
567        for (key, value) in schedule_params {
568            scheduler_attrs.insert(
569                key,
570                Value::Number(
571                    serde_json::Number::from_f64(value as f64)
572                        .expect("Invalid value: not a finite number"),
573                ),
574            );
575        }
576
577        let scheduler_node = ONNXNode {
578            name: "lr_scheduler".to_string(),
579            op_type: schedule_type.to_string(),
580            inputs: vec!["step".to_string()],
581            outputs: vec!["scheduled_learning_rate".to_string()],
582            attributes: scheduler_attrs,
583        };
584
585        model.graph.nodes.insert(0, scheduler_node);
586        model.graph.inputs.push("step".to_string());
587
588        // Update optimizer node to use scheduled learning rate
589        if let Some(optimizer_node) = model
590            .graph
591            .nodes
592            .iter_mut()
593            .find(|n| n.op_type == "Adam" || n.op_type == "SGD" || n.op_type == "AdamW")
594        {
595            if let Some(lr_input_idx) =
596                optimizer_node.inputs.iter().position(|i| i == "learning_rate")
597            {
598                optimizer_node.inputs[lr_input_idx] = "scheduled_learning_rate".to_string();
599            }
600        }
601
602        Ok(model)
603    }
604}
605
606#[cfg(test)]
607mod tests {
608    use super::*;
609
610    #[test]
611    fn test_onnx_adam_export() {
612        let exporter = ONNXOptimizerExporter::new();
613        let model = exporter.export_adam(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
614
615        assert_eq!(model.graph.name, "adam_optimizer_graph");
616        assert_eq!(model.graph.nodes.len(), 1);
617        assert_eq!(model.graph.nodes[0].op_type, "Adam");
618
619        utils::validate_model(&model).unwrap();
620    }
621
622    #[test]
623    fn test_onnx_sgd_export() {
624        let exporter = ONNXOptimizerExporter::new();
625        let model = exporter.export_sgd(0.01, 0.9, 1e-4, true).unwrap();
626
627        assert_eq!(model.graph.name, "sgd_optimizer_graph");
628        assert_eq!(model.graph.nodes.len(), 1);
629        assert_eq!(model.graph.nodes[0].op_type, "SGD");
630
631        utils::validate_model(&model).unwrap();
632    }
633
634    #[test]
635    fn test_onnx_adamw_export() {
636        let exporter = ONNXOptimizerExporter::new();
637        let model = exporter.export_adamw(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
638
639        assert_eq!(model.graph.name, "adamw_optimizer_graph");
640        assert_eq!(model.graph.nodes.len(), 1);
641        assert_eq!(model.graph.nodes[0].op_type, "AdamW");
642
643        utils::validate_model(&model).unwrap();
644    }
645
646    #[test]
647    fn test_config_creation() {
648        let exporter = ONNXOptimizerExporter::new();
649
650        let adam_config = exporter.create_adam_config(0.001, 0.9, 0.999, 1e-8, 0.01);
651        assert_eq!(adam_config.optimizer_type, "Adam");
652        assert_eq!(adam_config.learning_rate, 0.001);
653
654        let sgd_config = exporter.create_sgd_config(0.01, 0.9, 1e-4, true);
655        assert_eq!(sgd_config.optimizer_type, "SGD");
656        assert_eq!(sgd_config.learning_rate, 0.01);
657    }
658
659    #[test]
660    fn test_config_serialization() {
661        let exporter = ONNXOptimizerExporter::new();
662        let config = exporter.create_adam_config(0.001, 0.9, 0.999, 1e-8, 0.01);
663
664        let json = exporter.export_config(&config).unwrap();
665        assert!(json.contains("Adam"));
666        assert!(json.contains("0.001"));
667    }
668
669    #[test]
670    fn test_model_validation() {
671        let exporter = ONNXOptimizerExporter::new();
672        let model = exporter.export_adam(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
673
674        // Should pass validation
675        utils::validate_model(&model).unwrap();
676
677        // Test invalid model
678        let mut invalid_model = model.clone();
679        invalid_model.graph.nodes.clear();
680        assert!(utils::validate_model(&invalid_model).is_err());
681    }
682
683    #[test]
684    fn test_scheduler_integration() {
685        let exporter = ONNXOptimizerExporter::new();
686        let base_model = exporter.export_adam(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
687
688        let mut schedule_params = HashMap::new();
689        schedule_params.insert("decay_rate".to_string(), 0.95);
690
691        let model_with_scheduler =
692            utils::create_with_scheduler(base_model, "ExponentialDecay", schedule_params).unwrap();
693
694        assert_eq!(model_with_scheduler.graph.nodes.len(), 2);
695        assert_eq!(
696            model_with_scheduler.graph.nodes[0].op_type,
697            "ExponentialDecay"
698        );
699
700        utils::validate_model(&model_with_scheduler).unwrap();
701    }
702}