1use anyhow::{anyhow, Result};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ONNXNode {
14 pub name: String,
15 pub op_type: String,
16 pub inputs: Vec<String>,
17 pub outputs: Vec<String>,
18 pub attributes: HashMap<String, Value>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ONNXGraph {
24 pub name: String,
25 pub nodes: Vec<ONNXNode>,
26 pub inputs: Vec<String>,
27 pub outputs: Vec<String>,
28 pub initializers: HashMap<String, Vec<f32>>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ONNXModel {
34 pub ir_version: i64,
35 pub producer_name: String,
36 pub producer_version: String,
37 pub domain: String,
38 pub model_version: i64,
39 pub graph: ONNXGraph,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct OptimizerConfig {
45 pub optimizer_type: String,
46 pub learning_rate: f32,
47 pub parameters: HashMap<String, Value>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct ONNXExportConfig {
53 pub model_name: String,
54 pub opset_version: i64,
55 pub export_params: bool,
56 pub export_raw_ir: bool,
57 pub keep_initializers_as_inputs: bool,
58 pub custom_opsets: HashMap<String, i64>,
59 pub verbose: bool,
60}
61
62impl Default for ONNXExportConfig {
63 fn default() -> Self {
64 Self {
65 model_name: "TrustformeRS_Optimizer".to_string(),
66 opset_version: 17,
67 export_params: true,
68 export_raw_ir: false,
69 keep_initializers_as_inputs: false,
70 custom_opsets: HashMap::new(),
71 verbose: false,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct ONNXOptimizerMetadata {
79 pub optimizer_type: String,
80 pub version: String,
81 pub hyperparameters: HashMap<String, Value>,
82 pub state_variables: Vec<String>,
83 pub export_timestamp: String,
84 pub framework_version: String,
85}
86
87impl Default for ONNXOptimizerMetadata {
88 fn default() -> Self {
89 Self {
90 optimizer_type: "Adam".to_string(),
91 version: "1.0".to_string(),
92 hyperparameters: HashMap::new(),
93 state_variables: Vec::new(),
94 export_timestamp: "2025-07-22T00:00:00Z".to_string(),
95 framework_version: "0.1.0".to_string(),
96 }
97 }
98}
99
100pub struct ONNXOptimizerExporter {
102 producer_name: String,
103 producer_version: String,
104}
105
106impl ONNXOptimizerExporter {
107 pub fn new() -> Self {
109 Self {
110 producer_name: "TrustformeRS".to_string(),
111 producer_version: "1.0.0".to_string(),
112 }
113 }
114
115 pub fn export_adam(
117 &self,
118 learning_rate: f32,
119 beta1: f32,
120 beta2: f32,
121 epsilon: f32,
122 weight_decay: f32,
123 ) -> Result<ONNXModel> {
124 let mut nodes = Vec::new();
125 let mut initializers = HashMap::new();
126
127 initializers.insert("learning_rate".to_string(), vec![learning_rate]);
129 initializers.insert("beta1".to_string(), vec![beta1]);
130 initializers.insert("beta2".to_string(), vec![beta2]);
131 initializers.insert("epsilon".to_string(), vec![epsilon]);
132 initializers.insert("weight_decay".to_string(), vec![weight_decay]);
133
134 let mut adam_attrs = HashMap::new();
136 adam_attrs.insert(
137 "alpha".to_string(),
138 Value::Number(
139 serde_json::Number::from_f64(learning_rate as f64)
140 .expect("Invalid learning_rate: not a finite number"),
141 ),
142 );
143 adam_attrs.insert(
144 "beta".to_string(),
145 Value::Number(
146 serde_json::Number::from_f64(beta1 as f64)
147 .expect("Invalid beta1: not a finite number"),
148 ),
149 );
150 adam_attrs.insert(
151 "beta2".to_string(),
152 Value::Number(
153 serde_json::Number::from_f64(beta2 as f64)
154 .expect("Invalid beta2: not a finite number"),
155 ),
156 );
157 adam_attrs.insert(
158 "epsilon".to_string(),
159 Value::Number(
160 serde_json::Number::from_f64(epsilon as f64)
161 .expect("Invalid epsilon: not a finite number"),
162 ),
163 );
164 adam_attrs.insert(
165 "weight_decay".to_string(),
166 Value::Number(
167 serde_json::Number::from_f64(weight_decay as f64)
168 .expect("Invalid weight_decay: not a finite number"),
169 ),
170 );
171
172 let adam_node = ONNXNode {
173 name: "adam_optimizer".to_string(),
174 op_type: "Adam".to_string(),
175 inputs: vec![
176 "gradients".to_string(),
177 "learning_rate".to_string(),
178 "beta1".to_string(),
179 "beta2".to_string(),
180 "epsilon".to_string(),
181 "weight_decay".to_string(),
182 ],
183 outputs: vec!["updated_parameters".to_string()],
184 attributes: adam_attrs,
185 };
186
187 nodes.push(adam_node);
188
189 let graph = ONNXGraph {
190 name: "adam_optimizer_graph".to_string(),
191 nodes,
192 inputs: vec!["gradients".to_string()],
193 outputs: vec!["updated_parameters".to_string()],
194 initializers,
195 };
196
197 Ok(ONNXModel {
198 ir_version: 7,
199 producer_name: self.producer_name.clone(),
200 producer_version: self.producer_version.clone(),
201 domain: "ai.onnx".to_string(),
202 model_version: 1,
203 graph,
204 })
205 }
206
207 pub fn export_sgd(
209 &self,
210 learning_rate: f32,
211 momentum: f32,
212 weight_decay: f32,
213 nesterov: bool,
214 ) -> Result<ONNXModel> {
215 let mut nodes = Vec::new();
216 let mut initializers = HashMap::new();
217
218 initializers.insert("learning_rate".to_string(), vec![learning_rate]);
220 initializers.insert("momentum".to_string(), vec![momentum]);
221 initializers.insert("weight_decay".to_string(), vec![weight_decay]);
222
223 let mut sgd_attrs = HashMap::new();
225 sgd_attrs.insert(
226 "learning_rate".to_string(),
227 Value::Number(
228 serde_json::Number::from_f64(learning_rate as f64)
229 .expect("Invalid learning_rate: not a finite number"),
230 ),
231 );
232 sgd_attrs.insert(
233 "momentum".to_string(),
234 Value::Number(
235 serde_json::Number::from_f64(momentum as f64)
236 .expect("Invalid momentum: not a finite number"),
237 ),
238 );
239 sgd_attrs.insert(
240 "weight_decay".to_string(),
241 Value::Number(
242 serde_json::Number::from_f64(weight_decay as f64)
243 .expect("Invalid weight_decay: not a finite number"),
244 ),
245 );
246 sgd_attrs.insert("nesterov".to_string(), Value::Bool(nesterov));
247
248 let sgd_node = ONNXNode {
249 name: "sgd_optimizer".to_string(),
250 op_type: "SGD".to_string(),
251 inputs: vec![
252 "gradients".to_string(),
253 "learning_rate".to_string(),
254 "momentum".to_string(),
255 "weight_decay".to_string(),
256 ],
257 outputs: vec!["updated_parameters".to_string()],
258 attributes: sgd_attrs,
259 };
260
261 nodes.push(sgd_node);
262
263 let graph = ONNXGraph {
264 name: "sgd_optimizer_graph".to_string(),
265 nodes,
266 inputs: vec!["gradients".to_string()],
267 outputs: vec!["updated_parameters".to_string()],
268 initializers,
269 };
270
271 Ok(ONNXModel {
272 ir_version: 7,
273 producer_name: self.producer_name.clone(),
274 producer_version: self.producer_version.clone(),
275 domain: "ai.onnx".to_string(),
276 model_version: 1,
277 graph,
278 })
279 }
280
281 pub fn export_adamw(
283 &self,
284 learning_rate: f32,
285 beta1: f32,
286 beta2: f32,
287 epsilon: f32,
288 weight_decay: f32,
289 ) -> Result<ONNXModel> {
290 let mut nodes = Vec::new();
291 let mut initializers = HashMap::new();
292
293 initializers.insert("learning_rate".to_string(), vec![learning_rate]);
295 initializers.insert("beta1".to_string(), vec![beta1]);
296 initializers.insert("beta2".to_string(), vec![beta2]);
297 initializers.insert("epsilon".to_string(), vec![epsilon]);
298 initializers.insert("weight_decay".to_string(), vec![weight_decay]);
299
300 let mut adamw_attrs = HashMap::new();
302 adamw_attrs.insert(
303 "alpha".to_string(),
304 Value::Number(
305 serde_json::Number::from_f64(learning_rate as f64)
306 .expect("Invalid learning_rate: not a finite number"),
307 ),
308 );
309 adamw_attrs.insert(
310 "beta".to_string(),
311 Value::Number(
312 serde_json::Number::from_f64(beta1 as f64)
313 .expect("Invalid beta1: not a finite number"),
314 ),
315 );
316 adamw_attrs.insert(
317 "beta2".to_string(),
318 Value::Number(
319 serde_json::Number::from_f64(beta2 as f64)
320 .expect("Invalid beta2: not a finite number"),
321 ),
322 );
323 adamw_attrs.insert(
324 "epsilon".to_string(),
325 Value::Number(
326 serde_json::Number::from_f64(epsilon as f64)
327 .expect("Invalid epsilon: not a finite number"),
328 ),
329 );
330 adamw_attrs.insert(
331 "weight_decay".to_string(),
332 Value::Number(
333 serde_json::Number::from_f64(weight_decay as f64)
334 .expect("Invalid weight_decay: not a finite number"),
335 ),
336 );
337
338 let adamw_node = ONNXNode {
339 name: "adamw_optimizer".to_string(),
340 op_type: "AdamW".to_string(),
341 inputs: vec![
342 "gradients".to_string(),
343 "learning_rate".to_string(),
344 "beta1".to_string(),
345 "beta2".to_string(),
346 "epsilon".to_string(),
347 "weight_decay".to_string(),
348 ],
349 outputs: vec!["updated_parameters".to_string()],
350 attributes: adamw_attrs,
351 };
352
353 nodes.push(adamw_node);
354
355 let graph = ONNXGraph {
356 name: "adamw_optimizer_graph".to_string(),
357 nodes,
358 inputs: vec!["gradients".to_string()],
359 outputs: vec!["updated_parameters".to_string()],
360 initializers,
361 };
362
363 Ok(ONNXModel {
364 ir_version: 7,
365 producer_name: self.producer_name.clone(),
366 producer_version: self.producer_version.clone(),
367 domain: "ai.onnx".to_string(),
368 model_version: 1,
369 graph,
370 })
371 }
372
373 pub fn export_config(&self, config: &OptimizerConfig) -> Result<String> {
375 serde_json::to_string_pretty(config)
376 .map_err(|e| anyhow!("Failed to serialize optimizer config: {}", e))
377 }
378
379 pub fn save_model(&self, model: &ONNXModel, path: &str) -> Result<()> {
381 let json = serde_json::to_string_pretty(model)
382 .map_err(|e| anyhow!("Failed to serialize ONNX model: {}", e))?;
383
384 std::fs::write(path, json)
385 .map_err(|e| anyhow!("Failed to write ONNX model to file: {}", e))?;
386
387 Ok(())
388 }
389
390 pub fn create_adam_config(
392 &self,
393 learning_rate: f32,
394 beta1: f32,
395 beta2: f32,
396 epsilon: f32,
397 weight_decay: f32,
398 ) -> OptimizerConfig {
399 let mut parameters = HashMap::new();
400 parameters.insert(
401 "beta1".to_string(),
402 Value::Number(
403 serde_json::Number::from_f64(beta1 as f64)
404 .expect("Invalid beta1: not a finite number"),
405 ),
406 );
407 parameters.insert(
408 "beta2".to_string(),
409 Value::Number(
410 serde_json::Number::from_f64(beta2 as f64)
411 .expect("Invalid beta2: not a finite number"),
412 ),
413 );
414 parameters.insert(
415 "epsilon".to_string(),
416 Value::Number(
417 serde_json::Number::from_f64(epsilon as f64)
418 .expect("Invalid epsilon: not a finite number"),
419 ),
420 );
421 parameters.insert(
422 "weight_decay".to_string(),
423 Value::Number(
424 serde_json::Number::from_f64(weight_decay as f64)
425 .expect("Invalid weight_decay: not a finite number"),
426 ),
427 );
428
429 OptimizerConfig {
430 optimizer_type: "Adam".to_string(),
431 learning_rate,
432 parameters,
433 }
434 }
435
436 pub fn create_sgd_config(
437 &self,
438 learning_rate: f32,
439 momentum: f32,
440 weight_decay: f32,
441 nesterov: bool,
442 ) -> OptimizerConfig {
443 let mut parameters = HashMap::new();
444 parameters.insert(
445 "momentum".to_string(),
446 Value::Number(
447 serde_json::Number::from_f64(momentum as f64)
448 .expect("Invalid momentum: not a finite number"),
449 ),
450 );
451 parameters.insert(
452 "weight_decay".to_string(),
453 Value::Number(
454 serde_json::Number::from_f64(weight_decay as f64)
455 .expect("Invalid weight_decay: not a finite number"),
456 ),
457 );
458 parameters.insert("nesterov".to_string(), Value::Bool(nesterov));
459
460 OptimizerConfig {
461 optimizer_type: "SGD".to_string(),
462 learning_rate,
463 parameters,
464 }
465 }
466
467 pub fn create_adamw_config(
468 &self,
469 learning_rate: f32,
470 beta1: f32,
471 beta2: f32,
472 epsilon: f32,
473 weight_decay: f32,
474 ) -> OptimizerConfig {
475 let mut parameters = HashMap::new();
476 parameters.insert(
477 "beta1".to_string(),
478 Value::Number(
479 serde_json::Number::from_f64(beta1 as f64)
480 .expect("Invalid beta1: not a finite number"),
481 ),
482 );
483 parameters.insert(
484 "beta2".to_string(),
485 Value::Number(
486 serde_json::Number::from_f64(beta2 as f64)
487 .expect("Invalid beta2: not a finite number"),
488 ),
489 );
490 parameters.insert(
491 "epsilon".to_string(),
492 Value::Number(
493 serde_json::Number::from_f64(epsilon as f64)
494 .expect("Invalid epsilon: not a finite number"),
495 ),
496 );
497 parameters.insert(
498 "weight_decay".to_string(),
499 Value::Number(
500 serde_json::Number::from_f64(weight_decay as f64)
501 .expect("Invalid weight_decay: not a finite number"),
502 ),
503 );
504
505 OptimizerConfig {
506 optimizer_type: "AdamW".to_string(),
507 learning_rate,
508 parameters,
509 }
510 }
511}
512
513impl Default for ONNXOptimizerExporter {
514 fn default() -> Self {
515 Self::new()
516 }
517}
518
519pub mod utils {
521 use super::*;
522
523 pub fn validate_model(model: &ONNXModel) -> Result<()> {
525 if model.graph.nodes.is_empty() {
526 return Err(anyhow!("ONNX model must have at least one node"));
527 }
528
529 if model.graph.inputs.is_empty() {
530 return Err(anyhow!("ONNX model must have at least one input"));
531 }
532
533 if model.graph.outputs.is_empty() {
534 return Err(anyhow!("ONNX model must have at least one output"));
535 }
536
537 for node in &model.graph.nodes {
539 for input in &node.inputs {
540 if !model.graph.inputs.contains(input)
541 && !model.graph.initializers.contains_key(input)
542 {
543 let is_node_output =
545 model.graph.nodes.iter().any(|n| n.outputs.contains(input));
546
547 if !is_node_output {
548 return Err(anyhow!("Node input '{}' is not connected", input));
549 }
550 }
551 }
552 }
553
554 Ok(())
555 }
556
557 pub fn create_with_scheduler(
559 optimizer_model: ONNXModel,
560 schedule_type: &str,
561 schedule_params: HashMap<String, f32>,
562 ) -> Result<ONNXModel> {
563 let mut model = optimizer_model;
564
565 let mut scheduler_attrs = HashMap::new();
567 for (key, value) in schedule_params {
568 scheduler_attrs.insert(
569 key,
570 Value::Number(
571 serde_json::Number::from_f64(value as f64)
572 .expect("Invalid value: not a finite number"),
573 ),
574 );
575 }
576
577 let scheduler_node = ONNXNode {
578 name: "lr_scheduler".to_string(),
579 op_type: schedule_type.to_string(),
580 inputs: vec!["step".to_string()],
581 outputs: vec!["scheduled_learning_rate".to_string()],
582 attributes: scheduler_attrs,
583 };
584
585 model.graph.nodes.insert(0, scheduler_node);
586 model.graph.inputs.push("step".to_string());
587
588 if let Some(optimizer_node) = model
590 .graph
591 .nodes
592 .iter_mut()
593 .find(|n| n.op_type == "Adam" || n.op_type == "SGD" || n.op_type == "AdamW")
594 {
595 if let Some(lr_input_idx) =
596 optimizer_node.inputs.iter().position(|i| i == "learning_rate")
597 {
598 optimizer_node.inputs[lr_input_idx] = "scheduled_learning_rate".to_string();
599 }
600 }
601
602 Ok(model)
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 #[test]
611 fn test_onnx_adam_export() {
612 let exporter = ONNXOptimizerExporter::new();
613 let model = exporter.export_adam(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
614
615 assert_eq!(model.graph.name, "adam_optimizer_graph");
616 assert_eq!(model.graph.nodes.len(), 1);
617 assert_eq!(model.graph.nodes[0].op_type, "Adam");
618
619 utils::validate_model(&model).unwrap();
620 }
621
622 #[test]
623 fn test_onnx_sgd_export() {
624 let exporter = ONNXOptimizerExporter::new();
625 let model = exporter.export_sgd(0.01, 0.9, 1e-4, true).unwrap();
626
627 assert_eq!(model.graph.name, "sgd_optimizer_graph");
628 assert_eq!(model.graph.nodes.len(), 1);
629 assert_eq!(model.graph.nodes[0].op_type, "SGD");
630
631 utils::validate_model(&model).unwrap();
632 }
633
634 #[test]
635 fn test_onnx_adamw_export() {
636 let exporter = ONNXOptimizerExporter::new();
637 let model = exporter.export_adamw(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
638
639 assert_eq!(model.graph.name, "adamw_optimizer_graph");
640 assert_eq!(model.graph.nodes.len(), 1);
641 assert_eq!(model.graph.nodes[0].op_type, "AdamW");
642
643 utils::validate_model(&model).unwrap();
644 }
645
646 #[test]
647 fn test_config_creation() {
648 let exporter = ONNXOptimizerExporter::new();
649
650 let adam_config = exporter.create_adam_config(0.001, 0.9, 0.999, 1e-8, 0.01);
651 assert_eq!(adam_config.optimizer_type, "Adam");
652 assert_eq!(adam_config.learning_rate, 0.001);
653
654 let sgd_config = exporter.create_sgd_config(0.01, 0.9, 1e-4, true);
655 assert_eq!(sgd_config.optimizer_type, "SGD");
656 assert_eq!(sgd_config.learning_rate, 0.01);
657 }
658
659 #[test]
660 fn test_config_serialization() {
661 let exporter = ONNXOptimizerExporter::new();
662 let config = exporter.create_adam_config(0.001, 0.9, 0.999, 1e-8, 0.01);
663
664 let json = exporter.export_config(&config).unwrap();
665 assert!(json.contains("Adam"));
666 assert!(json.contains("0.001"));
667 }
668
669 #[test]
670 fn test_model_validation() {
671 let exporter = ONNXOptimizerExporter::new();
672 let model = exporter.export_adam(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
673
674 utils::validate_model(&model).unwrap();
676
677 let mut invalid_model = model.clone();
679 invalid_model.graph.nodes.clear();
680 assert!(utils::validate_model(&invalid_model).is_err());
681 }
682
683 #[test]
684 fn test_scheduler_integration() {
685 let exporter = ONNXOptimizerExporter::new();
686 let base_model = exporter.export_adam(0.001, 0.9, 0.999, 1e-8, 0.01).unwrap();
687
688 let mut schedule_params = HashMap::new();
689 schedule_params.insert("decay_rate".to_string(), 0.95);
690
691 let model_with_scheduler =
692 utils::create_with_scheduler(base_model, "ExponentialDecay", schedule_params).unwrap();
693
694 assert_eq!(model_with_scheduler.graph.nodes.len(), 2);
695 assert_eq!(
696 model_with_scheduler.graph.nodes[0].op_type,
697 "ExponentialDecay"
698 );
699
700 utils::validate_model(&model_with_scheduler).unwrap();
701 }
702}