1use crate::{backend::Backend, dtype::DType, error::ZyxError, shape::Shape, tensor::Tensor};
2use alloc::{string::String, vec::Vec};
3use core::fmt::Write as CoreFmtWrite;
4use std::fs::File;
5use std::io::{Read, Write};
6use std::path::Path;
7
8pub trait ModuleIO {
11 fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError>;
13 fn load(self, path: impl AsRef<Path>) -> Result<(), ZyxError>;
15}
16
17impl<'a, B: Backend + 'a, Tensors: IntoIterator<Item = &'a mut Tensor<B>>> ModuleIO for Tensors {
18 fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError> {
19 save(self.into_iter().map(|x| &*x), path)
20 }
21
22 fn load(self, path: impl AsRef<Path>) -> Result<(), ZyxError> {
23 let targets: Vec<&mut Tensor<B>> = self.into_iter().collect();
24 let dev = targets[0].backend();
25 let tensors = load(dev, path)?;
26 for (x, y) in targets.into_iter().zip(tensors) {
27 *x = y;
28 }
29 Ok(())
30 }
31}
32
33pub fn save<'a, B: Backend + 'a>(
38 tensors: impl IntoIterator<Item = &'a Tensor<B>>,
39 path: impl AsRef<Path>,
40) -> Result<(), ZyxError> {
41 let mut f = File::create(path)?;
42 let mut header = String::from("{");
43 let mut begin = 0;
44 let tensors: Vec<&Tensor<B>> = tensors.into_iter().collect();
45 for tensor in &tensors {
46 let dtype = tensor.dtype();
47 write!(header, "\"{}\":{{", tensor.id()).unwrap();
51 write!(header, "\"dtype\":\"{}\",", dtype.safetensors()).unwrap();
53 write!(header, "\"shape\":{},", tensor.shape().safetensors()).unwrap();
54 let size = tensor.numel() * dtype.byte_size();
55 write!(header, "\"data_offsets\":[{},{}]", begin, begin + size).unwrap();
56 begin += size;
57 write!(header, "}},").unwrap();
58 }
59 header.pop();
60 write!(header, "}}").unwrap();
61 let header_bytes = header.as_bytes();
62 f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
63 f.write_all(header_bytes)?;
64 for tensor in tensors {
65 match tensor.dtype() {
66 DType::F32 => {
67 let vec = tensor.to_vec::<f32>()?;
68 let mut bytes: Vec<u8> = Vec::with_capacity(vec.len() * 4);
69 for x in vec {
70 bytes.extend(x.to_le_bytes());
71 }
72 f.write_all(&bytes)?;
73 }
74 DType::F64 => {
75 let vec = tensor.to_vec::<f64>()?;
76 let mut bytes: Vec<u8> = Vec::with_capacity(vec.len() * 4);
77 for x in vec {
78 bytes.extend(x.to_le_bytes());
79 }
80 f.write_all(&bytes)?;
81 }
82 DType::I32 => {
83 let vec = tensor.to_vec::<i32>().unwrap();
84 let mut bytes: Vec<u8> = Vec::with_capacity(vec.len() * 4);
85 for x in vec {
86 bytes.extend(x.to_le_bytes());
87 }
88 f.write_all(&bytes)?;
89 }
90 };
91 }
92 Ok(())
93}
94
95pub fn load<B: Backend>(dev: B, path: impl AsRef<Path>) -> Result<Vec<Tensor<B>>, ZyxError> {
99 let mut f = File::open(path)?;
100 let mut header_len = [0u8; 8];
101 f.read_exact(&mut header_len)?;
102 let mut header = alloc::vec![0u8; usize::try_from(u64::from_le_bytes(header_len)).unwrap()];
103 f.read_exact(&mut header)?;
104 let header = core::str::from_utf8(&header)
105 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
106 let mut text = alloc::string::String::with_capacity(10);
107 let mut begin_str = false;
108 let mut i = 0;
109 let mut tensors = Vec::new();
110 let mut dtype = DType::F32;
111 let mut shape: Shape = [1].into();
112 for x in header.chars() {
113 if ['"', '[', ']'].contains(&x) {
114 if begin_str {
115 if i % 7 == 0 {
117 } else if i % 7 == 2 {
119 dtype = DType::from_safetensors(&text)?;
120 } else if i % 7 == 4 {
121 shape = Shape::from_safetensors(&text)?;
122 } else if i % 7 == 6 {
123 let offsets = text
126 .split(',')
127 .map(|offset| {
128 offset.parse::<usize>().map_err(|err| {
129 ZyxError::ParseError(alloc::format!(
130 "Could not parse safetensors offset: {err}"
131 ))
132 })
133 })
134 .collect::<Result<Vec<usize>, ZyxError>>()?;
135 if offsets[tensors.len() + 1] != shape.numel() * dtype.byte_size() {
137 return Err(ZyxError::ParseError(
138 "Safetensors shapes and offsets are incorrect.".into(),
139 ));
140 }
141 let mut buf = alloc::vec![0u8; shape.numel()*dtype.byte_size()];
142 f.read_exact(&mut buf)?;
143 tensors.push(match dtype {
144 DType::F32 => {
145 let vec: Vec<f32> = buf
146 .chunks_exact(dtype.byte_size())
147 .map(|x| f32::from_le_bytes([x[0], x[1], x[2], x[3]]))
148 .collect();
149 dev.tensor(vec)?.reshape(&shape)
150 }
151 DType::F64 => {
152 let vec: Vec<f64> = buf
153 .chunks_exact(dtype.byte_size())
154 .map(|x| {
155 f64::from_le_bytes([
156 x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7],
157 ])
158 })
159 .collect();
160 dev.tensor(vec)?.reshape(&shape)
161 }
162 DType::I32 => {
163 let vec: Vec<i32> = buf
164 .chunks_exact(dtype.byte_size())
165 .map(|x| i32::from_le_bytes([x[0], x[1], x[2], x[3]]))
166 .collect();
167 dev.tensor(vec)?.reshape(&shape)
168 }
169 });
170 }
171 i += 1;
172 text.clear();
173 begin_str = false;
174 } else {
175 text.clear();
176 begin_str = true;
177 }
178 } else {
179 text.push(x);
180 }
181 }
182 Ok(tensors)
183}