Skip to main content

tenflowers_core/graph/
serialization.rs

1//! Graph Serialization and Persistence
2//!
3//! This module provides functionality for serializing and deserializing graphs
4//! for persistence, transfer, and interoperability.
5
6use super::core::*;
7use crate::error::TensorError;
8use std::collections::HashMap;
9
10/// Serializable representation of a graph for persistence and transfer
11#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
13pub struct GraphDef {
14    pub nodes: Vec<NodeDef>,
15    pub edges: Vec<EdgeDef>,
16    pub version: u64,
17}
18
19/// Serializable representation of a graph node
20#[derive(Clone, Debug)]
21#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
22pub struct NodeDef {
23    pub id: NodeId,
24    pub name: String,
25    pub op_type: String,
26    pub device: String,
27    pub attributes: HashMap<String, AttributeValueDef>,
28}
29
30/// Serializable representation of a graph edge
31#[derive(Clone, Debug)]
32#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
33pub struct EdgeDef {
34    pub id: EdgeId,
35    pub from_node: NodeId,
36    pub to_node: NodeId,
37    pub from_output: usize,
38    pub to_input: usize,
39    pub dtype: String,
40    pub shape: Vec<usize>,
41    pub is_control: bool,
42}
43
44/// Serializable representation of attribute values
45#[derive(Clone, Debug)]
46#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
47pub enum AttributeValueDef {
48    String(String),
49    Int(i64),
50    Float(f64),
51    Bool(bool),
52    IntList(Vec<i64>),
53    FloatList(Vec<f64>),
54    Shape(Vec<usize>),
55    Tensor(Vec<f32>), // Flattened tensor data
56}
57
58impl From<AttributeValue> for AttributeValueDef {
59    fn from(value: AttributeValue) -> Self {
60        match value {
61            AttributeValue::String(s) => AttributeValueDef::String(s),
62            AttributeValue::Int(i) => AttributeValueDef::Int(i),
63            AttributeValue::Float(f) => AttributeValueDef::Float(f),
64            AttributeValue::Bool(b) => AttributeValueDef::Bool(b),
65            AttributeValue::IntList(list) => AttributeValueDef::IntList(list),
66            AttributeValue::FloatList(list) => AttributeValueDef::FloatList(list),
67            AttributeValue::Shape(shape) => AttributeValueDef::Shape(shape.dims().to_vec()),
68            AttributeValue::Tensor(tensor) => {
69                // Flatten tensor data for serialization
70                let data = tensor.as_slice().unwrap_or(&[]).to_vec();
71                AttributeValueDef::Tensor(data)
72            }
73        }
74    }
75}
76
77impl TryFrom<AttributeValueDef> for AttributeValue {
78    type Error = TensorError;
79
80    fn try_from(def: AttributeValueDef) -> Result<Self, Self::Error> {
81        match def {
82            AttributeValueDef::String(s) => Ok(AttributeValue::String(s)),
83            AttributeValueDef::Int(i) => Ok(AttributeValue::Int(i)),
84            AttributeValueDef::Float(f) => Ok(AttributeValue::Float(f)),
85            AttributeValueDef::Bool(b) => Ok(AttributeValue::Bool(b)),
86            AttributeValueDef::IntList(list) => Ok(AttributeValue::IntList(list)),
87            AttributeValueDef::FloatList(list) => Ok(AttributeValue::FloatList(list)),
88            AttributeValueDef::Shape(dims) => {
89                Ok(AttributeValue::Shape(crate::shape::Shape::new(dims)))
90            }
91            AttributeValueDef::Tensor(data) => {
92                // Reconstruct tensor from flattened data
93                // Note: This is simplified - in practice we'd need to store shape info
94                use crate::tensor::Tensor;
95                let shape = vec![data.len()];
96                let tensor = Tensor::from_vec(data, &shape)?;
97                Ok(AttributeValue::Tensor(tensor))
98            }
99        }
100    }
101}
102
103impl Graph {
104    /// Convert graph to serializable format
105    pub fn to_graph_def(&self) -> GraphDef {
106        let nodes = self
107            .nodes
108            .values()
109            .map(|node| NodeDef {
110                id: node.id,
111                name: node.name.clone(),
112                op_type: match &node.op_type {
113                    NodeType::Operation(op) => op.clone(),
114                    NodeType::Variable { dtype, shape: _ } => format!("Variable:{:?}", dtype),
115                    NodeType::Placeholder { dtype, shape: _ } => format!("Placeholder:{:?}", dtype),
116                    NodeType::Constant => "Constant".to_string(),
117                },
118                device: format!("{:?}", node.device),
119                attributes: node
120                    .attributes
121                    .iter()
122                    .map(|(k, v)| (k.clone(), v.clone().into()))
123                    .collect(),
124            })
125            .collect();
126
127        let edges = self
128            .edges
129            .values()
130            .map(|edge| EdgeDef {
131                id: edge.id,
132                from_node: edge.from_node,
133                to_node: edge.to_node,
134                from_output: edge.from_output,
135                to_input: edge.to_input,
136                dtype: format!("{:?}", edge.dtype),
137                shape: edge.shape.dims().to_vec(),
138                is_control: edge.is_control,
139            })
140            .collect();
141
142        GraphDef {
143            nodes,
144            edges,
145            version: self.version,
146        }
147    }
148
149    /// Create graph from serializable format
150    pub fn from_graph_def(graph_def: &GraphDef) -> Result<Self, TensorError> {
151        let mut graph = Graph::new();
152        let mut id_mapping: HashMap<NodeId, NodeId> = HashMap::new();
153
154        // Add all nodes
155        for node_def in &graph_def.nodes {
156            let op_type = if node_def.op_type.starts_with("Variable:") {
157                NodeType::Variable {
158                    dtype: crate::dtype::DType::Float32, // Simplified
159                    shape: crate::shape::Shape::new(vec![]),
160                }
161            } else if node_def.op_type.starts_with("Placeholder:") {
162                NodeType::Placeholder {
163                    dtype: crate::dtype::DType::Float32, // Simplified
164                    shape: crate::shape::Shape::new(vec![]),
165                }
166            } else if node_def.op_type == "Constant" {
167                NodeType::Constant
168            } else {
169                NodeType::Operation(node_def.op_type.clone())
170            };
171
172            let device = crate::device::Device::Cpu; // Simplified
173
174            let attributes: Result<HashMap<String, AttributeValue>, TensorError> = node_def
175                .attributes
176                .iter()
177                .map(|(k, v)| Ok((k.clone(), v.clone().try_into()?)))
178                .collect();
179
180            let new_id = graph.add_node(node_def.name.clone(), op_type, device, attributes?)?;
181
182            id_mapping.insert(node_def.id, new_id);
183        }
184
185        // Add all edges
186        for edge_def in &graph_def.edges {
187            let from_node = *id_mapping.get(&edge_def.from_node).ok_or_else(|| {
188                TensorError::invalid_argument(format!(
189                    "Node {} not found in mapping",
190                    edge_def.from_node
191                ))
192            })?;
193
194            let to_node = *id_mapping.get(&edge_def.to_node).ok_or_else(|| {
195                TensorError::invalid_argument(format!(
196                    "Node {} not found in mapping",
197                    edge_def.to_node
198                ))
199            })?;
200
201            graph.add_edge(
202                from_node,
203                to_node,
204                edge_def.from_output,
205                edge_def.to_input,
206                crate::dtype::DType::Float32, // Simplified
207                crate::shape::Shape::new(edge_def.shape.clone()),
208                edge_def.is_control,
209            )?;
210        }
211
212        graph.version = graph_def.version;
213        Ok(graph)
214    }
215
216    /// Save graph to file
217    #[cfg(feature = "serialize")]
218    pub fn save_to_file<P: AsRef<std::path::Path>>(&self, path: P) -> Result<(), TensorError> {
219        let graph_def = self.to_graph_def();
220        let serialized = oxicode::serde::encode_to_vec(&graph_def, oxicode::config::standard())
221            .map_err(|e| TensorError::invalid_argument(format!("Serialization failed: {}", e)))?;
222
223        std::fs::write(path, serialized)
224            .map_err(|e| TensorError::invalid_argument(format!("Failed to write file: {}", e)))?;
225
226        Ok(())
227    }
228
229    /// Load graph from file
230    #[cfg(feature = "serialize")]
231    pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self, TensorError> {
232        let data = std::fs::read(path)
233            .map_err(|e| TensorError::invalid_argument(format!("Failed to read file: {}", e)))?;
234
235        let graph_def: GraphDef =
236            oxicode::serde::decode_owned_from_slice(&data, oxicode::config::standard())
237                .map_err(|e| {
238                    TensorError::invalid_argument(format!("Deserialization failed: {}", e))
239                })?
240                .0;
241
242        Self::from_graph_def(&graph_def)
243    }
244
245    /// Convert graph to JSON string
246    #[cfg(feature = "serialize")]
247    pub fn to_json(&self) -> Result<String, TensorError> {
248        let graph_def = self.to_graph_def();
249        serde_json::to_string_pretty(&graph_def)
250            .map_err(|e| TensorError::invalid_argument(format!("JSON serialization failed: {}", e)))
251    }
252
253    /// Create graph from JSON string
254    #[cfg(feature = "serialize")]
255    pub fn from_json(json: &str) -> Result<Self, TensorError> {
256        let graph_def: GraphDef = serde_json::from_str(json).map_err(|e| {
257            TensorError::invalid_argument(format!("JSON deserialization failed: {}", e))
258        })?;
259
260        Self::from_graph_def(&graph_def)
261    }
262}