1use memmap2::{Mmap, MmapMut, MmapOptions};
47use ndarray::{Array, ArrayView, ArrayViewMut, IxDyn};
48use std::fs::OpenOptions;
49use std::io::{Seek, SeekFrom, Write};
50use std::path::{Path, PathBuf};
51
52const MAGIC: &[u8; 4] = b"MMAP";
58const FORMAT_VERSION: u8 = 1;
60const HEADER_SIZE: usize = 64;
62
63pub trait MmapElement: Copy + bytemuck::Pod + bytemuck::Zeroable + 'static {
75 fn dtype_id() -> u8;
77 fn element_size() -> usize;
79}
80
81impl MmapElement for f32 {
82 fn dtype_id() -> u8 {
83 1
84 }
85 fn element_size() -> usize {
86 4
87 }
88}
89
90impl MmapElement for f64 {
91 fn dtype_id() -> u8 {
92 2
93 }
94 fn element_size() -> usize {
95 8
96 }
97}
98
99impl MmapElement for i32 {
100 fn dtype_id() -> u8 {
101 3
102 }
103 fn element_size() -> usize {
104 4
105 }
106}
107
108impl MmapElement for i64 {
109 fn dtype_id() -> u8 {
110 4
111 }
112 fn element_size() -> usize {
113 8
114 }
115}
116
117enum MmapStorage<F: MmapElement> {
123 ReadOnly {
125 mmap: Mmap,
126 _phantom: std::marker::PhantomData<F>,
127 },
128 ReadWrite {
130 mmap: MmapMut,
131 _phantom: std::marker::PhantomData<F>,
132 },
133 CopyOnWrite {
135 mmap: Mmap,
136 cow_data: Option<Vec<F>>,
138 },
139}
140
141pub struct MmapArray<F: MmapElement> {
154 storage: MmapStorage<F>,
155 shape: Vec<usize>,
156 strides: Vec<usize>,
158 file_path: PathBuf,
159}
160
161impl<F: MmapElement> std::fmt::Debug for MmapArray<F> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 let mode = match &self.storage {
164 MmapStorage::ReadOnly { .. } => "ReadOnly",
165 MmapStorage::ReadWrite { .. } => "ReadWrite",
166 MmapStorage::CopyOnWrite { cow_data, .. } => {
167 if cow_data.is_some() {
168 "CopyOnWrite(dirty)"
169 } else {
170 "CopyOnWrite(clean)"
171 }
172 }
173 };
174 f.debug_struct("MmapArray")
175 .field("mode", &mode)
176 .field("shape", &self.shape)
177 .field("strides", &self.strides)
178 .field("file_path", &self.file_path)
179 .finish()
180 }
181}
182
183fn encode_header(buf: &mut [u8; HEADER_SIZE], dtype_id: u8, shape: &[usize]) {
193 let ndim = shape.len();
194 assert!(
195 ndim <= 6,
196 "MmapArray supports at most 6 dimensions (header limit)"
197 );
198
199 buf.fill(0);
200
201 buf[0..4].copy_from_slice(MAGIC);
203 buf[4] = FORMAT_VERSION;
205 buf[5] = dtype_id;
207 let ndim_u16 = ndim as u16;
209 buf[6..8].copy_from_slice(&ndim_u16.to_le_bytes());
210 let total: u64 = shape.iter().product::<usize>() as u64;
212 buf[8..16].copy_from_slice(&total.to_le_bytes());
213 for (i, &dim) in shape.iter().enumerate() {
215 let off = 16 + i * 8;
216 buf[off..off + 8].copy_from_slice(&(dim as u64).to_le_bytes());
217 }
218}
219
220fn decode_header(buf: &[u8; HEADER_SIZE]) -> Result<(u8, Vec<usize>), MmapError> {
222 if &buf[0..4] != MAGIC {
223 return Err(MmapError::InvalidMagic);
224 }
225 let version = buf[4];
226 if version != FORMAT_VERSION {
227 return Err(MmapError::VersionMismatch(version));
228 }
229 let dtype_id = buf[5];
230 let ndim = u16::from_le_bytes([buf[6], buf[7]]) as usize;
231 let _total_elements = u64::from_le_bytes(buf[8..16].try_into().map_err(|_| {
232 MmapError::Io(std::io::Error::new(
233 std::io::ErrorKind::InvalidData,
234 "header truncated at total_elements",
235 ))
236 })?);
237
238 let mut shape = Vec::with_capacity(ndim);
239 for i in 0..ndim {
240 let off = 16 + i * 8;
241 if off + 8 > HEADER_SIZE {
242 return Err(MmapError::Io(std::io::Error::new(
243 std::io::ErrorKind::InvalidData,
244 format!("header too small for ndim={}", ndim),
245 )));
246 }
247 let dim = u64::from_le_bytes(buf[off..off + 8].try_into().map_err(|_| {
248 MmapError::Io(std::io::Error::new(
249 std::io::ErrorKind::InvalidData,
250 "header truncated in shape",
251 ))
252 })?);
253 shape.push(dim as usize);
254 }
255
256 Ok((dtype_id, shape))
257}
258
259fn c_strides(shape: &[usize]) -> Vec<usize> {
261 let ndim = shape.len();
262 let mut strides = vec![1usize; ndim];
263 for i in (0..ndim.saturating_sub(1)).rev() {
264 strides[i] = strides[i + 1] * shape[i + 1];
265 }
266 strides
267}
268
269fn total_elements(shape: &[usize]) -> usize {
271 shape.iter().product()
272}
273
274impl<F: MmapElement> MmapArray<F> {
279 pub fn create(path: &Path, data: &Array<F, IxDyn>) -> Result<Self, MmapError> {
294 if !data.is_standard_layout() {
296 return Err(MmapError::NonContiguous);
297 }
298
299 let shape = data.shape().to_vec();
300 let n_elems = total_elements(&shape);
301 let data_bytes = n_elems * F::element_size();
302 let file_size = HEADER_SIZE + data_bytes;
303
304 let file = OpenOptions::new()
306 .read(true)
307 .write(true)
308 .create(true)
309 .truncate(true)
310 .open(path)?;
311 file.set_len(file_size as u64)?;
312
313 {
315 use std::io::BufWriter;
316 let mut writer = BufWriter::new(&file);
317 let mut header = [0u8; HEADER_SIZE];
318 encode_header(&mut header, F::dtype_id(), &shape);
319 writer.write_all(&header)?;
320 let raw: &[F] = data.as_slice().ok_or(MmapError::NonContiguous)?;
322 writer.write_all(bytemuck::cast_slice(raw))?;
323 writer.flush()?;
324 }
325
326 let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
329
330 let strides = c_strides(&shape);
331 Ok(Self {
332 storage: MmapStorage::ReadWrite {
333 mmap,
334 _phantom: std::marker::PhantomData,
335 },
336 shape,
337 strides,
338 file_path: path.to_path_buf(),
339 })
340 }
341
342 pub fn open_read_only(path: &Path) -> Result<Self, MmapError> {
347 let file = OpenOptions::new().read(true).open(path)?;
348 let mmap = unsafe { MmapOptions::new().map(&file)? };
349 Self::from_readonly_mmap(mmap, path)
350 }
351
352 pub fn open_read_write(path: &Path) -> Result<Self, MmapError> {
356 let file = OpenOptions::new().read(true).write(true).open(path)?;
357 let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
358
359 if mmap.len() < HEADER_SIZE {
361 return Err(MmapError::Io(std::io::Error::new(
362 std::io::ErrorKind::UnexpectedEof,
363 "file too small to contain header",
364 )));
365 }
366 let header_bytes: &[u8; HEADER_SIZE] = mmap[..HEADER_SIZE].try_into().map_err(|_| {
367 MmapError::Io(std::io::Error::new(
368 std::io::ErrorKind::InvalidData,
369 "could not read header",
370 ))
371 })?;
372 let (dtype_id, shape) = decode_header(header_bytes)?;
373 if dtype_id != F::dtype_id() {
374 return Err(MmapError::DtypeMismatch {
375 expected: F::dtype_id(),
376 actual: dtype_id,
377 });
378 }
379
380 let strides = c_strides(&shape);
381 Ok(Self {
382 storage: MmapStorage::ReadWrite {
383 mmap,
384 _phantom: std::marker::PhantomData,
385 },
386 shape,
387 strides,
388 file_path: path.to_path_buf(),
389 })
390 }
391
392 pub fn open_cow(path: &Path) -> Result<Self, MmapError> {
400 let file = OpenOptions::new().read(true).open(path)?;
401 let mmap = unsafe { MmapOptions::new().map(&file)? };
402 let (_, shape, strides) = Self::parse_readonly_mmap(&mmap, path)?;
403 Ok(Self {
404 storage: MmapStorage::CopyOnWrite {
405 mmap,
406 cow_data: None,
407 },
408 shape,
409 strides,
410 file_path: path.to_path_buf(),
411 })
412 }
413
414 pub fn view(&self) -> Result<ArrayView<'_, F, IxDyn>, MmapError> {
424 match &self.storage {
425 MmapStorage::ReadOnly { mmap, .. } => {
426 let data_slice = &mmap[HEADER_SIZE..];
427 let elems: &[F] = bytemuck::cast_slice(data_slice);
428 let ix = IxDyn(self.shape.as_slice());
429 let view = unsafe { ArrayView::from_shape_ptr(ix, elems.as_ptr()) };
432 Ok(view)
433 }
434 MmapStorage::ReadWrite { mmap, .. } => {
435 let data_slice = &mmap[HEADER_SIZE..];
436 let elems: &[F] = bytemuck::cast_slice(data_slice);
437 let ix = IxDyn(self.shape.as_slice());
438 let view = unsafe { ArrayView::from_shape_ptr(ix, elems.as_ptr()) };
440 Ok(view)
441 }
442 MmapStorage::CopyOnWrite { mmap, cow_data } => {
443 match cow_data {
444 None => {
445 let data_slice = &mmap[HEADER_SIZE..];
447 let elems: &[F] = bytemuck::cast_slice(data_slice);
448 let ix = IxDyn(self.shape.as_slice());
449 let view = unsafe { ArrayView::from_shape_ptr(ix, elems.as_ptr()) };
451 Ok(view)
452 }
453 Some(data) => {
454 let ix = IxDyn(self.shape.as_slice());
455 let view = unsafe { ArrayView::from_shape_ptr(ix, data.as_ptr()) };
457 Ok(view)
458 }
459 }
460 }
461 }
462 }
463
464 pub fn view_mut(&mut self) -> Result<ArrayViewMut<'_, F, IxDyn>, MmapError> {
470 match &mut self.storage {
471 MmapStorage::ReadOnly { .. } => Err(MmapError::ReadOnly),
472 MmapStorage::ReadWrite { mmap, .. } => {
473 let data_slice = &mut mmap[HEADER_SIZE..];
474 let elems: &mut [F] = bytemuck::cast_slice_mut(data_slice);
475 let ix = IxDyn(self.shape.as_slice());
476 let view = unsafe { ArrayViewMut::from_shape_ptr(ix, elems.as_mut_ptr()) };
478 Ok(view)
479 }
480 MmapStorage::CopyOnWrite { mmap, cow_data } => {
481 if cow_data.is_none() {
483 let data_slice = &mmap[HEADER_SIZE..];
484 let elems: &[F] = bytemuck::cast_slice(data_slice);
485 *cow_data = Some(elems.to_vec());
486 }
487 let data = cow_data.as_mut().ok_or_else(|| {
488 MmapError::Io(std::io::Error::other(
489 "COW data unexpectedly None after initialization",
490 ))
491 })?;
492 let ix = IxDyn(self.shape.as_slice());
493 let view = unsafe { ArrayViewMut::from_shape_ptr(ix, data.as_mut_ptr()) };
495 Ok(view)
496 }
497 }
498 }
499
500 pub fn flush(&self) -> Result<(), MmapError> {
505 match &self.storage {
506 MmapStorage::ReadWrite { mmap, .. } => {
507 mmap.flush()?;
508 Ok(())
509 }
510 _ => Ok(()),
511 }
512 }
513
514 pub fn shape(&self) -> &[usize] {
516 &self.shape
517 }
518
519 pub fn strides(&self) -> &[usize] {
521 &self.strides
522 }
523
524 pub fn len(&self) -> usize {
526 total_elements(&self.shape)
527 }
528
529 pub fn is_empty(&self) -> bool {
531 self.len() == 0
532 }
533
534 pub fn file_path(&self) -> &Path {
536 &self.file_path
537 }
538
539 pub fn to_owned_array(&self) -> Result<Array<F, IxDyn>, MmapError> {
543 let view = self.view()?;
544 Ok(view.to_owned())
545 }
546
547 fn from_readonly_mmap(mmap: Mmap, path: &Path) -> Result<Self, MmapError> {
553 let (_, shape, strides) = Self::parse_readonly_mmap(&mmap, path)?;
554 Ok(Self {
555 storage: MmapStorage::ReadOnly {
556 mmap,
557 _phantom: std::marker::PhantomData,
558 },
559 shape,
560 strides,
561 file_path: path.to_path_buf(),
562 })
563 }
564
565 fn parse_readonly_mmap(
569 mmap: &Mmap,
570 _path: &Path,
571 ) -> Result<(u8, Vec<usize>, Vec<usize>), MmapError> {
572 if mmap.len() < HEADER_SIZE {
573 return Err(MmapError::Io(std::io::Error::new(
574 std::io::ErrorKind::UnexpectedEof,
575 "file too small to contain header",
576 )));
577 }
578 let header_bytes: &[u8; HEADER_SIZE] = mmap[..HEADER_SIZE].try_into().map_err(|_| {
579 MmapError::Io(std::io::Error::new(
580 std::io::ErrorKind::InvalidData,
581 "could not read header slice",
582 ))
583 })?;
584 let (dtype_id, shape) = decode_header(header_bytes)?;
585 if dtype_id != F::dtype_id() {
586 return Err(MmapError::DtypeMismatch {
587 expected: F::dtype_id(),
588 actual: dtype_id,
589 });
590 }
591 let strides = c_strides(&shape);
592 Ok((dtype_id, shape, strides))
593 }
594
595 #[cfg(test)]
603 fn create_anonymous(shape: &[usize]) -> Result<Self, MmapError> {
604 use tempfile::tempfile;
605
606 let n_elems = total_elements(shape);
607 let file_size = HEADER_SIZE + n_elems * F::element_size();
608
609 let file = tempfile()?;
610 file.set_len(file_size as u64)?;
611
612 {
614 use std::io::BufWriter;
615 let mut writer = BufWriter::new(&file);
616 let mut header = [0u8; HEADER_SIZE];
617 encode_header(&mut header, F::dtype_id(), shape);
618 writer.write_all(&header)?;
619 writer.flush()?;
621 }
622
623 let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
624 let strides = c_strides(shape);
625
626 Ok(Self {
628 storage: MmapStorage::ReadWrite {
629 mmap,
630 _phantom: std::marker::PhantomData,
631 },
632 shape: shape.to_vec(),
633 strides,
634 file_path: PathBuf::from("<anonymous>"),
635 })
636 }
637}
638
639#[derive(Debug, thiserror::Error)]
645pub enum MmapError {
646 #[error("IO error: {0}")]
648 Io(#[from] std::io::Error),
649
650 #[error("Invalid magic bytes — not a valid mmap array file")]
652 InvalidMagic,
653
654 #[error("Version mismatch: expected 1, got {0}")]
656 VersionMismatch(u8),
657
658 #[error("Dtype mismatch: expected dtype_id {expected}, got {actual}")]
660 DtypeMismatch { expected: u8, actual: u8 },
661
662 #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
664 ShapeMismatch {
665 expected: Vec<usize>,
666 actual: Vec<usize>,
667 },
668
669 #[error("Array is not contiguous (non-contiguous layouts not supported)")]
671 NonContiguous,
672
673 #[error("Array is read-only")]
675 ReadOnly,
676}
677
678#[cfg(test)]
683mod tests {
684 use super::*;
685 use ndarray::{ArrayD, IxDyn};
686
687 fn make_f32_array(shape: &[usize]) -> ArrayD<f32> {
689 let n = shape.iter().product::<usize>();
690 let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
691 ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape mismatch in test helper")
692 }
693
694 fn make_f64_array(shape: &[usize]) -> ArrayD<f64> {
696 let n = shape.iter().product::<usize>();
697 let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
698 ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape mismatch in test helper")
699 }
700
701 #[test]
706 fn test_header_round_trip_1d() {
707 let mut buf = [0u8; HEADER_SIZE];
708 encode_header(&mut buf, 1, &[100]);
709 let (dtype_id, shape) = decode_header(&buf).expect("decode failed");
710 assert_eq!(dtype_id, 1);
711 assert_eq!(shape, vec![100usize]);
712 }
713
714 #[test]
715 fn test_header_round_trip_2d() {
716 let mut buf = [0u8; HEADER_SIZE];
717 encode_header(&mut buf, 2, &[3, 4]);
718 let (dtype_id, shape) = decode_header(&buf).expect("decode failed");
719 assert_eq!(dtype_id, 2);
720 assert_eq!(shape, vec![3, 4]);
721 }
722
723 #[test]
724 fn test_header_round_trip_6d() {
725 let mut buf = [0u8; HEADER_SIZE];
726 encode_header(&mut buf, 4, &[2, 3, 4, 5, 6, 7]);
727 let (dtype_id, shape) = decode_header(&buf).expect("decode failed");
728 assert_eq!(dtype_id, 4);
729 assert_eq!(shape, vec![2, 3, 4, 5, 6, 7]);
730 }
731
732 #[test]
733 fn test_bad_magic() {
734 let mut buf = [0u8; HEADER_SIZE];
735 buf[0..4].copy_from_slice(b"NOPE");
736 let err = decode_header(&buf).expect_err("should fail");
737 assert!(matches!(err, MmapError::InvalidMagic));
738 }
739
740 #[test]
741 fn test_bad_version() {
742 let mut buf = [0u8; HEADER_SIZE];
743 encode_header(&mut buf, 1, &[10]);
744 buf[4] = 99; let err = decode_header(&buf).expect_err("should fail");
746 assert!(matches!(err, MmapError::VersionMismatch(99)));
747 }
748
749 #[test]
754 fn test_c_strides_1d() {
755 assert_eq!(c_strides(&[5]), vec![1]);
756 }
757
758 #[test]
759 fn test_c_strides_2d() {
760 assert_eq!(c_strides(&[3, 4]), vec![4, 1]);
761 }
762
763 #[test]
764 fn test_c_strides_3d() {
765 assert_eq!(c_strides(&[2, 3, 4]), vec![12, 4, 1]);
766 }
767
768 #[test]
773 fn test_dtype_ids_are_distinct() {
774 let ids = [
775 f32::dtype_id(),
776 f64::dtype_id(),
777 i32::dtype_id(),
778 i64::dtype_id(),
779 ];
780 for i in 0..ids.len() {
781 for j in (i + 1)..ids.len() {
782 assert_ne!(ids[i], ids[j], "dtype IDs must be unique");
783 }
784 }
785 }
786
787 #[test]
788 fn test_element_sizes() {
789 assert_eq!(f32::element_size(), 4);
790 assert_eq!(f64::element_size(), 8);
791 assert_eq!(i32::element_size(), 4);
792 assert_eq!(i64::element_size(), 8);
793 }
794
795 #[test]
800 fn test_create_and_read_only_f32() {
801 let dir = std::env::temp_dir();
802 let path = dir.join("test_mmap_create_ro_f32.mmap");
803
804 let original = make_f32_array(&[4, 5]);
805 {
806 let arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
807 assert_eq!(arr.shape(), &[4, 5]);
808 assert_eq!(arr.len(), 20);
809 assert!(!arr.is_empty());
810 }
811
812 {
813 let arr = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
814 assert_eq!(arr.shape(), &[4, 5]);
815 let view = arr.view().expect("view failed");
816 for (a, b) in view.iter().zip(original.iter()) {
818 assert!((a - b).abs() < f32::EPSILON, "element mismatch: {a} vs {b}");
819 }
820 }
821
822 let _ = std::fs::remove_file(&path);
823 }
824
825 #[test]
826 fn test_create_and_read_only_f64() {
827 let dir = std::env::temp_dir();
828 let path = dir.join("test_mmap_create_ro_f64.mmap");
829
830 let original = make_f64_array(&[3, 3]);
831 {
832 let _arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
833 }
834
835 let arr = MmapArray::<f64>::open_read_only(&path).expect("open_read_only failed");
836 let owned = arr.to_owned_array().expect("to_owned failed");
837 assert_eq!(owned.shape(), &[3, 3]);
838 for (a, b) in owned.iter().zip(original.iter()) {
839 assert!((a - b).abs() < f64::EPSILON);
840 }
841
842 let _ = std::fs::remove_file(&path);
843 }
844
845 #[test]
846 fn test_create_and_read_only_i32() {
847 let dir = std::env::temp_dir();
848 let path = dir.join("test_mmap_create_ro_i32.mmap");
849
850 let n = 6usize;
851 let data: Vec<i32> = (0..n as i32).collect();
852 let original = ArrayD::from_shape_vec(IxDyn(&[2, 3]), data).expect("shape mismatch");
853 {
854 let _arr = MmapArray::<i32>::create(&path, &original).expect("create failed");
855 }
856
857 let arr = MmapArray::<i32>::open_read_only(&path).expect("open_read_only failed");
858 let view = arr.view().expect("view failed");
859 for (a, b) in view.iter().zip(original.iter()) {
860 assert_eq!(a, b);
861 }
862
863 let _ = std::fs::remove_file(&path);
864 }
865
866 #[test]
871 fn test_read_write_mutation() {
872 let dir = std::env::temp_dir();
873 let path = dir.join("test_mmap_rw.mmap");
874
875 let original = make_f64_array(&[5]);
876 {
877 let mut arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
878 {
879 let mut view = arr.view_mut().expect("view_mut failed");
880 view.iter_mut().for_each(|x| *x *= 2.0);
882 }
883 arr.flush().expect("flush failed");
884 }
885
886 let arr = MmapArray::<f64>::open_read_only(&path).expect("open_read_only failed");
888 let view = arr.view().expect("view failed");
889 for (i, &val) in view.iter().enumerate() {
890 let expected = (i as f64) * 2.0;
891 assert!(
892 (val - expected).abs() < f64::EPSILON,
893 "element {i}: got {val}, expected {expected}"
894 );
895 }
896
897 let _ = std::fs::remove_file(&path);
898 }
899
900 #[test]
901 fn test_open_read_write_then_mutate() {
902 let dir = std::env::temp_dir();
903 let path = dir.join("test_mmap_open_rw.mmap");
904
905 let original = make_f32_array(&[3, 3]);
906 {
907 let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
908 }
909
910 {
911 let mut arr = MmapArray::<f32>::open_read_write(&path).expect("open_read_write failed");
912 {
913 let mut view = arr.view_mut().expect("view_mut failed");
914 view.iter_mut().for_each(|x| *x += 100.0);
915 }
916 arr.flush().expect("flush failed");
917 }
918
919 let arr = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
920 let view = arr.view().expect("view failed");
921 for (i, &val) in view.iter().enumerate() {
922 let expected = i as f32 + 100.0;
923 assert!(
924 (val - expected).abs() < f32::EPSILON,
925 "element {i}: got {val}, expected {expected}"
926 );
927 }
928
929 let _ = std::fs::remove_file(&path);
930 }
931
932 #[test]
937 fn test_read_only_rejects_view_mut() {
938 let dir = std::env::temp_dir();
939 let path = dir.join("test_mmap_ro_no_mut.mmap");
940
941 let original = make_f32_array(&[2, 2]);
942 {
943 let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
944 }
945
946 let mut arr = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
947 let err = arr.view_mut().expect_err("should return ReadOnly error");
948 assert!(matches!(err, MmapError::ReadOnly));
949
950 let _ = std::fs::remove_file(&path);
951 }
952
953 #[test]
958 fn test_cow_no_copy_before_write() {
959 let dir = std::env::temp_dir();
960 let path = dir.join("test_mmap_cow_read.mmap");
961
962 let original = make_f64_array(&[4]);
963 {
964 let _arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
965 }
966
967 let arr = MmapArray::<f64>::open_cow(&path).expect("open_cow failed");
968 let view = arr.view().expect("view failed");
970 for (a, b) in view.iter().zip(original.iter()) {
971 assert!((a - b).abs() < f64::EPSILON);
972 }
973
974 let _ = std::fs::remove_file(&path);
975 }
976
977 #[test]
978 fn test_cow_mutates_in_ram_not_file() {
979 let dir = std::env::temp_dir();
980 let path = dir.join("test_mmap_cow_mutate.mmap");
981
982 let original = make_f32_array(&[6]);
983 {
984 let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
985 }
986
987 {
988 let mut arr = MmapArray::<f32>::open_cow(&path).expect("open_cow failed");
989 {
990 let mut view = arr.view_mut().expect("view_mut failed");
991 view.iter_mut().for_each(|x| *x = -1.0);
993 }
994
995 let view = arr.view().expect("view failed");
997 for &val in view.iter() {
998 assert!(
999 (val - (-1.0f32)).abs() < f32::EPSILON,
1000 "COW in-memory data wrong: {val}"
1001 );
1002 }
1003 arr.flush().expect("flush failed");
1005 }
1006
1007 let arr_check = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
1009 let view_check = arr_check.view().expect("view failed");
1010 for (i, &val) in view_check.iter().enumerate() {
1011 let expected = i as f32;
1012 assert!(
1013 (val - expected).abs() < f32::EPSILON,
1014 "file was modified when COW was used: element {i}: got {val}"
1015 );
1016 }
1017
1018 let _ = std::fs::remove_file(&path);
1019 }
1020
1021 #[test]
1026 fn test_dtype_mismatch_on_open() {
1027 let dir = std::env::temp_dir();
1028 let path = dir.join("test_mmap_dtype_mismatch.mmap");
1029
1030 let original = make_f32_array(&[8]);
1031 {
1032 let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
1033 }
1034
1035 let err = MmapArray::<f64>::open_read_only(&path).expect_err("should be DtypeMismatch");
1037 assert!(
1038 matches!(
1039 err,
1040 MmapError::DtypeMismatch {
1041 expected: 2,
1042 actual: 1
1043 }
1044 ),
1045 "unexpected error: {err:?}"
1046 );
1047
1048 let _ = std::fs::remove_file(&path);
1049 }
1050
1051 #[test]
1056 fn test_noncontiguous_rejected() {
1057 use ndarray::ShapeBuilder;
1058
1059 let dir = std::env::temp_dir();
1060 let path = dir.join("test_mmap_noncontiguous.mmap");
1061
1062 let fortran: ArrayD<f64> = ndarray::Array::zeros(IxDyn(&[3, 4]).f());
1065 assert!(
1066 !fortran.is_standard_layout(),
1067 "test precondition: Fortran array must be non-standard-layout"
1068 );
1069
1070 let err = MmapArray::<f64>::create(&path, &fortran)
1071 .expect_err("create() should reject Fortran-order array");
1072 assert!(
1073 matches!(err, MmapError::NonContiguous),
1074 "expected NonContiguous, got: {err:?}"
1075 );
1076
1077 let _ = std::fs::remove_file(&path);
1078 }
1079
1080 #[test]
1085 fn test_anonymous_create() {
1086 let arr = MmapArray::<f32>::create_anonymous(&[8, 8]).expect("anonymous create failed");
1087 assert_eq!(arr.shape(), &[8, 8]);
1088 assert_eq!(arr.len(), 64);
1089 let view = arr.view().expect("view failed");
1090 for &val in view.iter() {
1091 assert_eq!(val, 0.0f32);
1092 }
1093 }
1094
1095 #[test]
1096 fn test_anonymous_mutation() {
1097 let mut arr = MmapArray::<i64>::create_anonymous(&[4]).expect("anonymous create failed");
1098 {
1099 let mut view = arr.view_mut().expect("view_mut failed");
1100 for (i, x) in view.iter_mut().enumerate() {
1101 *x = i as i64 * 10;
1102 }
1103 }
1104 let view = arr.view().expect("view failed");
1105 for (i, &val) in view.iter().enumerate() {
1106 assert_eq!(val, i as i64 * 10, "element {i} wrong");
1107 }
1108 }
1109
1110 #[test]
1115 fn test_to_owned_array() {
1116 let dir = std::env::temp_dir();
1117 let path = dir.join("test_mmap_to_owned.mmap");
1118
1119 let original = make_f64_array(&[2, 3, 4]);
1120 {
1121 let _arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
1122 }
1123
1124 let arr = MmapArray::<f64>::open_read_only(&path).expect("open failed");
1125 let owned = arr.to_owned_array().expect("to_owned failed");
1126 assert_eq!(owned.shape(), original.shape());
1127 for (a, b) in owned.iter().zip(original.iter()) {
1128 assert!((a - b).abs() < f64::EPSILON);
1129 }
1130
1131 let _ = std::fs::remove_file(&path);
1132 }
1133
1134 #[test]
1139 fn test_empty_array() {
1140 let dir = std::env::temp_dir();
1141 let path = dir.join("test_mmap_empty.mmap");
1142
1143 let empty: ArrayD<f32> = ArrayD::zeros(IxDyn(&[0]));
1145 let arr = MmapArray::<f32>::create(&path, &empty).expect("create failed");
1146 assert!(arr.is_empty());
1147 assert_eq!(arr.len(), 0);
1148 assert_eq!(arr.shape(), &[0]);
1149
1150 let _ = std::fs::remove_file(&path);
1151 }
1152}