scirs2_io/ml_framework/converters/
pytorch.rs

1//! PyTorch format converter
2#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::MLFrameworkConverter;
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use crate::ml_framework::utils::{python_dict_to_tensor, tensor_to_python_dict};
8use std::collections::HashMap;
9use std::fs::File;
10use std::path::Path;
11
12/// PyTorch format converter
13pub struct PyTorchConverter;
14
15impl MLFrameworkConverter for PyTorchConverter {
16    fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
17        // Save in a PyTorch-compatible format (simplified)
18        let mut state_dict = HashMap::new();
19
20        for (name, tensor) in &model.weights {
21            state_dict.insert(name.clone(), tensor_to_python_dict(tensor)?);
22        }
23
24        let model_dict = serde_json::json!({
25            "state_dict": state_dict,
26            "metadata": model.metadata,
27            "config": model.config,
28        });
29
30        let file = File::create(path).map_err(IoError::Io)?;
31        serde_json::to_writer_pretty(file, &model_dict)
32            .map_err(|e| IoError::SerializationError(e.to_string()))
33    }
34
35    fn load_model(&self, path: &Path) -> Result<MLModel> {
36        let file = File::open(path).map_err(IoError::Io)?;
37        let model_dict: serde_json::Value = serde_json::from_reader(file)
38            .map_err(|e| IoError::SerializationError(e.to_string()))?;
39
40        let mut model = MLModel::new(MLFramework::PyTorch);
41
42        if let Some(metadata) = model_dict.get("metadata") {
43            model.metadata = serde_json::from_value(metadata.clone())
44                .map_err(|e| IoError::SerializationError(e.to_string()))?;
45        }
46
47        if let Some(config) = model_dict.get("config") {
48            model.config = serde_json::from_value(config.clone())
49                .map_err(|e| IoError::SerializationError(e.to_string()))?;
50        }
51
52        if let Some(state_dict) = model_dict.get("state_dict").and_then(|v| v.as_object()) {
53            for (name, tensor_data) in state_dict {
54                let tensor = python_dict_to_tensor(tensor_data)?;
55                model.weights.insert(name.clone(), tensor);
56            }
57        }
58
59        Ok(model)
60    }
61
62    fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
63        let tensor_dict = tensor_to_python_dict(tensor)?;
64        let file = File::create(path).map_err(IoError::Io)?;
65        serde_json::to_writer_pretty(file, &tensor_dict)
66            .map_err(|e| IoError::SerializationError(e.to_string()))
67    }
68
69    fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
70        let file = File::open(path).map_err(IoError::Io)?;
71        let tensor_dict: serde_json::Value = serde_json::from_reader(file)
72            .map_err(|e| IoError::SerializationError(e.to_string()))?;
73        python_dict_to_tensor(&tensor_dict)
74    }
75}