rustorch/serialization/
core.rs

1//! Core serialization traits and error types for Phase 9
2//! フェーズ9用コアシリアライゼーショントレイトとエラータイプ
3
4use crate::error::RusTorchError;
5use crate::tensor::Tensor;
6use num_traits::Float;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10use std::path::Path;
11
12/// Serialization error types
13/// シリアライゼーションエラータイプ
14#[derive(Debug, Clone)]
15pub enum SerializationError {
16    /// File I/O error
17    IoError(String),
18    /// Format error (invalid file format)
19    FormatError(String),
20    /// Version incompatibility
21    VersionError { expected: String, found: String },
22    /// Missing required field
23    MissingField(String),
24    /// Type mismatch during deserialization
25    TypeMismatch { expected: String, found: String },
26    /// Corruption detected
27    CorruptionError(String),
28    /// Unsupported operation
29    UnsupportedOperation(String),
30}
31
32impl fmt::Display for SerializationError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            SerializationError::IoError(msg) => write!(f, "I/O error: {}", msg),
36            SerializationError::FormatError(msg) => write!(f, "Format error: {}", msg),
37            SerializationError::VersionError { expected, found } => {
38                write!(
39                    f,
40                    "Version mismatch: expected {}, found {}",
41                    expected, found
42                )
43            }
44            SerializationError::MissingField(field) => {
45                write!(f, "Missing required field: {}", field)
46            }
47            SerializationError::TypeMismatch { expected, found } => {
48                write!(f, "Type mismatch: expected {}, found {}", expected, found)
49            }
50            SerializationError::CorruptionError(msg) => write!(f, "Data corruption: {}", msg),
51            SerializationError::UnsupportedOperation(msg) => {
52                write!(f, "Unsupported operation: {}", msg)
53            }
54        }
55    }
56}
57
58impl std::error::Error for SerializationError {}
59
60impl From<std::io::Error> for SerializationError {
61    fn from(error: std::io::Error) -> Self {
62        SerializationError::IoError(error.to_string())
63    }
64}
65
66impl From<SerializationError> for RusTorchError {
67    fn from(error: SerializationError) -> Self {
68        RusTorchError::SerializationError {
69            operation: "serialization".to_string(),
70            message: error.to_string(),
71        }
72    }
73}
74
75pub type SerializationResult<T> = Result<T, SerializationError>;
76
77/// Core trait for objects that can be saved
78/// 保存可能オブジェクトのコアトレイト
79pub trait Saveable {
80    /// Save object to binary format
81    /// オブジェクトをバイナリ形式で保存
82    fn save_binary(&self) -> SerializationResult<Vec<u8>>;
83
84    /// Get object type identifier
85    /// オブジェクトタイプ識別子を取得
86    fn type_id(&self) -> &'static str;
87
88    /// Get version information
89    /// バージョン情報を取得
90    fn version(&self) -> String {
91        "1.0.0".to_string()
92    }
93
94    /// Get metadata for object
95    /// オブジェクトのメタデータを取得
96    fn metadata(&self) -> HashMap<String, String> {
97        HashMap::new()
98    }
99}
100
101/// Core trait for objects that can be loaded
102/// 読み込み可能オブジェクトのコアトレイト
103pub trait Loadable: Sized {
104    /// Load object from binary format
105    /// バイナリ形式からオブジェクトを読み込み
106    fn load_binary(data: &[u8]) -> SerializationResult<Self>;
107
108    /// Get expected type identifier
109    /// 期待されるタイプ識別子を取得
110    fn expected_type_id() -> &'static str;
111
112    /// Validate version compatibility
113    /// バージョン互換性を検証
114    fn validate_version(version: &str) -> SerializationResult<()> {
115        if version.starts_with("1.") {
116            Ok(())
117        } else {
118            Err(SerializationError::VersionError {
119                expected: "1.x".to_string(),
120                found: version.to_string(),
121            })
122        }
123    }
124}
125
126/// File header for RusTorch serialization format
127/// RusTorchシリアライゼーション形式用ファイルヘッダー
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct FileHeader {
130    pub magic: [u8; 8],                    // "RUSTORCH"
131    pub version: String,                   // Version string
132    pub object_type: String,               // Object type identifier
133    pub metadata: HashMap<String, String>, // Additional metadata
134    pub checksum: u64,                     // Data integrity checksum
135}
136
137impl FileHeader {
138    /// Create new file header
139    /// 新しいファイルヘッダーを作成
140    pub fn new(object_type: String, metadata: HashMap<String, String>) -> Self {
141        Self {
142            magic: *b"RUSTORCH",
143            version: "1.0.0".to_string(),
144            object_type,
145            metadata,
146            checksum: 0, // Will be computed during save
147        }
148    }
149
150    /// Validate header magic and version
151    /// ヘッダーマジックとバージョンを検証
152    pub fn validate(&self) -> SerializationResult<()> {
153        if self.magic != *b"RUSTORCH" {
154            return Err(SerializationError::FormatError(
155                "Invalid file magic".to_string(),
156            ));
157        }
158
159        if !self.version.starts_with("1.") {
160            return Err(SerializationError::VersionError {
161                expected: "1.x".to_string(),
162                found: self.version.clone(),
163            });
164        }
165
166        Ok(())
167    }
168}
169
170/// Tensor serialization metadata
171/// テンソルシリアライゼーションメタデータ
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct TensorMetadata {
174    pub shape: Vec<usize>,
175    pub dtype: String,
176    pub device: String,
177    pub requires_grad: bool,
178    pub data_offset: u64,
179    pub data_size: u64,
180}
181
182/// Model serialization metadata
183/// モデルシリアライゼーションメタデータ
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct ModelMetadata {
186    pub model_type: String,
187    pub parameters: HashMap<String, TensorMetadata>,
188    pub buffers: HashMap<String, TensorMetadata>,
189    pub config: HashMap<String, String>,
190    pub training_state: bool,
191}
192
193/// Computation graph node for JIT
194/// JIT用計算グラフノード
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct GraphNode {
197    pub id: usize,
198    pub op_type: String,
199    pub inputs: Vec<usize>,
200    pub outputs: Vec<usize>,
201    pub attributes: HashMap<String, String>,
202}
203
204/// Computation graph for JIT compilation
205/// JITコンパイル用計算グラフ
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ComputationGraph<T: Float> {
208    pub nodes: Vec<GraphNode>,
209    pub inputs: Vec<String>,
210    pub outputs: Vec<String>,
211    #[serde(skip)]
212    pub constants: HashMap<String, Tensor<T>>,
213}
214
215impl<T: Float> ComputationGraph<T> {
216    /// Create new computation graph
217    /// 新しい計算グラフを作成
218    pub fn new() -> Self {
219        Self {
220            nodes: Vec::new(),
221            inputs: Vec::new(),
222            outputs: Vec::new(),
223            constants: HashMap::new(),
224        }
225    }
226
227    /// Add node to graph
228    /// グラフにノードを追加
229    pub fn add_node(&mut self, node: GraphNode) -> usize {
230        let id = self.nodes.len();
231        self.nodes.push(node);
232        id
233    }
234
235    /// Validate graph structure
236    /// グラフ構造を検証
237    pub fn validate(&self) -> SerializationResult<()> {
238        // Check for cycles, validate connections, etc.
239        for node in &self.nodes {
240            for &input_id in &node.inputs {
241                if input_id >= self.nodes.len() {
242                    return Err(SerializationError::FormatError(format!(
243                        "Invalid input node ID: {}",
244                        input_id
245                    )));
246                }
247            }
248        }
249        Ok(())
250    }
251}
252
253/// Utilities for checksum computation
254/// チェックサム計算ユーティリティ
255pub fn compute_checksum(data: &[u8]) -> u64 {
256    // Simple CRC64 implementation
257    let mut crc: u64 = 0xFFFF_FFFF_FFFF_FFFF;
258    for &byte in data {
259        crc ^= byte as u64;
260        for _ in 0..8 {
261            if crc & 1 != 0 {
262                crc = (crc >> 1) ^ 0xC96C_5795_D787_0F42;
263            } else {
264                crc >>= 1;
265            }
266        }
267    }
268    crc ^ 0xFFFF_FFFF_FFFF_FFFF
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_file_header_creation() {
277        let metadata = HashMap::new();
278        let header = FileHeader::new("tensor".to_string(), metadata);
279
280        assert_eq!(header.magic, *b"RUSTORCH");
281        assert_eq!(header.version, "1.0.0");
282        assert_eq!(header.object_type, "tensor");
283    }
284
285    #[test]
286    fn test_file_header_validation() {
287        let metadata = HashMap::new();
288        let mut header = FileHeader::new("tensor".to_string(), metadata);
289
290        // Valid header should pass
291        assert!(header.validate().is_ok());
292
293        // Invalid magic should fail
294        header.magic = *b"INVALID ";
295        assert!(header.validate().is_err());
296    }
297
298    #[test]
299    fn test_serialization_error_conversion() {
300        let io_error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
301        let ser_error: SerializationError = io_error.into();
302        let rust_error: RusTorchError = ser_error.into();
303
304        match rust_error {
305            RusTorchError::SerializationError { .. } => (),
306            _ => panic!("Expected SerializationError"),
307        }
308    }
309
310    #[test]
311    fn test_computation_graph() {
312        let mut graph: ComputationGraph<f32> = ComputationGraph::new();
313
314        let node = GraphNode {
315            id: 0,
316            op_type: "add".to_string(),
317            inputs: vec![],
318            outputs: vec![0],
319            attributes: HashMap::new(),
320        };
321
322        let id = graph.add_node(node);
323        assert_eq!(id, 0);
324        assert!(graph.validate().is_ok());
325    }
326
327    #[test]
328    fn test_checksum_computation() {
329        let data = b"test data";
330        let checksum1 = compute_checksum(data);
331        let checksum2 = compute_checksum(data);
332
333        // Same data should produce same checksum
334        assert_eq!(checksum1, checksum2);
335
336        // Different data should produce different checksum
337        let different_data = b"different test data";
338        let checksum3 = compute_checksum(different_data);
339        assert_ne!(checksum1, checksum3);
340    }
341}