zyx_core/
io.rs

1use crate::{backend::Backend, dtype::DType, error::ZyxError, shape::Shape, tensor::Tensor};
2use alloc::{string::String, vec::Vec};
3use core::fmt::Write as CoreFmtWrite;
4use std::fs::File;
5use std::io::{Read, Write};
6use std::path::Path;
7
8/// This trait is implemented automatically for all modules that implement
9/// IntoIterator<Item = &mut Tensor>
10pub trait ModuleIO {
11    /// Save self into path
12    fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError>;
13    /// Load self from path
14    fn load(self, path: impl AsRef<Path>) -> Result<(), ZyxError>;
15}
16
17impl<'a, B: Backend + 'a, Tensors: IntoIterator<Item = &'a mut Tensor<B>>> ModuleIO for Tensors {
18    fn save(self, path: impl AsRef<Path>) -> Result<(), ZyxError> {
19        save(self.into_iter().map(|x| &*x), path)
20    }
21
22    fn load(self, path: impl AsRef<Path>) -> Result<(), ZyxError> {
23        let targets: Vec<&mut Tensor<B>> = self.into_iter().collect();
24        let dev = targets[0].backend();
25        let tensors = load(dev, path)?;
26        for (x, y) in targets.into_iter().zip(tensors) {
27            *x = y;
28        }
29        Ok(())
30    }
31}
32
33/// Save all tensors into file.
34/// All parameters must be realized before calling this function, otherwise it will panic.
35/// # Errors
36/// Returns io erorr if there was problem writing file to filesystem.
37pub fn save<'a, B: Backend + 'a>(
38    tensors: impl IntoIterator<Item = &'a Tensor<B>>,
39    path: impl AsRef<Path>,
40) -> Result<(), ZyxError> {
41    let mut f = File::create(path)?;
42    let mut header = String::from("{");
43    let mut begin = 0;
44    let tensors: Vec<&Tensor<B>> = tensors.into_iter().collect();
45    for tensor in &tensors {
46        let dtype = tensor.dtype();
47        //if let Some(label) = tensor.label() {
48        //write!(header, "\"{label}\":{{").unwrap();
49        //} else {
50        write!(header, "\"{}\":{{", tensor.id()).unwrap();
51        //}
52        write!(header, "\"dtype\":\"{}\",", dtype.safetensors()).unwrap();
53        write!(header, "\"shape\":{},", tensor.shape().safetensors()).unwrap();
54        let size = tensor.numel() * dtype.byte_size();
55        write!(header, "\"data_offsets\":[{},{}]", begin, begin + size).unwrap();
56        begin += size;
57        write!(header, "}},").unwrap();
58    }
59    header.pop();
60    write!(header, "}}").unwrap();
61    let header_bytes = header.as_bytes();
62    f.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
63    f.write_all(header_bytes)?;
64    for tensor in tensors {
65        match tensor.dtype() {
66            DType::F32 => {
67                let vec = tensor.to_vec::<f32>()?;
68                let mut bytes: Vec<u8> = Vec::with_capacity(vec.len() * 4);
69                for x in vec {
70                    bytes.extend(x.to_le_bytes());
71                }
72                f.write_all(&bytes)?;
73            }
74            DType::F64 => {
75                let vec = tensor.to_vec::<f64>()?;
76                let mut bytes: Vec<u8> = Vec::with_capacity(vec.len() * 4);
77                for x in vec {
78                    bytes.extend(x.to_le_bytes());
79                }
80                f.write_all(&bytes)?;
81            }
82            DType::I32 => {
83                let vec = tensor.to_vec::<i32>().unwrap();
84                let mut bytes: Vec<u8> = Vec::with_capacity(vec.len() * 4);
85                for x in vec {
86                    bytes.extend(x.to_le_bytes());
87                }
88                f.write_all(&bytes)?;
89            }
90        };
91    }
92    Ok(())
93}
94
95/// Load all parameters from file
96/// # Errors
97/// Returns io error if there was io erorr or parsing error.
98pub fn load<B: Backend>(dev: B, path: impl AsRef<Path>) -> Result<Vec<Tensor<B>>, ZyxError> {
99    let mut f = File::open(path)?;
100    let mut header_len = [0u8; 8];
101    f.read_exact(&mut header_len)?;
102    let mut header = alloc::vec![0u8; usize::try_from(u64::from_le_bytes(header_len)).unwrap()];
103    f.read_exact(&mut header)?;
104    let header = core::str::from_utf8(&header)
105        .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
106    let mut text = alloc::string::String::with_capacity(10);
107    let mut begin_str = false;
108    let mut i = 0;
109    let mut tensors = Vec::new();
110    let mut dtype = DType::F32;
111    let mut shape: Shape = [1].into();
112    for x in header.chars() {
113        if ['"', '[', ']'].contains(&x) {
114            if begin_str {
115                //std::println!("{text}");
116                if i % 7 == 0 {
117                    //params[i / 7].set_label(&text);
118                } else if i % 7 == 2 {
119                    dtype = DType::from_safetensors(&text)?;
120                } else if i % 7 == 4 {
121                    shape = Shape::from_safetensors(&text)?;
122                } else if i % 7 == 6 {
123                    // TODO assert offsets
124                    //std::println!("Offsets: {text}");
125                    let offsets = text
126                        .split(',')
127                        .map(|offset| {
128                            offset.parse::<usize>().map_err(|err| {
129                                ZyxError::ParseError(alloc::format!(
130                                    "Could not parse safetensors offset: {err}"
131                                ))
132                            })
133                        })
134                        .collect::<Result<Vec<usize>, ZyxError>>()?;
135                    //std::println!("Offsets: {offsets:?}");
136                    if offsets[tensors.len() + 1] != shape.numel() * dtype.byte_size() {
137                        return Err(ZyxError::ParseError(
138                            "Safetensors shapes and offsets are incorrect.".into(),
139                        ));
140                    }
141                    let mut buf = alloc::vec![0u8; shape.numel()*dtype.byte_size()];
142                    f.read_exact(&mut buf)?;
143                    tensors.push(match dtype {
144                        DType::F32 => {
145                            let vec: Vec<f32> = buf
146                                .chunks_exact(dtype.byte_size())
147                                .map(|x| f32::from_le_bytes([x[0], x[1], x[2], x[3]]))
148                                .collect();
149                            dev.tensor(vec)?.reshape(&shape)
150                        }
151                        DType::F64 => {
152                            let vec: Vec<f64> = buf
153                                .chunks_exact(dtype.byte_size())
154                                .map(|x| {
155                                    f64::from_le_bytes([
156                                        x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7],
157                                    ])
158                                })
159                                .collect();
160                            dev.tensor(vec)?.reshape(&shape)
161                        }
162                        DType::I32 => {
163                            let vec: Vec<i32> = buf
164                                .chunks_exact(dtype.byte_size())
165                                .map(|x| i32::from_le_bytes([x[0], x[1], x[2], x[3]]))
166                                .collect();
167                            dev.tensor(vec)?.reshape(&shape)
168                        }
169                    });
170                }
171                i += 1;
172                text.clear();
173                begin_str = false;
174            } else {
175                text.clear();
176                begin_str = true;
177            }
178        } else {
179            text.push(x);
180        }
181    }
182    Ok(tensors)
183}