scirs2_io/ml_framework/converters/
huggingface.rs1#![allow(dead_code)]
3
4use crate::error::{IoError, Result};
5use crate::ml_framework::converters::{MLFrameworkConverter, SafeTensorsConverter};
6use crate::ml_framework::types::{MLFramework, MLModel, MLTensor};
7use std::fs::File;
8use std::path::Path;
9
10pub struct HuggingFaceConverter;
12
13impl MLFrameworkConverter for HuggingFaceConverter {
14 fn save_model(&self, model: &MLModel, path: &Path) -> Result<()> {
15 let config_path = path.with_extension("json");
17 let weights_path = path.with_extension("safetensors");
18
19 let config = serde_json::json!({
21 "architectures": [model.metadata.architecture],
22 "model_type": "custom",
23 "torch_dtype": "float32",
24 "_name_or_path": model.metadata.model_name,
25 "transformers_version": "4.30.0",
26 "config": model.config
27 });
28
29 let config_file = File::create(&config_path).map_err(IoError::Io)?;
30 serde_json::to_writer_pretty(config_file, &config)
31 .map_err(|e| IoError::SerializationError(e.to_string()))?;
32
33 let safetensors_converter = SafeTensorsConverter;
35 safetensors_converter.save_model(model, &weights_path)
36 }
37
38 fn load_model(&self, path: &Path) -> Result<MLModel> {
39 let config_path = path.with_extension("json");
40 let weights_path = path.with_extension("safetensors");
41
42 let config_file = File::open(&config_path).map_err(IoError::Io)?;
44 let config: serde_json::Value = serde_json::from_reader(config_file)
45 .map_err(|e| IoError::SerializationError(e.to_string()))?;
46
47 let safetensors_converter = SafeTensorsConverter;
49 let mut model = safetensors_converter.load_model(&weights_path)?;
50
51 model.metadata.framework = "HuggingFace".to_string();
53 if let Some(name) = config.get("_name_or_path").and_then(|v| v.as_str()) {
54 model.metadata.model_name = Some(name.to_string());
55 }
56 if let Some(arch) = config
57 .get("architectures")
58 .and_then(|v| v.as_array())
59 .and_then(|a| a.first())
60 .and_then(|v| v.as_str())
61 {
62 model.metadata.architecture = Some(arch.to_string());
63 }
64 if let Some(hf_config) = config.get("config") {
65 model.config = serde_json::from_value(hf_config.clone())
66 .map_err(|e| IoError::SerializationError(e.to_string()))?;
67 }
68
69 Ok(model)
70 }
71
72 fn save_tensor(&self, tensor: &MLTensor, path: &Path) -> Result<()> {
73 let safetensors_converter = SafeTensorsConverter;
75 safetensors_converter.save_tensor(tensor, path)
76 }
77
78 fn load_tensor(&self, path: &Path) -> Result<MLTensor> {
79 let safetensors_converter = SafeTensorsConverter;
81 safetensors_converter.load_tensor(path)
82 }
83}