scirs2_io/ml_framework/converters/
huggingface.rs

1//! HuggingFace format converter
2#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::{MLFrameworkConverter, SafeTensorsConverter};
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use std::fs::File;
8use std::path::Path;
9
10/// HuggingFace format converter
11pub struct HuggingFaceConverter;
12
13impl MLFrameworkConverter for HuggingFaceConverter {
14    fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
15        // HuggingFace models typically use safetensors + config.json
16        let config_path = path.with_extension("json");
17        let weights_path = path.with_extension("safetensors");
18
19        // Save config
20        let config = serde_json::json!({
21            "architectures": [model.metadata.architecture],
22            "model_type": "custom",
23            "torch_dtype": "float32",
24            "_name_or_path": model.metadata.model_name,
25            "transformers_version": "4.30.0",
26            "config": model.config
27        });
28
29        let config_file = File::create(&config_path).map_err(IoError::Io)?;
30        serde_json::to_writer_pretty(config_file, &config)
31            .map_err(|e| IoError::SerializationError(e.to_string()))?;
32
33        // Save weights in SafeTensors format
34        let safetensors_converter = SafeTensorsConverter;
35        safetensors_converter.save_model(model, &weights_path)
36    }
37
38    fn load_model(&self, path: &Path) -> Result<MLModel> {
39        let config_path = path.with_extension("json");
40        let weights_path = path.with_extension("safetensors");
41
42        // Load config
43        let config_file = File::open(&config_path).map_err(IoError::Io)?;
44        let config: serde_json::Value = serde_json::from_reader(config_file)
45            .map_err(|e| IoError::SerializationError(e.to_string()))?;
46
47        // Load weights
48        let safetensors_converter = SafeTensorsConverter;
49        let mut model = safetensors_converter.load_model(&weights_path)?;
50
51        // Update with HuggingFace-specific metadata
52        model.metadata.framework = "HuggingFace".to_string();
53        if let Some(name) = config.get("_name_or_path").and_then(|v| v.as_str()) {
54            model.metadata.model_name = Some(name.to_string());
55        }
56        if let Some(arch) = config
57            .get("architectures")
58            .and_then(|v| v.as_array())
59            .and_then(|a| a.first())
60            .and_then(|v| v.as_str())
61        {
62            model.metadata.architecture = Some(arch.to_string());
63        }
64        if let Some(hf_config) = config.get("config") {
65            model.config = serde_json::from_value(hf_config.clone())
66                .map_err(|e| IoError::SerializationError(e.to_string()))?;
67        }
68
69        Ok(model)
70    }
71
72    fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
73        // Use SafeTensors format for individual tensors
74        let safetensors_converter = SafeTensorsConverter;
75        safetensors_converter.save_tensor(tensor, path)
76    }
77
78    fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
79        // Use SafeTensors format for individual tensors
80        let safetensors_converter = SafeTensorsConverter;
81        safetensors_converter.load_tensor(path)
82    }
83}