tch_plus/tensor/
safetensors.rs1use crate::nn::VarStore;
6use crate::{Kind, TchError, Tensor};
7
8use std::convert::{TryFrom, TryInto};
9use std::path::Path;
10
11use safetensors::tensor::{Dtype, SafeTensors, TensorView, View};
12
13impl TryFrom<Kind> for Dtype {
14 type Error = TchError;
15 fn try_from(kind: Kind) -> Result<Self, Self::Error> {
16 let dtype = match kind {
17 Kind::Bool => Dtype::BOOL,
18 Kind::Uint8 => Dtype::U8,
19 Kind::Int8 => Dtype::I8,
20 Kind::Int16 => Dtype::I16,
21 Kind::Int => Dtype::I32,
22 Kind::Int64 => Dtype::I64,
23 Kind::BFloat16 => Dtype::BF16,
24 Kind::Half => Dtype::F16,
25 Kind::Float => Dtype::F32,
26 Kind::Double => Dtype::F64,
27 kind => return Err(TchError::Convert(format!("unsupported kind ({kind:?})"))),
28 };
29 Ok(dtype)
30 }
31}
32
33impl TryFrom<Dtype> for Kind {
34 type Error = TchError;
35 fn try_from(dtype: Dtype) -> Result<Self, Self::Error> {
36 let kind = match dtype {
37 Dtype::BOOL => Kind::Bool,
38 Dtype::U8 => Kind::Uint8,
39 Dtype::I8 => Kind::Int8,
40 Dtype::I16 => Kind::Int16,
41 Dtype::I32 => Kind::Int,
42 Dtype::I64 => Kind::Int64,
43 Dtype::BF16 => Kind::BFloat16,
44 Dtype::F16 => Kind::Half,
45 Dtype::F32 => Kind::Float,
46 Dtype::F64 => Kind::Double,
47 dtype => return Err(TchError::Convert(format!("unsupported dtype {dtype:?}"))),
48 };
49 Ok(kind)
50 }
51}
52
53impl<'a> TryFrom<TensorView<'a>> for Tensor {
54 type Error = TchError;
55 fn try_from(view: TensorView<'a>) -> Result<Self, Self::Error> {
56 let size: Vec<i64> = view.shape().iter().map(|&x| x as i64).collect();
57 let kind: Kind = view.dtype().try_into()?;
58 Tensor::f_from_data_size(view.data(), &size, kind)
59 }
60}
61
62struct SafeView<'a> {
63 tensor: &'a Tensor,
64 shape: Vec<usize>,
65 dtype: Dtype,
66}
67
68impl<'a> TryFrom<&'a Tensor> for SafeView<'a> {
69 type Error = TchError;
70
71 fn try_from(tensor: &'a Tensor) -> Result<Self, Self::Error> {
72 if tensor.is_sparse() {
73 return Err(TchError::Convert("Cannot save sparse tensors".to_string()));
74 }
75
76 if !tensor.is_contiguous() {
77 return Err(TchError::Convert("Cannot save non contiguous tensors".to_string()));
78 }
79
80 let dtype = tensor.kind().try_into()?;
81 let shape = tensor.size().iter().map(|&x| x as usize).collect();
82 Ok(Self { tensor, shape, dtype })
83 }
84}
85
86impl View for SafeView<'_> {
87 fn dtype(&self) -> Dtype {
88 self.dtype
89 }
90 fn shape(&self) -> &[usize] {
91 &self.shape
92 }
93
94 fn data(&self) -> std::borrow::Cow<[u8]> {
95 let mut data = vec![0; self.data_len()];
96 let numel = self.tensor.numel();
97 self.tensor.f_copy_data_u8(&mut data, numel).unwrap();
98 data.into()
99 }
100
101 fn data_len(&self) -> usize {
102 self.tensor.numel() * self.tensor.kind().elt_size_in_bytes()
103 }
104}
105
106fn wrap_err<P: AsRef<Path>>(path: P, err: safetensors::SafeTensorError) -> TchError {
107 TchError::SafeTensorError { path: path.as_ref().to_string_lossy().to_string(), err }
108}
109
110impl crate::Tensor {
111 pub fn read_safetensors<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
113 let file = std::fs::read(&path).map_err(|e| wrap_err(&path, e.into()))?;
114 let safetensors = SafeTensors::deserialize(&file).map_err(|e| wrap_err(&path, e))?;
115 safetensors.tensors().into_iter().map(|(name, view)| Ok((name, view.try_into()?))).collect()
116 }
117
118 pub fn write_safetensors<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
120 tensors: &[(S, T)],
121 path: P,
122 ) -> Result<(), TchError> {
123 let views = tensors
124 .iter()
125 .map(|(name, tensor)| {
126 Ok::<(&str, SafeView), TchError>((name.as_ref(), tensor.as_ref().try_into()?))
127 })
128 .collect::<Result<Vec<_>, _>>()?;
129 safetensors::tensor::serialize_to_file(views, &None, path.as_ref())
130 .map_err(|e| wrap_err(path, e))?;
131 Ok(())
132 }
133}
134
135impl VarStore {
136 pub fn read_safetensors<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
138 let file = std::fs::read(&path).map_err(|e| wrap_err(&path, e.into()))?;
139 let safetensors = SafeTensors::deserialize(&file).map_err(|e| wrap_err(&path, e))?;
140 for (name, tensor) in self.variables_.lock().unwrap().named_variables.iter_mut() {
141 let view = safetensors.tensor(name).map_err(|e| wrap_err(&path, e))?;
142 let data: Tensor = view.try_into()?;
143 tensor.f_copy_(&data)?
144 }
145 Ok(())
146 }
147
148 pub fn fill_safetensors<P: AsRef<Path>>(&self, path: P) -> Result<(), TchError> {
149 for (name, tensor) in Tensor::read_safetensors(path)? {
150 if let Some(s) = self.variables_.lock().unwrap().named_variables.get_mut(&name) {
151 s.f_copy_(&tensor)?
152 }
153 }
154 Ok(())
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use std::convert::TryInto;
161
162 use crate::Kind;
163 use safetensors::Dtype;
164
165 #[test]
166 fn parse() {
167 assert_eq!(TryInto::<Dtype>::try_into(Kind::Double).unwrap(), Dtype::F64);
169 assert_eq!(TryInto::<Dtype>::try_into(Kind::Float).unwrap(), Dtype::F32);
170 assert_eq!(TryInto::<Dtype>::try_into(Kind::Half).unwrap(), Dtype::F16);
171
172 assert_eq!(TryInto::<Dtype>::try_into(Kind::Int8).unwrap(), Dtype::I8);
173 assert_eq!(TryInto::<Dtype>::try_into(Kind::Uint8).unwrap(), Dtype::U8);
174
175 assert_eq!(TryInto::<Kind>::try_into(Dtype::F64).unwrap(), Kind::Double);
177 assert_eq!(TryInto::<Kind>::try_into(Dtype::F32).unwrap(), Kind::Float);
178 assert_eq!(TryInto::<Kind>::try_into(Dtype::F16).unwrap(), Kind::Half);
179
180 assert_eq!(TryInto::<Kind>::try_into(Dtype::I8).unwrap(), Kind::Int8);
181 assert_eq!(TryInto::<Kind>::try_into(Dtype::U8).unwrap(), Kind::Uint8);
182 }
183}