1use crate::error::RusTorchError;
5use crate::tensor::Tensor;
6use num_traits::Float;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10use std::path::Path;
11
12#[derive(Debug, Clone)]
15pub enum SerializationError {
16 IoError(String),
18 FormatError(String),
20 VersionError { expected: String, found: String },
22 MissingField(String),
24 TypeMismatch { expected: String, found: String },
26 CorruptionError(String),
28 UnsupportedOperation(String),
30}
31
32impl fmt::Display for SerializationError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 SerializationError::IoError(msg) => write!(f, "I/O error: {}", msg),
36 SerializationError::FormatError(msg) => write!(f, "Format error: {}", msg),
37 SerializationError::VersionError { expected, found } => {
38 write!(
39 f,
40 "Version mismatch: expected {}, found {}",
41 expected, found
42 )
43 }
44 SerializationError::MissingField(field) => {
45 write!(f, "Missing required field: {}", field)
46 }
47 SerializationError::TypeMismatch { expected, found } => {
48 write!(f, "Type mismatch: expected {}, found {}", expected, found)
49 }
50 SerializationError::CorruptionError(msg) => write!(f, "Data corruption: {}", msg),
51 SerializationError::UnsupportedOperation(msg) => {
52 write!(f, "Unsupported operation: {}", msg)
53 }
54 }
55 }
56}
57
58impl std::error::Error for SerializationError {}
59
60impl From<std::io::Error> for SerializationError {
61 fn from(error: std::io::Error) -> Self {
62 SerializationError::IoError(error.to_string())
63 }
64}
65
66impl From<SerializationError> for RusTorchError {
67 fn from(error: SerializationError) -> Self {
68 RusTorchError::SerializationError {
69 operation: "serialization".to_string(),
70 message: error.to_string(),
71 }
72 }
73}
74
75pub type SerializationResult<T> = Result<T, SerializationError>;
76
77pub trait Saveable {
80 fn save_binary(&self) -> SerializationResult<Vec<u8>>;
83
84 fn type_id(&self) -> &'static str;
87
88 fn version(&self) -> String {
91 "1.0.0".to_string()
92 }
93
94 fn metadata(&self) -> HashMap<String, String> {
97 HashMap::new()
98 }
99}
100
101pub trait Loadable: Sized {
104 fn load_binary(data: &[u8]) -> SerializationResult<Self>;
107
108 fn expected_type_id() -> &'static str;
111
112 fn validate_version(version: &str) -> SerializationResult<()> {
115 if version.starts_with("1.") {
116 Ok(())
117 } else {
118 Err(SerializationError::VersionError {
119 expected: "1.x".to_string(),
120 found: version.to_string(),
121 })
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FileHeader {
130 pub magic: [u8; 8], pub version: String, pub object_type: String, pub metadata: HashMap<String, String>, pub checksum: u64, }
136
137impl FileHeader {
138 pub fn new(object_type: String, metadata: HashMap<String, String>) -> Self {
141 Self {
142 magic: *b"RUSTORCH",
143 version: "1.0.0".to_string(),
144 object_type,
145 metadata,
146 checksum: 0, }
148 }
149
150 pub fn validate(&self) -> SerializationResult<()> {
153 if self.magic != *b"RUSTORCH" {
154 return Err(SerializationError::FormatError(
155 "Invalid file magic".to_string(),
156 ));
157 }
158
159 if !self.version.starts_with("1.") {
160 return Err(SerializationError::VersionError {
161 expected: "1.x".to_string(),
162 found: self.version.clone(),
163 });
164 }
165
166 Ok(())
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct TensorMetadata {
174 pub shape: Vec<usize>,
175 pub dtype: String,
176 pub device: String,
177 pub requires_grad: bool,
178 pub data_offset: u64,
179 pub data_size: u64,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ModelMetadata {
186 pub model_type: String,
187 pub parameters: HashMap<String, TensorMetadata>,
188 pub buffers: HashMap<String, TensorMetadata>,
189 pub config: HashMap<String, String>,
190 pub training_state: bool,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct GraphNode {
197 pub id: usize,
198 pub op_type: String,
199 pub inputs: Vec<usize>,
200 pub outputs: Vec<usize>,
201 pub attributes: HashMap<String, String>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ComputationGraph<T: Float> {
208 pub nodes: Vec<GraphNode>,
209 pub inputs: Vec<String>,
210 pub outputs: Vec<String>,
211 #[serde(skip)]
212 pub constants: HashMap<String, Tensor<T>>,
213}
214
215impl<T: Float> ComputationGraph<T> {
216 pub fn new() -> Self {
219 Self {
220 nodes: Vec::new(),
221 inputs: Vec::new(),
222 outputs: Vec::new(),
223 constants: HashMap::new(),
224 }
225 }
226
227 pub fn add_node(&mut self, node: GraphNode) -> usize {
230 let id = self.nodes.len();
231 self.nodes.push(node);
232 id
233 }
234
235 pub fn validate(&self) -> SerializationResult<()> {
238 for node in &self.nodes {
240 for &input_id in &node.inputs {
241 if input_id >= self.nodes.len() {
242 return Err(SerializationError::FormatError(format!(
243 "Invalid input node ID: {}",
244 input_id
245 )));
246 }
247 }
248 }
249 Ok(())
250 }
251}
252
253pub fn compute_checksum(data: &[u8]) -> u64 {
256 let mut crc: u64 = 0xFFFF_FFFF_FFFF_FFFF;
258 for &byte in data {
259 crc ^= byte as u64;
260 for _ in 0..8 {
261 if crc & 1 != 0 {
262 crc = (crc >> 1) ^ 0xC96C_5795_D787_0F42;
263 } else {
264 crc >>= 1;
265 }
266 }
267 }
268 crc ^ 0xFFFF_FFFF_FFFF_FFFF
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_file_header_creation() {
277 let metadata = HashMap::new();
278 let header = FileHeader::new("tensor".to_string(), metadata);
279
280 assert_eq!(header.magic, *b"RUSTORCH");
281 assert_eq!(header.version, "1.0.0");
282 assert_eq!(header.object_type, "tensor");
283 }
284
285 #[test]
286 fn test_file_header_validation() {
287 let metadata = HashMap::new();
288 let mut header = FileHeader::new("tensor".to_string(), metadata);
289
290 assert!(header.validate().is_ok());
292
293 header.magic = *b"INVALID ";
295 assert!(header.validate().is_err());
296 }
297
298 #[test]
299 fn test_serialization_error_conversion() {
300 let io_error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
301 let ser_error: SerializationError = io_error.into();
302 let rust_error: RusTorchError = ser_error.into();
303
304 match rust_error {
305 RusTorchError::SerializationError { .. } => (),
306 _ => panic!("Expected SerializationError"),
307 }
308 }
309
310 #[test]
311 fn test_computation_graph() {
312 let mut graph: ComputationGraph<f32> = ComputationGraph::new();
313
314 let node = GraphNode {
315 id: 0,
316 op_type: "add".to_string(),
317 inputs: vec![],
318 outputs: vec![0],
319 attributes: HashMap::new(),
320 };
321
322 let id = graph.add_node(node);
323 assert_eq!(id, 0);
324 assert!(graph.validate().is_ok());
325 }
326
327 #[test]
328 fn test_checksum_computation() {
329 let data = b"test data";
330 let checksum1 = compute_checksum(data);
331 let checksum2 = compute_checksum(data);
332
333 assert_eq!(checksum1, checksum2);
335
336 let different_data = b"different test data";
338 let checksum3 = compute_checksum(different_data);
339 assert_ne!(checksum1, checksum3);
340 }
341}