Skip to main content

scirs2_core/serialization/
mod.rs

1//! Cross-language serialization protocol for SciRS2.
2//!
3//! Provides a versioned binary format (`.scirs2` files) that can be
4//! read by Python bindings, WASM modules, and native Rust code.
5//!
6//! # Format
7//!
8//! Header: 64 bytes, little-endian
9//! - Magic: `b"SCIRS2\0\0"` (8 bytes)
10//! - Version: u16 major, u16 minor (4 bytes)
11//! - Payload type: u8 (1 byte): 0=array, 1=model, 2=stats, 3=custom
12//! - Compression: u8 (1 byte): 0=none, 1=lz4, 2=zstd
13//! - Checksum: u32 CRC32 of **uncompressed** payload (4 bytes)
14//! - Payload length: u64 (8 bytes) — bytes actually stored on disk
15//! - Reserved: 38 bytes of zeros
16//!
17//! Payload (variable length, may be compressed):
18//! - For arrays: `dtype(u8)` + `ndim(u8)` + `shape(u64 each, little-endian)` + raw element bytes (little-endian)
19//! - For models: JSON config + raw parameter bytes
20//! - For stats: key-value pairs with typed values
21//! - For custom: raw bytes (caller-defined format)
22//!
23//! # Example
24//!
25//! ```no_run
26//! use scirs2_core::serialization::{save_array, load_array, CompressionType};
27//! use ndarray::Array2;
28//! use std::path::Path;
29//!
30//! let data = Array2::<f32>::ones((3, 4)).into_dyn();
31//! let path = Path::new("/tmp/test.scirs2");
32//! save_array(path, &data, CompressionType::None).expect("should succeed");
33//! let loaded = load_array::<f32>(path).expect("should succeed");
34//! assert_eq!(data, loaded);
35//! ```
36
37use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
38use std::path::Path;
39
40use ndarray::{Array, IxDyn};
41
42const MAGIC: &[u8; 8] = b"SCIRS2\0\0";
43const VERSION_MAJOR: u16 = 0;
44const VERSION_MINOR: u16 = 3;
45const HEADER_SIZE: usize = 64;
46
47// Byte offsets within the 64-byte header
48const OFFSET_MAGIC: usize = 0;
49const OFFSET_VERSION_MAJOR: usize = 8;
50const OFFSET_VERSION_MINOR: usize = 10;
51const OFFSET_PAYLOAD_TYPE: usize = 12;
52const OFFSET_COMPRESSION: usize = 13;
53const OFFSET_CHECKSUM: usize = 14;
54const OFFSET_PAYLOAD_LENGTH: usize = 18;
55// bytes 26..64 are reserved (38 bytes, must be zero)
56
57// ─── PayloadType ──────────────────────────────────────────────────────────────
58
59/// Identifies the kind of data stored in the `.scirs2` payload.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61#[repr(u8)]
62pub enum PayloadType {
63    /// N-dimensional array with dtype prefix and shape encoding.
64    Array = 0,
65    /// Model: JSON config followed by raw parameter bytes.
66    Model = 1,
67    /// Statistics: key-value pairs with typed values.
68    Stats = 2,
69    /// Custom: raw bytes with caller-defined semantics.
70    Custom = 3,
71}
72
73impl PayloadType {
74    fn from_u8(v: u8) -> Result<Self, SerializationError> {
75        match v {
76            0 => Ok(Self::Array),
77            1 => Ok(Self::Model),
78            2 => Ok(Self::Stats),
79            3 => Ok(Self::Custom),
80            other => Err(SerializationError::UnknownPayloadType(other)),
81        }
82    }
83}
84
85// ─── CompressionType ──────────────────────────────────────────────────────────
86
87/// Compression algorithm applied to the payload bytes.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89#[repr(u8)]
90pub enum CompressionType {
91    /// No compression; payload stored verbatim.
92    None = 0,
93    /// LZ4 frame compression — very fast, moderate ratio.
94    Lz4 = 1,
95    /// Zstandard compression — moderate speed, excellent ratio.
96    Zstd = 2,
97}
98
99impl CompressionType {
100    fn from_u8(v: u8) -> Result<Self, SerializationError> {
101        match v {
102            0 => Ok(Self::None),
103            1 => Ok(Self::Lz4),
104            2 => Ok(Self::Zstd),
105            other => Err(SerializationError::Compression(format!(
106                "unknown compression type byte: {}",
107                other
108            ))),
109        }
110    }
111}
112
113// ─── Header ───────────────────────────────────────────────────────────────────
114
115/// Parsed `.scirs2` file header (64 bytes).
116#[derive(Debug, Clone)]
117pub struct Scirs2Header {
118    /// `(major, minor)` format version.
119    pub version: (u16, u16),
120    /// Kind of data in the payload.
121    pub payload_type: PayloadType,
122    /// How the payload is compressed on disk.
123    pub compression: CompressionType,
124    /// CRC32 of the **uncompressed** payload.
125    pub checksum: u32,
126    /// Byte count stored on disk (after optional compression).
127    pub payload_length: u64,
128}
129
130impl Scirs2Header {
131    /// Serialize this header into a fixed 64-byte array.
132    fn to_bytes(&self) -> [u8; HEADER_SIZE] {
133        let mut buf = [0u8; HEADER_SIZE];
134        buf[OFFSET_MAGIC..OFFSET_MAGIC + 8].copy_from_slice(MAGIC);
135        buf[OFFSET_VERSION_MAJOR..OFFSET_VERSION_MAJOR + 2]
136            .copy_from_slice(&self.version.0.to_le_bytes());
137        buf[OFFSET_VERSION_MINOR..OFFSET_VERSION_MINOR + 2]
138            .copy_from_slice(&self.version.1.to_le_bytes());
139        buf[OFFSET_PAYLOAD_TYPE] = self.payload_type as u8;
140        buf[OFFSET_COMPRESSION] = self.compression as u8;
141        buf[OFFSET_CHECKSUM..OFFSET_CHECKSUM + 4].copy_from_slice(&self.checksum.to_le_bytes());
142        buf[OFFSET_PAYLOAD_LENGTH..OFFSET_PAYLOAD_LENGTH + 8]
143            .copy_from_slice(&self.payload_length.to_le_bytes());
144        // bytes 26..64 remain zero (reserved)
145        buf
146    }
147
148    /// Parse a 64-byte buffer into a `Scirs2Header`.
149    fn from_bytes(buf: &[u8; HEADER_SIZE]) -> Result<Self, SerializationError> {
150        // Validate magic
151        if &buf[OFFSET_MAGIC..OFFSET_MAGIC + 8] != MAGIC.as_slice() {
152            return Err(SerializationError::InvalidMagic);
153        }
154
155        let major = u16::from_le_bytes([buf[OFFSET_VERSION_MAJOR], buf[OFFSET_VERSION_MAJOR + 1]]);
156        let minor = u16::from_le_bytes([buf[OFFSET_VERSION_MINOR], buf[OFFSET_VERSION_MINOR + 1]]);
157
158        // Forward-compatibility: reject files written by a future major version
159        if major > VERSION_MAJOR {
160            return Err(SerializationError::UnsupportedVersion(major, minor));
161        }
162
163        let payload_type = PayloadType::from_u8(buf[OFFSET_PAYLOAD_TYPE])?;
164        let compression = CompressionType::from_u8(buf[OFFSET_COMPRESSION])?;
165
166        let checksum = u32::from_le_bytes([
167            buf[OFFSET_CHECKSUM],
168            buf[OFFSET_CHECKSUM + 1],
169            buf[OFFSET_CHECKSUM + 2],
170            buf[OFFSET_CHECKSUM + 3],
171        ]);
172
173        // SAFETY: OFFSET_PAYLOAD_LENGTH..OFFSET_PAYLOAD_LENGTH+8 is always within [0,64)
174        let pl_bytes: [u8; 8] = buf[OFFSET_PAYLOAD_LENGTH..OFFSET_PAYLOAD_LENGTH + 8]
175            .try_into()
176            .map_err(|_| {
177                SerializationError::Io(io::Error::new(
178                    io::ErrorKind::InvalidData,
179                    "internal: slice length invariant violated reading payload_length",
180                ))
181            })?;
182        let payload_length = u64::from_le_bytes(pl_bytes);
183
184        Ok(Self {
185            version: (major, minor),
186            payload_type,
187            compression,
188            checksum,
189            payload_length,
190        })
191    }
192}
193
194// ─── Scirs2Writer ─────────────────────────────────────────────────────────────
195
196/// Low-level writer for `.scirs2` files.
197///
198/// Writes exactly one payload per instance. For multiple payloads
199/// use separate writers.
200///
201/// # Example
202///
203/// ```no_run
204/// use scirs2_core::serialization::{Scirs2Writer, PayloadType, CompressionType};
205/// use std::fs::File;
206///
207/// let file = File::create("/tmp/out.scirs2").expect("should succeed");
208/// let mut writer = Scirs2Writer::new(file);
209/// writer
210///     .write_payload(PayloadType::Custom, b"hello scirs2", CompressionType::None)
211///     .expect("should succeed");
212/// ```
213pub struct Scirs2Writer<W: Write> {
214    inner: W,
215}
216
217impl<W: Write> Scirs2Writer<W> {
218    /// Wrap an existing [`Write`] implementor.
219    pub fn new(writer: W) -> Self {
220        Self { inner: writer }
221    }
222
223    /// Compress (if requested) and write a single payload to the underlying writer.
224    ///
225    /// The CRC32 checksum is computed over the **uncompressed** `payload` bytes.
226    /// The stored bytes (in the file, after the header) may be compressed.
227    pub fn write_payload(
228        &mut self,
229        payload_type: PayloadType,
230        payload: &[u8],
231        compression: CompressionType,
232    ) -> Result<(), SerializationError> {
233        let checksum = crc32fast::hash(payload);
234        let stored = compress_payload(payload, compression)?;
235
236        let header = Scirs2Header {
237            version: (VERSION_MAJOR, VERSION_MINOR),
238            payload_type,
239            compression,
240            checksum,
241            payload_length: stored.len() as u64,
242        };
243
244        self.inner.write_all(&header.to_bytes())?;
245        self.inner.write_all(&stored)?;
246        Ok(())
247    }
248}
249
250// ─── Scirs2Reader ─────────────────────────────────────────────────────────────
251
252/// Low-level reader for `.scirs2` files.
253///
254/// The header is parsed eagerly on construction; the payload bytes are read
255/// lazily on demand.
256///
257/// # Example
258///
259/// ```no_run
260/// use scirs2_core::serialization::Scirs2Reader;
261/// use std::fs::File;
262/// use std::io::BufReader;
263///
264/// let file = BufReader::new(File::open("/tmp/out.scirs2").expect("should succeed"));
265/// let mut reader = Scirs2Reader::new(file).expect("should succeed");
266/// println!("payload type = {:?}", reader.header.payload_type);
267/// let bytes = reader.read_payload().expect("should succeed");
268/// ```
269pub struct Scirs2Reader<R: Read + Seek> {
270    inner: R,
271    /// Header parsed from the beginning of the file.
272    pub header: Scirs2Header,
273}
274
275impl<R: Read + Seek> Scirs2Reader<R> {
276    /// Open a `.scirs2` reader, validating and parsing the header immediately.
277    ///
278    /// Returns [`SerializationError::InvalidMagic`] when the file is not a
279    /// valid `.scirs2` file, or [`SerializationError::UnsupportedVersion`] when
280    /// the format major version is newer than this library.
281    pub fn new(mut reader: R) -> Result<Self, SerializationError> {
282        let mut buf = [0u8; HEADER_SIZE];
283        reader.read_exact(&mut buf)?;
284        let header = Scirs2Header::from_bytes(&buf)?;
285        Ok(Self {
286            inner: reader,
287            header,
288        })
289    }
290
291    /// Read and decompress the payload, returning the raw (uncompressed) bytes.
292    ///
293    /// This method seeks back to the start of the payload each time it is
294    /// called, so repeated calls are safe.
295    pub fn read_payload(&mut self) -> Result<Vec<u8>, SerializationError> {
296        self.inner.seek(SeekFrom::Start(HEADER_SIZE as u64))?;
297
298        let len = self.header.payload_length as usize;
299        let mut stored = vec![0u8; len];
300        self.inner.read_exact(&mut stored)?;
301
302        decompress_payload(&stored, self.header.compression, len)
303    }
304
305    /// Read the payload and verify its CRC32 against the header checksum.
306    ///
307    /// Returns `Ok(true)` if the checksum matches, `Ok(false)` otherwise.
308    /// Returns `Err` on I/O or decompression failure.
309    pub fn verify_checksum(&mut self) -> Result<bool, SerializationError> {
310        let payload = self.read_payload()?;
311        let computed = crc32fast::hash(&payload);
312        Ok(computed == self.header.checksum)
313    }
314}
315
316// ─── Compression helpers ──────────────────────────────────────────────────────
317
318/// Compress `data` using the requested algorithm.
319fn compress_payload(
320    data: &[u8],
321    compression: CompressionType,
322) -> Result<Vec<u8>, SerializationError> {
323    match compression {
324        CompressionType::None => Ok(data.to_vec()),
325
326        CompressionType::Lz4 => {
327            #[cfg(feature = "serialization")]
328            {
329                oxiarc_lz4::compress(data)
330                    .map_err(|e| SerializationError::Compression(format!("LZ4 compress: {}", e)))
331            }
332            #[cfg(not(feature = "serialization"))]
333            {
334                let _ = data;
335                Err(SerializationError::Compression(
336                    "LZ4 compression requires the `serialization` feature".to_string(),
337                ))
338            }
339        }
340
341        CompressionType::Zstd => {
342            #[cfg(feature = "serialization")]
343            {
344                oxiarc_zstd::compress(data)
345                    .map_err(|e| SerializationError::Compression(format!("Zstd compress: {}", e)))
346            }
347            #[cfg(not(feature = "serialization"))]
348            {
349                let _ = data;
350                Err(SerializationError::Compression(
351                    "Zstd compression requires the `serialization` feature".to_string(),
352                ))
353            }
354        }
355    }
356}
357
358/// Decompress `data` using the stored compression type.
359///
360/// `stored_len` is used as a hint for decompressors that require an output
361/// size hint (LZ4 frame decompressor uses `stored_len * 4` as the upper bound).
362fn decompress_payload(
363    data: &[u8],
364    compression: CompressionType,
365    stored_len: usize,
366) -> Result<Vec<u8>, SerializationError> {
367    match compression {
368        CompressionType::None => Ok(data.to_vec()),
369
370        CompressionType::Lz4 => {
371            #[cfg(feature = "serialization")]
372            {
373                // LZ4 frame decompression: use 4× the stored size as an upper bound.
374                // For highly compressible data this might need to be larger; in practice
375                // scientific arrays have at most ~8× decompression ratios.
376                let max_output = stored_len.saturating_mul(8).max(4096);
377                oxiarc_lz4::decompress(data, max_output)
378                    .map_err(|e| SerializationError::Compression(format!("LZ4 decompress: {}", e)))
379            }
380            #[cfg(not(feature = "serialization"))]
381            {
382                let _ = (data, stored_len);
383                Err(SerializationError::Compression(
384                    "LZ4 decompression requires the `serialization` feature".to_string(),
385                ))
386            }
387        }
388
389        CompressionType::Zstd => {
390            #[cfg(feature = "serialization")]
391            {
392                let _ = stored_len;
393                oxiarc_zstd::decompress(data)
394                    .map_err(|e| SerializationError::Compression(format!("Zstd decompress: {}", e)))
395            }
396            #[cfg(not(feature = "serialization"))]
397            {
398                let _ = (data, stored_len);
399                Err(SerializationError::Compression(
400                    "Zstd decompression requires the `serialization` feature".to_string(),
401                ))
402            }
403        }
404    }
405}
406
407// ─── ArrayElement trait ───────────────────────────────────────────────────────
408
409/// Element type supported by the `.scirs2` array payload.
410///
411/// Each concrete type carries a stable 1-byte `dtype_id` embedded in the file,
412/// enabling typed deserialization and cross-language interoperability.
413///
414/// # Dtype IDs
415///
416/// | ID | Type |
417/// |----|------|
418/// | 1  | f32  |
419/// | 2  | f64  |
420/// | 3  | i32  |
421/// | 4  | i64  |
422/// | 5  | u32  |
423/// | 6  | u64  |
424pub trait ArrayElement: Copy + 'static {
425    /// Stable 1-byte dtype identifier embedded in the binary format.
426    fn dtype_id() -> u8;
427    /// Size in bytes of one element.
428    fn element_size() -> usize;
429    /// Deserialize `n` elements from a little-endian byte slice.
430    fn from_le_bytes_slice(bytes: &[u8], n: usize) -> Vec<Self>;
431    /// Serialize a slice of elements to little-endian bytes.
432    fn to_le_bytes_vec(slice: &[Self]) -> Vec<u8>;
433}
434
435/// Implement `ArrayElement` for a primitive numeric type.
436macro_rules! impl_array_element {
437    ($ty:ty, $id:expr, $size:expr, $arr:expr) => {
438        impl ArrayElement for $ty {
439            fn dtype_id() -> u8 {
440                $id
441            }
442            fn element_size() -> usize {
443                $size
444            }
445
446            fn from_le_bytes_slice(bytes: &[u8], n: usize) -> Vec<Self> {
447                (0..n)
448                    .map(|i| {
449                        let start = i * $size;
450                        // We checked that `bytes` has enough data before calling this
451                        let arr: [u8; $size] =
452                            bytes[start..start + $size].try_into().unwrap_or($arr);
453                        <$ty>::from_le_bytes(arr)
454                    })
455                    .collect()
456            }
457
458            fn to_le_bytes_vec(slice: &[Self]) -> Vec<u8> {
459                slice.iter().flat_map(|v| v.to_le_bytes()).collect()
460            }
461        }
462    };
463}
464
465impl_array_element!(f32, 1, 4, [0u8; 4]);
466impl_array_element!(f64, 2, 8, [0u8; 8]);
467impl_array_element!(i32, 3, 4, [0u8; 4]);
468impl_array_element!(i64, 4, 8, [0u8; 8]);
469impl_array_element!(u32, 5, 4, [0u8; 4]);
470impl_array_element!(u64, 6, 8, [0u8; 8]);
471
472// ─── Array encoding / decoding ────────────────────────────────────────────────
473
474/// Encode an ndarray into the `.scirs2` array payload format.
475///
476/// Layout: `dtype_id(u8)` | `ndim(u8)` | `dim_0(u64le)` | … | `dim_{n-1}(u64le)` | `data(le bytes)`
477fn encode_array<F: ArrayElement>(array: &Array<F, IxDyn>) -> Vec<u8> {
478    let shape = array.shape();
479    let ndim = shape.len();
480
481    let header_bytes = 2 + ndim * 8;
482    let data_bytes = array.len() * F::element_size();
483    let mut buf = Vec::with_capacity(header_bytes + data_bytes);
484
485    buf.push(F::dtype_id());
486    buf.push(ndim as u8);
487
488    for &dim in shape {
489        buf.extend_from_slice(&(dim as u64).to_le_bytes());
490    }
491
492    // Collect in C-contiguous (row-major) iteration order
493    let data: Vec<F> = array.iter().copied().collect();
494    buf.extend_from_slice(&F::to_le_bytes_vec(&data));
495
496    buf
497}
498
499/// Decode an ndarray from the `.scirs2` array payload format.
500fn decode_array<F: ArrayElement>(payload: &[u8]) -> Result<Array<F, IxDyn>, SerializationError> {
501    if payload.len() < 2 {
502        return Err(SerializationError::Io(io::Error::new(
503            io::ErrorKind::UnexpectedEof,
504            "payload too short to contain array header (need at least 2 bytes)",
505        )));
506    }
507
508    let actual_dtype = payload[0];
509    let expected_dtype = F::dtype_id();
510    if actual_dtype != expected_dtype {
511        return Err(SerializationError::TypeMismatch {
512            expected: expected_dtype,
513            actual: actual_dtype,
514        });
515    }
516
517    let ndim = payload[1] as usize;
518    let shape_end = 2 + ndim * 8;
519
520    if payload.len() < shape_end {
521        return Err(SerializationError::Io(io::Error::new(
522            io::ErrorKind::UnexpectedEof,
523            format!(
524                "payload too short to read shape: need {} bytes for {} dims, have {}",
525                shape_end,
526                ndim,
527                payload.len()
528            ),
529        )));
530    }
531
532    let mut shape = Vec::with_capacity(ndim);
533    for i in 0..ndim {
534        let offset = 2 + i * 8;
535        let dim_bytes: [u8; 8] = payload[offset..offset + 8].try_into().map_err(|_| {
536            SerializationError::Io(io::Error::new(
537                io::ErrorKind::InvalidData,
538                format!("internal: failed to read dim {} from payload", i),
539            ))
540        })?;
541        shape.push(u64::from_le_bytes(dim_bytes) as usize);
542    }
543
544    let n_elements: usize = shape.iter().product();
545    let data_bytes = n_elements * F::element_size();
546
547    if payload.len() < shape_end + data_bytes {
548        return Err(SerializationError::Io(io::Error::new(
549            io::ErrorKind::UnexpectedEof,
550            format!(
551                "payload too short for array data: need {} bytes, have {}",
552                shape_end + data_bytes,
553                payload.len()
554            ),
555        )));
556    }
557
558    let elements = F::from_le_bytes_slice(&payload[shape_end..shape_end + data_bytes], n_elements);
559
560    Array::from_shape_vec(IxDyn(&shape), elements).map_err(|e| {
561        SerializationError::Io(io::Error::new(
562            io::ErrorKind::InvalidData,
563            format!("shape/data mismatch during array reconstruction: {}", e),
564        ))
565    })
566}
567
568// ─── Public convenience API ───────────────────────────────────────────────────
569
570/// Save an n-dimensional array to a `.scirs2` file.
571///
572/// The file is created (or truncated) at `path`. The element type `F` is
573/// embedded in the payload so [`load_array`] can verify type safety on load.
574///
575/// # Arguments
576///
577/// * `path` — Destination file path.
578/// * `array` — The array to serialize.
579/// * `compression` — Compression algorithm applied to the payload.
580///
581/// # Errors
582///
583/// Returns [`SerializationError`] on I/O failures or if the chosen compression
584/// algorithm is unavailable in this build.
585///
586/// # Example
587///
588/// ```no_run
589/// use scirs2_core::serialization::{save_array, CompressionType};
590/// use ndarray::Array2;
591///
592/// let data = Array2::<f64>::eye(4).into_dyn();
593/// save_array(std::path::Path::new("/tmp/eye4.scirs2"), &data, CompressionType::None).expect("should succeed");
594/// ```
595pub fn save_array<F: ArrayElement>(
596    path: &Path,
597    array: &Array<F, IxDyn>,
598    compression: CompressionType,
599) -> Result<(), SerializationError> {
600    let file = std::fs::File::create(path)?;
601    let writer = BufWriter::new(file);
602    let mut scirs2 = Scirs2Writer::new(writer);
603    let payload = encode_array(array);
604    scirs2.write_payload(PayloadType::Array, &payload, compression)
605}
606
607/// Load an n-dimensional array from a `.scirs2` file.
608///
609/// The element type `F` is checked against the dtype stored in the file;
610/// a [`SerializationError::TypeMismatch`] is returned if they differ.
611///
612/// # Example
613///
614/// ```no_run
615/// use scirs2_core::serialization::load_array;
616///
617/// let arr = load_array::<f64>(std::path::Path::new("/tmp/eye4.scirs2")).expect("should succeed");
618/// println!("shape: {:?}", arr.shape());
619/// ```
620pub fn load_array<F: ArrayElement>(path: &Path) -> Result<Array<F, IxDyn>, SerializationError> {
621    let file = std::fs::File::open(path)?;
622    let reader = BufReader::new(file);
623    let mut scirs2 = Scirs2Reader::new(reader)?;
624
625    if scirs2.header.payload_type != PayloadType::Array {
626        return Err(SerializationError::Io(io::Error::new(
627            io::ErrorKind::InvalidData,
628            format!(
629                "expected Array payload type (0), found {:?} ({})",
630                scirs2.header.payload_type, scirs2.header.payload_type as u8
631            ),
632        )));
633    }
634
635    let payload = scirs2.read_payload()?;
636    decode_array::<F>(&payload)
637}
638
639// ─── Error type ───────────────────────────────────────────────────────────────
640
641/// Errors that can occur during `.scirs2` serialization or deserialization.
642#[derive(Debug, thiserror::Error)]
643pub enum SerializationError {
644    /// Underlying I/O failure (file not found, permission denied, etc.).
645    #[error("IO error: {0}")]
646    Io(#[from] std::io::Error),
647
648    /// The file does not start with the expected `b"SCIRS2\0\0"` magic bytes.
649    #[error("Invalid magic bytes — not a valid .scirs2 file")]
650    InvalidMagic,
651
652    /// The file was written by a newer major version of this library.
653    #[error(
654        "Unsupported version {0}.{1} (this library supports up to {major}.x)",
655        major = VERSION_MAJOR
656    )]
657    UnsupportedVersion(u16, u16),
658
659    /// CRC32 of the decompressed payload does not match the stored checksum.
660    #[error("Checksum mismatch — file may be corrupted")]
661    ChecksumMismatch,
662
663    /// A compression or decompression operation failed.
664    #[error("Compression error: {0}")]
665    Compression(String),
666
667    /// The payload type byte is not one of the defined [`PayloadType`] variants.
668    #[error("Unknown payload type: {0}")]
669    UnknownPayloadType(u8),
670
671    /// The dtype stored in the file differs from the type `F` requested by the caller.
672    #[error("Type mismatch: expected dtype {expected}, found {actual}")]
673    TypeMismatch {
674        /// Dtype ID the caller requested.
675        expected: u8,
676        /// Dtype ID found in the file.
677        actual: u8,
678    },
679}
680
681// ─── Tests ────────────────────────────────────────────────────────────────────
682
683#[cfg(test)]
684mod tests {
685    use super::*;
686    use ndarray::{Array1, Array2, Array3};
687    use std::io::Cursor;
688
689    // ── header roundtrip ──────────────────────────────────────────────────────
690
691    #[test]
692    fn test_header_roundtrip_all_fields() {
693        let original = Scirs2Header {
694            version: (0, 3),
695            payload_type: PayloadType::Array,
696            compression: CompressionType::None,
697            checksum: 0xDEAD_BEEF,
698            payload_length: 1_234_567_890,
699        };
700        let bytes = original.to_bytes();
701        assert_eq!(bytes.len(), HEADER_SIZE);
702
703        let parsed = Scirs2Header::from_bytes(&bytes).expect("header parse failed");
704        assert_eq!(parsed.version, original.version);
705        assert_eq!(parsed.payload_type, original.payload_type);
706        assert_eq!(parsed.compression, original.compression);
707        assert_eq!(parsed.checksum, original.checksum);
708        assert_eq!(parsed.payload_length, original.payload_length);
709    }
710
711    #[test]
712    fn test_header_reserved_bytes_are_zero() {
713        let header = Scirs2Header {
714            version: (0, 3),
715            payload_type: PayloadType::Custom,
716            compression: CompressionType::None,
717            checksum: 42,
718            payload_length: 8,
719        };
720        let bytes = header.to_bytes();
721        // Reserved bytes: 26..64
722        for i in 26..64 {
723            assert_eq!(bytes[i], 0, "reserved byte {} should be zero", i);
724        }
725    }
726
727    #[test]
728    fn test_invalid_magic_rejected() {
729        let mut buf = [0u8; HEADER_SIZE];
730        buf[0..8].copy_from_slice(b"BADMAGIC");
731        assert!(
732            matches!(
733                Scirs2Header::from_bytes(&buf),
734                Err(SerializationError::InvalidMagic)
735            ),
736            "should reject non-SCIRS2 magic"
737        );
738    }
739
740    #[test]
741    fn test_future_major_version_rejected() {
742        let header = Scirs2Header {
743            version: (255, 0),
744            payload_type: PayloadType::Custom,
745            compression: CompressionType::None,
746            checksum: 0,
747            payload_length: 0,
748        };
749        let bytes = header.to_bytes();
750        assert!(
751            matches!(
752                Scirs2Header::from_bytes(&bytes),
753                Err(SerializationError::UnsupportedVersion(255, 0))
754            ),
755            "should reject future major version"
756        );
757    }
758
759    #[test]
760    fn test_unknown_payload_type_rejected() {
761        let header = Scirs2Header {
762            version: (0, 3),
763            payload_type: PayloadType::Custom,
764            compression: CompressionType::None,
765            checksum: 0,
766            payload_length: 0,
767        };
768        let mut bytes = header.to_bytes();
769        bytes[OFFSET_PAYLOAD_TYPE] = 99; // unknown type — not in PayloadType enum
770        let result = Scirs2Header::from_bytes(&bytes);
771        assert!(
772            matches!(result, Err(SerializationError::UnknownPayloadType(99))),
773            "should return UnknownPayloadType(99) for unknown payload type byte"
774        );
775    }
776
777    #[test]
778    fn test_payload_type_from_u8_all_variants() {
779        assert!(matches!(PayloadType::from_u8(0), Ok(PayloadType::Array)));
780        assert!(matches!(PayloadType::from_u8(1), Ok(PayloadType::Model)));
781        assert!(matches!(PayloadType::from_u8(2), Ok(PayloadType::Stats)));
782        assert!(matches!(PayloadType::from_u8(3), Ok(PayloadType::Custom)));
783        assert!(matches!(
784            PayloadType::from_u8(4),
785            Err(SerializationError::UnknownPayloadType(4))
786        ));
787    }
788
789    // ── writer / reader roundtrip ─────────────────────────────────────────────
790
791    #[test]
792    fn test_custom_payload_no_compression_roundtrip() {
793        let payload = b"the quick brown fox jumps over the lazy dog";
794        let mut buf = Vec::new();
795        {
796            let mut writer = Scirs2Writer::new(&mut buf);
797            writer
798                .write_payload(PayloadType::Custom, payload, CompressionType::None)
799                .expect("write_payload failed");
800        }
801
802        let cursor = Cursor::new(&buf);
803        let mut reader = Scirs2Reader::new(cursor).expect("Scirs2Reader::new failed");
804        assert_eq!(reader.header.payload_type, PayloadType::Custom);
805        assert_eq!(reader.header.compression, CompressionType::None);
806        assert_eq!(reader.header.payload_length, payload.len() as u64);
807
808        let out = reader.read_payload().expect("read_payload failed");
809        assert_eq!(out.as_slice(), payload.as_slice());
810    }
811
812    #[test]
813    fn test_empty_payload_roundtrip() {
814        let payload: &[u8] = b"";
815        let mut buf = Vec::new();
816        {
817            let mut writer = Scirs2Writer::new(&mut buf);
818            writer
819                .write_payload(PayloadType::Stats, payload, CompressionType::None)
820                .expect("write empty payload failed");
821        }
822        let cursor = Cursor::new(&buf);
823        let mut reader = Scirs2Reader::new(cursor).expect("reader init failed");
824        let out = reader.read_payload().expect("read empty payload failed");
825        assert!(out.is_empty());
826    }
827
828    #[test]
829    fn test_verify_checksum_passes_for_intact_data() {
830        let payload = b"integrity check payload 0xDEADBEEF";
831        let mut buf = Vec::new();
832        {
833            let mut writer = Scirs2Writer::new(&mut buf);
834            writer
835                .write_payload(PayloadType::Stats, payload, CompressionType::None)
836                .expect("write failed");
837        }
838        let cursor = Cursor::new(&buf);
839        let mut reader = Scirs2Reader::new(cursor).expect("reader init failed");
840        assert!(
841            reader.verify_checksum().expect("checksum check failed"),
842            "checksum should pass for intact data"
843        );
844    }
845
846    #[test]
847    fn test_verify_checksum_fails_on_bit_flip() {
848        let payload = b"data that will be corrupted in transit";
849        let mut buf = Vec::new();
850        {
851            let mut writer = Scirs2Writer::new(&mut buf);
852            writer
853                .write_payload(PayloadType::Custom, payload, CompressionType::None)
854                .expect("write failed");
855        }
856
857        // Flip the last byte of the payload section
858        let last = buf.len() - 1;
859        buf[last] ^= 0xFF;
860
861        let cursor = Cursor::new(&buf);
862        let mut reader = Scirs2Reader::new(cursor).expect("reader init (corrupted) failed");
863        assert!(
864            !reader.verify_checksum().expect("checksum check errored"),
865            "checksum should fail after bit flip"
866        );
867    }
868
869    #[test]
870    fn test_version_fields_in_file() {
871        let payload = b"version test";
872        let mut buf = Vec::new();
873        let mut writer = Scirs2Writer::new(&mut buf);
874        writer
875            .write_payload(PayloadType::Custom, payload, CompressionType::None)
876            .expect("write failed");
877
878        let cursor = Cursor::new(&buf);
879        let reader = Scirs2Reader::new(cursor).expect("reader failed");
880        assert_eq!(reader.header.version, (VERSION_MAJOR, VERSION_MINOR));
881    }
882
883    // ── array encode / decode (in-memory) ────────────────────────────────────
884
885    #[test]
886    fn test_encode_decode_f32_1d() {
887        let original =
888            Array1::<f32>::from_vec(vec![1.0, 2.5, -3.0, f32::MAX, f32::MIN_POSITIVE]).into_dyn();
889        let encoded = encode_array(&original);
890        let decoded = decode_array::<f32>(&encoded).expect("f32 decode failed");
891        assert_eq!(original, decoded, "f32 1d roundtrip mismatch");
892    }
893
894    #[test]
895    fn test_encode_decode_f64_2d() {
896        let original = Array2::<f64>::from_shape_vec(
897            (4, 6),
898            (0..24)
899                .map(|i| i as f64 * std::f64::consts::PI / 12.0)
900                .collect(),
901        )
902        .expect("shape error")
903        .into_dyn();
904
905        let encoded = encode_array(&original);
906        let decoded = decode_array::<f64>(&encoded).expect("f64 2d decode failed");
907        assert_eq!(original, decoded, "f64 2d roundtrip mismatch");
908    }
909
910    #[test]
911    fn test_encode_decode_i32_3d() {
912        let original =
913            Array3::<i32>::from_shape_vec((2, 3, 4), (0..24).map(|i| i as i32 - 12).collect())
914                .expect("shape error")
915                .into_dyn();
916
917        let encoded = encode_array(&original);
918        let decoded = decode_array::<i32>(&encoded).expect("i32 3d decode failed");
919        assert_eq!(original, decoded, "i32 3d roundtrip mismatch");
920    }
921
922    #[test]
923    fn test_encode_decode_i64_1d() {
924        let original = Array1::<i64>::from_vec(vec![i64::MIN, -1, 0, 1, i64::MAX]).into_dyn();
925        let encoded = encode_array(&original);
926        let decoded = decode_array::<i64>(&encoded).expect("i64 decode failed");
927        assert_eq!(original, decoded, "i64 roundtrip mismatch");
928    }
929
930    #[test]
931    fn test_encode_decode_u32() {
932        let original = Array1::<u32>::from_vec(vec![0, 1, u32::MAX / 2, u32::MAX]).into_dyn();
933        let encoded = encode_array(&original);
934        let decoded = decode_array::<u32>(&encoded).expect("u32 decode failed");
935        assert_eq!(original, decoded, "u32 roundtrip mismatch");
936    }
937
938    #[test]
939    fn test_encode_decode_u64() {
940        let original = Array1::<u64>::from_vec(vec![0, 1, u64::MAX / 2, u64::MAX]).into_dyn();
941        let encoded = encode_array(&original);
942        let decoded = decode_array::<u64>(&encoded).expect("u64 decode failed");
943        assert_eq!(original, decoded, "u64 roundtrip mismatch");
944    }
945
946    #[test]
947    fn test_dtype_mismatch_error() {
948        let original = Array1::<f32>::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
949        let encoded = encode_array(&original); // dtype_id = 1 (f32)
950                                               // Try to decode as f64 (dtype_id = 2)
951        let result = decode_array::<f64>(&encoded);
952        assert!(
953            matches!(
954                result,
955                Err(SerializationError::TypeMismatch {
956                    expected: 2,
957                    actual: 1
958                })
959            ),
960            "expected TypeMismatch error"
961        );
962    }
963
964    #[test]
965    fn test_encode_zero_dimensional_array() {
966        // 0-dimensional array (scalar)
967        let original = Array::<f64, IxDyn>::from_elem(IxDyn(&[]), 42.0);
968        let encoded = encode_array(&original);
969        let decoded = decode_array::<f64>(&encoded).expect("0d decode failed");
970        assert_eq!(original, decoded, "0d array roundtrip mismatch");
971    }
972
973    // ── save_array / load_array (file I/O) ────────────────────────────────────
974
975    #[test]
976    fn test_save_load_f32_no_compression() {
977        let tmp = std::env::temp_dir().join("scirs2_test_f32_nocomp.scirs2");
978        let original =
979            Array2::<f32>::from_shape_vec((8, 8), (0..64).map(|i| i as f32 * 0.5 - 16.0).collect())
980                .expect("shape error")
981                .into_dyn();
982
983        save_array(&tmp, &original, CompressionType::None).expect("save_array failed");
984        let loaded = load_array::<f32>(&tmp).expect("load_array failed");
985
986        assert_eq!(original, loaded, "f32 save/load mismatch");
987        std::fs::remove_file(&tmp).ok();
988    }
989
990    #[test]
991    fn test_save_load_f64_no_compression() {
992        let tmp = std::env::temp_dir().join("scirs2_test_f64_nocomp.scirs2");
993        let original = Array1::<f64>::linspace(0.0, 1.0, 500).into_dyn();
994
995        save_array(&tmp, &original, CompressionType::None).expect("save_array f64 failed");
996        let loaded = load_array::<f64>(&tmp).expect("load_array f64 failed");
997
998        assert_eq!(original, loaded, "f64 save/load mismatch");
999        std::fs::remove_file(&tmp).ok();
1000    }
1001
1002    #[test]
1003    fn test_save_load_empty_array() {
1004        let tmp = std::env::temp_dir().join("scirs2_test_empty.scirs2");
1005        let original = Array1::<f64>::from_vec(vec![]).into_dyn();
1006
1007        save_array(&tmp, &original, CompressionType::None).expect("save empty failed");
1008        let loaded = load_array::<f64>(&tmp).expect("load empty failed");
1009
1010        assert_eq!(original, loaded, "empty array roundtrip mismatch");
1011        std::fs::remove_file(&tmp).ok();
1012    }
1013
1014    #[test]
1015    fn test_save_load_large_f64_array() {
1016        let tmp = std::env::temp_dir().join("scirs2_test_large_f64.scirs2");
1017        let n = 100_000usize;
1018        let original =
1019            Array1::<f64>::from_iter((0..n).map(|i| (i as f64 / n as f64).sin())).into_dyn();
1020
1021        save_array(&tmp, &original, CompressionType::None).expect("save large failed");
1022        let loaded = load_array::<f64>(&tmp).expect("load large failed");
1023
1024        assert_eq!(
1025            original.shape(),
1026            loaded.shape(),
1027            "shape mismatch for large array"
1028        );
1029        for (a, b) in original.iter().zip(loaded.iter()) {
1030            assert_eq!(
1031                a.to_bits(),
1032                b.to_bits(),
1033                "element mismatch in large array (bit-exact)"
1034            );
1035        }
1036        std::fs::remove_file(&tmp).ok();
1037    }
1038
1039    #[test]
1040    fn test_save_load_3d_i32_array() {
1041        let tmp = std::env::temp_dir().join("scirs2_test_3d_i32.scirs2");
1042        let original =
1043            Array3::<i32>::from_shape_fn((5, 6, 7), |(i, j, k)| (i * 100 + j * 10 + k) as i32)
1044                .into_dyn();
1045
1046        save_array(&tmp, &original, CompressionType::None).expect("save 3d i32 failed");
1047        let loaded = load_array::<i32>(&tmp).expect("load 3d i32 failed");
1048
1049        assert_eq!(original, loaded, "3d i32 save/load mismatch");
1050        std::fs::remove_file(&tmp).ok();
1051    }
1052
1053    #[test]
1054    fn test_wrong_payload_type_error() {
1055        let tmp = std::env::temp_dir().join("scirs2_test_wrong_type.scirs2");
1056        {
1057            let file = std::fs::File::create(&tmp).expect("create failed");
1058            let mut writer = Scirs2Writer::new(BufWriter::new(file));
1059            writer
1060                .write_payload(
1061                    PayloadType::Custom,
1062                    b"definitely not an array",
1063                    CompressionType::None,
1064                )
1065                .expect("write failed");
1066        }
1067        let result = load_array::<f64>(&tmp);
1068        assert!(
1069            result.is_err(),
1070            "load_array should fail when payload type is not Array"
1071        );
1072        std::fs::remove_file(&tmp).ok();
1073    }
1074
1075    #[test]
1076    fn test_file_not_found_error() {
1077        let result = load_array::<f32>(Path::new("/nonexistent/path/does_not_exist.scirs2"));
1078        assert!(
1079            matches!(result, Err(SerializationError::Io(_))),
1080            "should return Io error for missing file"
1081        );
1082    }
1083
1084    #[test]
1085    fn test_checksum_is_stored_in_file() {
1086        let tmp = std::env::temp_dir().join("scirs2_test_checksum_stored.scirs2");
1087        let original = Array1::<f64>::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
1088        save_array(&tmp, &original, CompressionType::None).expect("save failed");
1089
1090        // Read and verify through Scirs2Reader
1091        let file = std::fs::File::open(&tmp).expect("open failed");
1092        let mut reader = Scirs2Reader::new(BufReader::new(file)).expect("reader failed");
1093        let ok = reader.verify_checksum().expect("checksum check failed");
1094        assert!(ok, "checksum should pass for freshly saved file");
1095
1096        std::fs::remove_file(&tmp).ok();
1097    }
1098
1099    #[test]
1100    fn test_lz4_compression_roundtrip() {
1101        let tmp = std::env::temp_dir().join("scirs2_test_lz4.scirs2");
1102        // Highly compressible data: a constant array
1103        let original = Array1::<f32>::from_elem(1000, 1.23456_f32).into_dyn();
1104
1105        let result = save_array(&tmp, &original, CompressionType::Lz4);
1106        match result {
1107            Ok(()) => {
1108                let loaded = load_array::<f32>(&tmp).expect("load lz4 failed");
1109                assert_eq!(original, loaded, "lz4 roundtrip mismatch");
1110            }
1111            Err(SerializationError::Compression(_)) => {
1112                // LZ4 not available in this configuration — skip
1113                eprintln!("LZ4 not available, skipping lz4 test");
1114            }
1115            Err(e) => panic!("unexpected error during lz4 test: {}", e),
1116        }
1117        std::fs::remove_file(&tmp).ok();
1118    }
1119
1120    #[test]
1121    fn test_zstd_compression_roundtrip() {
1122        let tmp = std::env::temp_dir().join("scirs2_test_zstd.scirs2");
1123        // Highly compressible data
1124        let original = Array2::<f64>::zeros((100, 100)).into_dyn();
1125
1126        let result = save_array(&tmp, &original, CompressionType::Zstd);
1127        match result {
1128            Ok(()) => {
1129                let loaded = load_array::<f64>(&tmp).expect("load zstd failed");
1130                assert_eq!(original, loaded, "zstd roundtrip mismatch");
1131            }
1132            Err(SerializationError::Compression(_)) => {
1133                eprintln!("Zstd not available, skipping zstd test");
1134            }
1135            Err(e) => panic!("unexpected error during zstd test: {}", e),
1136        }
1137        std::fs::remove_file(&tmp).ok();
1138    }
1139}