1use crate::ModelError;
11use std::collections::HashMap;
12use std::path::Path;
13use yscv_tensor::Tensor;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SafeTensorDType {
20 F32,
21 F16,
22 BF16,
23 I32,
24 I64,
25 U8,
26 Bool,
27}
28
29impl SafeTensorDType {
30 fn element_size(self) -> usize {
32 match self {
33 Self::F32 | Self::I32 => 4,
34 Self::F16 | Self::BF16 => 2,
35 Self::I64 => 8,
36 Self::U8 | Self::Bool => 1,
37 }
38 }
39
40 fn from_str(s: &str) -> Result<Self, ModelError> {
41 match s {
42 "F32" => Ok(Self::F32),
43 "F16" => Ok(Self::F16),
44 "BF16" => Ok(Self::BF16),
45 "I32" => Ok(Self::I32),
46 "I64" => Ok(Self::I64),
47 "U8" => Ok(Self::U8),
48 "BOOL" => Ok(Self::Bool),
49 other => Err(ModelError::SafeTensorsParse {
50 message: format!("unsupported dtype: {other}"),
51 }),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct TensorInfo {
59 pub dtype: SafeTensorDType,
60 pub shape: Vec<usize>,
61 pub data_offsets: (usize, usize),
62}
63
64pub struct SafeTensorFile {
66 tensors: HashMap<String, TensorInfo>,
68 data: Vec<u8>,
70}
71
72impl SafeTensorFile {
73 pub fn from_file(path: &Path) -> Result<Self, ModelError> {
75 let bytes = std::fs::read(path).map_err(|e| ModelError::SafeTensorsIo {
76 path: path.display().to_string(),
77 message: e.to_string(),
78 })?;
79 Self::from_bytes(&bytes)
80 }
81
82 pub fn from_bytes(bytes: &[u8]) -> Result<Self, ModelError> {
84 if bytes.len() < 8 {
85 return Err(ModelError::SafeTensorsParse {
86 message: "file too small: missing header length".into(),
87 });
88 }
89
90 let header_len = u64::from_le_bytes([
91 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
92 ]) as usize;
93
94 let header_end =
95 8usize
96 .checked_add(header_len)
97 .ok_or_else(|| ModelError::SafeTensorsParse {
98 message: "header length overflow".into(),
99 })?;
100
101 if bytes.len() < header_end {
102 return Err(ModelError::SafeTensorsParse {
103 message: format!(
104 "file too small for header: need {} bytes, have {}",
105 header_end,
106 bytes.len()
107 ),
108 });
109 }
110
111 let header_str = std::str::from_utf8(&bytes[8..header_end]).map_err(|e| {
112 ModelError::SafeTensorsParse {
113 message: format!("header is not valid UTF-8: {e}"),
114 }
115 })?;
116
117 let header_map: serde_json::Map<String, serde_json::Value> =
118 serde_json::from_str(header_str).map_err(|e| ModelError::SafeTensorsParse {
119 message: format!("invalid JSON header: {e}"),
120 })?;
121
122 let mut tensors = HashMap::new();
123 for (name, value) in &header_map {
124 if name == "__metadata__" {
125 continue;
126 }
127 let obj = value
128 .as_object()
129 .ok_or_else(|| ModelError::SafeTensorsParse {
130 message: format!("tensor entry '{name}' is not a JSON object"),
131 })?;
132
133 let dtype_str = obj.get("dtype").and_then(|v| v.as_str()).ok_or_else(|| {
134 ModelError::SafeTensorsParse {
135 message: format!("tensor '{name}' missing 'dtype' string"),
136 }
137 })?;
138
139 let dtype = SafeTensorDType::from_str(dtype_str)?;
140
141 let shape_arr = obj.get("shape").and_then(|v| v.as_array()).ok_or_else(|| {
142 ModelError::SafeTensorsParse {
143 message: format!("tensor '{name}' missing 'shape' array"),
144 }
145 })?;
146
147 let shape: Vec<usize> = shape_arr
148 .iter()
149 .map(|v| {
150 v.as_u64()
151 .map(|n| n as usize)
152 .ok_or_else(|| ModelError::SafeTensorsParse {
153 message: format!("tensor '{name}' shape contains non-integer"),
154 })
155 })
156 .collect::<Result<_, _>>()?;
157
158 let offsets_arr = obj
159 .get("data_offsets")
160 .and_then(|v| v.as_array())
161 .ok_or_else(|| ModelError::SafeTensorsParse {
162 message: format!("tensor '{name}' missing 'data_offsets' array"),
163 })?;
164
165 if offsets_arr.len() != 2 {
166 return Err(ModelError::SafeTensorsParse {
167 message: format!("tensor '{name}' data_offsets must have exactly 2 elements"),
168 });
169 }
170
171 let start = offsets_arr[0]
172 .as_u64()
173 .ok_or_else(|| ModelError::SafeTensorsParse {
174 message: format!("tensor '{name}' data_offsets[0] is not an integer"),
175 })? as usize;
176 let end = offsets_arr[1]
177 .as_u64()
178 .ok_or_else(|| ModelError::SafeTensorsParse {
179 message: format!("tensor '{name}' data_offsets[1] is not an integer"),
180 })? as usize;
181
182 tensors.insert(
183 name.clone(),
184 TensorInfo {
185 dtype,
186 shape,
187 data_offsets: (start, end),
188 },
189 );
190 }
191
192 let data = bytes[header_end..].to_vec();
193
194 Ok(Self { tensors, data })
195 }
196
197 pub fn tensor_names(&self) -> Vec<&str> {
199 self.tensors.keys().map(|s| s.as_str()).collect()
200 }
201
202 pub fn tensor_info(&self, name: &str) -> Option<TensorInfo> {
204 self.tensors.get(name).cloned()
205 }
206
207 pub fn load_tensor(&self, name: &str) -> Result<Tensor, ModelError> {
212 let info = self
213 .tensors
214 .get(name)
215 .ok_or_else(|| ModelError::SafeTensorsParse {
216 message: format!("tensor '{name}' not found"),
217 })?;
218
219 let (start, end) = info.data_offsets;
220 if end > self.data.len() || start > end {
221 return Err(ModelError::SafeTensorsParse {
222 message: format!(
223 "tensor '{name}' data_offsets [{start}, {end}) out of bounds (data len = {})",
224 self.data.len()
225 ),
226 });
227 }
228
229 let raw = &self.data[start..end];
230 let elem_size = info.dtype.element_size();
231 let expected_elements: usize = info.shape.iter().copied().product();
232 let expected_bytes = expected_elements * elem_size;
233
234 if raw.len() != expected_bytes {
235 return Err(ModelError::SafeTensorsParse {
236 message: format!(
237 "tensor '{name}' expected {expected_bytes} bytes, got {}",
238 raw.len()
239 ),
240 });
241 }
242
243 match info.dtype {
244 SafeTensorDType::F32 => {
245 let f32_data: Vec<f32> = raw
246 .chunks_exact(4)
247 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
248 .collect();
249 Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
250 }
251 SafeTensorDType::F16 => {
252 let u16_data: Vec<u16> = raw
254 .chunks_exact(2)
255 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
256 .collect();
257 let f16_tensor = Tensor::from_f16(info.shape.clone(), u16_data)?;
258 Ok(f16_tensor.to_dtype(yscv_tensor::DType::F32))
259 }
260 SafeTensorDType::BF16 => {
261 let u16_data: Vec<u16> = raw
262 .chunks_exact(2)
263 .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
264 .collect();
265 let bf16_tensor = Tensor::from_bf16(info.shape.clone(), u16_data)?;
266 Ok(bf16_tensor.to_dtype(yscv_tensor::DType::F32))
267 }
268 SafeTensorDType::I32 => {
269 let f32_data: Vec<f32> = raw
270 .chunks_exact(4)
271 .map(|chunk| {
272 i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32
273 })
274 .collect();
275 Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
276 }
277 SafeTensorDType::I64 => {
278 let f32_data: Vec<f32> = raw
279 .chunks_exact(8)
280 .map(|chunk| {
281 i64::from_le_bytes([
282 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
283 chunk[7],
284 ]) as f32
285 })
286 .collect();
287 Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
288 }
289 SafeTensorDType::U8 | SafeTensorDType::Bool => {
290 let f32_data: Vec<f32> = raw.iter().map(|&b| b as f32).collect();
291 Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
292 }
293 }
294 }
295}
296
297pub fn load_state_dict(path: &Path) -> Result<HashMap<String, Tensor>, ModelError> {
301 let file = SafeTensorFile::from_file(path)?;
302 let mut map = HashMap::new();
303 for name in file.tensor_names() {
304 let name_owned = name.to_string();
305 let tensor = file.load_tensor(name)?;
306 map.insert(name_owned, tensor);
307 }
308 Ok(map)
309}