Skip to main content

scirs2_core/memory/
mmap_array.rs

1//! # Memory-Mapped NDArray Wrapper
2//!
3//! This module provides a zero-copy, file-backed ndarray with Copy-on-Write (COW) semantics.
4//!
5//! ## Overview
6//!
7//! [`MmapArray<F>`] wraps a file on disk as an ndarray, enabling:
8//!
9//! - **Zero-copy reads**: Data is served directly from OS page cache without copying to RAM.
10//! - **Read-write mmap**: Mutations are written directly to the file via mmap.
11//! - **Copy-on-Write**: Reads are zero-copy; the first write triggers a copy of the data to RAM,
12//!   after which all writes stay in RAM until explicitly persisted.
13//!
14//! ## File Format
15//!
16//! The binary file starts with a 64-byte header (little-endian):
17//!
18//! ```text
19//! Offset  Size  Field
20//! ------  ----  -----
21//!  0..4    4    Magic bytes: b"MMAP"
22//!  4       1    Version: 1
23//!  5       1    dtype_id (1=f32, 2=f64, 3=i32, 4=i64)
24//!  6..8    2    ndim (u16, little-endian)
25//!  8..16   8    total_elements (u64, little-endian)
26//! 16..16+8*ndim  8 per dim  shape dimensions (u64 each, little-endian)
27//! ... zero-padding to byte 64
28//! 64..    data  Raw element bytes (F, little-endian, row-major / C order)
29//! ```
30//!
31//! ## Example
32//!
33//! ```rust,no_run
34//! # #[cfg(feature = "mmap")]
35//! # {
36//! use scirs2_core::memory::mmap_array::{MmapArray, MmapError};
37//! use ndarray::ArrayD;
38//!
39//! let tmp = std::env::temp_dir().join("example.mmap");
40//! let data = ArrayD::<f32>::zeros(ndarray::IxDyn(&[4, 8]));
41//! let arr = MmapArray::<f32>::create(&tmp, &data).expect("should succeed");
42//! assert_eq!(arr.shape(), &[4, 8]);
43//! # }
44//! ```
45
46use memmap2::{Mmap, MmapMut, MmapOptions};
47use ndarray::{Array, ArrayView, ArrayViewMut, IxDyn};
48use std::fs::OpenOptions;
49use std::io::{Seek, SeekFrom, Write};
50use std::path::{Path, PathBuf};
51
52// ---------------------------------------------------------------------------
53// Header constants
54// ---------------------------------------------------------------------------
55
56/// Magic bytes at the start of every mmap array file.
57const MAGIC: &[u8; 4] = b"MMAP";
58/// Current file format version.
59const FORMAT_VERSION: u8 = 1;
60/// Total header size in bytes. Data begins at this offset.
61const HEADER_SIZE: usize = 64;
62
63// ---------------------------------------------------------------------------
64// MmapElement trait
65// ---------------------------------------------------------------------------
66
67/// Types that can be stored in a memory-mapped file.
68///
69/// Implementors must be:
70/// - `Copy` — trivially duplicable, no heap ownership
71/// - `bytemuck::Pod` — safe for byte-level reinterpretation
72/// - `bytemuck::Zeroable` — a zero-initialized value is valid
73/// - `'static` — no borrowed references
74pub trait MmapElement: Copy + bytemuck::Pod + bytemuck::Zeroable + 'static {
75    /// Unique byte tag written to the file header identifying this type.
76    fn dtype_id() -> u8;
77    /// Size of a single element in bytes.
78    fn element_size() -> usize;
79}
80
81impl MmapElement for f32 {
82    fn dtype_id() -> u8 {
83        1
84    }
85    fn element_size() -> usize {
86        4
87    }
88}
89
90impl MmapElement for f64 {
91    fn dtype_id() -> u8 {
92        2
93    }
94    fn element_size() -> usize {
95        8
96    }
97}
98
99impl MmapElement for i32 {
100    fn dtype_id() -> u8 {
101        3
102    }
103    fn element_size() -> usize {
104        4
105    }
106}
107
108impl MmapElement for i64 {
109    fn dtype_id() -> u8 {
110        4
111    }
112    fn element_size() -> usize {
113        8
114    }
115}
116
117// ---------------------------------------------------------------------------
118// Internal storage enum
119// ---------------------------------------------------------------------------
120
121/// The underlying memory storage — one of three modes.
122enum MmapStorage<F: MmapElement> {
123    /// Read-only mapping backed directly by the file.
124    ReadOnly {
125        mmap: Mmap,
126        _phantom: std::marker::PhantomData<F>,
127    },
128    /// Read-write mapping backed directly by the file.
129    ReadWrite {
130        mmap: MmapMut,
131        _phantom: std::marker::PhantomData<F>,
132    },
133    /// Copy-on-write: reads zero-copy from `mmap`; first write populates `cow_data`.
134    CopyOnWrite {
135        mmap: Mmap,
136        /// `None` until the first write triggers the copy.
137        cow_data: Option<Vec<F>>,
138    },
139}
140
141// ---------------------------------------------------------------------------
142// MmapArray
143// ---------------------------------------------------------------------------
144
145/// A zero-copy, file-backed ndarray with optional Copy-on-Write semantics.
146///
147/// The array is stored in a flat binary file with a 64-byte header.  Array
148/// data is accessed via `memmap2`, so the OS manages paging automatically.
149///
150/// # Type Parameter
151///
152/// `F` must implement [`MmapElement`] (currently: `f32`, `f64`, `i32`, `i64`).
153pub struct MmapArray<F: MmapElement> {
154    storage: MmapStorage<F>,
155    shape: Vec<usize>,
156    /// C-order strides in element counts (not bytes).
157    strides: Vec<usize>,
158    file_path: PathBuf,
159}
160
161impl<F: MmapElement> std::fmt::Debug for MmapArray<F> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        let mode = match &self.storage {
164            MmapStorage::ReadOnly { .. } => "ReadOnly",
165            MmapStorage::ReadWrite { .. } => "ReadWrite",
166            MmapStorage::CopyOnWrite { cow_data, .. } => {
167                if cow_data.is_some() {
168                    "CopyOnWrite(dirty)"
169                } else {
170                    "CopyOnWrite(clean)"
171                }
172            }
173        };
174        f.debug_struct("MmapArray")
175            .field("mode", &mode)
176            .field("shape", &self.shape)
177            .field("strides", &self.strides)
178            .field("file_path", &self.file_path)
179            .finish()
180    }
181}
182
183// ---------------------------------------------------------------------------
184// Header encode / decode
185// ---------------------------------------------------------------------------
186
187/// Encode a 64-byte header into `buf`.
188///
189/// # Panics
190///
191/// Panics if `ndim` > 6 (the header only has room for 6 shape dimensions).
192fn encode_header(buf: &mut [u8; HEADER_SIZE], dtype_id: u8, shape: &[usize]) {
193    let ndim = shape.len();
194    assert!(
195        ndim <= 6,
196        "MmapArray supports at most 6 dimensions (header limit)"
197    );
198
199    buf.fill(0);
200
201    // Magic
202    buf[0..4].copy_from_slice(MAGIC);
203    // Version
204    buf[4] = FORMAT_VERSION;
205    // dtype_id
206    buf[5] = dtype_id;
207    // ndim (u16 LE)
208    let ndim_u16 = ndim as u16;
209    buf[6..8].copy_from_slice(&ndim_u16.to_le_bytes());
210    // total_elements (u64 LE)
211    let total: u64 = shape.iter().product::<usize>() as u64;
212    buf[8..16].copy_from_slice(&total.to_le_bytes());
213    // shape dimensions
214    for (i, &dim) in shape.iter().enumerate() {
215        let off = 16 + i * 8;
216        buf[off..off + 8].copy_from_slice(&(dim as u64).to_le_bytes());
217    }
218}
219
220/// Decode a 64-byte header, returning `(dtype_id, shape)`.
221fn decode_header(buf: &[u8; HEADER_SIZE]) -> Result<(u8, Vec<usize>), MmapError> {
222    if &buf[0..4] != MAGIC {
223        return Err(MmapError::InvalidMagic);
224    }
225    let version = buf[4];
226    if version != FORMAT_VERSION {
227        return Err(MmapError::VersionMismatch(version));
228    }
229    let dtype_id = buf[5];
230    let ndim = u16::from_le_bytes([buf[6], buf[7]]) as usize;
231    let _total_elements = u64::from_le_bytes(buf[8..16].try_into().map_err(|_| {
232        MmapError::Io(std::io::Error::new(
233            std::io::ErrorKind::InvalidData,
234            "header truncated at total_elements",
235        ))
236    })?);
237
238    let mut shape = Vec::with_capacity(ndim);
239    for i in 0..ndim {
240        let off = 16 + i * 8;
241        if off + 8 > HEADER_SIZE {
242            return Err(MmapError::Io(std::io::Error::new(
243                std::io::ErrorKind::InvalidData,
244                format!("header too small for ndim={}", ndim),
245            )));
246        }
247        let dim = u64::from_le_bytes(buf[off..off + 8].try_into().map_err(|_| {
248            MmapError::Io(std::io::Error::new(
249                std::io::ErrorKind::InvalidData,
250                "header truncated in shape",
251            ))
252        })?);
253        shape.push(dim as usize);
254    }
255
256    Ok((dtype_id, shape))
257}
258
259/// Compute C-order (row-major) strides for the given shape (element counts, not bytes).
260fn c_strides(shape: &[usize]) -> Vec<usize> {
261    let ndim = shape.len();
262    let mut strides = vec![1usize; ndim];
263    for i in (0..ndim.saturating_sub(1)).rev() {
264        strides[i] = strides[i + 1] * shape[i + 1];
265    }
266    strides
267}
268
269/// Total element count for a shape slice.
270fn total_elements(shape: &[usize]) -> usize {
271    shape.iter().product()
272}
273
274// ---------------------------------------------------------------------------
275// impl MmapArray
276// ---------------------------------------------------------------------------
277
278impl<F: MmapElement> MmapArray<F> {
279    // -----------------------------------------------------------------------
280    // Constructors
281    // -----------------------------------------------------------------------
282
283    /// Create a new memory-mapped file from an existing ndarray and return a
284    /// **read-write** [`MmapArray`] backed by that file.
285    ///
286    /// The input array must be in standard C-order (contiguous) layout.
287    /// A 64-byte header is written first, followed immediately by the raw
288    /// element bytes.
289    ///
290    /// # Errors
291    ///
292    /// Returns [`MmapError::NonContiguous`] if `data` is not a contiguous array.
293    pub fn create(path: &Path, data: &Array<F, IxDyn>) -> Result<Self, MmapError> {
294        // Require a contiguous layout so we can do a single slice copy.
295        if !data.is_standard_layout() {
296            return Err(MmapError::NonContiguous);
297        }
298
299        let shape = data.shape().to_vec();
300        let n_elems = total_elements(&shape);
301        let data_bytes = n_elems * F::element_size();
302        let file_size = HEADER_SIZE + data_bytes;
303
304        // Open / create file, set length.
305        let file = OpenOptions::new()
306            .read(true)
307            .write(true)
308            .create(true)
309            .truncate(true)
310            .open(path)?;
311        file.set_len(file_size as u64)?;
312
313        // Write header via a plain write (simpler than mmap for small header).
314        {
315            use std::io::BufWriter;
316            let mut writer = BufWriter::new(&file);
317            let mut header = [0u8; HEADER_SIZE];
318            encode_header(&mut header, F::dtype_id(), &shape);
319            writer.write_all(&header)?;
320            // Write element bytes from the array's raw slice.
321            let raw: &[F] = data.as_slice().ok_or(MmapError::NonContiguous)?;
322            writer.write_all(bytemuck::cast_slice(raw))?;
323            writer.flush()?;
324        }
325
326        // Now open a read-write mmap over the entire file (including header).
327        // We use the whole file mapping and expose only the data region.
328        let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
329
330        let strides = c_strides(&shape);
331        Ok(Self {
332            storage: MmapStorage::ReadWrite {
333                mmap,
334                _phantom: std::marker::PhantomData,
335            },
336            shape,
337            strides,
338            file_path: path.to_path_buf(),
339        })
340    }
341
342    /// Open an existing `.mmap` file in **read-only** mode.
343    ///
344    /// The returned array cannot be mutated; calls to [`view_mut`](Self::view_mut)
345    /// will return [`MmapError::ReadOnly`].
346    pub fn open_read_only(path: &Path) -> Result<Self, MmapError> {
347        let file = OpenOptions::new().read(true).open(path)?;
348        let mmap = unsafe { MmapOptions::new().map(&file)? };
349        Self::from_readonly_mmap(mmap, path)
350    }
351
352    /// Open an existing `.mmap` file in **read-write** mode.
353    ///
354    /// Mutations are written directly to the file (no buffering).
355    pub fn open_read_write(path: &Path) -> Result<Self, MmapError> {
356        let file = OpenOptions::new().read(true).write(true).open(path)?;
357        let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
358
359        // Read and validate header from the mapping itself.
360        if mmap.len() < HEADER_SIZE {
361            return Err(MmapError::Io(std::io::Error::new(
362                std::io::ErrorKind::UnexpectedEof,
363                "file too small to contain header",
364            )));
365        }
366        let header_bytes: &[u8; HEADER_SIZE] = mmap[..HEADER_SIZE].try_into().map_err(|_| {
367            MmapError::Io(std::io::Error::new(
368                std::io::ErrorKind::InvalidData,
369                "could not read header",
370            ))
371        })?;
372        let (dtype_id, shape) = decode_header(header_bytes)?;
373        if dtype_id != F::dtype_id() {
374            return Err(MmapError::DtypeMismatch {
375                expected: F::dtype_id(),
376                actual: dtype_id,
377            });
378        }
379
380        let strides = c_strides(&shape);
381        Ok(Self {
382            storage: MmapStorage::ReadWrite {
383                mmap,
384                _phantom: std::marker::PhantomData,
385            },
386            shape,
387            strides,
388            file_path: path.to_path_buf(),
389        })
390    }
391
392    /// Open an existing `.mmap` file in **Copy-on-Write** mode.
393    ///
394    /// Reads are served zero-copy from the OS page cache.  The first call to
395    /// [`view_mut`](Self::view_mut) triggers a full copy of the data into RAM,
396    /// after which all mutations happen in-memory.  The original file is never
397    /// modified unless you later call [`flush`](Self::flush) (which, in COW mode,
398    /// is a no-op since there is no writable mapping to flush).
399    pub fn open_cow(path: &Path) -> Result<Self, MmapError> {
400        let file = OpenOptions::new().read(true).open(path)?;
401        let mmap = unsafe { MmapOptions::new().map(&file)? };
402        let (_, shape, strides) = Self::parse_readonly_mmap(&mmap, path)?;
403        Ok(Self {
404            storage: MmapStorage::CopyOnWrite {
405                mmap,
406                cow_data: None,
407            },
408            shape,
409            strides,
410            file_path: path.to_path_buf(),
411        })
412    }
413
414    // -----------------------------------------------------------------------
415    // Accessors
416    // -----------------------------------------------------------------------
417
418    /// Return a read-only ndarray view.
419    ///
420    /// For read-only and read-write modes, the view is zero-copy (backed by the mmap).
421    /// For COW mode before any write, the view is also zero-copy.
422    /// For COW mode after the first write, the view is backed by the in-RAM copy.
423    pub fn view(&self) -> Result<ArrayView<'_, F, IxDyn>, MmapError> {
424        match &self.storage {
425            MmapStorage::ReadOnly { mmap, .. } => {
426                let data_slice = &mmap[HEADER_SIZE..];
427                let elems: &[F] = bytemuck::cast_slice(data_slice);
428                let ix = IxDyn(self.shape.as_slice());
429                // SAFETY: We own the mmap for the lifetime of `&self`, the
430                // pointer is valid, and the length matches `self.shape`.
431                let view = unsafe { ArrayView::from_shape_ptr(ix, elems.as_ptr()) };
432                Ok(view)
433            }
434            MmapStorage::ReadWrite { mmap, .. } => {
435                let data_slice = &mmap[HEADER_SIZE..];
436                let elems: &[F] = bytemuck::cast_slice(data_slice);
437                let ix = IxDyn(self.shape.as_slice());
438                // SAFETY: Same as above.
439                let view = unsafe { ArrayView::from_shape_ptr(ix, elems.as_ptr()) };
440                Ok(view)
441            }
442            MmapStorage::CopyOnWrite { mmap, cow_data } => {
443                match cow_data {
444                    None => {
445                        // No write yet — serve directly from the mmap.
446                        let data_slice = &mmap[HEADER_SIZE..];
447                        let elems: &[F] = bytemuck::cast_slice(data_slice);
448                        let ix = IxDyn(self.shape.as_slice());
449                        // SAFETY: Mmap is valid for `&self` lifetime.
450                        let view = unsafe { ArrayView::from_shape_ptr(ix, elems.as_ptr()) };
451                        Ok(view)
452                    }
453                    Some(data) => {
454                        let ix = IxDyn(self.shape.as_slice());
455                        // SAFETY: `data` is a Vec owned by `self`, pointer valid for `&self`.
456                        let view = unsafe { ArrayView::from_shape_ptr(ix, data.as_ptr()) };
457                        Ok(view)
458                    }
459                }
460            }
461        }
462    }
463
464    /// Return a mutable ndarray view.
465    ///
466    /// - **ReadOnly**: always returns [`MmapError::ReadOnly`].
467    /// - **ReadWrite**: the view is zero-copy and writes go directly to the file.
468    /// - **COW**: triggers a copy of the mmap data into RAM on the first call.
469    pub fn view_mut(&mut self) -> Result<ArrayViewMut<'_, F, IxDyn>, MmapError> {
470        match &mut self.storage {
471            MmapStorage::ReadOnly { .. } => Err(MmapError::ReadOnly),
472            MmapStorage::ReadWrite { mmap, .. } => {
473                let data_slice = &mut mmap[HEADER_SIZE..];
474                let elems: &mut [F] = bytemuck::cast_slice_mut(data_slice);
475                let ix = IxDyn(self.shape.as_slice());
476                // SAFETY: We have exclusive access via `&mut self`.
477                let view = unsafe { ArrayViewMut::from_shape_ptr(ix, elems.as_mut_ptr()) };
478                Ok(view)
479            }
480            MmapStorage::CopyOnWrite { mmap, cow_data } => {
481                // Trigger COW fault if this is the first write.
482                if cow_data.is_none() {
483                    let data_slice = &mmap[HEADER_SIZE..];
484                    let elems: &[F] = bytemuck::cast_slice(data_slice);
485                    *cow_data = Some(elems.to_vec());
486                }
487                let data = cow_data.as_mut().ok_or_else(|| {
488                    MmapError::Io(std::io::Error::other(
489                        "COW data unexpectedly None after initialization",
490                    ))
491                })?;
492                let ix = IxDyn(self.shape.as_slice());
493                // SAFETY: We have exclusive access via `&mut self`.
494                let view = unsafe { ArrayViewMut::from_shape_ptr(ix, data.as_mut_ptr()) };
495                Ok(view)
496            }
497        }
498    }
499
500    /// Flush changes to disk.
501    ///
502    /// - **ReadWrite**: calls `MmapMut::flush()`.
503    /// - **ReadOnly** / **COW**: no-op (returns `Ok(())`).
504    pub fn flush(&self) -> Result<(), MmapError> {
505        match &self.storage {
506            MmapStorage::ReadWrite { mmap, .. } => {
507                mmap.flush()?;
508                Ok(())
509            }
510            _ => Ok(()),
511        }
512    }
513
514    /// Return the shape of the array.
515    pub fn shape(&self) -> &[usize] {
516        &self.shape
517    }
518
519    /// Return the C-order strides (in element counts).
520    pub fn strides(&self) -> &[usize] {
521        &self.strides
522    }
523
524    /// Return the total number of elements.
525    pub fn len(&self) -> usize {
526        total_elements(&self.shape)
527    }
528
529    /// Return `true` if the array has zero elements.
530    pub fn is_empty(&self) -> bool {
531        self.len() == 0
532    }
533
534    /// Return the file path backing this array.
535    pub fn file_path(&self) -> &Path {
536        &self.file_path
537    }
538
539    /// Copy all elements into a heap-allocated [`Array<F, IxDyn>`].
540    ///
541    /// This always copies; it is equivalent to `.view()?.to_owned()`.
542    pub fn to_owned_array(&self) -> Result<Array<F, IxDyn>, MmapError> {
543        let view = self.view()?;
544        Ok(view.to_owned())
545    }
546
547    // -----------------------------------------------------------------------
548    // Internal helpers
549    // -----------------------------------------------------------------------
550
551    /// Construct a `ReadOnly` variant from an already-mapped region.
552    fn from_readonly_mmap(mmap: Mmap, path: &Path) -> Result<Self, MmapError> {
553        let (_, shape, strides) = Self::parse_readonly_mmap(&mmap, path)?;
554        Ok(Self {
555            storage: MmapStorage::ReadOnly {
556                mmap,
557                _phantom: std::marker::PhantomData,
558            },
559            shape,
560            strides,
561            file_path: path.to_path_buf(),
562        })
563    }
564
565    /// Validate and parse the header from a read-only mapping.
566    ///
567    /// Returns `(dtype_id, shape, strides)`.
568    fn parse_readonly_mmap(
569        mmap: &Mmap,
570        _path: &Path,
571    ) -> Result<(u8, Vec<usize>, Vec<usize>), MmapError> {
572        if mmap.len() < HEADER_SIZE {
573            return Err(MmapError::Io(std::io::Error::new(
574                std::io::ErrorKind::UnexpectedEof,
575                "file too small to contain header",
576            )));
577        }
578        let header_bytes: &[u8; HEADER_SIZE] = mmap[..HEADER_SIZE].try_into().map_err(|_| {
579            MmapError::Io(std::io::Error::new(
580                std::io::ErrorKind::InvalidData,
581                "could not read header slice",
582            ))
583        })?;
584        let (dtype_id, shape) = decode_header(header_bytes)?;
585        if dtype_id != F::dtype_id() {
586            return Err(MmapError::DtypeMismatch {
587                expected: F::dtype_id(),
588                actual: dtype_id,
589            });
590        }
591        let strides = c_strides(&shape);
592        Ok((dtype_id, shape, strides))
593    }
594
595    // -----------------------------------------------------------------------
596    // Test-only helpers
597    // -----------------------------------------------------------------------
598
599    /// Create an anonymous (non-file-backed) mmap for testing purposes.
600    ///
601    /// This uses an anonymous mapping backed by a temporary file.
602    #[cfg(test)]
603    fn create_anonymous(shape: &[usize]) -> Result<Self, MmapError> {
604        use tempfile::tempfile;
605
606        let n_elems = total_elements(shape);
607        let file_size = HEADER_SIZE + n_elems * F::element_size();
608
609        let file = tempfile()?;
610        file.set_len(file_size as u64)?;
611
612        // Write header.
613        {
614            use std::io::BufWriter;
615            let mut writer = BufWriter::new(&file);
616            let mut header = [0u8; HEADER_SIZE];
617            encode_header(&mut header, F::dtype_id(), shape);
618            writer.write_all(&header)?;
619            // Data is zero-initialized by `set_len`.
620            writer.flush()?;
621        }
622
623        let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
624        let strides = c_strides(shape);
625
626        // We store a dummy path since this is anonymous.
627        Ok(Self {
628            storage: MmapStorage::ReadWrite {
629                mmap,
630                _phantom: std::marker::PhantomData,
631            },
632            shape: shape.to_vec(),
633            strides,
634            file_path: PathBuf::from("<anonymous>"),
635        })
636    }
637}
638
639// ---------------------------------------------------------------------------
640// MmapError
641// ---------------------------------------------------------------------------
642
643/// Errors that can occur when creating or accessing a [`MmapArray`].
644#[derive(Debug, thiserror::Error)]
645pub enum MmapError {
646    /// An I/O error occurred (file open, read, write, flush, etc.).
647    #[error("IO error: {0}")]
648    Io(#[from] std::io::Error),
649
650    /// The file does not start with the expected magic bytes `b"MMAP"`.
651    #[error("Invalid magic bytes — not a valid mmap array file")]
652    InvalidMagic,
653
654    /// The file header reports a format version other than 1.
655    #[error("Version mismatch: expected 1, got {0}")]
656    VersionMismatch(u8),
657
658    /// The dtype tag in the file does not match the requested element type.
659    #[error("Dtype mismatch: expected dtype_id {expected}, got {actual}")]
660    DtypeMismatch { expected: u8, actual: u8 },
661
662    /// The stored shape does not match the shape provided by the caller.
663    #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
664    ShapeMismatch {
665        expected: Vec<usize>,
666        actual: Vec<usize>,
667    },
668
669    /// The source ndarray is not in contiguous C-order layout.
670    #[error("Array is not contiguous (non-contiguous layouts not supported)")]
671    NonContiguous,
672
673    /// The mapping is read-only and a mutable view was requested.
674    #[error("Array is read-only")]
675    ReadOnly,
676}
677
678// ---------------------------------------------------------------------------
679// Unit tests
680// ---------------------------------------------------------------------------
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use ndarray::{ArrayD, IxDyn};
686
687    /// Build a small f32 array filled with ascending values.
688    fn make_f32_array(shape: &[usize]) -> ArrayD<f32> {
689        let n = shape.iter().product::<usize>();
690        let data: Vec<f32> = (0..n).map(|i| i as f32).collect();
691        ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape mismatch in test helper")
692    }
693
694    /// Build a small f64 array filled with ascending values.
695    fn make_f64_array(shape: &[usize]) -> ArrayD<f64> {
696        let n = shape.iter().product::<usize>();
697        let data: Vec<f64> = (0..n).map(|i| i as f64).collect();
698        ArrayD::from_shape_vec(IxDyn(shape), data).expect("shape mismatch in test helper")
699    }
700
701    // -----------------------------------------------------------------------
702    // Header round-trip tests
703    // -----------------------------------------------------------------------
704
705    #[test]
706    fn test_header_round_trip_1d() {
707        let mut buf = [0u8; HEADER_SIZE];
708        encode_header(&mut buf, 1, &[100]);
709        let (dtype_id, shape) = decode_header(&buf).expect("decode failed");
710        assert_eq!(dtype_id, 1);
711        assert_eq!(shape, vec![100usize]);
712    }
713
714    #[test]
715    fn test_header_round_trip_2d() {
716        let mut buf = [0u8; HEADER_SIZE];
717        encode_header(&mut buf, 2, &[3, 4]);
718        let (dtype_id, shape) = decode_header(&buf).expect("decode failed");
719        assert_eq!(dtype_id, 2);
720        assert_eq!(shape, vec![3, 4]);
721    }
722
723    #[test]
724    fn test_header_round_trip_6d() {
725        let mut buf = [0u8; HEADER_SIZE];
726        encode_header(&mut buf, 4, &[2, 3, 4, 5, 6, 7]);
727        let (dtype_id, shape) = decode_header(&buf).expect("decode failed");
728        assert_eq!(dtype_id, 4);
729        assert_eq!(shape, vec![2, 3, 4, 5, 6, 7]);
730    }
731
732    #[test]
733    fn test_bad_magic() {
734        let mut buf = [0u8; HEADER_SIZE];
735        buf[0..4].copy_from_slice(b"NOPE");
736        let err = decode_header(&buf).expect_err("should fail");
737        assert!(matches!(err, MmapError::InvalidMagic));
738    }
739
740    #[test]
741    fn test_bad_version() {
742        let mut buf = [0u8; HEADER_SIZE];
743        encode_header(&mut buf, 1, &[10]);
744        buf[4] = 99; // corrupt version
745        let err = decode_header(&buf).expect_err("should fail");
746        assert!(matches!(err, MmapError::VersionMismatch(99)));
747    }
748
749    // -----------------------------------------------------------------------
750    // c_strides
751    // -----------------------------------------------------------------------
752
753    #[test]
754    fn test_c_strides_1d() {
755        assert_eq!(c_strides(&[5]), vec![1]);
756    }
757
758    #[test]
759    fn test_c_strides_2d() {
760        assert_eq!(c_strides(&[3, 4]), vec![4, 1]);
761    }
762
763    #[test]
764    fn test_c_strides_3d() {
765        assert_eq!(c_strides(&[2, 3, 4]), vec![12, 4, 1]);
766    }
767
768    // -----------------------------------------------------------------------
769    // MmapElement implementations
770    // -----------------------------------------------------------------------
771
772    #[test]
773    fn test_dtype_ids_are_distinct() {
774        let ids = [
775            f32::dtype_id(),
776            f64::dtype_id(),
777            i32::dtype_id(),
778            i64::dtype_id(),
779        ];
780        for i in 0..ids.len() {
781            for j in (i + 1)..ids.len() {
782                assert_ne!(ids[i], ids[j], "dtype IDs must be unique");
783            }
784        }
785    }
786
787    #[test]
788    fn test_element_sizes() {
789        assert_eq!(f32::element_size(), 4);
790        assert_eq!(f64::element_size(), 8);
791        assert_eq!(i32::element_size(), 4);
792        assert_eq!(i64::element_size(), 8);
793    }
794
795    // -----------------------------------------------------------------------
796    // create + open_read_only
797    // -----------------------------------------------------------------------
798
799    #[test]
800    fn test_create_and_read_only_f32() {
801        let dir = std::env::temp_dir();
802        let path = dir.join("test_mmap_create_ro_f32.mmap");
803
804        let original = make_f32_array(&[4, 5]);
805        {
806            let arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
807            assert_eq!(arr.shape(), &[4, 5]);
808            assert_eq!(arr.len(), 20);
809            assert!(!arr.is_empty());
810        }
811
812        {
813            let arr = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
814            assert_eq!(arr.shape(), &[4, 5]);
815            let view = arr.view().expect("view failed");
816            // Check that values match
817            for (a, b) in view.iter().zip(original.iter()) {
818                assert!((a - b).abs() < f32::EPSILON, "element mismatch: {a} vs {b}");
819            }
820        }
821
822        let _ = std::fs::remove_file(&path);
823    }
824
825    #[test]
826    fn test_create_and_read_only_f64() {
827        let dir = std::env::temp_dir();
828        let path = dir.join("test_mmap_create_ro_f64.mmap");
829
830        let original = make_f64_array(&[3, 3]);
831        {
832            let _arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
833        }
834
835        let arr = MmapArray::<f64>::open_read_only(&path).expect("open_read_only failed");
836        let owned = arr.to_owned_array().expect("to_owned failed");
837        assert_eq!(owned.shape(), &[3, 3]);
838        for (a, b) in owned.iter().zip(original.iter()) {
839            assert!((a - b).abs() < f64::EPSILON);
840        }
841
842        let _ = std::fs::remove_file(&path);
843    }
844
845    #[test]
846    fn test_create_and_read_only_i32() {
847        let dir = std::env::temp_dir();
848        let path = dir.join("test_mmap_create_ro_i32.mmap");
849
850        let n = 6usize;
851        let data: Vec<i32> = (0..n as i32).collect();
852        let original = ArrayD::from_shape_vec(IxDyn(&[2, 3]), data).expect("shape mismatch");
853        {
854            let _arr = MmapArray::<i32>::create(&path, &original).expect("create failed");
855        }
856
857        let arr = MmapArray::<i32>::open_read_only(&path).expect("open_read_only failed");
858        let view = arr.view().expect("view failed");
859        for (a, b) in view.iter().zip(original.iter()) {
860            assert_eq!(a, b);
861        }
862
863        let _ = std::fs::remove_file(&path);
864    }
865
866    // -----------------------------------------------------------------------
867    // read-write mode
868    // -----------------------------------------------------------------------
869
870    #[test]
871    fn test_read_write_mutation() {
872        let dir = std::env::temp_dir();
873        let path = dir.join("test_mmap_rw.mmap");
874
875        let original = make_f64_array(&[5]);
876        {
877            let mut arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
878            {
879                let mut view = arr.view_mut().expect("view_mut failed");
880                // Double every element in-place.
881                view.iter_mut().for_each(|x| *x *= 2.0);
882            }
883            arr.flush().expect("flush failed");
884        }
885
886        // Re-open and verify changes persisted.
887        let arr = MmapArray::<f64>::open_read_only(&path).expect("open_read_only failed");
888        let view = arr.view().expect("view failed");
889        for (i, &val) in view.iter().enumerate() {
890            let expected = (i as f64) * 2.0;
891            assert!(
892                (val - expected).abs() < f64::EPSILON,
893                "element {i}: got {val}, expected {expected}"
894            );
895        }
896
897        let _ = std::fs::remove_file(&path);
898    }
899
900    #[test]
901    fn test_open_read_write_then_mutate() {
902        let dir = std::env::temp_dir();
903        let path = dir.join("test_mmap_open_rw.mmap");
904
905        let original = make_f32_array(&[3, 3]);
906        {
907            let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
908        }
909
910        {
911            let mut arr = MmapArray::<f32>::open_read_write(&path).expect("open_read_write failed");
912            {
913                let mut view = arr.view_mut().expect("view_mut failed");
914                view.iter_mut().for_each(|x| *x += 100.0);
915            }
916            arr.flush().expect("flush failed");
917        }
918
919        let arr = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
920        let view = arr.view().expect("view failed");
921        for (i, &val) in view.iter().enumerate() {
922            let expected = i as f32 + 100.0;
923            assert!(
924                (val - expected).abs() < f32::EPSILON,
925                "element {i}: got {val}, expected {expected}"
926            );
927        }
928
929        let _ = std::fs::remove_file(&path);
930    }
931
932    // -----------------------------------------------------------------------
933    // read-only rejects mutation
934    // -----------------------------------------------------------------------
935
936    #[test]
937    fn test_read_only_rejects_view_mut() {
938        let dir = std::env::temp_dir();
939        let path = dir.join("test_mmap_ro_no_mut.mmap");
940
941        let original = make_f32_array(&[2, 2]);
942        {
943            let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
944        }
945
946        let mut arr = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
947        let err = arr.view_mut().expect_err("should return ReadOnly error");
948        assert!(matches!(err, MmapError::ReadOnly));
949
950        let _ = std::fs::remove_file(&path);
951    }
952
953    // -----------------------------------------------------------------------
954    // Copy-on-Write semantics
955    // -----------------------------------------------------------------------
956
957    #[test]
958    fn test_cow_no_copy_before_write() {
959        let dir = std::env::temp_dir();
960        let path = dir.join("test_mmap_cow_read.mmap");
961
962        let original = make_f64_array(&[4]);
963        {
964            let _arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
965        }
966
967        let arr = MmapArray::<f64>::open_cow(&path).expect("open_cow failed");
968        // Before any write, cow_data should be None (verified by reading successfully).
969        let view = arr.view().expect("view failed");
970        for (a, b) in view.iter().zip(original.iter()) {
971            assert!((a - b).abs() < f64::EPSILON);
972        }
973
974        let _ = std::fs::remove_file(&path);
975    }
976
977    #[test]
978    fn test_cow_mutates_in_ram_not_file() {
979        let dir = std::env::temp_dir();
980        let path = dir.join("test_mmap_cow_mutate.mmap");
981
982        let original = make_f32_array(&[6]);
983        {
984            let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
985        }
986
987        {
988            let mut arr = MmapArray::<f32>::open_cow(&path).expect("open_cow failed");
989            {
990                let mut view = arr.view_mut().expect("view_mut failed");
991                // Write should trigger COW fault and copy to RAM.
992                view.iter_mut().for_each(|x| *x = -1.0);
993            }
994
995            // In-memory view should reflect mutation.
996            let view = arr.view().expect("view failed");
997            for &val in view.iter() {
998                assert!(
999                    (val - (-1.0f32)).abs() < f32::EPSILON,
1000                    "COW in-memory data wrong: {val}"
1001                );
1002            }
1003            // flush is a no-op in COW mode; should not error.
1004            arr.flush().expect("flush failed");
1005        }
1006
1007        // The original file must be UNCHANGED because we used COW.
1008        let arr_check = MmapArray::<f32>::open_read_only(&path).expect("open_read_only failed");
1009        let view_check = arr_check.view().expect("view failed");
1010        for (i, &val) in view_check.iter().enumerate() {
1011            let expected = i as f32;
1012            assert!(
1013                (val - expected).abs() < f32::EPSILON,
1014                "file was modified when COW was used: element {i}: got {val}"
1015            );
1016        }
1017
1018        let _ = std::fs::remove_file(&path);
1019    }
1020
1021    // -----------------------------------------------------------------------
1022    // dtype mismatch detection
1023    // -----------------------------------------------------------------------
1024
1025    #[test]
1026    fn test_dtype_mismatch_on_open() {
1027        let dir = std::env::temp_dir();
1028        let path = dir.join("test_mmap_dtype_mismatch.mmap");
1029
1030        let original = make_f32_array(&[8]);
1031        {
1032            let _arr = MmapArray::<f32>::create(&path, &original).expect("create failed");
1033        }
1034
1035        // Try to open as f64 — must fail with DtypeMismatch.
1036        let err = MmapArray::<f64>::open_read_only(&path).expect_err("should be DtypeMismatch");
1037        assert!(
1038            matches!(
1039                err,
1040                MmapError::DtypeMismatch {
1041                    expected: 2,
1042                    actual: 1
1043                }
1044            ),
1045            "unexpected error: {err:?}"
1046        );
1047
1048        let _ = std::fs::remove_file(&path);
1049    }
1050
1051    // -----------------------------------------------------------------------
1052    // Non-contiguous array rejection
1053    // -----------------------------------------------------------------------
1054
1055    #[test]
1056    fn test_noncontiguous_rejected() {
1057        use ndarray::ShapeBuilder;
1058
1059        let dir = std::env::temp_dir();
1060        let path = dir.join("test_mmap_noncontiguous.mmap");
1061
1062        // Create a Fortran-order (column-major) array, which is NOT in standard
1063        // C (row-major) layout.  `create()` must reject it with NonContiguous.
1064        let fortran: ArrayD<f64> = ndarray::Array::zeros(IxDyn(&[3, 4]).f());
1065        assert!(
1066            !fortran.is_standard_layout(),
1067            "test precondition: Fortran array must be non-standard-layout"
1068        );
1069
1070        let err = MmapArray::<f64>::create(&path, &fortran)
1071            .expect_err("create() should reject Fortran-order array");
1072        assert!(
1073            matches!(err, MmapError::NonContiguous),
1074            "expected NonContiguous, got: {err:?}"
1075        );
1076
1077        let _ = std::fs::remove_file(&path);
1078    }
1079
1080    // -----------------------------------------------------------------------
1081    // Anonymous mmap (test-only helper)
1082    // -----------------------------------------------------------------------
1083
1084    #[test]
1085    fn test_anonymous_create() {
1086        let arr = MmapArray::<f32>::create_anonymous(&[8, 8]).expect("anonymous create failed");
1087        assert_eq!(arr.shape(), &[8, 8]);
1088        assert_eq!(arr.len(), 64);
1089        let view = arr.view().expect("view failed");
1090        for &val in view.iter() {
1091            assert_eq!(val, 0.0f32);
1092        }
1093    }
1094
1095    #[test]
1096    fn test_anonymous_mutation() {
1097        let mut arr = MmapArray::<i64>::create_anonymous(&[4]).expect("anonymous create failed");
1098        {
1099            let mut view = arr.view_mut().expect("view_mut failed");
1100            for (i, x) in view.iter_mut().enumerate() {
1101                *x = i as i64 * 10;
1102            }
1103        }
1104        let view = arr.view().expect("view failed");
1105        for (i, &val) in view.iter().enumerate() {
1106            assert_eq!(val, i as i64 * 10, "element {i} wrong");
1107        }
1108    }
1109
1110    // -----------------------------------------------------------------------
1111    // to_owned_array
1112    // -----------------------------------------------------------------------
1113
1114    #[test]
1115    fn test_to_owned_array() {
1116        let dir = std::env::temp_dir();
1117        let path = dir.join("test_mmap_to_owned.mmap");
1118
1119        let original = make_f64_array(&[2, 3, 4]);
1120        {
1121            let _arr = MmapArray::<f64>::create(&path, &original).expect("create failed");
1122        }
1123
1124        let arr = MmapArray::<f64>::open_read_only(&path).expect("open failed");
1125        let owned = arr.to_owned_array().expect("to_owned failed");
1126        assert_eq!(owned.shape(), original.shape());
1127        for (a, b) in owned.iter().zip(original.iter()) {
1128            assert!((a - b).abs() < f64::EPSILON);
1129        }
1130
1131        let _ = std::fs::remove_file(&path);
1132    }
1133
1134    // -----------------------------------------------------------------------
1135    // is_empty
1136    // -----------------------------------------------------------------------
1137
1138    #[test]
1139    fn test_empty_array() {
1140        let dir = std::env::temp_dir();
1141        let path = dir.join("test_mmap_empty.mmap");
1142
1143        // A 0-element array (shape [0]).
1144        let empty: ArrayD<f32> = ArrayD::zeros(IxDyn(&[0]));
1145        let arr = MmapArray::<f32>::create(&path, &empty).expect("create failed");
1146        assert!(arr.is_empty());
1147        assert_eq!(arr.len(), 0);
1148        assert_eq!(arr.shape(), &[0]);
1149
1150        let _ = std::fs::remove_file(&path);
1151    }
1152}