tch_plus/tensor/
npy.rs

1//! Numpy support for tensors.
2//!
3//! This module implements the support for reading and writing `.npy` and `.npz` files. The file
4//! format spec can be found at:
5//! <https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html>.
6use crate::{Kind, TchError, Tensor};
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::{BufReader, Read, Write};
10use std::path::Path;
11
12const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
13const NPY_SUFFIX: &str = ".npy";
14
15fn read_header<R: Read>(reader: &mut R) -> Result<String, TchError> {
16    let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
17    reader.read_exact(&mut magic_string)?;
18    if magic_string != NPY_MAGIC_STRING {
19        return Err(TchError::FileFormat("magic string mismatch".to_string()));
20    }
21    let mut version = [0u8; 2];
22    reader.read_exact(&mut version)?;
23    let header_len_len = match version[0] {
24        1 => 2,
25        2 => 4,
26        otherwise => return Err(TchError::FileFormat(format!("unsupported version {otherwise}"))),
27    };
28    let mut header_len = vec![0u8; header_len_len];
29    reader.read_exact(&mut header_len)?;
30    let header_len = header_len.iter().rev().fold(0_usize, |acc, &v| 256 * acc + v as usize);
31    let mut header = vec![0u8; header_len];
32    reader.read_exact(&mut header)?;
33    Ok(String::from_utf8_lossy(&header).to_string())
34}
35
36#[derive(Debug, PartialEq)]
37struct Header {
38    descr: Kind,
39    fortran_order: bool,
40    shape: Vec<i64>,
41}
42
43impl Header {
44    fn to_string(&self) -> Result<String, TchError> {
45        let fortran_order = if self.fortran_order { "True" } else { "False" };
46        let mut shape = self.shape.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(",");
47        let descr = match self.descr {
48            Kind::Half => "f2",
49            Kind::Float => "f4",
50            Kind::Double => "f8",
51            Kind::Int => "i4",
52            Kind::Int64 => "i8",
53            Kind::Int16 => "i2",
54            Kind::Int8 => "i1",
55            Kind::Uint8 => "u1",
56            descr => return Err(TchError::FileFormat(format!("unsupported kind {descr:?}"))),
57        };
58        if !shape.is_empty() {
59            shape.push(',')
60        }
61        Ok(format!(
62            "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
63        ))
64    }
65
66    // Hacky parser for the npy header, a typical example would be:
67    // {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
68    fn parse(header: &str) -> Result<Header, TchError> {
69        let header =
70            header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
71
72        let mut parts: Vec<String> = vec![];
73        let mut start_index = 0usize;
74        let mut cnt_parenthesis = 0i64;
75        for (index, c) in header.chars().enumerate() {
76            match c {
77                '(' => cnt_parenthesis += 1,
78                ')' => cnt_parenthesis -= 1,
79                ',' => {
80                    if cnt_parenthesis == 0 {
81                        parts.push(header[start_index..index].to_owned());
82                        start_index = index + 1;
83                    }
84                }
85                _ => {}
86            }
87        }
88        parts.push(header[start_index..].to_owned());
89        let mut part_map: HashMap<String, String> = HashMap::new();
90        for part in parts.iter() {
91            let part = part.trim();
92            if !part.is_empty() {
93                match part.split(':').collect::<Vec<_>>().as_slice() {
94                    [key, value] => {
95                        let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
96                        let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
97                        let _ = part_map.insert(key.to_owned(), value.to_owned());
98                    }
99                    _ => {
100                        return Err(TchError::FileFormat(format!(
101                            "unable to parse header {header}"
102                        )))
103                    }
104                }
105            }
106        }
107        let fortran_order = match part_map.get("fortran_order") {
108            None => false,
109            Some(fortran_order) => match fortran_order.as_ref() {
110                "False" => false,
111                "True" => true,
112                _ => {
113                    return Err(TchError::FileFormat(format!(
114                        "unknown fortran_order {fortran_order}"
115                    )))
116                }
117            },
118        };
119        let descr = match part_map.get("descr") {
120            None => return Err(TchError::FileFormat("no descr in header".to_string())),
121            Some(descr) => {
122                if descr.is_empty() {
123                    return Err(TchError::FileFormat("empty descr".to_string()));
124                }
125                if descr.starts_with('>') {
126                    return Err(TchError::FileFormat(format!("little-endian descr {descr}")));
127                }
128                // the only supported types in tensor are:
129                //     float64, float32, float16,
130                //     complex64, complex128,
131                //     int64, int32, int16, int8,
132                //     uint8, and bool.
133                match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
134                    "e" | "f2" => Kind::Half,
135                    "f" | "f4" => Kind::Float,
136                    "d" | "f8" => Kind::Double,
137                    "i" | "i4" => Kind::Int,
138                    "q" | "i8" => Kind::Int64,
139                    "h" | "i2" => Kind::Int16,
140                    "b" | "i1" => Kind::Int8,
141                    "B" | "u1" => Kind::Uint8,
142                    "?" | "b1" => Kind::Bool,
143                    "F" | "F4" | "c8" => Kind::ComplexFloat,
144                    "D" | "F8" | "c16" => Kind::ComplexDouble,
145                    descr => {
146                        return Err(TchError::FileFormat(format!("unrecognized descr {descr}")))
147                    }
148                }
149            }
150        };
151        let shape = match part_map.get("shape") {
152            None => return Err(TchError::FileFormat("no shape in header".to_string())),
153            Some(shape) => {
154                let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
155                if shape.is_empty() {
156                    vec![]
157                } else {
158                    shape
159                        .split(',')
160                        .map(|v| v.trim().parse::<i64>())
161                        .collect::<Result<Vec<_>, _>>()?
162                }
163            }
164        };
165        Ok(Header { descr, fortran_order, shape })
166    }
167}
168
169impl crate::Tensor {
170    /// Reads a npy file and return the stored tensor.
171    pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
172        let mut reader = File::open(path.as_ref())?;
173        let header = read_header(&mut reader)?;
174        let header = Header::parse(&header)?;
175        if header.fortran_order {
176            return Err(TchError::FileFormat("fortran order not supported".to_string()));
177        }
178        let mut data: Vec<u8> = vec![];
179        reader.read_to_end(&mut data)?;
180        Tensor::f_from_data_size(&data, &header.shape, header.descr)
181    }
182
183    /// Reads a npz file and returns some named tensors.
184    pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
185        let zip_reader = BufReader::new(File::open(path.as_ref())?);
186        let mut zip = zip::ZipArchive::new(zip_reader)?;
187        let mut result = vec![];
188        for i in 0..zip.len() {
189            let mut reader = zip.by_index(i).unwrap();
190            let name = {
191                let name = reader.name();
192                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
193            };
194            let header = read_header(&mut reader)?;
195            let header = Header::parse(&header)?;
196            if header.fortran_order {
197                return Err(TchError::FileFormat("fortran order not supported".to_string()));
198            }
199            let mut data: Vec<u8> = vec![];
200            reader.read_to_end(&mut data)?;
201            let tensor = Tensor::f_from_data_size(&data, &header.shape, header.descr)?;
202            result.push((name, tensor))
203        }
204        Ok(result)
205    }
206
207    fn write<T: Write>(&self, f: &mut T) -> Result<(), TchError> {
208        f.write_all(NPY_MAGIC_STRING)?;
209        f.write_all(&[1u8, 0u8])?;
210        let kind = self.f_kind()?;
211        let header = Header { descr: kind, fortran_order: false, shape: self.size() };
212        let mut header = header.to_string()?;
213        let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
214        for _ in 0..pad % 16 {
215            header.push(' ')
216        }
217        header.push('\n');
218        f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
219        f.write_all(header.as_bytes())?;
220        let numel = self.numel();
221        let mut content = vec![0u8; numel * kind.elt_size_in_bytes()];
222        self.f_copy_data_u8(&mut content, numel)?;
223        f.write_all(&content)?;
224        Ok(())
225    }
226
227    /// Writes a tensor in the npy format so that it can be read using python.
228    pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
229        let mut f = File::create(path.as_ref())?;
230        self.write(&mut f)
231    }
232
233    pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
234        ts: &[(S, T)],
235        path: P,
236    ) -> Result<(), TchError> {
237        let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
238        let options =
239            zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
240
241        for (name, tensor) in ts.iter() {
242            zip.start_file(format!("{}.npy", name.as_ref()), options)?;
243            tensor.as_ref().write(&mut zip)?
244        }
245        Ok(())
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::Header;
252
253    #[test]
254    fn parse() {
255        let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
256        assert_eq!(
257            Header::parse(h).unwrap(),
258            Header { descr: crate::Kind::Double, fortran_order: false, shape: vec![128] }
259        );
260        let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
261        let h = Header::parse(h).unwrap();
262        assert_eq!(
263            h,
264            Header { descr: crate::Kind::Float, fortran_order: true, shape: vec![256, 1, 128] }
265        );
266        assert_eq!(
267            h.to_string().unwrap(),
268            "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
269        );
270
271        let h = Header { descr: crate::Kind::Int64, fortran_order: false, shape: vec![] };
272        assert_eq!(
273            h.to_string().unwrap(),
274            "{'descr': '<i8', 'fortran_order': False, 'shape': (), }"
275        );
276    }
277}