1use std::io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
38use std::path::Path;
39
40use ndarray::{Array, IxDyn};
41
42const MAGIC: &[u8; 8] = b"SCIRS2\0\0";
43const VERSION_MAJOR: u16 = 0;
44const VERSION_MINOR: u16 = 3;
45const HEADER_SIZE: usize = 64;
46
47const OFFSET_MAGIC: usize = 0;
49const OFFSET_VERSION_MAJOR: usize = 8;
50const OFFSET_VERSION_MINOR: usize = 10;
51const OFFSET_PAYLOAD_TYPE: usize = 12;
52const OFFSET_COMPRESSION: usize = 13;
53const OFFSET_CHECKSUM: usize = 14;
54const OFFSET_PAYLOAD_LENGTH: usize = 18;
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61#[repr(u8)]
62pub enum PayloadType {
63 Array = 0,
65 Model = 1,
67 Stats = 2,
69 Custom = 3,
71}
72
73impl PayloadType {
74 fn from_u8(v: u8) -> Result<Self, SerializationError> {
75 match v {
76 0 => Ok(Self::Array),
77 1 => Ok(Self::Model),
78 2 => Ok(Self::Stats),
79 3 => Ok(Self::Custom),
80 other => Err(SerializationError::UnknownPayloadType(other)),
81 }
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89#[repr(u8)]
90pub enum CompressionType {
91 None = 0,
93 Lz4 = 1,
95 Zstd = 2,
97}
98
99impl CompressionType {
100 fn from_u8(v: u8) -> Result<Self, SerializationError> {
101 match v {
102 0 => Ok(Self::None),
103 1 => Ok(Self::Lz4),
104 2 => Ok(Self::Zstd),
105 other => Err(SerializationError::Compression(format!(
106 "unknown compression type byte: {}",
107 other
108 ))),
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
117pub struct Scirs2Header {
118 pub version: (u16, u16),
120 pub payload_type: PayloadType,
122 pub compression: CompressionType,
124 pub checksum: u32,
126 pub payload_length: u64,
128}
129
130impl Scirs2Header {
131 fn to_bytes(&self) -> [u8; HEADER_SIZE] {
133 let mut buf = [0u8; HEADER_SIZE];
134 buf[OFFSET_MAGIC..OFFSET_MAGIC + 8].copy_from_slice(MAGIC);
135 buf[OFFSET_VERSION_MAJOR..OFFSET_VERSION_MAJOR + 2]
136 .copy_from_slice(&self.version.0.to_le_bytes());
137 buf[OFFSET_VERSION_MINOR..OFFSET_VERSION_MINOR + 2]
138 .copy_from_slice(&self.version.1.to_le_bytes());
139 buf[OFFSET_PAYLOAD_TYPE] = self.payload_type as u8;
140 buf[OFFSET_COMPRESSION] = self.compression as u8;
141 buf[OFFSET_CHECKSUM..OFFSET_CHECKSUM + 4].copy_from_slice(&self.checksum.to_le_bytes());
142 buf[OFFSET_PAYLOAD_LENGTH..OFFSET_PAYLOAD_LENGTH + 8]
143 .copy_from_slice(&self.payload_length.to_le_bytes());
144 buf
146 }
147
148 fn from_bytes(buf: &[u8; HEADER_SIZE]) -> Result<Self, SerializationError> {
150 if &buf[OFFSET_MAGIC..OFFSET_MAGIC + 8] != MAGIC.as_slice() {
152 return Err(SerializationError::InvalidMagic);
153 }
154
155 let major = u16::from_le_bytes([buf[OFFSET_VERSION_MAJOR], buf[OFFSET_VERSION_MAJOR + 1]]);
156 let minor = u16::from_le_bytes([buf[OFFSET_VERSION_MINOR], buf[OFFSET_VERSION_MINOR + 1]]);
157
158 if major > VERSION_MAJOR {
160 return Err(SerializationError::UnsupportedVersion(major, minor));
161 }
162
163 let payload_type = PayloadType::from_u8(buf[OFFSET_PAYLOAD_TYPE])?;
164 let compression = CompressionType::from_u8(buf[OFFSET_COMPRESSION])?;
165
166 let checksum = u32::from_le_bytes([
167 buf[OFFSET_CHECKSUM],
168 buf[OFFSET_CHECKSUM + 1],
169 buf[OFFSET_CHECKSUM + 2],
170 buf[OFFSET_CHECKSUM + 3],
171 ]);
172
173 let pl_bytes: [u8; 8] = buf[OFFSET_PAYLOAD_LENGTH..OFFSET_PAYLOAD_LENGTH + 8]
175 .try_into()
176 .map_err(|_| {
177 SerializationError::Io(io::Error::new(
178 io::ErrorKind::InvalidData,
179 "internal: slice length invariant violated reading payload_length",
180 ))
181 })?;
182 let payload_length = u64::from_le_bytes(pl_bytes);
183
184 Ok(Self {
185 version: (major, minor),
186 payload_type,
187 compression,
188 checksum,
189 payload_length,
190 })
191 }
192}
193
194pub struct Scirs2Writer<W: Write> {
214 inner: W,
215}
216
217impl<W: Write> Scirs2Writer<W> {
218 pub fn new(writer: W) -> Self {
220 Self { inner: writer }
221 }
222
223 pub fn write_payload(
228 &mut self,
229 payload_type: PayloadType,
230 payload: &[u8],
231 compression: CompressionType,
232 ) -> Result<(), SerializationError> {
233 let checksum = crc32fast::hash(payload);
234 let stored = compress_payload(payload, compression)?;
235
236 let header = Scirs2Header {
237 version: (VERSION_MAJOR, VERSION_MINOR),
238 payload_type,
239 compression,
240 checksum,
241 payload_length: stored.len() as u64,
242 };
243
244 self.inner.write_all(&header.to_bytes())?;
245 self.inner.write_all(&stored)?;
246 Ok(())
247 }
248}
249
250pub struct Scirs2Reader<R: Read + Seek> {
270 inner: R,
271 pub header: Scirs2Header,
273}
274
275impl<R: Read + Seek> Scirs2Reader<R> {
276 pub fn new(mut reader: R) -> Result<Self, SerializationError> {
282 let mut buf = [0u8; HEADER_SIZE];
283 reader.read_exact(&mut buf)?;
284 let header = Scirs2Header::from_bytes(&buf)?;
285 Ok(Self {
286 inner: reader,
287 header,
288 })
289 }
290
291 pub fn read_payload(&mut self) -> Result<Vec<u8>, SerializationError> {
296 self.inner.seek(SeekFrom::Start(HEADER_SIZE as u64))?;
297
298 let len = self.header.payload_length as usize;
299 let mut stored = vec![0u8; len];
300 self.inner.read_exact(&mut stored)?;
301
302 decompress_payload(&stored, self.header.compression, len)
303 }
304
305 pub fn verify_checksum(&mut self) -> Result<bool, SerializationError> {
310 let payload = self.read_payload()?;
311 let computed = crc32fast::hash(&payload);
312 Ok(computed == self.header.checksum)
313 }
314}
315
316fn compress_payload(
320 data: &[u8],
321 compression: CompressionType,
322) -> Result<Vec<u8>, SerializationError> {
323 match compression {
324 CompressionType::None => Ok(data.to_vec()),
325
326 CompressionType::Lz4 => {
327 #[cfg(feature = "serialization")]
328 {
329 oxiarc_lz4::compress(data)
330 .map_err(|e| SerializationError::Compression(format!("LZ4 compress: {}", e)))
331 }
332 #[cfg(not(feature = "serialization"))]
333 {
334 let _ = data;
335 Err(SerializationError::Compression(
336 "LZ4 compression requires the `serialization` feature".to_string(),
337 ))
338 }
339 }
340
341 CompressionType::Zstd => {
342 #[cfg(feature = "serialization")]
343 {
344 oxiarc_zstd::compress(data)
345 .map_err(|e| SerializationError::Compression(format!("Zstd compress: {}", e)))
346 }
347 #[cfg(not(feature = "serialization"))]
348 {
349 let _ = data;
350 Err(SerializationError::Compression(
351 "Zstd compression requires the `serialization` feature".to_string(),
352 ))
353 }
354 }
355 }
356}
357
358fn decompress_payload(
363 data: &[u8],
364 compression: CompressionType,
365 stored_len: usize,
366) -> Result<Vec<u8>, SerializationError> {
367 match compression {
368 CompressionType::None => Ok(data.to_vec()),
369
370 CompressionType::Lz4 => {
371 #[cfg(feature = "serialization")]
372 {
373 let max_output = stored_len.saturating_mul(8).max(4096);
377 oxiarc_lz4::decompress(data, max_output)
378 .map_err(|e| SerializationError::Compression(format!("LZ4 decompress: {}", e)))
379 }
380 #[cfg(not(feature = "serialization"))]
381 {
382 let _ = (data, stored_len);
383 Err(SerializationError::Compression(
384 "LZ4 decompression requires the `serialization` feature".to_string(),
385 ))
386 }
387 }
388
389 CompressionType::Zstd => {
390 #[cfg(feature = "serialization")]
391 {
392 let _ = stored_len;
393 oxiarc_zstd::decompress(data)
394 .map_err(|e| SerializationError::Compression(format!("Zstd decompress: {}", e)))
395 }
396 #[cfg(not(feature = "serialization"))]
397 {
398 let _ = (data, stored_len);
399 Err(SerializationError::Compression(
400 "Zstd decompression requires the `serialization` feature".to_string(),
401 ))
402 }
403 }
404 }
405}
406
407pub trait ArrayElement: Copy + 'static {
425 fn dtype_id() -> u8;
427 fn element_size() -> usize;
429 fn from_le_bytes_slice(bytes: &[u8], n: usize) -> Vec<Self>;
431 fn to_le_bytes_vec(slice: &[Self]) -> Vec<u8>;
433}
434
435macro_rules! impl_array_element {
437 ($ty:ty, $id:expr, $size:expr, $arr:expr) => {
438 impl ArrayElement for $ty {
439 fn dtype_id() -> u8 {
440 $id
441 }
442 fn element_size() -> usize {
443 $size
444 }
445
446 fn from_le_bytes_slice(bytes: &[u8], n: usize) -> Vec<Self> {
447 (0..n)
448 .map(|i| {
449 let start = i * $size;
450 let arr: [u8; $size] =
452 bytes[start..start + $size].try_into().unwrap_or($arr);
453 <$ty>::from_le_bytes(arr)
454 })
455 .collect()
456 }
457
458 fn to_le_bytes_vec(slice: &[Self]) -> Vec<u8> {
459 slice.iter().flat_map(|v| v.to_le_bytes()).collect()
460 }
461 }
462 };
463}
464
465impl_array_element!(f32, 1, 4, [0u8; 4]);
466impl_array_element!(f64, 2, 8, [0u8; 8]);
467impl_array_element!(i32, 3, 4, [0u8; 4]);
468impl_array_element!(i64, 4, 8, [0u8; 8]);
469impl_array_element!(u32, 5, 4, [0u8; 4]);
470impl_array_element!(u64, 6, 8, [0u8; 8]);
471
472fn encode_array<F: ArrayElement>(array: &Array<F, IxDyn>) -> Vec<u8> {
478 let shape = array.shape();
479 let ndim = shape.len();
480
481 let header_bytes = 2 + ndim * 8;
482 let data_bytes = array.len() * F::element_size();
483 let mut buf = Vec::with_capacity(header_bytes + data_bytes);
484
485 buf.push(F::dtype_id());
486 buf.push(ndim as u8);
487
488 for &dim in shape {
489 buf.extend_from_slice(&(dim as u64).to_le_bytes());
490 }
491
492 let data: Vec<F> = array.iter().copied().collect();
494 buf.extend_from_slice(&F::to_le_bytes_vec(&data));
495
496 buf
497}
498
499fn decode_array<F: ArrayElement>(payload: &[u8]) -> Result<Array<F, IxDyn>, SerializationError> {
501 if payload.len() < 2 {
502 return Err(SerializationError::Io(io::Error::new(
503 io::ErrorKind::UnexpectedEof,
504 "payload too short to contain array header (need at least 2 bytes)",
505 )));
506 }
507
508 let actual_dtype = payload[0];
509 let expected_dtype = F::dtype_id();
510 if actual_dtype != expected_dtype {
511 return Err(SerializationError::TypeMismatch {
512 expected: expected_dtype,
513 actual: actual_dtype,
514 });
515 }
516
517 let ndim = payload[1] as usize;
518 let shape_end = 2 + ndim * 8;
519
520 if payload.len() < shape_end {
521 return Err(SerializationError::Io(io::Error::new(
522 io::ErrorKind::UnexpectedEof,
523 format!(
524 "payload too short to read shape: need {} bytes for {} dims, have {}",
525 shape_end,
526 ndim,
527 payload.len()
528 ),
529 )));
530 }
531
532 let mut shape = Vec::with_capacity(ndim);
533 for i in 0..ndim {
534 let offset = 2 + i * 8;
535 let dim_bytes: [u8; 8] = payload[offset..offset + 8].try_into().map_err(|_| {
536 SerializationError::Io(io::Error::new(
537 io::ErrorKind::InvalidData,
538 format!("internal: failed to read dim {} from payload", i),
539 ))
540 })?;
541 shape.push(u64::from_le_bytes(dim_bytes) as usize);
542 }
543
544 let n_elements: usize = shape.iter().product();
545 let data_bytes = n_elements * F::element_size();
546
547 if payload.len() < shape_end + data_bytes {
548 return Err(SerializationError::Io(io::Error::new(
549 io::ErrorKind::UnexpectedEof,
550 format!(
551 "payload too short for array data: need {} bytes, have {}",
552 shape_end + data_bytes,
553 payload.len()
554 ),
555 )));
556 }
557
558 let elements = F::from_le_bytes_slice(&payload[shape_end..shape_end + data_bytes], n_elements);
559
560 Array::from_shape_vec(IxDyn(&shape), elements).map_err(|e| {
561 SerializationError::Io(io::Error::new(
562 io::ErrorKind::InvalidData,
563 format!("shape/data mismatch during array reconstruction: {}", e),
564 ))
565 })
566}
567
568pub fn save_array<F: ArrayElement>(
596 path: &Path,
597 array: &Array<F, IxDyn>,
598 compression: CompressionType,
599) -> Result<(), SerializationError> {
600 let file = std::fs::File::create(path)?;
601 let writer = BufWriter::new(file);
602 let mut scirs2 = Scirs2Writer::new(writer);
603 let payload = encode_array(array);
604 scirs2.write_payload(PayloadType::Array, &payload, compression)
605}
606
607pub fn load_array<F: ArrayElement>(path: &Path) -> Result<Array<F, IxDyn>, SerializationError> {
621 let file = std::fs::File::open(path)?;
622 let reader = BufReader::new(file);
623 let mut scirs2 = Scirs2Reader::new(reader)?;
624
625 if scirs2.header.payload_type != PayloadType::Array {
626 return Err(SerializationError::Io(io::Error::new(
627 io::ErrorKind::InvalidData,
628 format!(
629 "expected Array payload type (0), found {:?} ({})",
630 scirs2.header.payload_type, scirs2.header.payload_type as u8
631 ),
632 )));
633 }
634
635 let payload = scirs2.read_payload()?;
636 decode_array::<F>(&payload)
637}
638
639#[derive(Debug, thiserror::Error)]
643pub enum SerializationError {
644 #[error("IO error: {0}")]
646 Io(#[from] std::io::Error),
647
648 #[error("Invalid magic bytes — not a valid .scirs2 file")]
650 InvalidMagic,
651
652 #[error(
654 "Unsupported version {0}.{1} (this library supports up to {major}.x)",
655 major = VERSION_MAJOR
656 )]
657 UnsupportedVersion(u16, u16),
658
659 #[error("Checksum mismatch — file may be corrupted")]
661 ChecksumMismatch,
662
663 #[error("Compression error: {0}")]
665 Compression(String),
666
667 #[error("Unknown payload type: {0}")]
669 UnknownPayloadType(u8),
670
671 #[error("Type mismatch: expected dtype {expected}, found {actual}")]
673 TypeMismatch {
674 expected: u8,
676 actual: u8,
678 },
679}
680
681#[cfg(test)]
684mod tests {
685 use super::*;
686 use ndarray::{Array1, Array2, Array3};
687 use std::io::Cursor;
688
689 #[test]
692 fn test_header_roundtrip_all_fields() {
693 let original = Scirs2Header {
694 version: (0, 3),
695 payload_type: PayloadType::Array,
696 compression: CompressionType::None,
697 checksum: 0xDEAD_BEEF,
698 payload_length: 1_234_567_890,
699 };
700 let bytes = original.to_bytes();
701 assert_eq!(bytes.len(), HEADER_SIZE);
702
703 let parsed = Scirs2Header::from_bytes(&bytes).expect("header parse failed");
704 assert_eq!(parsed.version, original.version);
705 assert_eq!(parsed.payload_type, original.payload_type);
706 assert_eq!(parsed.compression, original.compression);
707 assert_eq!(parsed.checksum, original.checksum);
708 assert_eq!(parsed.payload_length, original.payload_length);
709 }
710
711 #[test]
712 fn test_header_reserved_bytes_are_zero() {
713 let header = Scirs2Header {
714 version: (0, 3),
715 payload_type: PayloadType::Custom,
716 compression: CompressionType::None,
717 checksum: 42,
718 payload_length: 8,
719 };
720 let bytes = header.to_bytes();
721 for i in 26..64 {
723 assert_eq!(bytes[i], 0, "reserved byte {} should be zero", i);
724 }
725 }
726
727 #[test]
728 fn test_invalid_magic_rejected() {
729 let mut buf = [0u8; HEADER_SIZE];
730 buf[0..8].copy_from_slice(b"BADMAGIC");
731 assert!(
732 matches!(
733 Scirs2Header::from_bytes(&buf),
734 Err(SerializationError::InvalidMagic)
735 ),
736 "should reject non-SCIRS2 magic"
737 );
738 }
739
740 #[test]
741 fn test_future_major_version_rejected() {
742 let header = Scirs2Header {
743 version: (255, 0),
744 payload_type: PayloadType::Custom,
745 compression: CompressionType::None,
746 checksum: 0,
747 payload_length: 0,
748 };
749 let bytes = header.to_bytes();
750 assert!(
751 matches!(
752 Scirs2Header::from_bytes(&bytes),
753 Err(SerializationError::UnsupportedVersion(255, 0))
754 ),
755 "should reject future major version"
756 );
757 }
758
759 #[test]
760 fn test_unknown_payload_type_rejected() {
761 let header = Scirs2Header {
762 version: (0, 3),
763 payload_type: PayloadType::Custom,
764 compression: CompressionType::None,
765 checksum: 0,
766 payload_length: 0,
767 };
768 let mut bytes = header.to_bytes();
769 bytes[OFFSET_PAYLOAD_TYPE] = 99; let result = Scirs2Header::from_bytes(&bytes);
771 assert!(
772 matches!(result, Err(SerializationError::UnknownPayloadType(99))),
773 "should return UnknownPayloadType(99) for unknown payload type byte"
774 );
775 }
776
777 #[test]
778 fn test_payload_type_from_u8_all_variants() {
779 assert!(matches!(PayloadType::from_u8(0), Ok(PayloadType::Array)));
780 assert!(matches!(PayloadType::from_u8(1), Ok(PayloadType::Model)));
781 assert!(matches!(PayloadType::from_u8(2), Ok(PayloadType::Stats)));
782 assert!(matches!(PayloadType::from_u8(3), Ok(PayloadType::Custom)));
783 assert!(matches!(
784 PayloadType::from_u8(4),
785 Err(SerializationError::UnknownPayloadType(4))
786 ));
787 }
788
789 #[test]
792 fn test_custom_payload_no_compression_roundtrip() {
793 let payload = b"the quick brown fox jumps over the lazy dog";
794 let mut buf = Vec::new();
795 {
796 let mut writer = Scirs2Writer::new(&mut buf);
797 writer
798 .write_payload(PayloadType::Custom, payload, CompressionType::None)
799 .expect("write_payload failed");
800 }
801
802 let cursor = Cursor::new(&buf);
803 let mut reader = Scirs2Reader::new(cursor).expect("Scirs2Reader::new failed");
804 assert_eq!(reader.header.payload_type, PayloadType::Custom);
805 assert_eq!(reader.header.compression, CompressionType::None);
806 assert_eq!(reader.header.payload_length, payload.len() as u64);
807
808 let out = reader.read_payload().expect("read_payload failed");
809 assert_eq!(out.as_slice(), payload.as_slice());
810 }
811
812 #[test]
813 fn test_empty_payload_roundtrip() {
814 let payload: &[u8] = b"";
815 let mut buf = Vec::new();
816 {
817 let mut writer = Scirs2Writer::new(&mut buf);
818 writer
819 .write_payload(PayloadType::Stats, payload, CompressionType::None)
820 .expect("write empty payload failed");
821 }
822 let cursor = Cursor::new(&buf);
823 let mut reader = Scirs2Reader::new(cursor).expect("reader init failed");
824 let out = reader.read_payload().expect("read empty payload failed");
825 assert!(out.is_empty());
826 }
827
828 #[test]
829 fn test_verify_checksum_passes_for_intact_data() {
830 let payload = b"integrity check payload 0xDEADBEEF";
831 let mut buf = Vec::new();
832 {
833 let mut writer = Scirs2Writer::new(&mut buf);
834 writer
835 .write_payload(PayloadType::Stats, payload, CompressionType::None)
836 .expect("write failed");
837 }
838 let cursor = Cursor::new(&buf);
839 let mut reader = Scirs2Reader::new(cursor).expect("reader init failed");
840 assert!(
841 reader.verify_checksum().expect("checksum check failed"),
842 "checksum should pass for intact data"
843 );
844 }
845
846 #[test]
847 fn test_verify_checksum_fails_on_bit_flip() {
848 let payload = b"data that will be corrupted in transit";
849 let mut buf = Vec::new();
850 {
851 let mut writer = Scirs2Writer::new(&mut buf);
852 writer
853 .write_payload(PayloadType::Custom, payload, CompressionType::None)
854 .expect("write failed");
855 }
856
857 let last = buf.len() - 1;
859 buf[last] ^= 0xFF;
860
861 let cursor = Cursor::new(&buf);
862 let mut reader = Scirs2Reader::new(cursor).expect("reader init (corrupted) failed");
863 assert!(
864 !reader.verify_checksum().expect("checksum check errored"),
865 "checksum should fail after bit flip"
866 );
867 }
868
869 #[test]
870 fn test_version_fields_in_file() {
871 let payload = b"version test";
872 let mut buf = Vec::new();
873 let mut writer = Scirs2Writer::new(&mut buf);
874 writer
875 .write_payload(PayloadType::Custom, payload, CompressionType::None)
876 .expect("write failed");
877
878 let cursor = Cursor::new(&buf);
879 let reader = Scirs2Reader::new(cursor).expect("reader failed");
880 assert_eq!(reader.header.version, (VERSION_MAJOR, VERSION_MINOR));
881 }
882
883 #[test]
886 fn test_encode_decode_f32_1d() {
887 let original =
888 Array1::<f32>::from_vec(vec![1.0, 2.5, -3.0, f32::MAX, f32::MIN_POSITIVE]).into_dyn();
889 let encoded = encode_array(&original);
890 let decoded = decode_array::<f32>(&encoded).expect("f32 decode failed");
891 assert_eq!(original, decoded, "f32 1d roundtrip mismatch");
892 }
893
894 #[test]
895 fn test_encode_decode_f64_2d() {
896 let original = Array2::<f64>::from_shape_vec(
897 (4, 6),
898 (0..24)
899 .map(|i| i as f64 * std::f64::consts::PI / 12.0)
900 .collect(),
901 )
902 .expect("shape error")
903 .into_dyn();
904
905 let encoded = encode_array(&original);
906 let decoded = decode_array::<f64>(&encoded).expect("f64 2d decode failed");
907 assert_eq!(original, decoded, "f64 2d roundtrip mismatch");
908 }
909
910 #[test]
911 fn test_encode_decode_i32_3d() {
912 let original =
913 Array3::<i32>::from_shape_vec((2, 3, 4), (0..24).map(|i| i as i32 - 12).collect())
914 .expect("shape error")
915 .into_dyn();
916
917 let encoded = encode_array(&original);
918 let decoded = decode_array::<i32>(&encoded).expect("i32 3d decode failed");
919 assert_eq!(original, decoded, "i32 3d roundtrip mismatch");
920 }
921
922 #[test]
923 fn test_encode_decode_i64_1d() {
924 let original = Array1::<i64>::from_vec(vec![i64::MIN, -1, 0, 1, i64::MAX]).into_dyn();
925 let encoded = encode_array(&original);
926 let decoded = decode_array::<i64>(&encoded).expect("i64 decode failed");
927 assert_eq!(original, decoded, "i64 roundtrip mismatch");
928 }
929
930 #[test]
931 fn test_encode_decode_u32() {
932 let original = Array1::<u32>::from_vec(vec![0, 1, u32::MAX / 2, u32::MAX]).into_dyn();
933 let encoded = encode_array(&original);
934 let decoded = decode_array::<u32>(&encoded).expect("u32 decode failed");
935 assert_eq!(original, decoded, "u32 roundtrip mismatch");
936 }
937
938 #[test]
939 fn test_encode_decode_u64() {
940 let original = Array1::<u64>::from_vec(vec![0, 1, u64::MAX / 2, u64::MAX]).into_dyn();
941 let encoded = encode_array(&original);
942 let decoded = decode_array::<u64>(&encoded).expect("u64 decode failed");
943 assert_eq!(original, decoded, "u64 roundtrip mismatch");
944 }
945
946 #[test]
947 fn test_dtype_mismatch_error() {
948 let original = Array1::<f32>::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
949 let encoded = encode_array(&original); let result = decode_array::<f64>(&encoded);
952 assert!(
953 matches!(
954 result,
955 Err(SerializationError::TypeMismatch {
956 expected: 2,
957 actual: 1
958 })
959 ),
960 "expected TypeMismatch error"
961 );
962 }
963
964 #[test]
965 fn test_encode_zero_dimensional_array() {
966 let original = Array::<f64, IxDyn>::from_elem(IxDyn(&[]), 42.0);
968 let encoded = encode_array(&original);
969 let decoded = decode_array::<f64>(&encoded).expect("0d decode failed");
970 assert_eq!(original, decoded, "0d array roundtrip mismatch");
971 }
972
973 #[test]
976 fn test_save_load_f32_no_compression() {
977 let tmp = std::env::temp_dir().join("scirs2_test_f32_nocomp.scirs2");
978 let original =
979 Array2::<f32>::from_shape_vec((8, 8), (0..64).map(|i| i as f32 * 0.5 - 16.0).collect())
980 .expect("shape error")
981 .into_dyn();
982
983 save_array(&tmp, &original, CompressionType::None).expect("save_array failed");
984 let loaded = load_array::<f32>(&tmp).expect("load_array failed");
985
986 assert_eq!(original, loaded, "f32 save/load mismatch");
987 std::fs::remove_file(&tmp).ok();
988 }
989
990 #[test]
991 fn test_save_load_f64_no_compression() {
992 let tmp = std::env::temp_dir().join("scirs2_test_f64_nocomp.scirs2");
993 let original = Array1::<f64>::linspace(0.0, 1.0, 500).into_dyn();
994
995 save_array(&tmp, &original, CompressionType::None).expect("save_array f64 failed");
996 let loaded = load_array::<f64>(&tmp).expect("load_array f64 failed");
997
998 assert_eq!(original, loaded, "f64 save/load mismatch");
999 std::fs::remove_file(&tmp).ok();
1000 }
1001
1002 #[test]
1003 fn test_save_load_empty_array() {
1004 let tmp = std::env::temp_dir().join("scirs2_test_empty.scirs2");
1005 let original = Array1::<f64>::from_vec(vec![]).into_dyn();
1006
1007 save_array(&tmp, &original, CompressionType::None).expect("save empty failed");
1008 let loaded = load_array::<f64>(&tmp).expect("load empty failed");
1009
1010 assert_eq!(original, loaded, "empty array roundtrip mismatch");
1011 std::fs::remove_file(&tmp).ok();
1012 }
1013
1014 #[test]
1015 fn test_save_load_large_f64_array() {
1016 let tmp = std::env::temp_dir().join("scirs2_test_large_f64.scirs2");
1017 let n = 100_000usize;
1018 let original =
1019 Array1::<f64>::from_iter((0..n).map(|i| (i as f64 / n as f64).sin())).into_dyn();
1020
1021 save_array(&tmp, &original, CompressionType::None).expect("save large failed");
1022 let loaded = load_array::<f64>(&tmp).expect("load large failed");
1023
1024 assert_eq!(
1025 original.shape(),
1026 loaded.shape(),
1027 "shape mismatch for large array"
1028 );
1029 for (a, b) in original.iter().zip(loaded.iter()) {
1030 assert_eq!(
1031 a.to_bits(),
1032 b.to_bits(),
1033 "element mismatch in large array (bit-exact)"
1034 );
1035 }
1036 std::fs::remove_file(&tmp).ok();
1037 }
1038
1039 #[test]
1040 fn test_save_load_3d_i32_array() {
1041 let tmp = std::env::temp_dir().join("scirs2_test_3d_i32.scirs2");
1042 let original =
1043 Array3::<i32>::from_shape_fn((5, 6, 7), |(i, j, k)| (i * 100 + j * 10 + k) as i32)
1044 .into_dyn();
1045
1046 save_array(&tmp, &original, CompressionType::None).expect("save 3d i32 failed");
1047 let loaded = load_array::<i32>(&tmp).expect("load 3d i32 failed");
1048
1049 assert_eq!(original, loaded, "3d i32 save/load mismatch");
1050 std::fs::remove_file(&tmp).ok();
1051 }
1052
1053 #[test]
1054 fn test_wrong_payload_type_error() {
1055 let tmp = std::env::temp_dir().join("scirs2_test_wrong_type.scirs2");
1056 {
1057 let file = std::fs::File::create(&tmp).expect("create failed");
1058 let mut writer = Scirs2Writer::new(BufWriter::new(file));
1059 writer
1060 .write_payload(
1061 PayloadType::Custom,
1062 b"definitely not an array",
1063 CompressionType::None,
1064 )
1065 .expect("write failed");
1066 }
1067 let result = load_array::<f64>(&tmp);
1068 assert!(
1069 result.is_err(),
1070 "load_array should fail when payload type is not Array"
1071 );
1072 std::fs::remove_file(&tmp).ok();
1073 }
1074
1075 #[test]
1076 fn test_file_not_found_error() {
1077 let result = load_array::<f32>(Path::new("/nonexistent/path/does_not_exist.scirs2"));
1078 assert!(
1079 matches!(result, Err(SerializationError::Io(_))),
1080 "should return Io error for missing file"
1081 );
1082 }
1083
1084 #[test]
1085 fn test_checksum_is_stored_in_file() {
1086 let tmp = std::env::temp_dir().join("scirs2_test_checksum_stored.scirs2");
1087 let original = Array1::<f64>::from_vec(vec![1.0, 2.0, 3.0]).into_dyn();
1088 save_array(&tmp, &original, CompressionType::None).expect("save failed");
1089
1090 let file = std::fs::File::open(&tmp).expect("open failed");
1092 let mut reader = Scirs2Reader::new(BufReader::new(file)).expect("reader failed");
1093 let ok = reader.verify_checksum().expect("checksum check failed");
1094 assert!(ok, "checksum should pass for freshly saved file");
1095
1096 std::fs::remove_file(&tmp).ok();
1097 }
1098
1099 #[test]
1100 fn test_lz4_compression_roundtrip() {
1101 let tmp = std::env::temp_dir().join("scirs2_test_lz4.scirs2");
1102 let original = Array1::<f32>::from_elem(1000, 1.23456_f32).into_dyn();
1104
1105 let result = save_array(&tmp, &original, CompressionType::Lz4);
1106 match result {
1107 Ok(()) => {
1108 let loaded = load_array::<f32>(&tmp).expect("load lz4 failed");
1109 assert_eq!(original, loaded, "lz4 roundtrip mismatch");
1110 }
1111 Err(SerializationError::Compression(_)) => {
1112 eprintln!("LZ4 not available, skipping lz4 test");
1114 }
1115 Err(e) => panic!("unexpected error during lz4 test: {}", e),
1116 }
1117 std::fs::remove_file(&tmp).ok();
1118 }
1119
1120 #[test]
1121 fn test_zstd_compression_roundtrip() {
1122 let tmp = std::env::temp_dir().join("scirs2_test_zstd.scirs2");
1123 let original = Array2::<f64>::zeros((100, 100)).into_dyn();
1125
1126 let result = save_array(&tmp, &original, CompressionType::Zstd);
1127 match result {
1128 Ok(()) => {
1129 let loaded = load_array::<f64>(&tmp).expect("load zstd failed");
1130 assert_eq!(original, loaded, "zstd roundtrip mismatch");
1131 }
1132 Err(SerializationError::Compression(_)) => {
1133 eprintln!("Zstd not available, skipping zstd test");
1134 }
1135 Err(e) => panic!("unexpected error during zstd test: {}", e),
1136 }
1137 std::fs::remove_file(&tmp).ok();
1138 }
1139}