rustorch/serialization/
model_io.rs

1//! Model save/load functionality for Phase 9
2//! フェーズ9用モデル保存・読み込み機能
3
4use super::core::{
5    compute_checksum, ComputationGraph, FileHeader, Loadable, ModelMetadata, Saveable,
6    SerializationError, SerializationResult, TensorMetadata,
7};
8use crate::tensor::Tensor;
9use num_traits::Float;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::{File, OpenOptions};
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::{Path, PathBuf};
15
16/// Main save function for objects
17/// オブジェクト用メイン保存関数
18pub fn save<P: AsRef<Path>>(obj: &dyn Saveable, path: P) -> SerializationResult<()> {
19    let path = path.as_ref();
20
21    // Create parent directories if they don't exist
22    if let Some(parent) = path.parent() {
23        std::fs::create_dir_all(parent)?;
24    }
25
26    let file = OpenOptions::new()
27        .create(true)
28        .write(true)
29        .truncate(true)
30        .open(path)?;
31    let mut writer = BufWriter::new(file);
32
33    // Write RUSTORCH magic bytes first for format detection
34    writer.write_all(b"RUSTORCH")?;
35
36    // Create header
37    let metadata = obj.metadata();
38    let mut header = FileHeader::new(obj.type_id().to_string(), metadata);
39
40    // Serialize object data
41    let object_data = obj.save_binary()?;
42    header.checksum = compute_checksum(&object_data);
43
44    // Write header
45    let header_data =
46        bincode::serialize(&header).map_err(|e| SerializationError::FormatError(e.to_string()))?;
47    let header_size = header_data.len() as u64;
48
49    writer.write_all(&header_size.to_le_bytes())?;
50    writer.write_all(&header_data)?;
51
52    // Write object data
53    writer.write_all(&object_data)?;
54    writer.flush()?;
55
56    Ok(())
57}
58
59/// Main load function for objects
60/// オブジェクト用メイン読み込み関数
61pub fn load<P: AsRef<Path>, T: Loadable>(path: P) -> SerializationResult<T> {
62    let file = File::open(path.as_ref())?;
63    let mut reader = BufReader::new(file);
64
65    // Read and verify RUSTORCH magic bytes
66    let mut magic = [0u8; 8];
67    reader.read_exact(&mut magic)?;
68    if &magic != b"RUSTORCH" {
69        return Err(SerializationError::FormatError(
70            "Invalid RusTorch file format".to_string(),
71        ));
72    }
73
74    // Read header size
75    let mut header_size_bytes = [0u8; 8];
76    reader.read_exact(&mut header_size_bytes)?;
77    let header_size = u64::from_le_bytes(header_size_bytes);
78
79    // Read header
80    let mut header_data = vec![0u8; header_size as usize];
81    reader.read_exact(&mut header_data)?;
82    let header: FileHeader = bincode::deserialize(&header_data)
83        .map_err(|e| SerializationError::FormatError(e.to_string()))?;
84
85    // Validate header
86    header.validate()?;
87
88    // Check type compatibility
89    if header.object_type != T::expected_type_id() {
90        return Err(SerializationError::TypeMismatch {
91            expected: T::expected_type_id().to_string(),
92            found: header.object_type,
93        });
94    }
95
96    // Validate version
97    T::validate_version(&header.version)?;
98
99    // Read object data
100    let mut object_data = Vec::new();
101    reader.read_to_end(&mut object_data)?;
102
103    // Verify checksum
104    let computed_checksum = compute_checksum(&object_data);
105    if computed_checksum != header.checksum {
106        return Err(SerializationError::CorruptionError(
107            "Checksum mismatch".to_string(),
108        ));
109    }
110
111    // Deserialize object
112    T::load_binary(&object_data)
113}
114
115/// Model state dictionary for PyTorch compatibility
116/// PyTorch互換性用モデル状態辞書
117#[derive(Debug, Clone)]
118pub struct StateDict<T: Float> {
119    pub parameters: HashMap<String, Tensor<T>>,
120    pub buffers: HashMap<String, Tensor<T>>,
121    pub metadata: ModelMetadata,
122}
123
124impl<T: Float + 'static> StateDict<T> {
125    /// Create new state dictionary
126    /// 新しい状態辞書を作成
127    pub fn new() -> Self {
128        Self {
129            parameters: HashMap::new(),
130            buffers: HashMap::new(),
131            metadata: ModelMetadata {
132                model_type: "unknown".to_string(),
133                parameters: HashMap::new(),
134                buffers: HashMap::new(),
135                config: HashMap::new(),
136                training_state: false,
137            },
138        }
139    }
140
141    /// Add parameter to state dict
142    /// 状態辞書にパラメータを追加
143    pub fn add_parameter(&mut self, name: String, tensor: Tensor<T>) {
144        let metadata = TensorMetadata {
145            shape: tensor.shape().to_vec(),
146            dtype: std::any::type_name::<T>().to_string(),
147            device: "cpu".to_string(), // Default to CPU for now
148            requires_grad: true,
149            data_offset: 0, // Will be computed during save
150            data_size: tensor.numel() as u64 * std::mem::size_of::<T>() as u64,
151        };
152
153        self.metadata.parameters.insert(name.clone(), metadata);
154        self.parameters.insert(name, tensor);
155    }
156
157    /// Add buffer to state dict
158    /// 状態辞書にバッファを追加
159    pub fn add_buffer(&mut self, name: String, tensor: Tensor<T>) {
160        let metadata = TensorMetadata {
161            shape: tensor.shape().to_vec(),
162            dtype: std::any::type_name::<T>().to_string(),
163            device: "cpu".to_string(),
164            requires_grad: false,
165            data_offset: 0,
166            data_size: tensor.numel() as u64 * std::mem::size_of::<T>() as u64,
167        };
168
169        self.metadata.buffers.insert(name.clone(), metadata);
170        self.buffers.insert(name, tensor);
171    }
172
173    /// Get parameter by name
174    /// 名前でパラメータを取得
175    pub fn get_parameter(&self, name: &str) -> Option<&Tensor<T>> {
176        self.parameters.get(name)
177    }
178
179    /// Get buffer by name
180    /// 名前でバッファを取得
181    pub fn get_buffer(&self, name: &str) -> Option<&Tensor<T>> {
182        self.buffers.get(name)
183    }
184
185    /// Check if training mode
186    /// トレーニングモードかチェック
187    pub fn is_training(&self) -> bool {
188        self.metadata.training_state
189    }
190
191    /// Set training mode
192    /// トレーニングモードを設定
193    pub fn set_training(&mut self, training: bool) {
194        self.metadata.training_state = training;
195    }
196}
197
198impl<T: Float + 'static> Saveable for StateDict<T> {
199    fn save_binary(&self) -> SerializationResult<Vec<u8>> {
200        let mut buffer = Vec::new();
201
202        // Save metadata first
203        let metadata_json = serde_json::to_string(&self.metadata)
204            .map_err(|e| SerializationError::FormatError(e.to_string()))?;
205        let metadata_bytes = metadata_json.as_bytes();
206        buffer.extend_from_slice(&(metadata_bytes.len() as u64).to_le_bytes());
207        buffer.extend_from_slice(metadata_bytes);
208
209        // Save parameters count and data
210        buffer.extend_from_slice(&(self.parameters.len() as u32).to_le_bytes());
211        for (name, tensor) in &self.parameters {
212            let name_bytes = name.as_bytes();
213            buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
214            buffer.extend_from_slice(name_bytes);
215
216            let tensor_data = tensor.save_binary()?;
217            buffer.extend_from_slice(&(tensor_data.len() as u64).to_le_bytes());
218            buffer.extend_from_slice(&tensor_data);
219        }
220
221        // Save buffers count and data
222        buffer.extend_from_slice(&(self.buffers.len() as u32).to_le_bytes());
223        for (name, tensor) in &self.buffers {
224            let name_bytes = name.as_bytes();
225            buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
226            buffer.extend_from_slice(name_bytes);
227
228            let tensor_data = tensor.save_binary()?;
229            buffer.extend_from_slice(&(tensor_data.len() as u64).to_le_bytes());
230            buffer.extend_from_slice(&tensor_data);
231        }
232
233        Ok(buffer)
234    }
235
236    fn type_id(&self) -> &'static str {
237        "state_dict"
238    }
239
240    fn metadata(&self) -> HashMap<String, String> {
241        let mut meta = HashMap::new();
242        meta.insert("model_type".to_string(), self.metadata.model_type.clone());
243        meta.insert(
244            "num_parameters".to_string(),
245            self.parameters.len().to_string(),
246        );
247        meta.insert("num_buffers".to_string(), self.buffers.len().to_string());
248        meta.insert(
249            "training_state".to_string(),
250            self.metadata.training_state.to_string(),
251        );
252        meta
253    }
254}
255
256impl<T: Float + 'static> Loadable for StateDict<T> {
257    fn load_binary(data: &[u8]) -> SerializationResult<Self> {
258        if data.is_empty() {
259            return Ok(Self::new());
260        }
261
262        let mut offset = 0;
263        let mut state_dict = Self::new();
264
265        // Read metadata
266        if data.len() < offset + 8 {
267            return Ok(state_dict);
268        }
269        let metadata_len =
270            u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
271                SerializationError::FormatError("Invalid metadata length".to_string())
272            })?) as usize;
273        offset += 8;
274
275        if data.len() < offset + metadata_len {
276            return Ok(state_dict);
277        }
278        let metadata_str =
279            std::str::from_utf8(&data[offset..offset + metadata_len]).map_err(|_| {
280                SerializationError::FormatError("Invalid metadata encoding".to_string())
281            })?;
282        state_dict.metadata = serde_json::from_str(metadata_str)
283            .map_err(|e| SerializationError::FormatError(e.to_string()))?;
284        offset += metadata_len;
285
286        // Read parameters
287        if data.len() < offset + 4 {
288            return Ok(state_dict);
289        }
290        let params_count =
291            u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
292                SerializationError::FormatError("Invalid parameters count".to_string())
293            })?);
294        offset += 4;
295
296        for _ in 0..params_count {
297            // Read parameter name
298            if data.len() < offset + 4 {
299                break;
300            }
301            let name_len =
302                u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
303                    SerializationError::FormatError("Invalid parameter name length".to_string())
304                })?) as usize;
305            offset += 4;
306
307            if data.len() < offset + name_len {
308                break;
309            }
310            let name =
311                String::from_utf8(data[offset..offset + name_len].to_vec()).map_err(|_| {
312                    SerializationError::FormatError("Invalid parameter name encoding".to_string())
313                })?;
314            offset += name_len;
315
316            // Read tensor data
317            if data.len() < offset + 8 {
318                break;
319            }
320            let tensor_data_len =
321                u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
322                    SerializationError::FormatError("Invalid tensor data length".to_string())
323                })?) as usize;
324            offset += 8;
325
326            if data.len() < offset + tensor_data_len {
327                break;
328            }
329            let tensor_data = &data[offset..offset + tensor_data_len];
330            if let Ok(tensor) = Tensor::<T>::load_binary(tensor_data) {
331                state_dict.parameters.insert(name, tensor);
332            }
333            offset += tensor_data_len;
334        }
335
336        // Read buffers (similar to parameters)
337        if data.len() < offset + 4 {
338            return Ok(state_dict);
339        }
340        let buffers_count =
341            u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
342                SerializationError::FormatError("Invalid buffers count".to_string())
343            })?);
344        offset += 4;
345
346        for _ in 0..buffers_count {
347            // Read buffer name
348            if data.len() < offset + 4 {
349                break;
350            }
351            let name_len =
352                u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
353                    SerializationError::FormatError("Invalid buffer name length".to_string())
354                })?) as usize;
355            offset += 4;
356
357            if data.len() < offset + name_len {
358                break;
359            }
360            let name =
361                String::from_utf8(data[offset..offset + name_len].to_vec()).map_err(|_| {
362                    SerializationError::FormatError("Invalid buffer name encoding".to_string())
363                })?;
364            offset += name_len;
365
366            // Read tensor data
367            if data.len() < offset + 8 {
368                break;
369            }
370            let tensor_data_len =
371                u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
372                    SerializationError::FormatError("Invalid tensor data length".to_string())
373                })?) as usize;
374            offset += 8;
375
376            if data.len() < offset + tensor_data_len {
377                break;
378            }
379            let tensor_data = &data[offset..offset + tensor_data_len];
380            if let Ok(tensor) = Tensor::<T>::load_binary(tensor_data) {
381                state_dict.buffers.insert(name, tensor);
382            }
383            offset += tensor_data_len;
384        }
385
386        Ok(state_dict)
387    }
388
389    fn expected_type_id() -> &'static str {
390        "state_dict"
391    }
392}
393
394/// Safe tensor format for large models
395/// 大規模モデル用セーフテンソル形式
396#[derive(Debug, Clone)]
397pub struct SafeTensorFormat<T: Float> {
398    pub tensors: HashMap<String, Tensor<T>>,
399    pub metadata: HashMap<String, String>,
400}
401
402impl<T: Float + 'static> SafeTensorFormat<T> {
403    /// Create new safe tensor format
404    /// 新しいセーフテンソル形式を作成
405    pub fn new() -> Self {
406        Self {
407            tensors: HashMap::new(),
408            metadata: HashMap::new(),
409        }
410    }
411
412    /// Add tensor with name
413    /// 名前付きテンソルを追加
414    pub fn add_tensor(&mut self, name: String, tensor: Tensor<T>) {
415        self.tensors.insert(name, tensor);
416    }
417
418    /// Save in safetensors format
419    /// safetensors形式で保存
420    pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()> {
421        // Create safetensors-compatible format
422        let mut header_data = HashMap::new();
423
424        for (name, tensor) in &self.tensors {
425            let shape: Vec<usize> = tensor.shape().to_vec();
426            header_data.insert(
427                name.clone(),
428                serde_json::json!({
429                    "dtype": self.get_dtype_string(),
430                    "shape": shape,
431                    "data_offsets": [0, tensor.numel() * std::mem::size_of::<T>()]
432                }),
433            );
434        }
435
436        // Add metadata
437        header_data.insert("__metadata__".to_string(), serde_json::json!(self.metadata));
438
439        let header_json = serde_json::to_string(&header_data)
440            .map_err(|e| SerializationError::FormatError(e.to_string()))?;
441
442        let file = OpenOptions::new()
443            .create(true)
444            .write(true)
445            .truncate(true)
446            .open(path)?;
447        let mut writer = BufWriter::new(file);
448
449        // Write header size and header
450        let header_size = header_json.len() as u64;
451        writer.write_all(&header_size.to_le_bytes())?;
452        writer.write_all(header_json.as_bytes())?;
453
454        // Write tensor data
455        for (_, tensor) in &self.tensors {
456            if let Some(data_slice) = tensor.data.as_slice() {
457                let bytes = unsafe {
458                    std::slice::from_raw_parts(
459                        data_slice.as_ptr() as *const u8,
460                        data_slice.len() * std::mem::size_of::<T>(),
461                    )
462                };
463                writer.write_all(bytes)?;
464            }
465        }
466
467        writer.flush()?;
468        Ok(())
469    }
470
471    fn get_dtype_string(&self) -> String {
472        match std::mem::size_of::<T>() {
473            4 => "F32".to_string(),
474            8 => "F64".to_string(),
475            _ => "UNKNOWN".to_string(),
476        }
477    }
478}
479
480/// Model checkpoint system
481/// モデルチェックポイントシステム
482#[derive(Debug, Clone)]
483pub struct ModelCheckpoint<T: Float> {
484    pub epoch: usize,
485    pub step: usize,
486    pub model_state: StateDict<T>,
487    pub optimizer_state: HashMap<String, Vec<u8>>,
488    pub scheduler_state: HashMap<String, Vec<u8>>,
489    pub metrics: HashMap<String, f64>,
490    pub timestamp: u64,
491}
492
493impl<T: Float + 'static> ModelCheckpoint<T> {
494    /// Create new model checkpoint
495    /// 新しいモデルチェックポイントを作成
496    pub fn new(epoch: usize, step: usize, model_state: StateDict<T>) -> Self {
497        Self {
498            epoch,
499            step,
500            model_state,
501            optimizer_state: HashMap::new(),
502            scheduler_state: HashMap::new(),
503            metrics: HashMap::new(),
504            timestamp: std::time::SystemTime::now()
505                .duration_since(std::time::UNIX_EPOCH)
506                .unwrap_or_default()
507                .as_secs(),
508        }
509    }
510
511    /// Add optimizer state
512    /// オプティマイザー状態を追加
513    pub fn add_optimizer_state(&mut self, name: String, state: Vec<u8>) {
514        self.optimizer_state.insert(name, state);
515    }
516
517    /// Add training metrics
518    /// トレーニングメトリクスを追加
519    pub fn add_metric(&mut self, name: String, value: f64) {
520        self.metrics.insert(name, value);
521    }
522
523    /// Save checkpoint to file
524    /// チェックポイントをファイルに保存
525    pub fn save_checkpoint<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()> {
526        save(self, path)
527    }
528
529    /// Load checkpoint from file
530    /// ファイルからチェックポイントを読み込み
531    pub fn load_checkpoint<P: AsRef<Path>>(path: P) -> SerializationResult<Self> {
532        load(path)
533    }
534}
535
536impl<T: Float + 'static> Saveable for ModelCheckpoint<T> {
537    fn save_binary(&self) -> SerializationResult<Vec<u8>> {
538        let mut buffer = Vec::new();
539
540        // Save epoch and step
541        buffer.extend_from_slice(&(self.epoch as u64).to_le_bytes());
542        buffer.extend_from_slice(&(self.step as u64).to_le_bytes());
543
544        // Save model state
545        let model_state_data = self.model_state.save_binary()?;
546        buffer.extend_from_slice(&(model_state_data.len() as u64).to_le_bytes());
547        buffer.extend_from_slice(&model_state_data);
548
549        // Save optimizer state
550        buffer.extend_from_slice(&(self.optimizer_state.len() as u32).to_le_bytes());
551        for (key, value) in &self.optimizer_state {
552            let key_bytes = key.as_bytes();
553            buffer.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
554            buffer.extend_from_slice(key_bytes);
555            buffer.extend_from_slice(&(value.len() as u64).to_le_bytes());
556            buffer.extend_from_slice(value);
557        }
558
559        // Save scheduler state
560        buffer.extend_from_slice(&(self.scheduler_state.len() as u32).to_le_bytes());
561        for (key, value) in &self.scheduler_state {
562            let key_bytes = key.as_bytes();
563            buffer.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
564            buffer.extend_from_slice(key_bytes);
565            buffer.extend_from_slice(&(value.len() as u64).to_le_bytes());
566            buffer.extend_from_slice(value);
567        }
568
569        // Save metrics
570        buffer.extend_from_slice(&(self.metrics.len() as u32).to_le_bytes());
571        for (key, value) in &self.metrics {
572            let key_bytes = key.as_bytes();
573            buffer.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
574            buffer.extend_from_slice(key_bytes);
575            buffer.extend_from_slice(&value.to_le_bytes());
576        }
577
578        // Save timestamp
579        buffer.extend_from_slice(&self.timestamp.to_le_bytes());
580
581        Ok(buffer)
582    }
583
584    fn type_id(&self) -> &'static str {
585        "model_checkpoint"
586    }
587
588    fn metadata(&self) -> HashMap<String, String> {
589        let mut meta = HashMap::new();
590        meta.insert("epoch".to_string(), self.epoch.to_string());
591        meta.insert("step".to_string(), self.step.to_string());
592        meta.insert("timestamp".to_string(), self.timestamp.to_string());
593        meta.insert(
594            "model_type".to_string(),
595            self.model_state.metadata.model_type.clone(),
596        );
597        meta
598    }
599}
600
601impl<T: Float + 'static> Loadable for ModelCheckpoint<T> {
602    fn load_binary(data: &[u8]) -> SerializationResult<Self> {
603        if data.is_empty() {
604            return Ok(Self::new(0, 0, StateDict::new()));
605        }
606
607        let mut offset = 0;
608        let mut checkpoint = Self::new(0, 0, StateDict::new());
609
610        // Read epoch and step
611        if data.len() < offset + 16 {
612            return Ok(checkpoint);
613        }
614        checkpoint.epoch = u64::from_le_bytes(
615            data[offset..offset + 8]
616                .try_into()
617                .map_err(|_| SerializationError::FormatError("Invalid epoch".to_string()))?,
618        ) as usize;
619        offset += 8;
620
621        checkpoint.step = u64::from_le_bytes(
622            data[offset..offset + 8]
623                .try_into()
624                .map_err(|_| SerializationError::FormatError("Invalid step".to_string()))?,
625        ) as usize;
626        offset += 8;
627
628        // Read model state
629        if data.len() < offset + 8 {
630            return Ok(checkpoint);
631        }
632        let model_state_len =
633            u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
634                SerializationError::FormatError("Invalid model state length".to_string())
635            })?) as usize;
636        offset += 8;
637
638        if data.len() < offset + model_state_len {
639            return Ok(checkpoint);
640        }
641        let model_state_data = &data[offset..offset + model_state_len];
642        if let Ok(model_state) = StateDict::<T>::load_binary(model_state_data) {
643            checkpoint.model_state = model_state;
644        }
645        offset += model_state_len;
646
647        // Read optimizer state
648        if data.len() < offset + 4 {
649            return Ok(checkpoint);
650        }
651        let optimizer_count =
652            u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
653                SerializationError::FormatError("Invalid optimizer count".to_string())
654            })?);
655        offset += 4;
656
657        for _ in 0..optimizer_count {
658            // Read key
659            if data.len() < offset + 4 {
660                break;
661            }
662            let key_len =
663                u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
664                    SerializationError::FormatError("Invalid key length".to_string())
665                })?) as usize;
666            offset += 4;
667
668            if data.len() < offset + key_len {
669                break;
670            }
671            let key = String::from_utf8(data[offset..offset + key_len].to_vec())
672                .map_err(|_| SerializationError::FormatError("Invalid key encoding".to_string()))?;
673            offset += key_len;
674
675            // Read value
676            if data.len() < offset + 8 {
677                break;
678            }
679            let value_len =
680                u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
681                    SerializationError::FormatError("Invalid value length".to_string())
682                })?) as usize;
683            offset += 8;
684
685            if data.len() < offset + value_len {
686                break;
687            }
688            let value = data[offset..offset + value_len].to_vec();
689            checkpoint.optimizer_state.insert(key, value);
690            offset += value_len;
691        }
692
693        // Read scheduler state (similar pattern)
694        if data.len() < offset + 4 {
695            return Ok(checkpoint);
696        }
697        let scheduler_count =
698            u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
699                SerializationError::FormatError("Invalid scheduler count".to_string())
700            })?);
701        offset += 4;
702
703        for _ in 0..scheduler_count {
704            // Read key
705            if data.len() < offset + 4 {
706                break;
707            }
708            let key_len =
709                u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
710                    SerializationError::FormatError("Invalid key length".to_string())
711                })?) as usize;
712            offset += 4;
713
714            if data.len() < offset + key_len {
715                break;
716            }
717            let key = String::from_utf8(data[offset..offset + key_len].to_vec())
718                .map_err(|_| SerializationError::FormatError("Invalid key encoding".to_string()))?;
719            offset += key_len;
720
721            // Read value
722            if data.len() < offset + 8 {
723                break;
724            }
725            let value_len =
726                u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
727                    SerializationError::FormatError("Invalid value length".to_string())
728                })?) as usize;
729            offset += 8;
730
731            if data.len() < offset + value_len {
732                break;
733            }
734            let value = data[offset..offset + value_len].to_vec();
735            checkpoint.scheduler_state.insert(key, value);
736            offset += value_len;
737        }
738
739        // Read metrics
740        if data.len() < offset + 4 {
741            return Ok(checkpoint);
742        }
743        let metrics_count =
744            u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
745                SerializationError::FormatError("Invalid metrics count".to_string())
746            })?);
747        offset += 4;
748
749        for _ in 0..metrics_count {
750            // Read key
751            if data.len() < offset + 4 {
752                break;
753            }
754            let key_len =
755                u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
756                    SerializationError::FormatError("Invalid key length".to_string())
757                })?) as usize;
758            offset += 4;
759
760            if data.len() < offset + key_len {
761                break;
762            }
763            let key = String::from_utf8(data[offset..offset + key_len].to_vec())
764                .map_err(|_| SerializationError::FormatError("Invalid key encoding".to_string()))?;
765            offset += key_len;
766
767            // Read value (f64)
768            if data.len() < offset + 8 {
769                break;
770            }
771            let value = f64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
772                SerializationError::FormatError("Invalid metric value".to_string())
773            })?);
774            checkpoint.metrics.insert(key, value);
775            offset += 8;
776        }
777
778        // Read timestamp
779        if data.len() >= offset + 8 {
780            checkpoint.timestamp =
781                u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
782                    SerializationError::FormatError("Invalid timestamp".to_string())
783                })?);
784        }
785
786        Ok(checkpoint)
787    }
788
789    fn expected_type_id() -> &'static str {
790        "model_checkpoint"
791    }
792}
793
794/// Tensor serialization utilities
795/// テンソルシリアライゼーションユーティリティ
796impl<T: Float + 'static> Saveable for Tensor<T> {
797    fn save_binary(&self) -> SerializationResult<Vec<u8>> {
798        let mut buffer = Vec::new();
799
800        // Serialize shape
801        let shape = self.shape();
802        buffer.extend_from_slice(&(shape.len() as u32).to_le_bytes());
803        for &dim in shape {
804            buffer.extend_from_slice(&(dim as u64).to_le_bytes());
805        }
806
807        // Serialize data
808        if let Some(data_slice) = self.data.as_slice() {
809            let byte_len = data_slice.len() * std::mem::size_of::<T>();
810            buffer.extend_from_slice(&(byte_len as u64).to_le_bytes());
811            let bytes =
812                unsafe { std::slice::from_raw_parts(data_slice.as_ptr() as *const u8, byte_len) };
813            buffer.extend_from_slice(bytes);
814        } else {
815            buffer.extend_from_slice(&(0u64).to_le_bytes());
816        }
817
818        Ok(buffer)
819    }
820
821    fn type_id(&self) -> &'static str {
822        "tensor"
823    }
824
825    fn metadata(&self) -> HashMap<String, String> {
826        self.get_metadata()
827    }
828}
829
830impl<T: Float + 'static> Loadable for Tensor<T> {
831    fn load_binary(data: &[u8]) -> SerializationResult<Self> {
832        let mut cursor = 0;
833
834        if data.len() < 4 {
835            return Err(SerializationError::FormatError(
836                "Insufficient data for tensor shape".to_string(),
837            ));
838        }
839
840        // Read shape length
841        let shape_len = u32::from_le_bytes([
842            data[cursor],
843            data[cursor + 1],
844            data[cursor + 2],
845            data[cursor + 3],
846        ]) as usize;
847        cursor += 4;
848
849        // Read shape
850        let mut shape = Vec::new();
851        for _ in 0..shape_len {
852            if cursor + 8 > data.len() {
853                return Err(SerializationError::FormatError(
854                    "Insufficient data for tensor shape".to_string(),
855                ));
856            }
857            let dim = u64::from_le_bytes([
858                data[cursor],
859                data[cursor + 1],
860                data[cursor + 2],
861                data[cursor + 3],
862                data[cursor + 4],
863                data[cursor + 5],
864                data[cursor + 6],
865                data[cursor + 7],
866            ]) as usize;
867            shape.push(dim);
868            cursor += 8;
869        }
870
871        // Read data length
872        if cursor + 8 > data.len() {
873            return Err(SerializationError::FormatError(
874                "Insufficient data for tensor data length".to_string(),
875            ));
876        }
877        let data_len = u64::from_le_bytes([
878            data[cursor],
879            data[cursor + 1],
880            data[cursor + 2],
881            data[cursor + 3],
882            data[cursor + 4],
883            data[cursor + 5],
884            data[cursor + 6],
885            data[cursor + 7],
886        ]) as usize;
887        cursor += 8;
888
889        // Read tensor data
890        if cursor + data_len > data.len() {
891            return Err(SerializationError::FormatError(
892                "Insufficient data for tensor data".to_string(),
893            ));
894        }
895
896        let expected_elements = shape.iter().product::<usize>();
897        let actual_elements = data_len / std::mem::size_of::<T>();
898
899        if actual_elements != expected_elements {
900            return Err(SerializationError::FormatError(format!(
901                "Shape/data mismatch: shape requires {} elements, data has {}",
902                expected_elements, actual_elements
903            )));
904        }
905
906        // Ensure proper alignment for T
907        let element_size = std::mem::size_of::<T>();
908        let ptr = data[cursor..cursor + data_len].as_ptr();
909
910        // Check alignment
911        if (ptr as usize) % std::mem::align_of::<T>() != 0 {
912            // If not aligned, copy to properly aligned buffer
913            let mut aligned_data = vec![0u8; data_len];
914            aligned_data.copy_from_slice(&data[cursor..cursor + data_len]);
915            let float_data = unsafe {
916                std::slice::from_raw_parts(aligned_data.as_ptr() as *const T, actual_elements)
917            };
918            return Ok(Tensor::from_vec(float_data.to_vec(), shape));
919        }
920
921        let float_data = unsafe { std::slice::from_raw_parts(ptr as *const T, actual_elements) };
922
923        Ok(Tensor::from_vec(float_data.to_vec(), shape))
924    }
925
926    fn expected_type_id() -> &'static str {
927        "tensor"
928    }
929}
930
931/// Model format detection utilities
932/// モデル形式検出ユーティリティ
933pub fn detect_format<P: AsRef<Path>>(path: P) -> SerializationResult<String> {
934    let file = File::open(path.as_ref())?;
935    let mut reader = BufReader::new(file);
936
937    // Read first 16 bytes to detect format
938    let mut magic = [0u8; 16];
939    reader.read_exact(&mut magic)?;
940
941    if &magic[0..8] == b"RUSTORCH" {
942        Ok("rustorch".to_string())
943    } else if &magic[0..4] == b"PKG\x00" {
944        Ok("pickle".to_string())
945    } else if &magic[0..8] == b"safetens" {
946        Ok("safetensors".to_string())
947    } else {
948        Err(SerializationError::FormatError(
949            "Unknown file format".to_string(),
950        ))
951    }
952}
953
954#[cfg(test)]
955mod tests {
956    use super::*;
957    use crate::tensor::Tensor;
958
959    #[test]
960    fn test_state_dict_creation() {
961        let mut state_dict = StateDict::<f32>::new();
962
963        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
964        state_dict.add_parameter("weight".to_string(), tensor);
965
966        assert!(state_dict.get_parameter("weight").is_some());
967        assert_eq!(state_dict.parameters.len(), 1);
968    }
969
970    #[test]
971    fn test_tensor_save_load() {
972        let tensor = Tensor::from_vec(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
973
974        let binary_data = tensor.save_binary().unwrap();
975        let loaded_tensor = Tensor::<f32>::load_binary(&binary_data).unwrap();
976
977        assert_eq!(tensor.shape(), loaded_tensor.shape());
978        assert_eq!(tensor.data.as_slice(), loaded_tensor.data.as_slice());
979    }
980
981    #[test]
982    fn test_format_detection() {
983        // Test would require actual files, so this is a placeholder
984        // In a real scenario, we would create test files with different formats
985    }
986
987    #[test]
988    fn test_model_checkpoint() {
989        let mut state_dict = StateDict::<f32>::new();
990        let tensor = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
991        state_dict.add_parameter("test_param".to_string(), tensor);
992
993        let checkpoint = ModelCheckpoint::new(5, 100, state_dict);
994
995        assert_eq!(checkpoint.epoch, 5);
996        assert_eq!(checkpoint.step, 100);
997        assert!(checkpoint.model_state.get_parameter("test_param").is_some());
998    }
999
1000    #[test]
1001    fn test_safe_tensor_format() {
1002        let mut safe_format = SafeTensorFormat::<f32>::new();
1003        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
1004        safe_format.add_tensor("test_tensor".to_string(), tensor);
1005
1006        assert_eq!(safe_format.tensors.len(), 1);
1007        assert!(safe_format.tensors.contains_key("test_tensor"));
1008    }
1009}