scirs2_io/ml_framework/converters/
onnx.rs1#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::MLFrameworkConverter;
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::fs::File;
9use std::path::Path;
10
11pub struct ONNXConverter;
13
14impl MLFrameworkConverter for ONNXConverter {
15 fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
16 let onnx_model = serde_json::json!({
18 "format": "onnx",
19 "version": "1.0",
20 "graph": {
21 "name": model.metadata.model_name,
22 "inputs": model.metadata.inputshapes,
23 "outputs": model.metadata.outputshapes,
24 "initializers": model.weights.iter().map(|(name, tensor)| {
25 serde_json::json!({
26 "name": name,
27 "shape": tensor.metadata.shape,
28 "dtype": tensor.metadata.dtype,
29 })
30 }).collect::<Vec<_>>(),
31 },
32 "metadata": model.metadata,
33 });
34
35 let file = File::create(path).map_err(IoError::Io)?;
36 serde_json::to_writer_pretty(file, &onnx_model)
37 .map_err(|e| IoError::SerializationError(e.to_string()))
38 }
39
40 fn load_model(&self, path: &Path) -> Result<MLModel> {
41 let file = File::open(path).map_err(IoError::Io)?;
42 let onnx_model: serde_json::Value = serde_json::from_reader(file)
43 .map_err(|e| IoError::SerializationError(e.to_string()))?;
44
45 let mut model = MLModel::new(MLFramework::ONNX);
46
47 if let Some(graph) = onnx_model.get("graph") {
49 if let Some(name) = graph.get("name").and_then(|v| v.as_str()) {
50 model.metadata.model_name = Some(name.to_string());
51 }
52
53 if let Some(inputs) = graph.get("inputs").and_then(|v| v.as_object()) {
55 for (name, shape_val) in inputs {
56 if let Some(shape) = shape_val.as_array() {
57 let shape_vec: Vec<usize> = shape
58 .iter()
59 .filter_map(|v| v.as_u64().map(|u| u as usize))
60 .collect();
61 model.metadata.inputshapes.insert(name.clone(), shape_vec);
62 }
63 }
64 }
65
66 if let Some(outputs) = graph.get("outputs").and_then(|v| v.as_object()) {
67 for (name, shape_val) in outputs {
68 if let Some(shape) = shape_val.as_array() {
69 let shape_vec: Vec<usize> = shape
70 .iter()
71 .filter_map(|v| v.as_u64().map(|u| u as usize))
72 .collect();
73 model.metadata.outputshapes.insert(name.clone(), shape_vec);
74 }
75 }
76 }
77
78 if let Some(initializers) = graph.get("initializers").and_then(|v| v.as_array()) {
80 for init in initializers {
81 if let Some(init_obj) = init.as_object() {
82 if let (Some(name), Some(shape), Some(_dtype)) = (
83 init_obj.get("name").and_then(|v| v.as_str()),
84 init_obj.get("shape").and_then(|v| v.as_array()),
85 init_obj.get("dtype"),
86 ) {
87 let shape_vec: Vec<usize> = shape
88 .iter()
89 .filter_map(|v| v.as_u64().map(|u| u as usize))
90 .collect();
91
92 let data = if let Some(data_array) =
94 init_obj.get("data").and_then(|v| v.as_array())
95 {
96 data_array
98 .iter()
99 .filter_map(|v| v.as_f64().map(|f| f as f32))
100 .collect::<Vec<f32>>()
101 } else {
102 let total_elements: usize = shape_vec.iter().product();
104 vec![0.0f32; total_elements]
105 };
106
107 if let Ok(array) = ArrayD::from_shape_vec(IxDyn(&shape_vec), data) {
108 model.weights.insert(
109 name.to_string(),
110 MLTensor::new(array, Some(name.to_string())),
111 );
112 }
113 }
114 }
115 }
116 }
117 }
118
119 Ok(model)
120 }
121
122 fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
123 let tensor_data = serde_json::json!({
124 "name": tensor.metadata.name,
125 "shape": tensor.metadata.shape,
126 "dtype": "float32",
127 "data": tensor.data.as_slice().unwrap().to_vec(),
128 });
129
130 let file = File::create(path).map_err(IoError::Io)?;
131 serde_json::to_writer_pretty(file, &tensor_data)
132 .map_err(|e| IoError::SerializationError(e.to_string()))
133 }
134
135 fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
136 let file = File::open(path).map_err(IoError::Io)?;
137 let tensor_data: serde_json::Value = serde_json::from_reader(file)
138 .map_err(|e| IoError::SerializationError(e.to_string()))?;
139
140 let shape: Vec<usize> = serde_json::from_value(tensor_data["shape"].clone())
141 .map_err(|e| IoError::SerializationError(e.to_string()))?;
142
143 let data: Vec<f32> = serde_json::from_value(tensor_data["data"].clone())
144 .map_err(|e| IoError::SerializationError(e.to_string()))?;
145
146 let name = tensor_data
147 .get("name")
148 .and_then(|v| v.as_str())
149 .map(|s| s.to_string());
150
151 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
152 .map_err(|e| IoError::Other(e.to_string()))?;
153
154 Ok(MLTensor::new(array, name))
155 }
156}