tch_plus/tensor/
safetensors.rs

1//! Safetensors support for tensors.
2//!
3//! This module implements reading and writing tensors in the `.safetensors` format.
4//! <https://github.com/huggingface/safetensors>
5use crate::nn::VarStore;
6use crate::{Kind, TchError, Tensor};
7
8use std::convert::{TryFrom, TryInto};
9use std::path::Path;
10
11use safetensors::tensor::{Dtype, SafeTensors, TensorView, View};
12
13impl TryFrom<Kind> for Dtype {
14    type Error = TchError;
15    fn try_from(kind: Kind) -> Result<Self, Self::Error> {
16        let dtype = match kind {
17            Kind::Bool => Dtype::BOOL,
18            Kind::Uint8 => Dtype::U8,
19            Kind::Int8 => Dtype::I8,
20            Kind::Int16 => Dtype::I16,
21            Kind::Int => Dtype::I32,
22            Kind::Int64 => Dtype::I64,
23            Kind::BFloat16 => Dtype::BF16,
24            Kind::Half => Dtype::F16,
25            Kind::Float => Dtype::F32,
26            Kind::Double => Dtype::F64,
27            kind => return Err(TchError::Convert(format!("unsupported kind ({kind:?})"))),
28        };
29        Ok(dtype)
30    }
31}
32
33impl TryFrom<Dtype> for Kind {
34    type Error = TchError;
35    fn try_from(dtype: Dtype) -> Result<Self, Self::Error> {
36        let kind = match dtype {
37            Dtype::BOOL => Kind::Bool,
38            Dtype::U8 => Kind::Uint8,
39            Dtype::I8 => Kind::Int8,
40            Dtype::I16 => Kind::Int16,
41            Dtype::I32 => Kind::Int,
42            Dtype::I64 => Kind::Int64,
43            Dtype::BF16 => Kind::BFloat16,
44            Dtype::F16 => Kind::Half,
45            Dtype::F32 => Kind::Float,
46            Dtype::F64 => Kind::Double,
47            dtype => return Err(TchError::Convert(format!("unsupported dtype {dtype:?}"))),
48        };
49        Ok(kind)
50    }
51}
52
53impl<'a> TryFrom<TensorView<'a>> for Tensor {
54    type Error = TchError;
55    fn try_from(view: TensorView<'a>) -> Result<Self, Self::Error> {
56        let size: Vec<i64> = view.shape().iter().map(|&x| x as i64).collect();
57        let kind: Kind = view.dtype().try_into()?;
58        Tensor::f_from_data_size(view.data(), &size, kind)
59    }
60}
61
62struct SafeView<'a> {
63    tensor: &'a Tensor,
64    shape: Vec<usize>,
65    dtype: Dtype,
66}
67
68impl<'a> TryFrom<&'a Tensor> for SafeView<'a> {
69    type Error = TchError;
70
71    fn try_from(tensor: &'a Tensor) -> Result<Self, Self::Error> {
72        if tensor.is_sparse() {
73            return Err(TchError::Convert("Cannot save sparse tensors".to_string()));
74        }
75
76        if !tensor.is_contiguous() {
77            return Err(TchError::Convert("Cannot save non contiguous tensors".to_string()));
78        }
79
80        let dtype = tensor.kind().try_into()?;
81        let shape = tensor.size().iter().map(|&x| x as usize).collect();
82        Ok(Self { tensor, shape, dtype })
83    }
84}
85
86impl View for SafeView<'_> {
87    fn dtype(&self) -> Dtype {
88        self.dtype
89    }
90    fn shape(&self) -> &[usize] {
91        &self.shape
92    }
93
94    fn data(&self) -> std::borrow::Cow<[u8]> {
95        let mut data = vec![0; self.data_len()];
96        let numel = self.tensor.numel();
97        self.tensor.f_copy_data_u8(&mut data, numel).unwrap();
98        data.into()
99    }
100
101    fn data_len(&self) -> usize {
102        self.tensor.numel() * self.tensor.kind().elt_size_in_bytes()
103    }
104}
105
106fn wrap_err<P: AsRef<Path>>(path: P, err: safetensors::SafeTensorError) -> TchError {
107    TchError::SafeTensorError { path: path.as_ref().to_string_lossy().to_string(), err }
108}
109
110impl crate::Tensor {
111    /// Reads a safetensors file and returns some named tensors.
112    pub fn read_safetensors<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
113        let file = std::fs::read(&path).map_err(|e| wrap_err(&path, e.into()))?;
114        let safetensors = SafeTensors::deserialize(&file).map_err(|e| wrap_err(&path, e))?;
115        safetensors.tensors().into_iter().map(|(name, view)| Ok((name, view.try_into()?))).collect()
116    }
117
118    /// Writes a tensor in the safetensors format.
119    pub fn write_safetensors<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
120        tensors: &[(S, T)],
121        path: P,
122    ) -> Result<(), TchError> {
123        let views = tensors
124            .iter()
125            .map(|(name, tensor)| {
126                Ok::<(&str, SafeView), TchError>((name.as_ref(), tensor.as_ref().try_into()?))
127            })
128            .collect::<Result<Vec<_>, _>>()?;
129        safetensors::tensor::serialize_to_file(views, &None, path.as_ref())
130            .map_err(|e| wrap_err(path, e))?;
131        Ok(())
132    }
133}
134
135impl VarStore {
136    /// Read data from safe tensor file, missing tensors will raise a error.
137    pub fn read_safetensors<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
138        let file = std::fs::read(&path).map_err(|e| wrap_err(&path, e.into()))?;
139        let safetensors = SafeTensors::deserialize(&file).map_err(|e| wrap_err(&path, e))?;
140        for (name, tensor) in self.variables_.lock().unwrap().named_variables.iter_mut() {
141            let view = safetensors.tensor(name).map_err(|e| wrap_err(&path, e))?;
142            let data: Tensor = view.try_into()?;
143            tensor.f_copy_(&data)?
144        }
145        Ok(())
146    }
147
148    pub fn fill_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<(), TchError> {
149        for (name, tensor) in Tensor::read_safetensors(path)? {
150            if let Some(s) = self.variables_.lock().unwrap().named_variables.get_mut(&name) {
151                s.f_copy_(&tensor)?
152            }
153        }
154        Ok(())
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use std::convert::TryInto;
161
162    use crate::Kind;
163    use safetensors::Dtype;
164
165    #[test]
166    fn parse() {
167        // From Kind to Dtype
168        assert_eq!(TryInto::<Dtype>::try_into(Kind::Double).unwrap(), Dtype::F64);
169        assert_eq!(TryInto::<Dtype>::try_into(Kind::Float).unwrap(), Dtype::F32);
170        assert_eq!(TryInto::<Dtype>::try_into(Kind::Half).unwrap(), Dtype::F16);
171
172        assert_eq!(TryInto::<Dtype>::try_into(Kind::Int8).unwrap(), Dtype::I8);
173        assert_eq!(TryInto::<Dtype>::try_into(Kind::Uint8).unwrap(), Dtype::U8);
174
175        // From Dtype to Kind
176        assert_eq!(TryInto::<Kind>::try_into(Dtype::F64).unwrap(), Kind::Double);
177        assert_eq!(TryInto::<Kind>::try_into(Dtype::F32).unwrap(), Kind::Float);
178        assert_eq!(TryInto::<Kind>::try_into(Dtype::F16).unwrap(), Kind::Half);
179
180        assert_eq!(TryInto::<Kind>::try_into(Dtype::I8).unwrap(), Kind::Int8);
181        assert_eq!(TryInto::<Kind>::try_into(Dtype::U8).unwrap(), Kind::Uint8);
182    }
183}