Skip to main content

scirs2_numpy/
masked.rs

1//! Masked array support: a NumPy `numpy.ma`-compatible interface.
2//!
3//! A masked array pairs a data buffer with a boolean mask indicating which
4//! elements are valid.  Elements where the mask is `true` are considered
5//! invalid (masked out) and are excluded from aggregate operations such as
6//! [`MaskedArray::mean`] and [`MaskedArray::sum`].
7
8use pyo3::prelude::*;
9
10/// A masked array of `f64` values.
11///
12/// Each element has an associated boolean mask value:
13/// - `false` → element is valid and participates in computations.
14/// - `true`  → element is masked (invalid) and is replaced by `fill_value`
15///   when the filled view is requested.
16#[pyclass(name = "MaskedArray")]
17pub struct MaskedArray {
18    /// Flat data buffer.
19    data: Vec<f64>,
20    /// Flat mask buffer; parallel to `data`.
21    mask: Vec<bool>,
22    /// Logical shape of the array.
23    shape: Vec<usize>,
24    /// Value used to fill masked positions in [`Self::filled`].
25    fill_value: f64,
26}
27
28#[pymethods]
29impl MaskedArray {
30    /// Construct a new masked array.
31    ///
32    /// # Arguments
33    /// * `data`       – flat element buffer; length must equal the product of `shape`.
34    /// * `mask`       – optional flat mask; defaults to all-`false` (nothing masked).
35    /// * `shape`      – logical shape; must have product equal to `data.len()`.
36    /// * `fill_value` – value substituted for masked elements (default: `f64::NAN`).
37    #[new]
38    pub fn new(
39        data: Vec<f64>,
40        mask: Option<Vec<bool>>,
41        shape: Vec<usize>,
42        fill_value: Option<f64>,
43    ) -> PyResult<Self> {
44        let n: usize = shape.iter().product();
45        if data.len() != n {
46            return Err(pyo3::exceptions::PyValueError::new_err(format!(
47                "data length {} does not match shape product {}",
48                data.len(),
49                n
50            )));
51        }
52        let mask = mask.unwrap_or_else(|| vec![false; n]);
53        if mask.len() != n {
54            return Err(pyo3::exceptions::PyValueError::new_err(
55                "mask length does not match shape product",
56            ));
57        }
58        Ok(Self {
59            data,
60            mask,
61            shape,
62            fill_value: fill_value.unwrap_or(f64::NAN),
63        })
64    }
65
66    /// Return the data with masked positions replaced by `fill_value`.
67    pub fn filled(&self) -> Vec<f64> {
68        self.data
69            .iter()
70            .zip(self.mask.iter())
71            .map(|(&d, &m)| if m { self.fill_value } else { d })
72            .collect()
73    }
74
75    /// Return the number of unmasked (valid) elements.
76    pub fn count(&self) -> usize {
77        self.mask.iter().filter(|&&m| !m).count()
78    }
79
80    /// Return the mean of unmasked elements, or `None` if all elements are masked.
81    pub fn mean(&self) -> Option<f64> {
82        let valid: Vec<f64> = self
83            .data
84            .iter()
85            .zip(self.mask.iter())
86            .filter(|(_, &m)| !m)
87            .map(|(&d, _)| d)
88            .collect();
89        if valid.is_empty() {
90            None
91        } else {
92            Some(valid.iter().sum::<f64>() / valid.len() as f64)
93        }
94    }
95
96    /// Return the sum of unmasked elements.
97    pub fn sum(&self) -> f64 {
98        self.data
99            .iter()
100            .zip(self.mask.iter())
101            .filter(|(_, &m)| !m)
102            .map(|(&d, _)| d)
103            .sum()
104    }
105
106    /// Return the logical shape.
107    pub fn shape(&self) -> Vec<usize> {
108        self.shape.clone()
109    }
110
111    /// Return the raw data buffer.
112    pub fn data(&self) -> Vec<f64> {
113        self.data.clone()
114    }
115
116    /// Return the mask buffer.
117    pub fn mask(&self) -> Vec<bool> {
118        self.mask.clone()
119    }
120
121    /// Return the fill value used for masked positions.
122    pub fn fill_value(&self) -> f64 {
123        self.fill_value
124    }
125
126    /// Set the mask flag for a single flat element index.
127    pub fn mask_element(&mut self, idx: usize, masked: bool) -> PyResult<()> {
128        if idx >= self.mask.len() {
129            return Err(pyo3::exceptions::PyIndexError::new_err(
130                "index out of bounds",
131            ));
132        }
133        self.mask[idx] = masked;
134        Ok(())
135    }
136
137    /// Apply an element-wise unary operation to unmasked values.
138    ///
139    /// Masked positions receive `fill_value`.  Supported operations:
140    /// - `"abs"` – absolute value
141    /// - `"sqrt"` – square root
142    /// - `"log"` – natural logarithm
143    pub fn apply_unmasked(&self, op: &str) -> PyResult<Vec<f64>> {
144        let fill = self.fill_value;
145        match op {
146            "abs" => Ok(self
147                .data
148                .iter()
149                .zip(self.mask.iter())
150                .map(|(&d, &m)| if m { fill } else { d.abs() })
151                .collect()),
152            "sqrt" => Ok(self
153                .data
154                .iter()
155                .zip(self.mask.iter())
156                .map(|(&d, &m)| if m { fill } else { d.sqrt() })
157                .collect()),
158            "log" => Ok(self
159                .data
160                .iter()
161                .zip(self.mask.iter())
162                .map(|(&d, &m)| if m { fill } else { d.ln() })
163                .collect()),
164            _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
165                "unknown operation '{op}'; supported: abs, sqrt, log"
166            ))),
167        }
168    }
169}
170
171/// Create a 1-D masked array from parallel data and mask vectors.
172///
173/// This mirrors `numpy.ma.array(data, mask=mask)`.
174#[pyfunction]
175pub fn masked_array(data: Vec<f64>, mask: Vec<bool>) -> PyResult<MaskedArray> {
176    let n = data.len();
177    MaskedArray::new(data, Some(mask), vec![n], None)
178}
179
180/// Create a 1-D masked array with all elements below `threshold` masked.
181///
182/// Mirrors `numpy.ma.masked_less(data, threshold)`.
183#[pyfunction]
184pub fn masked_less(data: Vec<f64>, threshold: f64) -> MaskedArray {
185    let n = data.len();
186    let mask: Vec<bool> = data.iter().map(|&d| d < threshold).collect();
187    MaskedArray {
188        data,
189        mask,
190        shape: vec![n],
191        fill_value: f64::NAN,
192    }
193}
194
195/// Register masked-array classes and functions into a PyO3 module.
196///
197/// Call this from your `#[pymodule]` init function to expose `MaskedArray`,
198/// `masked_array`, and `masked_less`.
199pub fn register_masked_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
200    m.add_class::<MaskedArray>()?;
201    m.add_function(wrap_pyfunction!(masked_array, m)?)?;
202    m.add_function(wrap_pyfunction!(masked_less, m)?)?;
203    Ok(())
204}