1use crate::error::{NeuralError, Result};
16use oxicode::{config as oxicode_config, serde as oxicode_serde};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20use scirs2_core::ndarray::{Array1, Array2, Array4};
22
23#[non_exhaustive]
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[repr(i32)]
31#[derive(Default)]
32pub enum OnnxDataType {
33 #[default]
35 Float32 = 1,
36 Int32 = 6,
38 Int64 = 7,
40 Float64 = 11,
42}
43
44#[non_exhaustive]
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
49pub enum OnnxAttribute {
50 Float(f32),
52 Int(i64),
54 String(String),
56 Floats(Vec<f32>),
58 Ints(Vec<i64>),
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct OnnxNode {
65 pub op_type: String,
67 pub name: String,
69 pub inputs: Vec<String>,
71 pub outputs: Vec<String>,
73 pub attributes: HashMap<String, OnnxAttribute>,
75}
76
77impl OnnxNode {
78 pub fn new(
80 op_type: impl Into<String>,
81 name: impl Into<String>,
82 inputs: Vec<String>,
83 outputs: Vec<String>,
84 ) -> Self {
85 Self {
86 op_type: op_type.into(),
87 name: name.into(),
88 inputs,
89 outputs,
90 attributes: HashMap::new(),
91 }
92 }
93
94 pub fn with_attr(mut self, key: impl Into<String>, value: OnnxAttribute) -> Self {
96 self.attributes.insert(key.into(), value);
97 self
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct OnnxTensor {
106 pub name: String,
108 pub data_type: OnnxDataType,
110 pub dims: Vec<i64>,
112 pub float_data: Vec<f32>,
114 pub int64_data: Vec<i64>,
116}
117
118impl OnnxTensor {
119 pub fn from_f64_slice(name: impl Into<String>, dims: Vec<i64>, data: &[f64]) -> Self {
123 Self {
124 name: name.into(),
125 data_type: OnnxDataType::Float32,
126 dims,
127 float_data: data.iter().map(|&v| v as f32).collect(),
128 int64_data: Vec::new(),
129 }
130 }
131
132 pub fn from_f32_slice(name: impl Into<String>, dims: Vec<i64>, data: &[f32]) -> Self {
134 Self {
135 name: name.into(),
136 data_type: OnnxDataType::Float32,
137 dims,
138 float_data: data.to_vec(),
139 int64_data: Vec::new(),
140 }
141 }
142
143 pub fn numel(&self) -> usize {
145 self.dims
146 .iter()
147 .map(|&d| d as usize)
148 .product::<usize>()
149 .max(1)
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct OnnxValueInfo {
157 pub name: String,
159 pub data_type: OnnxDataType,
161 pub shape: Vec<Option<i64>>,
163}
164
165impl OnnxValueInfo {
166 pub fn new(name: impl Into<String>, data_type: OnnxDataType, shape: Vec<Option<i64>>) -> Self {
168 Self {
169 name: name.into(),
170 data_type,
171 shape,
172 }
173 }
174}
175
176#[derive(Debug, Clone, Serialize, Deserialize)]
178pub struct OnnxGraph {
179 pub nodes: Vec<OnnxNode>,
181 pub inputs: Vec<OnnxValueInfo>,
183 pub outputs: Vec<OnnxValueInfo>,
185 pub initializers: Vec<OnnxTensor>,
187}
188
189impl OnnxGraph {
190 pub fn new() -> Self {
192 Self {
193 nodes: Vec::new(),
194 inputs: Vec::new(),
195 outputs: Vec::new(),
196 initializers: Vec::new(),
197 }
198 }
199}
200
201impl Default for OnnxGraph {
202 fn default() -> Self {
203 Self::new()
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct OnnxModel {
210 pub graph: OnnxGraph,
212 pub opset_version: i64,
214 pub ir_version: i64,
216 pub producer_name: String,
218 pub model_version: i64,
220}
221
222impl Default for OnnxModel {
223 fn default() -> Self {
224 Self {
225 graph: OnnxGraph::new(),
226 opset_version: 17,
227 ir_version: 8,
228 producer_name: "scirs2-neural".to_string(),
229 model_version: 1,
230 }
231 }
232}
233
234impl OnnxModel {
235 pub fn new(graph: OnnxGraph) -> Self {
237 Self {
238 graph,
239 ..Default::default()
240 }
241 }
242
243 pub fn to_bytes(&self) -> Result<Vec<u8>> {
253 let cfg = oxicode_config::standard();
254 oxicode_serde::encode_to_vec(self, cfg)
255 .map_err(|e| NeuralError::SerializationError(format!("oxicode encode error: {e}")))
256 }
257
258 pub fn from_bytes(data: &[u8]) -> Result<Self> {
260 let cfg = oxicode_config::standard();
261 oxicode_serde::decode_owned_from_slice(data, cfg)
262 .map(|(model, _)| model)
263 .map_err(|e| NeuralError::DeserializationError(format!("oxicode decode error: {e}")))
264 }
265
266 pub fn to_json(&self) -> Result<String> {
268 serde_json::to_string_pretty(self)
269 .map_err(|e| NeuralError::SerializationError(format!("JSON encode error: {e}")))
270 }
271
272 pub fn from_json(json: &str) -> Result<Self> {
274 serde_json::from_str(json)
275 .map_err(|e| NeuralError::DeserializationError(format!("JSON decode error: {e}")))
276 }
277}
278
279pub trait OnnxExportable {
285 fn to_onnx_nodes(&self, input_name: &str, output_name: &str, prefix: &str) -> Vec<OnnxNode>;
291
292 fn to_onnx_initializers(&self, prefix: &str) -> Vec<OnnxTensor>;
294}
295
296pub fn export_linear(
308 weights: &Array2<f64>,
309 bias: Option<&Array1<f64>>,
310 input_name: &str,
311 output_name: &str,
312 prefix: &str,
313) -> (Vec<OnnxNode>, Vec<OnnxTensor>) {
314 let w_name = format!("{prefix}.weight");
315 let b_name = format!("{prefix}.bias");
316
317 let out_features = weights.nrows() as i64;
318 let in_features = weights.ncols() as i64;
319
320 let w_flat: Vec<f64> = weights.iter().copied().collect();
322 let w_tensor = OnnxTensor::from_f64_slice(&w_name, vec![out_features, in_features], &w_flat);
323
324 let mut node_inputs = vec![input_name.to_string(), w_name.clone()];
325 let mut initializers = vec![w_tensor];
326
327 if let Some(b) = bias {
329 let b_flat: Vec<f64> = b.iter().copied().collect();
330 let b_tensor = OnnxTensor::from_f64_slice(&b_name, vec![out_features], &b_flat);
331 initializers.push(b_tensor);
332 node_inputs.push(b_name.clone());
333 }
334
335 let node = OnnxNode::new(
336 "Gemm",
337 format!("{prefix}/Gemm"),
338 node_inputs,
339 vec![output_name.to_string()],
340 )
341 .with_attr("transB", OnnxAttribute::Int(1))
342 .with_attr("alpha", OnnxAttribute::Float(1.0))
343 .with_attr("beta", OnnxAttribute::Float(1.0));
344
345 (vec![node], initializers)
346}
347
348pub fn export_conv2d(
355 weights: &Array4<f64>,
356 bias: Option<&Array1<f64>>,
357 stride: &[usize],
358 padding: &[usize],
359 input_name: &str,
360 output_name: &str,
361 prefix: &str,
362) -> (Vec<OnnxNode>, Vec<OnnxTensor>) {
363 let w_name = format!("{prefix}.weight");
364 let b_name = format!("{prefix}.bias");
365
366 let shape = weights.shape();
367 let dims: Vec<i64> = shape.iter().map(|&d| d as i64).collect();
368 let w_flat: Vec<f64> = weights.iter().copied().collect();
369 let w_tensor = OnnxTensor::from_f64_slice(&w_name, dims, &w_flat);
370
371 let strides_attr: Vec<i64> = stride.iter().map(|&s| s as i64).collect();
372 let pads_onnx: Vec<i64> = padding
373 .iter()
374 .chain(padding.iter())
375 .map(|&p| p as i64)
376 .collect(); let mut node_inputs = vec![input_name.to_string(), w_name.clone()];
379 let mut initializers = vec![w_tensor];
380
381 if let Some(b) = bias {
382 let out_channels = shape[0] as i64;
383 let b_flat: Vec<f64> = b.iter().copied().collect();
384 let b_tensor = OnnxTensor::from_f64_slice(&b_name, vec![out_channels], &b_flat);
385 initializers.push(b_tensor);
386 node_inputs.push(b_name.clone());
387 }
388
389 let node = OnnxNode::new(
390 "Conv",
391 format!("{prefix}/Conv"),
392 node_inputs,
393 vec![output_name.to_string()],
394 )
395 .with_attr("strides", OnnxAttribute::Ints(strides_attr))
396 .with_attr("pads", OnnxAttribute::Ints(pads_onnx));
397
398 (vec![node], initializers)
399}
400
401pub fn export_activation(kind: &str, input_name: &str, output_name: &str) -> OnnxNode {
408 let (op_type, extra): (&str, Option<(&str, OnnxAttribute)>) = match kind.to_lowercase().as_str()
409 {
410 "relu" => ("Relu", None),
411 "sigmoid" => ("Sigmoid", None),
412 "tanh" => ("Tanh", None),
413 "gelu" => ("Gelu", None),
414 "leaky_relu" => ("LeakyRelu", Some(("alpha", OnnxAttribute::Float(0.01)))),
415 "elu" => ("Elu", Some(("alpha", OnnxAttribute::Float(1.0)))),
416 "selu" => ("Selu", None),
417 "softmax" => ("Softmax", Some(("axis", OnnxAttribute::Int(-1)))),
418 "log_softmax" => ("LogSoftmax", Some(("axis", OnnxAttribute::Int(-1)))),
419 unknown => {
420 let mut node = OnnxNode::new(
421 "Relu",
422 format!("{unknown}/fallback_Relu"),
423 vec![input_name.to_string()],
424 vec![output_name.to_string()],
425 );
426 node.attributes.insert(
427 "_scirs2_unsupported_activation".to_string(),
428 OnnxAttribute::String(unknown.to_string()),
429 );
430 return node;
431 }
432 };
433
434 let mut node = OnnxNode::new(
435 op_type,
436 format!("{input_name}/{op_type}"),
437 vec![input_name.to_string()],
438 vec![output_name.to_string()],
439 );
440
441 if let Some((key, val)) = extra {
442 node.attributes.insert(key.to_string(), val);
443 }
444
445 node
446}
447
448pub fn export_batchnorm(
457 scale: &[f64],
458 bias: &[f64],
459 mean: &[f64],
460 var: &[f64],
461 epsilon: Option<f32>,
462 input_name: &str,
463 output_name: &str,
464 prefix: &str,
465) -> (Vec<OnnxNode>, Vec<OnnxTensor>) {
466 let num_features = scale.len() as i64;
467 let eps = epsilon.unwrap_or(1e-5_f32);
468
469 let scale_name = format!("{prefix}.scale");
470 let bias_name = format!("{prefix}.bias");
471 let mean_name = format!("{prefix}.mean");
472 let var_name = format!("{prefix}.var");
473
474 let initializers = vec![
475 OnnxTensor::from_f64_slice(&scale_name, vec![num_features], scale),
476 OnnxTensor::from_f64_slice(&bias_name, vec![num_features], bias),
477 OnnxTensor::from_f64_slice(&mean_name, vec![num_features], mean),
478 OnnxTensor::from_f64_slice(&var_name, vec![num_features], var),
479 ];
480
481 let node = OnnxNode::new(
482 "BatchNormalization",
483 format!("{prefix}/BatchNormalization"),
484 vec![
485 input_name.to_string(),
486 scale_name,
487 bias_name,
488 mean_name,
489 var_name,
490 ],
491 vec![output_name.to_string()],
492 )
493 .with_attr("epsilon", OnnxAttribute::Float(eps));
494
495 (vec![node], initializers)
496}
497
498pub fn export_sequential(
533 layers: &[(String, Vec<OnnxNode>, Vec<OnnxTensor>)],
534 input_shape: &[Option<i64>],
535) -> OnnxModel {
536 let mut graph = OnnxGraph::new();
537
538 graph.inputs.push(OnnxValueInfo::new(
540 "input_0",
541 OnnxDataType::Float32,
542 input_shape.to_vec(),
543 ));
544
545 let mut last_output = "input_0".to_string();
547 for (layer_name, nodes, inits) in layers {
548 graph.initializers.extend(inits.iter().cloned());
549 for node in nodes {
550 graph.nodes.push(node.clone());
551 }
552 if let Some(last_node) = nodes.last() {
554 if let Some(out) = last_node.outputs.first() {
555 last_output = out.clone();
556 } else {
557 last_output = format!("{layer_name}_out");
558 }
559 }
560 }
561
562 graph.outputs.push(OnnxValueInfo::new(
564 last_output,
565 OnnxDataType::Float32,
566 vec![None],
567 ));
568
569 OnnxModel::new(graph)
570}
571
572#[cfg(test)]
577mod tests {
578 use super::*;
579 use scirs2_core::ndarray::{Array1, Array2, Array4};
580
581 #[test]
586 fn test_onnx_activation_node_relu() {
587 let node = export_activation("relu", "x", "y");
588 assert_eq!(node.op_type, "Relu");
589 assert_eq!(node.inputs, vec!["x".to_string()]);
590 assert_eq!(node.outputs, vec!["y".to_string()]);
591 }
592
593 #[test]
594 fn test_onnx_activation_node_sigmoid() {
595 let node = export_activation("sigmoid", "x", "y");
596 assert_eq!(node.op_type, "Sigmoid");
597 }
598
599 #[test]
600 fn test_onnx_activation_node_tanh() {
601 let node = export_activation("tanh", "x", "y");
602 assert_eq!(node.op_type, "Tanh");
603 }
604
605 #[test]
606 fn test_onnx_activation_node_softmax() {
607 let node = export_activation("softmax", "x", "y");
608 assert_eq!(node.op_type, "Softmax");
609 assert!(node.attributes.contains_key("axis"));
610 }
611
612 #[test]
613 fn test_onnx_activation_node_unknown_fallback() {
614 let node = export_activation("crelu_custom", "x", "y");
615 assert_eq!(node.op_type, "Relu");
617 assert!(node
618 .attributes
619 .contains_key("_scirs2_unsupported_activation"));
620 }
621
622 #[test]
627 fn test_onnx_linear_node_no_bias() {
628 let w = Array2::<f64>::zeros((4, 8));
629 let (nodes, inits) = export_linear(&w, None, "x", "y", "fc");
630 assert_eq!(nodes.len(), 1);
631 assert_eq!(nodes[0].op_type, "Gemm");
632 assert_eq!(inits.len(), 1);
634 assert_eq!(inits[0].dims, vec![4_i64, 8_i64]);
635 assert_eq!(inits[0].float_data.len(), 32);
636 }
637
638 #[test]
639 fn test_onnx_linear_node_with_bias() {
640 let w = Array2::<f64>::zeros((4, 8));
641 let b = Array1::<f64>::zeros(4);
642 let (nodes, inits) = export_linear(&w, Some(&b), "x", "y", "fc");
643 assert_eq!(nodes.len(), 1);
644 assert_eq!(inits.len(), 2); assert_eq!(inits[1].dims, vec![4_i64]);
647 assert_eq!(inits[1].float_data.len(), 4);
648 }
649
650 #[test]
651 fn test_onnx_linear_trans_b_attribute() {
652 let w = Array2::<f64>::zeros((3, 5));
653 let (nodes, _) = export_linear(&w, None, "x", "y", "fc");
654 let trans_b = nodes[0].attributes.get("transB").expect("transB attribute");
655 assert_eq!(trans_b, &OnnxAttribute::Int(1));
656 }
657
658 #[test]
663 fn test_onnx_conv2d_node() {
664 let w = Array4::<f64>::zeros((16, 3, 3, 3));
666 let (nodes, inits) = export_conv2d(&w, None, &[1, 1], &[1, 1], "x", "y", "conv1");
667 assert_eq!(nodes.len(), 1);
668 assert_eq!(nodes[0].op_type, "Conv");
669 assert_eq!(inits.len(), 1);
670 assert_eq!(inits[0].dims, vec![16, 3, 3, 3]);
671 assert_eq!(inits[0].float_data.len(), 16 * 3 * 3 * 3);
672 }
673
674 #[test]
675 fn test_onnx_conv2d_with_bias() {
676 let w = Array4::<f64>::zeros((8, 1, 5, 5));
677 let b = Array1::<f64>::zeros(8);
678 let (nodes, inits) = export_conv2d(&w, Some(&b), &[2, 2], &[0, 0], "x", "y", "conv0");
679 assert_eq!(nodes.len(), 1);
680 assert_eq!(inits.len(), 2);
681 let strides = nodes[0].attributes.get("strides").expect("strides");
683 assert_eq!(strides, &OnnxAttribute::Ints(vec![2, 2]));
684 }
685
686 #[test]
691 fn test_onnx_batchnorm_export() {
692 let scale = vec![1.0_f64; 32];
693 let bias = vec![0.0_f64; 32];
694 let mean = vec![0.0_f64; 32];
695 let var = vec![1.0_f64; 32];
696 let (nodes, inits) = export_batchnorm(&scale, &bias, &mean, &var, None, "x", "y", "bn1");
697 assert_eq!(nodes.len(), 1);
698 assert_eq!(nodes[0].op_type, "BatchNormalization");
699 assert_eq!(inits.len(), 4);
701 for init in &inits {
702 assert_eq!(init.dims, vec![32_i64]);
703 assert_eq!(init.float_data.len(), 32);
704 }
705 }
706
707 #[test]
708 fn test_onnx_batchnorm_epsilon_attribute() {
709 let v = vec![1.0_f64; 4];
710 let (nodes, _) = export_batchnorm(&v, &v, &v, &v, Some(1e-3), "x", "y", "bn");
711 let eps = nodes[0].attributes.get("epsilon").expect("epsilon attr");
712 assert_eq!(eps, &OnnxAttribute::Float(1e-3_f32));
713 }
714
715 #[test]
720 fn test_onnx_opset_default() {
721 let model = OnnxModel::default();
722 assert_eq!(model.opset_version, 17);
723 assert_eq!(model.ir_version, 8);
724 assert_eq!(model.producer_name, "scirs2-neural");
725 }
726
727 fn build_small_model() -> OnnxModel {
732 let w = Array2::<f64>::zeros((4, 8));
733 let b = Array1::<f64>::zeros(4);
734 let (nodes, inits) = export_linear(&w, Some(&b), "input_0", "output_0", "fc0");
735 let mut graph = OnnxGraph::new();
736 graph.inputs.push(OnnxValueInfo::new(
737 "input_0",
738 OnnxDataType::Float32,
739 vec![None, Some(8)],
740 ));
741 graph.outputs.push(OnnxValueInfo::new(
742 "output_0",
743 OnnxDataType::Float32,
744 vec![None, Some(4)],
745 ));
746 graph.nodes.extend(nodes);
747 graph.initializers.extend(inits);
748 OnnxModel::new(graph)
749 }
750
751 #[test]
752 fn test_onnx_model_roundtrip_bytes() {
753 let original = build_small_model();
754 let bytes = original.to_bytes().expect("to_bytes failed");
755 let restored = OnnxModel::from_bytes(&bytes).expect("from_bytes failed");
756 assert_eq!(restored.opset_version, original.opset_version);
757 assert_eq!(restored.graph.nodes.len(), original.graph.nodes.len());
758 assert_eq!(
759 restored.graph.initializers.len(),
760 original.graph.initializers.len()
761 );
762 assert_eq!(
763 restored.graph.initializers[0].float_data.len(),
764 original.graph.initializers[0].float_data.len()
765 );
766 }
767
768 #[test]
769 fn test_onnx_json_roundtrip() {
770 let original = build_small_model();
771 let json = original.to_json().expect("to_json failed");
772 assert!(json.contains("Gemm"));
773 let restored = OnnxModel::from_json(&json).expect("from_json failed");
774 assert_eq!(restored.graph.nodes[0].op_type, "Gemm");
775 assert_eq!(restored.graph.inputs[0].name, "input_0");
776 }
777
778 #[test]
779 fn test_onnx_json_contains_producer_name() {
780 let model = OnnxModel::default();
781 let json = model.to_json().expect("to_json");
782 assert!(json.contains("scirs2-neural"));
783 }
784
785 #[test]
790 fn test_onnx_sequential_graph() {
791 let w1 = Array2::<f64>::zeros((64, 784));
792 let (n1, i1) = export_linear(&w1, None, "input_0", "fc0_out", "fc0");
793 let act1 = export_activation("relu", "fc0_out", "act0_out");
794
795 let w2 = Array2::<f64>::zeros((10, 64));
796 let (n2, i2) = export_linear(&w2, None, "act0_out", "output_0", "fc1");
797
798 let layers = vec![
799 ("fc0".to_string(), n1, i1),
800 ("act0".to_string(), vec![act1], vec![]),
801 ("fc1".to_string(), n2, i2),
802 ];
803
804 let model = export_sequential(&layers, &[None, Some(784)]);
805 assert_eq!(model.graph.nodes.len(), 3);
807 assert_eq!(model.graph.nodes[0].op_type, "Gemm");
808 assert_eq!(model.graph.nodes[1].op_type, "Relu");
809 assert_eq!(model.graph.nodes[2].op_type, "Gemm");
810 assert_eq!(model.graph.initializers.len(), 2);
812 assert_eq!(model.opset_version, 17);
813 }
814
815 #[test]
816 fn test_onnx_sequential_single_layer() {
817 let w = Array2::<f64>::zeros((2, 3));
818 let (nodes, inits) = export_linear(&w, None, "input_0", "output_0", "fc");
819 let layers = vec![("fc".to_string(), nodes, inits)];
820 let model = export_sequential(&layers, &[None, Some(3)]);
821 assert_eq!(model.graph.nodes.len(), 1);
822 assert_eq!(model.graph.inputs[0].name, "input_0");
823 }
824
825 #[test]
826 fn test_onnx_tensor_numel() {
827 let t = OnnxTensor::from_f64_slice("t", vec![2, 3, 4], &[0.0_f64; 24]);
828 assert_eq!(t.numel(), 24);
829 }
830
831 #[test]
832 fn test_onnx_node_builder_with_attr() {
833 let node = OnnxNode::new("Relu", "r", vec!["x".to_string()], vec!["y".to_string()])
834 .with_attr("alpha", OnnxAttribute::Float(0.1));
835 assert!(node.attributes.contains_key("alpha"));
836 }
837}