1use super::core::*;
7use crate::error::TensorError;
8use std::collections::HashMap;
9
10#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
13pub struct GraphDef {
14 pub nodes: Vec<NodeDef>,
15 pub edges: Vec<EdgeDef>,
16 pub version: u64,
17}
18
19#[derive(Clone, Debug)]
21#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
22pub struct NodeDef {
23 pub id: NodeId,
24 pub name: String,
25 pub op_type: String,
26 pub device: String,
27 pub attributes: HashMap<String, AttributeValueDef>,
28}
29
30#[derive(Clone, Debug)]
32#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
33pub struct EdgeDef {
34 pub id: EdgeId,
35 pub from_node: NodeId,
36 pub to_node: NodeId,
37 pub from_output: usize,
38 pub to_input: usize,
39 pub dtype: String,
40 pub shape: Vec<usize>,
41 pub is_control: bool,
42}
43
44#[derive(Clone, Debug)]
46#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
47pub enum AttributeValueDef {
48 String(String),
49 Int(i64),
50 Float(f64),
51 Bool(bool),
52 IntList(Vec<i64>),
53 FloatList(Vec<f64>),
54 Shape(Vec<usize>),
55 Tensor(Vec<f32>), }
57
58impl From<AttributeValue> for AttributeValueDef {
59 fn from(value: AttributeValue) -> Self {
60 match value {
61 AttributeValue::String(s) => AttributeValueDef::String(s),
62 AttributeValue::Int(i) => AttributeValueDef::Int(i),
63 AttributeValue::Float(f) => AttributeValueDef::Float(f),
64 AttributeValue::Bool(b) => AttributeValueDef::Bool(b),
65 AttributeValue::IntList(list) => AttributeValueDef::IntList(list),
66 AttributeValue::FloatList(list) => AttributeValueDef::FloatList(list),
67 AttributeValue::Shape(shape) => AttributeValueDef::Shape(shape.dims().to_vec()),
68 AttributeValue::Tensor(tensor) => {
69 let data = tensor.as_slice().unwrap_or(&[]).to_vec();
71 AttributeValueDef::Tensor(data)
72 }
73 }
74 }
75}
76
77impl TryFrom<AttributeValueDef> for AttributeValue {
78 type Error = TensorError;
79
80 fn try_from(def: AttributeValueDef) -> Result<Self, Self::Error> {
81 match def {
82 AttributeValueDef::String(s) => Ok(AttributeValue::String(s)),
83 AttributeValueDef::Int(i) => Ok(AttributeValue::Int(i)),
84 AttributeValueDef::Float(f) => Ok(AttributeValue::Float(f)),
85 AttributeValueDef::Bool(b) => Ok(AttributeValue::Bool(b)),
86 AttributeValueDef::IntList(list) => Ok(AttributeValue::IntList(list)),
87 AttributeValueDef::FloatList(list) => Ok(AttributeValue::FloatList(list)),
88 AttributeValueDef::Shape(dims) => {
89 Ok(AttributeValue::Shape(crate::shape::Shape::new(dims)))
90 }
91 AttributeValueDef::Tensor(data) => {
92 use crate::tensor::Tensor;
95 let shape = vec![data.len()];
96 let tensor = Tensor::from_vec(data, &shape)?;
97 Ok(AttributeValue::Tensor(tensor))
98 }
99 }
100 }
101}
102
103impl Graph {
104 pub fn to_graph_def(&self) -> GraphDef {
106 let nodes = self
107 .nodes
108 .values()
109 .map(|node| NodeDef {
110 id: node.id,
111 name: node.name.clone(),
112 op_type: match &node.op_type {
113 NodeType::Operation(op) => op.clone(),
114 NodeType::Variable { dtype, shape: _ } => format!("Variable:{:?}", dtype),
115 NodeType::Placeholder { dtype, shape: _ } => format!("Placeholder:{:?}", dtype),
116 NodeType::Constant => "Constant".to_string(),
117 },
118 device: format!("{:?}", node.device),
119 attributes: node
120 .attributes
121 .iter()
122 .map(|(k, v)| (k.clone(), v.clone().into()))
123 .collect(),
124 })
125 .collect();
126
127 let edges = self
128 .edges
129 .values()
130 .map(|edge| EdgeDef {
131 id: edge.id,
132 from_node: edge.from_node,
133 to_node: edge.to_node,
134 from_output: edge.from_output,
135 to_input: edge.to_input,
136 dtype: format!("{:?}", edge.dtype),
137 shape: edge.shape.dims().to_vec(),
138 is_control: edge.is_control,
139 })
140 .collect();
141
142 GraphDef {
143 nodes,
144 edges,
145 version: self.version,
146 }
147 }
148
149 pub fn from_graph_def(graph_def: &GraphDef) -> Result<Self, TensorError> {
151 let mut graph = Graph::new();
152 let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
153
154 for node_def in &graph_def.nodes {
156 let op_type = if node_def.op_type.starts_with("Variable:") {
157 NodeType::Variable {
158 dtype: crate::dtype::DType::Float32, shape: crate::shape::Shape::new(vec![]),
160 }
161 } else if node_def.op_type.starts_with("Placeholder:") {
162 NodeType::Placeholder {
163 dtype: crate::dtype::DType::Float32, shape: crate::shape::Shape::new(vec![]),
165 }
166 } else if node_def.op_type == "Constant" {
167 NodeType::Constant
168 } else {
169 NodeType::Operation(node_def.op_type.clone())
170 };
171
172 let device = crate::device::Device::Cpu; let attributes: Result<HashMap<String, AttributeValue>, TensorError> = node_def
175 .attributes
176 .iter()
177 .map(|(k, v)| Ok((k.clone(), v.clone().try_into()?)))
178 .collect();
179
180 let new_id = graph.add_node(node_def.name.clone(), op_type, device, attributes?)?;
181
182 id_mapping.insert(node_def.id, new_id);
183 }
184
185 for edge_def in &graph_def.edges {
187 let from_node = *id_mapping.get(&edge_def.from_node).ok_or_else(|| {
188 TensorError::invalid_argument(format!(
189 "Node {} not found in mapping",
190 edge_def.from_node
191 ))
192 })?;
193
194 let to_node = *id_mapping.get(&edge_def.to_node).ok_or_else(|| {
195 TensorError::invalid_argument(format!(
196 "Node {} not found in mapping",
197 edge_def.to_node
198 ))
199 })?;
200
201 graph.add_edge(
202 from_node,
203 to_node,
204 edge_def.from_output,
205 edge_def.to_input,
206 crate::dtype::DType::Float32, crate::shape::Shape::new(edge_def.shape.clone()),
208 edge_def.is_control,
209 )?;
210 }
211
212 graph.version = graph_def.version;
213 Ok(graph)
214 }
215
216 #[cfg(feature = "serialize")]
218 pub fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), TensorError> {
219 let graph_def = self.to_graph_def();
220 let serialized = oxicode::serde::encode_to_vec(&graph_def, oxicode::config::standard())
221 .map_err(|e| TensorError::invalid_argument(format!("Serialization failed: {}", e)))?;
222
223 std::fs::write(path, serialized)
224 .map_err(|e| TensorError::invalid_argument(format!("Failed to write file: {}", e)))?;
225
226 Ok(())
227 }
228
229 #[cfg(feature = "serialize")]
231 pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, TensorError> {
232 let data = std::fs::read(path)
233 .map_err(|e| TensorError::invalid_argument(format!("Failed to read file: {}", e)))?;
234
235 let graph_def: GraphDef =
236 oxicode::serde::decode_owned_from_slice(&data, oxicode::config::standard())
237 .map_err(|e| {
238 TensorError::invalid_argument(format!("Deserialization failed: {}", e))
239 })?
240 .0;
241
242 Self::from_graph_def(&graph_def)
243 }
244
245 #[cfg(feature = "serialize")]
247 pub fn to_json(&self) -> Result<String, TensorError> {
248 let graph_def = self.to_graph_def();
249 serde_json::to_string_pretty(&graph_def)
250 .map_err(|e| TensorError::invalid_argument(format!("JSON serialization failed: {}", e)))
251 }
252
253 #[cfg(feature = "serialize")]
255 pub fn from_json(json: &str) -> Result<Self, TensorError> {
256 let graph_def: GraphDef = serde_json::from_str(json).map_err(|e| {
257 TensorError::invalid_argument(format!("JSON deserialization failed: {}", e))
258 })?;
259
260 Self::from_graph_def(&graph_def)
261 }
262}