Skip to main content

scirs2_io/npy/
types.rs

1//! Types and utilities for NumPy binary format support.
2
3use crate::error::{IoError, Result};
4
5/// NumPy magic string
6pub const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
7
8/// NPY format major version
9pub const NPY_MAJOR_VERSION: u8 = 1;
10
11/// NPY format minor version
12pub const NPY_MINOR_VERSION: u8 = 0;
13
14/// NumPy data type descriptor
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum NpyDtype {
17    /// 32-bit floating point
18    Float32,
19    /// 64-bit floating point
20    Float64,
21    /// 32-bit signed integer
22    Int32,
23    /// 64-bit signed integer
24    Int64,
25}
26
27impl NpyDtype {
28    /// Size in bytes of one element
29    pub fn element_size(&self) -> usize {
30        match self {
31            NpyDtype::Float32 => 4,
32            NpyDtype::Float64 => 8,
33            NpyDtype::Int32 => 4,
34            NpyDtype::Int64 => 8,
35        }
36    }
37
38    /// NumPy dtype string for little-endian
39    pub fn npy_str_le(&self) -> &'static str {
40        match self {
41            NpyDtype::Float32 => "<f4",
42            NpyDtype::Float64 => "<f8",
43            NpyDtype::Int32 => "<i4",
44            NpyDtype::Int64 => "<i8",
45        }
46    }
47
48    /// NumPy dtype string for big-endian
49    pub fn npy_str_be(&self) -> &'static str {
50        match self {
51            NpyDtype::Float32 => ">f4",
52            NpyDtype::Float64 => ">f8",
53            NpyDtype::Int32 => ">i4",
54            NpyDtype::Int64 => ">i8",
55        }
56    }
57
58    /// Parse dtype from NumPy descriptor string
59    pub fn from_descr(descr: &str) -> Result<(Self, ByteOrder)> {
60        let descr = descr.trim().trim_matches('\'').trim_matches('"');
61        if descr.len() < 3 {
62            return Err(IoError::FormatError(format!(
63                "Invalid dtype descriptor: '{}'",
64                descr
65            )));
66        }
67
68        let endian_char = descr.as_bytes()[0];
69        let type_char = descr.as_bytes()[1];
70        let size_str = &descr[2..];
71
72        let byte_order = match endian_char {
73            b'<' | b'=' => ByteOrder::LittleEndian,
74            b'>' => ByteOrder::BigEndian,
75            b'|' => ByteOrder::NotApplicable,
76            _ => {
77                return Err(IoError::FormatError(format!(
78                    "Unknown endian prefix: '{}'",
79                    endian_char as char
80                )))
81            }
82        };
83
84        let size: usize = size_str
85            .parse()
86            .map_err(|_| IoError::FormatError(format!("Invalid dtype size: '{}'", size_str)))?;
87
88        let dtype = match (type_char, size) {
89            (b'f', 4) => NpyDtype::Float32,
90            (b'f', 8) => NpyDtype::Float64,
91            (b'i', 4) => NpyDtype::Int32,
92            (b'i', 8) => NpyDtype::Int64,
93            _ => {
94                return Err(IoError::FormatError(format!(
95                    "Unsupported dtype: type='{}', size={}",
96                    type_char as char, size
97                )))
98            }
99        };
100
101        Ok((dtype, byte_order))
102    }
103}
104
105/// Byte ordering
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum ByteOrder {
108    /// Little-endian
109    LittleEndian,
110    /// Big-endian
111    BigEndian,
112    /// Not applicable (single-byte types)
113    NotApplicable,
114}
115
116/// Parsed header from a .npy file
117#[derive(Debug, Clone)]
118pub struct NpyHeader {
119    /// Data type
120    pub dtype: NpyDtype,
121    /// Byte order
122    pub byte_order: ByteOrder,
123    /// Whether the data is in Fortran (column-major) order
124    pub fortran_order: bool,
125    /// Shape of the array
126    pub shape: Vec<usize>,
127}
128
129impl NpyHeader {
130    /// Total number of elements
131    pub fn num_elements(&self) -> usize {
132        self.shape.iter().product()
133    }
134
135    /// Serialize as NumPy header dict string
136    pub fn to_header_string(&self) -> String {
137        let descr = if cfg!(target_endian = "little") {
138            self.dtype.npy_str_le()
139        } else {
140            self.dtype.npy_str_be()
141        };
142
143        let fortran_str = if self.fortran_order { "True" } else { "False" };
144
145        let shape_str = if self.shape.len() == 1 {
146            format!("({},)", self.shape[0])
147        } else {
148            let parts: Vec<String> = self.shape.iter().map(|s| s.to_string()).collect();
149            format!("({})", parts.join(", "))
150        };
151
152        format!(
153            "{{'descr': '{}', 'fortran_order': {}, 'shape': {}, }}",
154            descr, fortran_str, shape_str
155        )
156    }
157}
158
159/// Data read from a .npy file
160#[derive(Debug, Clone)]
161pub enum NpyArray {
162    /// f32 data with shape
163    Float32 {
164        /// Flat data
165        data: Vec<f32>,
166        /// Shape
167        shape: Vec<usize>,
168    },
169    /// f64 data with shape
170    Float64 {
171        /// Flat data
172        data: Vec<f64>,
173        /// Shape
174        shape: Vec<usize>,
175    },
176    /// i32 data with shape
177    Int32 {
178        /// Flat data
179        data: Vec<i32>,
180        /// Shape
181        shape: Vec<usize>,
182    },
183    /// i64 data with shape
184    Int64 {
185        /// Flat data
186        data: Vec<i64>,
187        /// Shape
188        shape: Vec<usize>,
189    },
190}
191
192impl NpyArray {
193    /// Get the shape
194    pub fn shape(&self) -> &[usize] {
195        match self {
196            NpyArray::Float32 { shape, .. } => shape,
197            NpyArray::Float64 { shape, .. } => shape,
198            NpyArray::Int32 { shape, .. } => shape,
199            NpyArray::Int64 { shape, .. } => shape,
200        }
201    }
202
203    /// Get the dtype
204    pub fn dtype(&self) -> NpyDtype {
205        match self {
206            NpyArray::Float32 { .. } => NpyDtype::Float32,
207            NpyArray::Float64 { .. } => NpyDtype::Float64,
208            NpyArray::Int32 { .. } => NpyDtype::Int32,
209            NpyArray::Int64 { .. } => NpyDtype::Int64,
210        }
211    }
212
213    /// Total number of elements
214    pub fn num_elements(&self) -> usize {
215        self.shape().iter().product()
216    }
217
218    /// Try to get f64 data
219    pub fn as_f64(&self) -> Result<&[f64]> {
220        match self {
221            NpyArray::Float64 { data, .. } => Ok(data),
222            _ => Err(IoError::ConversionError(format!(
223                "Array is {:?}, not Float64",
224                self.dtype()
225            ))),
226        }
227    }
228
229    /// Try to get f32 data
230    pub fn as_f32(&self) -> Result<&[f32]> {
231        match self {
232            NpyArray::Float32 { data, .. } => Ok(data),
233            _ => Err(IoError::ConversionError(format!(
234                "Array is {:?}, not Float32",
235                self.dtype()
236            ))),
237        }
238    }
239
240    /// Try to get i32 data
241    pub fn as_i32(&self) -> Result<&[i32]> {
242        match self {
243            NpyArray::Int32 { data, .. } => Ok(data),
244            _ => Err(IoError::ConversionError(format!(
245                "Array is {:?}, not Int32",
246                self.dtype()
247            ))),
248        }
249    }
250
251    /// Try to get i64 data
252    pub fn as_i64(&self) -> Result<&[i64]> {
253        match self {
254            NpyArray::Int64 { data, .. } => Ok(data),
255            _ => Err(IoError::ConversionError(format!(
256                "Array is {:?}, not Int64",
257                self.dtype()
258            ))),
259        }
260    }
261}
262
263/// Parse the header dict from raw header string
264pub fn parse_header_dict(header_str: &str) -> Result<NpyHeader> {
265    let header_str = header_str
266        .trim()
267        .trim_end_matches('\n')
268        .trim_end_matches('\0');
269
270    // Extract 'descr' value
271    let descr = extract_dict_value(header_str, "descr")?;
272    let (dtype, byte_order) = NpyDtype::from_descr(&descr)?;
273
274    // Extract 'fortran_order'
275    let fortran_str = extract_dict_value(header_str, "fortran_order")?;
276    let fortran_order = fortran_str.trim() == "True";
277
278    // Extract 'shape'
279    let shape_str = extract_dict_value(header_str, "shape")?;
280    let shape = parse_shape(&shape_str)?;
281
282    Ok(NpyHeader {
283        dtype,
284        byte_order,
285        fortran_order,
286        shape,
287    })
288}
289
290/// Extract a value from a Python dict string by key
291fn extract_dict_value(dict_str: &str, key: &str) -> Result<String> {
292    let search = format!("'{}': ", key);
293    let pos = dict_str.find(&search).or_else(|| {
294        let alt_search = format!("\"{}\":", key);
295        dict_str.find(&alt_search)
296    });
297
298    let start = match pos {
299        Some(p) => p + search.len(),
300        None => {
301            // Try alternate format
302            let alt = format!("'{}':", key);
303            match dict_str.find(&alt) {
304                Some(p) => p + alt.len(),
305                None => {
306                    return Err(IoError::FormatError(format!(
307                        "Key '{}' not found in header: {}",
308                        key, dict_str
309                    )))
310                }
311            }
312        }
313    };
314
315    let remaining = dict_str[start..].trim_start();
316
317    // Handle different value types
318    if remaining.starts_with('\'') || remaining.starts_with('"') {
319        let quote = remaining.as_bytes()[0];
320        let end = remaining[1..]
321            .find(|c: char| c as u8 == quote)
322            .ok_or_else(|| {
323                IoError::FormatError(format!("Unterminated string for key '{}'", key))
324            })?;
325        Ok(remaining[1..end + 1].to_string())
326    } else if remaining.starts_with('(') {
327        let end = remaining
328            .find(')')
329            .ok_or_else(|| IoError::FormatError(format!("Unterminated tuple for key '{}'", key)))?;
330        Ok(remaining[..end + 1].to_string())
331    } else {
332        // Boolean or other: read until comma or '}'
333        let end = remaining.find([',', '}']).unwrap_or(remaining.len());
334        Ok(remaining[..end].trim().to_string())
335    }
336}
337
338/// Parse a Python tuple shape string like "(3, 4)" or "(5,)"
339fn parse_shape(shape_str: &str) -> Result<Vec<usize>> {
340    let inner = shape_str
341        .trim()
342        .trim_start_matches('(')
343        .trim_end_matches(')');
344
345    if inner.is_empty() {
346        return Ok(vec![]); // scalar
347    }
348
349    let mut shape = Vec::new();
350    for part in inner.split(',') {
351        let part = part.trim();
352        if part.is_empty() {
353            continue;
354        }
355        let dim: usize = part
356            .parse()
357            .map_err(|_| IoError::FormatError(format!("Invalid shape dimension: '{}'", part)))?;
358        shape.push(dim);
359    }
360
361    Ok(shape)
362}