1use crate::error::{IoError, Result};
4
5pub const NPY_MAGIC: &[u8; 6] = b"\x93NUMPY";
7
8pub const NPY_MAJOR_VERSION: u8 = 1;
10
11pub const NPY_MINOR_VERSION: u8 = 0;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum NpyDtype {
17 Float32,
19 Float64,
21 Int32,
23 Int64,
25}
26
27impl NpyDtype {
28 pub fn element_size(&self) -> usize {
30 match self {
31 NpyDtype::Float32 => 4,
32 NpyDtype::Float64 => 8,
33 NpyDtype::Int32 => 4,
34 NpyDtype::Int64 => 8,
35 }
36 }
37
38 pub fn npy_str_le(&self) -> &'static str {
40 match self {
41 NpyDtype::Float32 => "<f4",
42 NpyDtype::Float64 => "<f8",
43 NpyDtype::Int32 => "<i4",
44 NpyDtype::Int64 => "<i8",
45 }
46 }
47
48 pub fn npy_str_be(&self) -> &'static str {
50 match self {
51 NpyDtype::Float32 => ">f4",
52 NpyDtype::Float64 => ">f8",
53 NpyDtype::Int32 => ">i4",
54 NpyDtype::Int64 => ">i8",
55 }
56 }
57
58 pub fn from_descr(descr: &str) -> Result<(Self, ByteOrder)> {
60 let descr = descr.trim().trim_matches('\'').trim_matches('"');
61 if descr.len() < 3 {
62 return Err(IoError::FormatError(format!(
63 "Invalid dtype descriptor: '{}'",
64 descr
65 )));
66 }
67
68 let endian_char = descr.as_bytes()[0];
69 let type_char = descr.as_bytes()[1];
70 let size_str = &descr[2..];
71
72 let byte_order = match endian_char {
73 b'<' | b'=' => ByteOrder::LittleEndian,
74 b'>' => ByteOrder::BigEndian,
75 b'|' => ByteOrder::NotApplicable,
76 _ => {
77 return Err(IoError::FormatError(format!(
78 "Unknown endian prefix: '{}'",
79 endian_char as char
80 )))
81 }
82 };
83
84 let size: usize = size_str
85 .parse()
86 .map_err(|_| IoError::FormatError(format!("Invalid dtype size: '{}'", size_str)))?;
87
88 let dtype = match (type_char, size) {
89 (b'f', 4) => NpyDtype::Float32,
90 (b'f', 8) => NpyDtype::Float64,
91 (b'i', 4) => NpyDtype::Int32,
92 (b'i', 8) => NpyDtype::Int64,
93 _ => {
94 return Err(IoError::FormatError(format!(
95 "Unsupported dtype: type='{}', size={}",
96 type_char as char, size
97 )))
98 }
99 };
100
101 Ok((dtype, byte_order))
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum ByteOrder {
108 LittleEndian,
110 BigEndian,
112 NotApplicable,
114}
115
116#[derive(Debug, Clone)]
118pub struct NpyHeader {
119 pub dtype: NpyDtype,
121 pub byte_order: ByteOrder,
123 pub fortran_order: bool,
125 pub shape: Vec<usize>,
127}
128
129impl NpyHeader {
130 pub fn num_elements(&self) -> usize {
132 self.shape.iter().product()
133 }
134
135 pub fn to_header_string(&self) -> String {
137 let descr = if cfg!(target_endian = "little") {
138 self.dtype.npy_str_le()
139 } else {
140 self.dtype.npy_str_be()
141 };
142
143 let fortran_str = if self.fortran_order { "True" } else { "False" };
144
145 let shape_str = if self.shape.len() == 1 {
146 format!("({},)", self.shape[0])
147 } else {
148 let parts: Vec<String> = self.shape.iter().map(|s| s.to_string()).collect();
149 format!("({})", parts.join(", "))
150 };
151
152 format!(
153 "{{'descr': '{}', 'fortran_order': {}, 'shape': {}, }}",
154 descr, fortran_str, shape_str
155 )
156 }
157}
158
159#[derive(Debug, Clone)]
161pub enum NpyArray {
162 Float32 {
164 data: Vec<f32>,
166 shape: Vec<usize>,
168 },
169 Float64 {
171 data: Vec<f64>,
173 shape: Vec<usize>,
175 },
176 Int32 {
178 data: Vec<i32>,
180 shape: Vec<usize>,
182 },
183 Int64 {
185 data: Vec<i64>,
187 shape: Vec<usize>,
189 },
190}
191
192impl NpyArray {
193 pub fn shape(&self) -> &[usize] {
195 match self {
196 NpyArray::Float32 { shape, .. } => shape,
197 NpyArray::Float64 { shape, .. } => shape,
198 NpyArray::Int32 { shape, .. } => shape,
199 NpyArray::Int64 { shape, .. } => shape,
200 }
201 }
202
203 pub fn dtype(&self) -> NpyDtype {
205 match self {
206 NpyArray::Float32 { .. } => NpyDtype::Float32,
207 NpyArray::Float64 { .. } => NpyDtype::Float64,
208 NpyArray::Int32 { .. } => NpyDtype::Int32,
209 NpyArray::Int64 { .. } => NpyDtype::Int64,
210 }
211 }
212
213 pub fn num_elements(&self) -> usize {
215 self.shape().iter().product()
216 }
217
218 pub fn as_f64(&self) -> Result<&[f64]> {
220 match self {
221 NpyArray::Float64 { data, .. } => Ok(data),
222 _ => Err(IoError::ConversionError(format!(
223 "Array is {:?}, not Float64",
224 self.dtype()
225 ))),
226 }
227 }
228
229 pub fn as_f32(&self) -> Result<&[f32]> {
231 match self {
232 NpyArray::Float32 { data, .. } => Ok(data),
233 _ => Err(IoError::ConversionError(format!(
234 "Array is {:?}, not Float32",
235 self.dtype()
236 ))),
237 }
238 }
239
240 pub fn as_i32(&self) -> Result<&[i32]> {
242 match self {
243 NpyArray::Int32 { data, .. } => Ok(data),
244 _ => Err(IoError::ConversionError(format!(
245 "Array is {:?}, not Int32",
246 self.dtype()
247 ))),
248 }
249 }
250
251 pub fn as_i64(&self) -> Result<&[i64]> {
253 match self {
254 NpyArray::Int64 { data, .. } => Ok(data),
255 _ => Err(IoError::ConversionError(format!(
256 "Array is {:?}, not Int64",
257 self.dtype()
258 ))),
259 }
260 }
261}
262
263pub fn parse_header_dict(header_str: &str) -> Result<NpyHeader> {
265 let header_str = header_str
266 .trim()
267 .trim_end_matches('\n')
268 .trim_end_matches('\0');
269
270 let descr = extract_dict_value(header_str, "descr")?;
272 let (dtype, byte_order) = NpyDtype::from_descr(&descr)?;
273
274 let fortran_str = extract_dict_value(header_str, "fortran_order")?;
276 let fortran_order = fortran_str.trim() == "True";
277
278 let shape_str = extract_dict_value(header_str, "shape")?;
280 let shape = parse_shape(&shape_str)?;
281
282 Ok(NpyHeader {
283 dtype,
284 byte_order,
285 fortran_order,
286 shape,
287 })
288}
289
290fn extract_dict_value(dict_str: &str, key: &str) -> Result<String> {
292 let search = format!("'{}': ", key);
293 let pos = dict_str.find(&search).or_else(|| {
294 let alt_search = format!("\"{}\":", key);
295 dict_str.find(&alt_search)
296 });
297
298 let start = match pos {
299 Some(p) => p + search.len(),
300 None => {
301 let alt = format!("'{}':", key);
303 match dict_str.find(&alt) {
304 Some(p) => p + alt.len(),
305 None => {
306 return Err(IoError::FormatError(format!(
307 "Key '{}' not found in header: {}",
308 key, dict_str
309 )))
310 }
311 }
312 }
313 };
314
315 let remaining = dict_str[start..].trim_start();
316
317 if remaining.starts_with('\'') || remaining.starts_with('"') {
319 let quote = remaining.as_bytes()[0];
320 let end = remaining[1..]
321 .find(|c: char| c as u8 == quote)
322 .ok_or_else(|| {
323 IoError::FormatError(format!("Unterminated string for key '{}'", key))
324 })?;
325 Ok(remaining[1..end + 1].to_string())
326 } else if remaining.starts_with('(') {
327 let end = remaining
328 .find(')')
329 .ok_or_else(|| IoError::FormatError(format!("Unterminated tuple for key '{}'", key)))?;
330 Ok(remaining[..end + 1].to_string())
331 } else {
332 let end = remaining.find([',', '}']).unwrap_or(remaining.len());
334 Ok(remaining[..end].trim().to_string())
335 }
336}
337
338fn parse_shape(shape_str: &str) -> Result<Vec<usize>> {
340 let inner = shape_str
341 .trim()
342 .trim_start_matches('(')
343 .trim_end_matches(')');
344
345 if inner.is_empty() {
346 return Ok(vec![]); }
348
349 let mut shape = Vec::new();
350 for part in inner.split(',') {
351 let part = part.trim();
352 if part.is_empty() {
353 continue;
354 }
355 let dim: usize = part
356 .parse()
357 .map_err(|_| IoError::FormatError(format!("Invalid shape dimension: '{}'", part)))?;
358 shape.push(dim);
359 }
360
361 Ok(shape)
362}