Skip to main content

yscv_model/
weights.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use yscv_tensor::Tensor;
5
6use crate::ModelError;
7
8/// Lightweight safetensors-compatible weight file format.
9///
10/// File format: JSON header (length-prefixed) + raw f32 data.
11/// Header: `{ "tensor_name": { "shape": [d0, d1, ...], "offset": N, "length": M }, ... }`
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13struct TensorMeta {
14    shape: Vec<usize>,
15    offset: usize,
16    length: usize,
17}
18
19/// Saves a named set of tensors to a binary file (safetensors-like format).
20///
21/// Format: [8 bytes: header_len as u64 LE] [header JSON bytes] [raw f32 data].
22pub fn save_weights(path: &Path, tensors: &HashMap<String, Tensor>) -> Result<(), ModelError> {
23    let mut meta_map: HashMap<String, TensorMeta> = HashMap::new();
24    let mut raw_data: Vec<u8> = Vec::new();
25
26    let mut names: Vec<&String> = tensors.keys().collect();
27    names.sort();
28
29    for name in &names {
30        let t = &tensors[*name];
31        let offset = raw_data.len();
32        let bytes = f32_slice_to_bytes(t.data());
33        let byte_len = bytes.len();
34        raw_data.extend_from_slice(&bytes);
35        meta_map.insert(
36            (*name).clone(),
37            TensorMeta {
38                shape: t.shape().to_vec(),
39                offset,
40                length: byte_len,
41            },
42        );
43    }
44
45    let header_json =
46        serde_json::to_string(&meta_map).map_err(|e| ModelError::CheckpointSerialization {
47            message: e.to_string(),
48        })?;
49    let header_bytes = header_json.as_bytes();
50    let header_len = header_bytes.len() as u64;
51
52    let mut file_data = Vec::new();
53    file_data.extend_from_slice(&header_len.to_le_bytes());
54    file_data.extend_from_slice(header_bytes);
55    file_data.extend_from_slice(&raw_data);
56
57    std::fs::write(path, &file_data).map_err(|e| ModelError::DatasetLoadIo {
58        path: path.display().to_string(),
59        message: e.to_string(),
60    })
61}
62
63/// Loads named tensors from a binary weight file.
64pub fn load_weights(path: &Path) -> Result<HashMap<String, Tensor>, ModelError> {
65    let file_data = std::fs::read(path).map_err(|e| ModelError::DatasetLoadIo {
66        path: path.display().to_string(),
67        message: e.to_string(),
68    })?;
69
70    if file_data.len() < 8 {
71        return Err(ModelError::CheckpointSerialization {
72            message: "weight file too small".into(),
73        });
74    }
75
76    let header_len = u64::from_le_bytes(file_data[..8].try_into().expect("8-byte slice")) as usize;
77    if file_data.len() < 8 + header_len {
78        return Err(ModelError::CheckpointSerialization {
79            message: "weight file header truncated".into(),
80        });
81    }
82
83    let header_str = std::str::from_utf8(&file_data[8..8 + header_len]).map_err(|e| {
84        ModelError::CheckpointSerialization {
85            message: e.to_string(),
86        }
87    })?;
88    let meta_map: HashMap<String, TensorMeta> =
89        serde_json::from_str(header_str).map_err(|e| ModelError::CheckpointSerialization {
90            message: e.to_string(),
91        })?;
92
93    let data_start = 8 + header_len;
94    let raw = &file_data[data_start..];
95
96    let mut tensors = HashMap::new();
97    for (name, meta) in &meta_map {
98        if meta.offset + meta.length > raw.len() {
99            return Err(ModelError::CheckpointSerialization {
100                message: format!("tensor '{name}' data out of bounds"),
101            });
102        }
103        let bytes = &raw[meta.offset..meta.offset + meta.length];
104        let f32_data = bytes_to_f32_vec(bytes);
105        let t = Tensor::from_vec(meta.shape.clone(), f32_data)?;
106        tensors.insert(name.clone(), t);
107    }
108
109    Ok(tensors)
110}
111
112/// Lists tensor names and shapes from a weight file without loading data.
113pub fn inspect_weights(path: &Path) -> Result<HashMap<String, Vec<usize>>, ModelError> {
114    let file_data = std::fs::read(path).map_err(|e| ModelError::DatasetLoadIo {
115        path: path.display().to_string(),
116        message: e.to_string(),
117    })?;
118    if file_data.len() < 8 {
119        return Err(ModelError::CheckpointSerialization {
120            message: "weight file too small".into(),
121        });
122    }
123    let header_len = u64::from_le_bytes(file_data[..8].try_into().expect("8-byte slice")) as usize;
124    let header_str = std::str::from_utf8(&file_data[8..8 + header_len]).map_err(|e| {
125        ModelError::CheckpointSerialization {
126            message: e.to_string(),
127        }
128    })?;
129    let meta_map: HashMap<String, TensorMeta> =
130        serde_json::from_str(header_str).map_err(|e| ModelError::CheckpointSerialization {
131            message: e.to_string(),
132        })?;
133    Ok(meta_map.into_iter().map(|(k, v)| (k, v.shape)).collect())
134}
135
136fn f32_slice_to_bytes(data: &[f32]) -> Vec<u8> {
137    let mut bytes = Vec::with_capacity(data.len() * 4);
138    for &v in data {
139        bytes.extend_from_slice(&v.to_le_bytes());
140    }
141    bytes
142}
143
144fn bytes_to_f32_vec(data: &[u8]) -> Vec<f32> {
145    assert!(
146        data.len().is_multiple_of(4),
147        "byte slice length must be multiple of 4"
148    );
149    data.chunks_exact(4)
150        .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
151        .collect()
152}