scirs2_io/ml_framework/converters/
coreml.rs1#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::MLFrameworkConverter;
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::fs::File;
9use std::path::Path;
10
11pub struct CoreMLConverter;
13
14impl MLFrameworkConverter for CoreMLConverter {
15 fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
16 let coreml_model = serde_json::json!({
18 "format": "coreml",
19 "specificationVersion": 5,
20 "description": {
21 "metadata": {
22 "userDefined": model.metadata.parameters,
23 "author": "SciRS2",
24 "license": "MIT",
25 "shortDescription": model.metadata.model_name.clone().unwrap_or_default()
26 },
27 "input": model.metadata.inputshapes.iter().map(|(name, shape)| {
28 serde_json::json!({
29 "name": name,
30 "type": {
31 "multiArrayType": {
32 "shape": shape,
33 "dataType": "FLOAT32"
34 }
35 }
36 })
37 }).collect::<Vec<_>>(),
38 "output": model.metadata.outputshapes.iter().map(|(name, shape)| {
39 serde_json::json!({
40 "name": name,
41 "type": {
42 "multiArrayType": {
43 "shape": shape,
44 "dataType": "FLOAT32"
45 }
46 }
47 })
48 }).collect::<Vec<_>>()
49 },
50 "neuralNetwork": {
51 "layers": [],
52 "preprocessing": []
53 },
54 "weights": model.weights.iter().map(|(name, tensor)| {
55 (name.clone(), serde_json::json!({
56 "shape": tensor.metadata.shape,
57 "floatValue": tensor.data.as_slice().unwrap().to_vec()
58 }))
59 }).collect::<serde_json::Map<String, serde_json::Value>>()
60 });
61
62 let file = File::create(path).map_err(IoError::Io)?;
63 serde_json::to_writer_pretty(file, &coreml_model)
64 .map_err(|e| IoError::SerializationError(e.to_string()))
65 }
66
67 fn load_model(&self, path: &Path) -> Result<MLModel> {
68 let file = File::open(path).map_err(IoError::Io)?;
69 let coreml_model: serde_json::Value = serde_json::from_reader(file)
70 .map_err(|e| IoError::SerializationError(e.to_string()))?;
71
72 let mut model = MLModel::new(MLFramework::CoreML);
73
74 if let Some(description) = coreml_model.get("description") {
76 if let Some(metadata) = description.get("metadata") {
77 if let Some(short_desc) = metadata.get("shortDescription").and_then(|v| v.as_str())
78 {
79 model.metadata.model_name = Some(short_desc.to_string());
80 }
81 }
82
83 if let Some(inputs) = description.get("input").and_then(|v| v.as_array()) {
85 for input in inputs {
86 if let Some(input_obj) = input.as_object() {
87 if let (Some(name), Some(shape)) = (
88 input_obj.get("name").and_then(|v| v.as_str()),
89 input_obj
90 .get("type")
91 .and_then(|t| t.get("multiArrayType"))
92 .and_then(|mat| mat.get("shape"))
93 .and_then(|s| s.as_array()),
94 ) {
95 let shape_vec: Vec<usize> = shape
96 .iter()
97 .filter_map(|v| v.as_u64().map(|u| u as usize))
98 .collect();
99 model
100 .metadata
101 .inputshapes
102 .insert(name.to_string(), shape_vec);
103 }
104 }
105 }
106 }
107
108 if let Some(outputs) = description.get("output").and_then(|v| v.as_array()) {
110 for output in outputs {
111 if let Some(output_obj) = output.as_object() {
112 if let (Some(name), Some(shape)) = (
113 output_obj.get("name").and_then(|v| v.as_str()),
114 output_obj
115 .get("type")
116 .and_then(|t| t.get("multiArrayType"))
117 .and_then(|mat| mat.get("shape"))
118 .and_then(|s| s.as_array()),
119 ) {
120 let shape_vec: Vec<usize> = shape
121 .iter()
122 .filter_map(|v| v.as_u64().map(|u| u as usize))
123 .collect();
124 model
125 .metadata
126 .outputshapes
127 .insert(name.to_string(), shape_vec);
128 }
129 }
130 }
131 }
132 }
133
134 if let Some(weights) = coreml_model.get("weights").and_then(|v| v.as_object()) {
136 for (name, weight_data) in weights {
137 let shape: Vec<usize> = serde_json::from_value(weight_data["shape"].clone())
138 .map_err(|e| IoError::SerializationError(e.to_string()))?;
139
140 let data: Vec<f32> = serde_json::from_value(weight_data["floatValue"].clone())
141 .map_err(|e| IoError::SerializationError(e.to_string()))?;
142
143 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
144 .map_err(|e| IoError::Other(e.to_string()))?;
145
146 model
147 .weights
148 .insert(name.clone(), MLTensor::new(array, Some(name.clone())));
149 }
150 }
151
152 Ok(model)
153 }
154
155 fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
156 let tensor_data = serde_json::json!({
157 "coreml_multiarray": {
158 "shape": tensor.metadata.shape,
159 "dataType": "FLOAT32",
160 "floatValue": tensor.data.as_slice().unwrap().to_vec()
161 }
162 });
163
164 let file = File::create(path).map_err(IoError::Io)?;
165 serde_json::to_writer_pretty(file, &tensor_data)
166 .map_err(|e| IoError::SerializationError(e.to_string()))
167 }
168
169 fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
170 let file = File::open(path).map_err(IoError::Io)?;
171 let tensor_data: serde_json::Value = serde_json::from_reader(file)
172 .map_err(|e| IoError::SerializationError(e.to_string()))?;
173
174 if let Some(multiarray) = tensor_data.get("coreml_multiarray") {
175 let shape: Vec<usize> = serde_json::from_value(multiarray["shape"].clone())
176 .map_err(|e| IoError::SerializationError(e.to_string()))?;
177
178 let data: Vec<f32> = serde_json::from_value(multiarray["floatValue"].clone())
179 .map_err(|e| IoError::SerializationError(e.to_string()))?;
180
181 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
182 .map_err(|e| IoError::Other(e.to_string()))?;
183
184 return Ok(MLTensor::new(array, None));
185 }
186
187 Err(IoError::Other("Invalid CoreML tensor format".to_string()))
188 }
189}