1use crate::{Kind, TchError, Tensor};
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::{BufReader, Read, Write};
10use std::path::Path;
11
12const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
13const NPY_SUFFIX: &str = ".npy";
14
15fn read_header<R: Read>(reader: &mut R) -> Result<String, TchError> {
16 let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
17 reader.read_exact(&mut magic_string)?;
18 if magic_string != NPY_MAGIC_STRING {
19 return Err(TchError::FileFormat("magic string mismatch".to_string()));
20 }
21 let mut version = [0u8; 2];
22 reader.read_exact(&mut version)?;
23 let header_len_len = match version[0] {
24 1 => 2,
25 2 => 4,
26 otherwise => return Err(TchError::FileFormat(format!("unsupported version {otherwise}"))),
27 };
28 let mut header_len = vec![0u8; header_len_len];
29 reader.read_exact(&mut header_len)?;
30 let header_len = header_len.iter().rev().fold(0_usize, |acc, &v| 256 * acc + v as usize);
31 let mut header = vec![0u8; header_len];
32 reader.read_exact(&mut header)?;
33 Ok(String::from_utf8_lossy(&header).to_string())
34}
35
36#[derive(Debug, PartialEq)]
37struct Header {
38 descr: Kind,
39 fortran_order: bool,
40 shape: Vec<i64>,
41}
42
43impl Header {
44 fn to_string(&self) -> Result<String, TchError> {
45 let fortran_order = if self.fortran_order { "True" } else { "False" };
46 let mut shape = self.shape.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(",");
47 let descr = match self.descr {
48 Kind::Half => "f2",
49 Kind::Float => "f4",
50 Kind::Double => "f8",
51 Kind::Int => "i4",
52 Kind::Int64 => "i8",
53 Kind::Int16 => "i2",
54 Kind::Int8 => "i1",
55 Kind::Uint8 => "u1",
56 descr => return Err(TchError::FileFormat(format!("unsupported kind {descr:?}"))),
57 };
58 if !shape.is_empty() {
59 shape.push(',')
60 }
61 Ok(format!(
62 "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
63 ))
64 }
65
66 fn parse(header: &str) -> Result<Header, TchError> {
69 let header =
70 header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
71
72 let mut parts: Vec<String> = vec![];
73 let mut start_index = 0usize;
74 let mut cnt_parenthesis = 0i64;
75 for (index, c) in header.chars().enumerate() {
76 match c {
77 '(' => cnt_parenthesis += 1,
78 ')' => cnt_parenthesis -= 1,
79 ',' => {
80 if cnt_parenthesis == 0 {
81 parts.push(header[start_index..index].to_owned());
82 start_index = index + 1;
83 }
84 }
85 _ => {}
86 }
87 }
88 parts.push(header[start_index..].to_owned());
89 let mut part_map: HashMap<String, String> = HashMap::new();
90 for part in parts.iter() {
91 let part = part.trim();
92 if !part.is_empty() {
93 match part.split(':').collect::<Vec<_>>().as_slice() {
94 [key, value] => {
95 let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
96 let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
97 let _ = part_map.insert(key.to_owned(), value.to_owned());
98 }
99 _ => {
100 return Err(TchError::FileFormat(format!(
101 "unable to parse header {header}"
102 )))
103 }
104 }
105 }
106 }
107 let fortran_order = match part_map.get("fortran_order") {
108 None => false,
109 Some(fortran_order) => match fortran_order.as_ref() {
110 "False" => false,
111 "True" => true,
112 _ => {
113 return Err(TchError::FileFormat(format!(
114 "unknown fortran_order {fortran_order}"
115 )))
116 }
117 },
118 };
119 let descr = match part_map.get("descr") {
120 None => return Err(TchError::FileFormat("no descr in header".to_string())),
121 Some(descr) => {
122 if descr.is_empty() {
123 return Err(TchError::FileFormat("empty descr".to_string()));
124 }
125 if descr.starts_with('>') {
126 return Err(TchError::FileFormat(format!("little-endian descr {descr}")));
127 }
128 match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
134 "e" | "f2" => Kind::Half,
135 "f" | "f4" => Kind::Float,
136 "d" | "f8" => Kind::Double,
137 "i" | "i4" => Kind::Int,
138 "q" | "i8" => Kind::Int64,
139 "h" | "i2" => Kind::Int16,
140 "b" | "i1" => Kind::Int8,
141 "B" | "u1" => Kind::Uint8,
142 "?" | "b1" => Kind::Bool,
143 "F" | "F4" | "c8" => Kind::ComplexFloat,
144 "D" | "F8" | "c16" => Kind::ComplexDouble,
145 descr => {
146 return Err(TchError::FileFormat(format!("unrecognized descr {descr}")))
147 }
148 }
149 }
150 };
151 let shape = match part_map.get("shape") {
152 None => return Err(TchError::FileFormat("no shape in header".to_string())),
153 Some(shape) => {
154 let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
155 if shape.is_empty() {
156 vec![]
157 } else {
158 shape
159 .split(',')
160 .map(|v| v.trim().parse::<i64>())
161 .collect::<Result<Vec<_>, _>>()?
162 }
163 }
164 };
165 Ok(Header { descr, fortran_order, shape })
166 }
167}
168
169impl crate::Tensor {
170 pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
172 let mut reader = File::open(path.as_ref())?;
173 let header = read_header(&mut reader)?;
174 let header = Header::parse(&header)?;
175 if header.fortran_order {
176 return Err(TchError::FileFormat("fortran order not supported".to_string()));
177 }
178 let mut data: Vec<u8> = vec![];
179 reader.read_to_end(&mut data)?;
180 Tensor::f_from_data_size(&data, &header.shape, header.descr)
181 }
182
183 pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
185 let zip_reader = BufReader::new(File::open(path.as_ref())?);
186 let mut zip = zip::ZipArchive::new(zip_reader)?;
187 let mut result = vec![];
188 for i in 0..zip.len() {
189 let mut reader = zip.by_index(i).unwrap();
190 let name = {
191 let name = reader.name();
192 name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
193 };
194 let header = read_header(&mut reader)?;
195 let header = Header::parse(&header)?;
196 if header.fortran_order {
197 return Err(TchError::FileFormat("fortran order not supported".to_string()));
198 }
199 let mut data: Vec<u8> = vec![];
200 reader.read_to_end(&mut data)?;
201 let tensor = Tensor::f_from_data_size(&data, &header.shape, header.descr)?;
202 result.push((name, tensor))
203 }
204 Ok(result)
205 }
206
207 fn write<T: Write>(&self, f: &mut T) -> Result<(), TchError> {
208 f.write_all(NPY_MAGIC_STRING)?;
209 f.write_all(&[1u8, 0u8])?;
210 let kind = self.f_kind()?;
211 let header = Header { descr: kind, fortran_order: false, shape: self.size() };
212 let mut header = header.to_string()?;
213 let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
214 for _ in 0..pad % 16 {
215 header.push(' ')
216 }
217 header.push('\n');
218 f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
219 f.write_all(header.as_bytes())?;
220 let numel = self.numel();
221 let mut content = vec![0u8; numel * kind.elt_size_in_bytes()];
222 self.f_copy_data_u8(&mut content, numel)?;
223 f.write_all(&content)?;
224 Ok(())
225 }
226
227 pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
229 let mut f = File::create(path.as_ref())?;
230 self.write(&mut f)
231 }
232
233 pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
234 ts: &[(S, T)],
235 path: P,
236 ) -> Result<(), TchError> {
237 let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
238 let options =
239 zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
240
241 for (name, tensor) in ts.iter() {
242 zip.start_file(format!("{}.npy", name.as_ref()), options)?;
243 tensor.as_ref().write(&mut zip)?
244 }
245 Ok(())
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::Header;
252
253 #[test]
254 fn parse() {
255 let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
256 assert_eq!(
257 Header::parse(h).unwrap(),
258 Header { descr: crate::Kind::Double, fortran_order: false, shape: vec![128] }
259 );
260 let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
261 let h = Header::parse(h).unwrap();
262 assert_eq!(
263 h,
264 Header { descr: crate::Kind::Float, fortran_order: true, shape: vec![256, 1, 128] }
265 );
266 assert_eq!(
267 h.to_string().unwrap(),
268 "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
269 );
270
271 let h = Header { descr: crate::Kind::Int64, fortran_order: false, shape: vec![] };
272 assert_eq!(
273 h.to_string().unwrap(),
274 "{'descr': '<i8', 'fortran_order': False, 'shape': (), }"
275 );
276 }
277}