scirs2_io/ml_framework/converters/
tensorflow.rs1#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::MLFrameworkConverter;
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use scirs2_core::ndarray::{ArrayD, IxDyn};
8use std::fs::File;
9use std::path::Path;
10
11pub struct TensorFlowConverter;
13
14impl MLFrameworkConverter for TensorFlowConverter {
15 fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
16 let model_dir = path.parent().unwrap_or(Path::new("."));
18 std::fs::create_dir_all(model_dir).map_err(IoError::Io)?;
19
20 let tf_model = serde_json::json!({
21 "saved_model_schema_version": 1,
22 "meta_graphs": [{
23 "meta_info_def": {
24 "meta_graph_version": "v2.0.0",
25 "tensorflow_version": "2.12.0",
26 "tags": ["serve"]
27 },
28 "graph_def": {
29 "versions": { "producer": 1982, "min_consumer": 12 }
30 },
31 "signature_def": {
32 "serving_default": {
33 "inputs": model.metadata.inputshapes,
34 "outputs": model.metadata.outputshapes,
35 "method_name": "tensorflow/serving/predict"
36 }
37 }
38 }],
39 "variables": model.weights.iter().map(|(name, tensor)| {
40 serde_json::json!({
41 "name": name,
42 "shape": tensor.metadata.shape,
43 "dtype": format!("{:?}", tensor.metadata.dtype),
44 "data": tensor.data.as_slice().unwrap().to_vec()
45 })
46 }).collect::<Vec<_>>()
47 });
48
49 let file = File::create(path).map_err(IoError::Io)?;
50 serde_json::to_writer_pretty(file, &tf_model)
51 .map_err(|e| IoError::SerializationError(e.to_string()))
52 }
53
54 fn load_model(&self, path: &Path) -> Result<MLModel> {
55 let file = File::open(path).map_err(IoError::Io)?;
56 let tf_model: serde_json::Value = serde_json::from_reader(file)
57 .map_err(|e| IoError::SerializationError(e.to_string()))?;
58
59 let mut model = MLModel::new(MLFramework::TensorFlow);
60
61 if let Some(meta_graphs) = tf_model.get("meta_graphs").and_then(|v| v.as_array()) {
63 if let Some(meta_graph) = meta_graphs.first() {
64 if let Some(signature_def) = meta_graph
65 .get("signature_def")
66 .and_then(|v| v.get("serving_default"))
67 {
68 if let Some(inputs) = signature_def.get("inputs").and_then(|v| v.as_object()) {
69 for (name, input_info) in inputs {
70 if let Some(shape) = input_info.as_array() {
71 let shape_vec: Vec<usize> = shape
72 .iter()
73 .filter_map(|v| v.as_u64().map(|u| u as usize))
74 .collect();
75 model.metadata.inputshapes.insert(name.clone(), shape_vec);
76 }
77 }
78 }
79 }
80 }
81 }
82
83 if let Some(variables) = tf_model.get("variables").and_then(|v| v.as_array()) {
85 for var in variables {
86 if let Some(var_obj) = var.as_object() {
87 if let (Some(name), Some(shape), Some(data)) = (
88 var_obj.get("name").and_then(|v| v.as_str()),
89 var_obj.get("shape").and_then(|v| v.as_array()),
90 var_obj.get("data").and_then(|v| v.as_array()),
91 ) {
92 let shape_vec: Vec<usize> = shape
93 .iter()
94 .filter_map(|v| v.as_u64().map(|u| u as usize))
95 .collect();
96
97 let data_vec: Vec<f32> = data
98 .iter()
99 .filter_map(|v| v.as_f64().map(|f| f as f32))
100 .collect();
101
102 if let Ok(array) = ArrayD::from_shape_vec(IxDyn(&shape_vec), data_vec) {
103 model.weights.insert(
104 name.to_string(),
105 MLTensor::new(array, Some(name.to_string())),
106 );
107 }
108 }
109 }
110 }
111 }
112
113 Ok(model)
114 }
115
116 fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
117 let tensor_data = serde_json::json!({
118 "tensor": {
119 "dtype": format!("{:?}", tensor.metadata.dtype),
120 "tensorshape": {
121 "dim": tensor.metadata.shape.iter().map(|&d| serde_json::json!({"size": d})).collect::<Vec<_>>()
122 },
123 "tensor_content": tensor.data.as_slice().unwrap()
124 .iter()
125 .flat_map(|f| f.to_le_bytes().to_vec())
126 .collect::<Vec<u8>>()
127 }
128 });
129
130 let file = File::create(path).map_err(IoError::Io)?;
131 serde_json::to_writer_pretty(file, &tensor_data)
132 .map_err(|e| IoError::SerializationError(e.to_string()))
133 }
134
135 fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
136 let file = File::open(path).map_err(IoError::Io)?;
137 let tensor_data: serde_json::Value = serde_json::from_reader(file)
138 .map_err(|e| IoError::SerializationError(e.to_string()))?;
139
140 if let Some(tensor) = tensor_data.get("tensor") {
141 let shape: Vec<usize> = tensor
142 .get("tensorshape")
143 .and_then(|ts| ts.get("dim"))
144 .and_then(|dims| dims.as_array())
145 .map(|dims| {
146 dims.iter()
147 .filter_map(|d| d.get("size").and_then(|s| s.as_u64().map(|u| u as usize)))
148 .collect()
149 })
150 .unwrap_or_default();
151
152 let content = tensor.get("tensor_content").and_then(|c| c.as_array());
154 let data: Vec<f32> = if let Some(content_array) = content {
155 content_array
156 .iter()
157 .filter_map(|v| v.as_f64().map(|f| f as f32))
158 .collect()
159 } else {
160 vec![0.0; shape.iter().product()]
161 };
162
163 let array = ArrayD::from_shape_vec(IxDyn(&shape), data)
164 .map_err(|e| IoError::Other(e.to_string()))?;
165
166 return Ok(MLTensor::new(array, None));
167 }
168
169 Err(IoError::Other(
170 "Invalid TensorFlow tensor format".to_string(),
171 ))
172 }
173}