1use std::collections::HashMap;
2use std::path::Path;
3
4use yscv_tensor::Tensor;
5
6use crate::ModelError;
7
8#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13struct TensorMeta {
14 shape: Vec<usize>,
15 offset: usize,
16 length: usize,
17}
18
19pub fn save_weights(path: &Path, tensors: &HashMap<String, Tensor>) -> Result<(), ModelError> {
23 let mut meta_map: HashMap<String, TensorMeta> = HashMap::new();
24 let mut raw_data: Vec<u8> = Vec::new();
25
26 let mut names: Vec<&String> = tensors.keys().collect();
27 names.sort();
28
29 for name in &names {
30 let t = &tensors[*name];
31 let offset = raw_data.len();
32 let bytes = f32_slice_to_bytes(t.data());
33 let byte_len = bytes.len();
34 raw_data.extend_from_slice(&bytes);
35 meta_map.insert(
36 (*name).clone(),
37 TensorMeta {
38 shape: t.shape().to_vec(),
39 offset,
40 length: byte_len,
41 },
42 );
43 }
44
45 let header_json =
46 serde_json::to_string(&meta_map).map_err(|e| ModelError::CheckpointSerialization {
47 message: e.to_string(),
48 })?;
49 let header_bytes = header_json.as_bytes();
50 let header_len = header_bytes.len() as u64;
51
52 let mut file_data = Vec::new();
53 file_data.extend_from_slice(&header_len.to_le_bytes());
54 file_data.extend_from_slice(header_bytes);
55 file_data.extend_from_slice(&raw_data);
56
57 std::fs::write(path, &file_data).map_err(|e| ModelError::DatasetLoadIo {
58 path: path.display().to_string(),
59 message: e.to_string(),
60 })
61}
62
63pub fn load_weights(path: &Path) -> Result<HashMap<String, Tensor>, ModelError> {
65 let file_data = std::fs::read(path).map_err(|e| ModelError::DatasetLoadIo {
66 path: path.display().to_string(),
67 message: e.to_string(),
68 })?;
69
70 if file_data.len() < 8 {
71 return Err(ModelError::CheckpointSerialization {
72 message: "weight file too small".into(),
73 });
74 }
75
76 let header_len = u64::from_le_bytes(file_data[..8].try_into().expect("8-byte slice")) as usize;
77 if file_data.len() < 8 + header_len {
78 return Err(ModelError::CheckpointSerialization {
79 message: "weight file header truncated".into(),
80 });
81 }
82
83 let header_str = std::str::from_utf8(&file_data[8..8 + header_len]).map_err(|e| {
84 ModelError::CheckpointSerialization {
85 message: e.to_string(),
86 }
87 })?;
88 let meta_map: HashMap<String, TensorMeta> =
89 serde_json::from_str(header_str).map_err(|e| ModelError::CheckpointSerialization {
90 message: e.to_string(),
91 })?;
92
93 let data_start = 8 + header_len;
94 let raw = &file_data[data_start..];
95
96 let mut tensors = HashMap::new();
97 for (name, meta) in &meta_map {
98 if meta.offset + meta.length > raw.len() {
99 return Err(ModelError::CheckpointSerialization {
100 message: format!("tensor '{name}' data out of bounds"),
101 });
102 }
103 let bytes = &raw[meta.offset..meta.offset + meta.length];
104 let f32_data = bytes_to_f32_vec(bytes);
105 let t = Tensor::from_vec(meta.shape.clone(), f32_data)?;
106 tensors.insert(name.clone(), t);
107 }
108
109 Ok(tensors)
110}
111
112pub fn inspect_weights(path: &Path) -> Result<HashMap<String, Vec<usize>>, ModelError> {
114 let file_data = std::fs::read(path).map_err(|e| ModelError::DatasetLoadIo {
115 path: path.display().to_string(),
116 message: e.to_string(),
117 })?;
118 if file_data.len() < 8 {
119 return Err(ModelError::CheckpointSerialization {
120 message: "weight file too small".into(),
121 });
122 }
123 let header_len = u64::from_le_bytes(file_data[..8].try_into().expect("8-byte slice")) as usize;
124 let header_str = std::str::from_utf8(&file_data[8..8 + header_len]).map_err(|e| {
125 ModelError::CheckpointSerialization {
126 message: e.to_string(),
127 }
128 })?;
129 let meta_map: HashMap<String, TensorMeta> =
130 serde_json::from_str(header_str).map_err(|e| ModelError::CheckpointSerialization {
131 message: e.to_string(),
132 })?;
133 Ok(meta_map.into_iter().map(|(k, v)| (k, v.shape)).collect())
134}
135
136fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
137 let mut bytes = Vec::with_capacity(data.len() * 4);
138 for &v in data {
139 bytes.extend_from_slice(&v.to_le_bytes());
140 }
141 bytes
142}
143
144fn bytes_to_f32_vec(data: &[u8]) -> Vec<f32> {
145 assert!(
146 data.len().is_multiple_of(4),
147 "byte slice length must be multiple of 4"
148 );
149 data.chunks_exact(4)
150 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
151 .collect()
152}