Skip to main content

yscv_model/
safetensors.rs

1//! SafeTensors file parser and pretrained weight loader.
2//!
3//! The [SafeTensors](https://huggingface.co/docs/safetensors/) format stores
4//! tensors in a simple binary layout:
5//!
6//! 1. 8 bytes — little-endian `u64` header length
7//! 2. N bytes — UTF-8 JSON header mapping tensor names to metadata
8//! 3. Remaining bytes — contiguous raw tensor data
9
10use crate::ModelError;
11use std::collections::HashMap;
12use std::path::Path;
13use yscv_tensor::Tensor;
14
15// ── Public types ────────────────────────────────────────────────────
16
17/// Supported element types in a SafeTensors file.
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum SafeTensorDType {
20    F32,
21    F16,
22    BF16,
23    I32,
24    I64,
25    U8,
26    Bool,
27}
28
29impl SafeTensorDType {
30    /// Number of bytes per element.
31    fn element_size(self) -> usize {
32        match self {
33            Self::F32 | Self::I32 => 4,
34            Self::F16 | Self::BF16 => 2,
35            Self::I64 => 8,
36            Self::U8 | Self::Bool => 1,
37        }
38    }
39
40    fn from_str(s: &str) -> Result<Self, ModelError> {
41        match s {
42            "F32" => Ok(Self::F32),
43            "F16" => Ok(Self::F16),
44            "BF16" => Ok(Self::BF16),
45            "I32" => Ok(Self::I32),
46            "I64" => Ok(Self::I64),
47            "U8" => Ok(Self::U8),
48            "BOOL" => Ok(Self::Bool),
49            other => Err(ModelError::SafeTensorsParse {
50                message: format!("unsupported dtype: {other}"),
51            }),
52        }
53    }
54}
55
56/// Per-tensor metadata extracted from the SafeTensors JSON header.
57#[derive(Debug, Clone)]
58pub struct TensorInfo {
59    pub dtype: SafeTensorDType,
60    pub shape: Vec<usize>,
61    pub data_offsets: (usize, usize),
62}
63
64/// A parsed SafeTensors file backed by an in-memory byte buffer.
65pub struct SafeTensorFile {
66    /// Parsed tensor metadata, keyed by name.
67    tensors: HashMap<String, TensorInfo>,
68    /// The raw data section (everything after the JSON header).
69    data: Vec<u8>,
70}
71
72impl SafeTensorFile {
73    /// Parse a SafeTensors file from disk.
74    pub fn from_file(path: &Path) -> Result<Self, ModelError> {
75        let bytes = std::fs::read(path).map_err(|e| ModelError::SafeTensorsIo {
76            path: path.display().to_string(),
77            message: e.to_string(),
78        })?;
79        Self::from_bytes(&bytes)
80    }
81
82    /// Parse a SafeTensors file from an in-memory byte slice.
83    pub fn from_bytes(bytes: &[u8]) -> Result<Self, ModelError> {
84        if bytes.len() < 8 {
85            return Err(ModelError::SafeTensorsParse {
86                message: "file too small: missing header length".into(),
87            });
88        }
89
90        let header_len = u64::from_le_bytes([
91            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
92        ]) as usize;
93
94        let header_end =
95            8usize
96                .checked_add(header_len)
97                .ok_or_else(|| ModelError::SafeTensorsParse {
98                    message: "header length overflow".into(),
99                })?;
100
101        if bytes.len() < header_end {
102            return Err(ModelError::SafeTensorsParse {
103                message: format!(
104                    "file too small for header: need {} bytes, have {}",
105                    header_end,
106                    bytes.len()
107                ),
108            });
109        }
110
111        let header_str = std::str::from_utf8(&bytes[8..header_end]).map_err(|e| {
112            ModelError::SafeTensorsParse {
113                message: format!("header is not valid UTF-8: {e}"),
114            }
115        })?;
116
117        let header_map: serde_json::Map<String, serde_json::Value> =
118            serde_json::from_str(header_str).map_err(|e| ModelError::SafeTensorsParse {
119                message: format!("invalid JSON header: {e}"),
120            })?;
121
122        let mut tensors = HashMap::new();
123        for (name, value) in &header_map {
124            if name == "__metadata__" {
125                continue;
126            }
127            let obj = value
128                .as_object()
129                .ok_or_else(|| ModelError::SafeTensorsParse {
130                    message: format!("tensor entry '{name}' is not a JSON object"),
131                })?;
132
133            let dtype_str = obj.get("dtype").and_then(|v| v.as_str()).ok_or_else(|| {
134                ModelError::SafeTensorsParse {
135                    message: format!("tensor '{name}' missing 'dtype' string"),
136                }
137            })?;
138
139            let dtype = SafeTensorDType::from_str(dtype_str)?;
140
141            let shape_arr = obj.get("shape").and_then(|v| v.as_array()).ok_or_else(|| {
142                ModelError::SafeTensorsParse {
143                    message: format!("tensor '{name}' missing 'shape' array"),
144                }
145            })?;
146
147            let shape: Vec<usize> = shape_arr
148                .iter()
149                .map(|v| {
150                    v.as_u64()
151                        .map(|n| n as usize)
152                        .ok_or_else(|| ModelError::SafeTensorsParse {
153                            message: format!("tensor '{name}' shape contains non-integer"),
154                        })
155                })
156                .collect::<Result<_, _>>()?;
157
158            let offsets_arr = obj
159                .get("data_offsets")
160                .and_then(|v| v.as_array())
161                .ok_or_else(|| ModelError::SafeTensorsParse {
162                    message: format!("tensor '{name}' missing 'data_offsets' array"),
163                })?;
164
165            if offsets_arr.len() != 2 {
166                return Err(ModelError::SafeTensorsParse {
167                    message: format!("tensor '{name}' data_offsets must have exactly 2 elements"),
168                });
169            }
170
171            let start = offsets_arr[0]
172                .as_u64()
173                .ok_or_else(|| ModelError::SafeTensorsParse {
174                    message: format!("tensor '{name}' data_offsets[0] is not an integer"),
175                })? as usize;
176            let end = offsets_arr[1]
177                .as_u64()
178                .ok_or_else(|| ModelError::SafeTensorsParse {
179                    message: format!("tensor '{name}' data_offsets[1] is not an integer"),
180                })? as usize;
181
182            tensors.insert(
183                name.clone(),
184                TensorInfo {
185                    dtype,
186                    shape,
187                    data_offsets: (start, end),
188                },
189            );
190        }
191
192        let data = bytes[header_end..].to_vec();
193
194        Ok(Self { tensors, data })
195    }
196
197    /// Returns a list of all tensor names in the file.
198    pub fn tensor_names(&self) -> Vec<&str> {
199        self.tensors.keys().map(|s| s.as_str()).collect()
200    }
201
202    /// Returns metadata for a tensor by name.
203    pub fn tensor_info(&self, name: &str) -> Option<TensorInfo> {
204        self.tensors.get(name).cloned()
205    }
206
207    /// Load a single tensor by name, converting to F32 if necessary.
208    ///
209    /// F16 and BF16 data are converted to F32. I32, I64, U8, and Bool are
210    /// converted to F32 by casting each element.
211    pub fn load_tensor(&self, name: &str) -> Result<Tensor, ModelError> {
212        let info = self
213            .tensors
214            .get(name)
215            .ok_or_else(|| ModelError::SafeTensorsParse {
216                message: format!("tensor '{name}' not found"),
217            })?;
218
219        let (start, end) = info.data_offsets;
220        if end > self.data.len() || start > end {
221            return Err(ModelError::SafeTensorsParse {
222                message: format!(
223                    "tensor '{name}' data_offsets [{start}, {end}) out of bounds (data len = {})",
224                    self.data.len()
225                ),
226            });
227        }
228
229        let raw = &self.data[start..end];
230        let elem_size = info.dtype.element_size();
231        let expected_elements: usize = info.shape.iter().copied().product();
232        let expected_bytes = expected_elements * elem_size;
233
234        if raw.len() != expected_bytes {
235            return Err(ModelError::SafeTensorsParse {
236                message: format!(
237                    "tensor '{name}' expected {expected_bytes} bytes, got {}",
238                    raw.len()
239                ),
240            });
241        }
242
243        match info.dtype {
244            SafeTensorDType::F32 => {
245                let f32_data: Vec<f32> = raw
246                    .chunks_exact(4)
247                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
248                    .collect();
249                Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
250            }
251            SafeTensorDType::F16 => {
252                // Load as u16 bit patterns, then use Tensor's built-in to_dtype for F16->F32
253                let u16_data: Vec<u16> = raw
254                    .chunks_exact(2)
255                    .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
256                    .collect();
257                let f16_tensor = Tensor::from_f16(info.shape.clone(), u16_data)?;
258                Ok(f16_tensor.to_dtype(yscv_tensor::DType::F32))
259            }
260            SafeTensorDType::BF16 => {
261                let u16_data: Vec<u16> = raw
262                    .chunks_exact(2)
263                    .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
264                    .collect();
265                let bf16_tensor = Tensor::from_bf16(info.shape.clone(), u16_data)?;
266                Ok(bf16_tensor.to_dtype(yscv_tensor::DType::F32))
267            }
268            SafeTensorDType::I32 => {
269                let f32_data: Vec<f32> = raw
270                    .chunks_exact(4)
271                    .map(|chunk| {
272                        i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as f32
273                    })
274                    .collect();
275                Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
276            }
277            SafeTensorDType::I64 => {
278                let f32_data: Vec<f32> = raw
279                    .chunks_exact(8)
280                    .map(|chunk| {
281                        i64::from_le_bytes([
282                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
283                            chunk[7],
284                        ]) as f32
285                    })
286                    .collect();
287                Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
288            }
289            SafeTensorDType::U8 | SafeTensorDType::Bool => {
290                let f32_data: Vec<f32> = raw.iter().map(|&b| b as f32).collect();
291                Ok(Tensor::from_vec(info.shape.clone(), f32_data)?)
292            }
293        }
294    }
295}
296
297/// Load all tensors from a SafeTensors file into a name-to-tensor map.
298///
299/// All tensors are converted to F32.
300pub fn load_state_dict(path: &Path) -> Result<HashMap<String, Tensor>, ModelError> {
301    let file = SafeTensorFile::from_file(path)?;
302    let mut map = HashMap::new();
303    for name in file.tensor_names() {
304        let name_owned = name.to_string();
305        let tensor = file.load_tensor(name)?;
306        map.insert(name_owned, tensor);
307    }
308    Ok(map)
309}