scirs2_numpy/masked.rs
1//! Masked array support: a NumPy `numpy.ma`-compatible interface.
2//!
3//! A masked array pairs a data buffer with a boolean mask indicating which
4//! elements are valid. Elements where the mask is `true` are considered
5//! invalid (masked out) and are excluded from aggregate operations such as
6//! [`MaskedArray::mean`] and [`MaskedArray::sum`].
7
8use pyo3::prelude::*;
9
10/// A masked array of `f64` values.
11///
12/// Each element has an associated boolean mask value:
13/// - `false` → element is valid and participates in computations.
14/// - `true` → element is masked (invalid) and is replaced by `fill_value`
15/// when the filled view is requested.
16#[pyclass(name = "MaskedArray")]
17pub struct MaskedArray {
18 /// Flat data buffer.
19 data: Vec<f64>,
20 /// Flat mask buffer; parallel to `data`.
21 mask: Vec<bool>,
22 /// Logical shape of the array.
23 shape: Vec<usize>,
24 /// Value used to fill masked positions in [`Self::filled`].
25 fill_value: f64,
26}
27
28#[pymethods]
29impl MaskedArray {
30 /// Construct a new masked array.
31 ///
32 /// # Arguments
33 /// * `data` – flat element buffer; length must equal the product of `shape`.
34 /// * `mask` – optional flat mask; defaults to all-`false` (nothing masked).
35 /// * `shape` – logical shape; must have product equal to `data.len()`.
36 /// * `fill_value` – value substituted for masked elements (default: `f64::NAN`).
37 #[new]
38 pub fn new(
39 data: Vec<f64>,
40 mask: Option<Vec<bool>>,
41 shape: Vec<usize>,
42 fill_value: Option<f64>,
43 ) -> PyResult<Self> {
44 let n: usize = shape.iter().product();
45 if data.len() != n {
46 return Err(pyo3::exceptions::PyValueError::new_err(format!(
47 "data length {} does not match shape product {}",
48 data.len(),
49 n
50 )));
51 }
52 let mask = mask.unwrap_or_else(|| vec![false; n]);
53 if mask.len() != n {
54 return Err(pyo3::exceptions::PyValueError::new_err(
55 "mask length does not match shape product",
56 ));
57 }
58 Ok(Self {
59 data,
60 mask,
61 shape,
62 fill_value: fill_value.unwrap_or(f64::NAN),
63 })
64 }
65
66 /// Return the data with masked positions replaced by `fill_value`.
67 pub fn filled(&self) -> Vec<f64> {
68 self.data
69 .iter()
70 .zip(self.mask.iter())
71 .map(|(&d, &m)| if m { self.fill_value } else { d })
72 .collect()
73 }
74
75 /// Return the number of unmasked (valid) elements.
76 pub fn count(&self) -> usize {
77 self.mask.iter().filter(|&&m| !m).count()
78 }
79
80 /// Return the mean of unmasked elements, or `None` if all elements are masked.
81 pub fn mean(&self) -> Option<f64> {
82 let valid: Vec<f64> = self
83 .data
84 .iter()
85 .zip(self.mask.iter())
86 .filter(|(_, &m)| !m)
87 .map(|(&d, _)| d)
88 .collect();
89 if valid.is_empty() {
90 None
91 } else {
92 Some(valid.iter().sum::<f64>() / valid.len() as f64)
93 }
94 }
95
96 /// Return the sum of unmasked elements.
97 pub fn sum(&self) -> f64 {
98 self.data
99 .iter()
100 .zip(self.mask.iter())
101 .filter(|(_, &m)| !m)
102 .map(|(&d, _)| d)
103 .sum()
104 }
105
106 /// Return the logical shape.
107 pub fn shape(&self) -> Vec<usize> {
108 self.shape.clone()
109 }
110
111 /// Return the raw data buffer.
112 pub fn data(&self) -> Vec<f64> {
113 self.data.clone()
114 }
115
116 /// Return the mask buffer.
117 pub fn mask(&self) -> Vec<bool> {
118 self.mask.clone()
119 }
120
121 /// Return the fill value used for masked positions.
122 pub fn fill_value(&self) -> f64 {
123 self.fill_value
124 }
125
126 /// Set the mask flag for a single flat element index.
127 pub fn mask_element(&mut self, idx: usize, masked: bool) -> PyResult<()> {
128 if idx >= self.mask.len() {
129 return Err(pyo3::exceptions::PyIndexError::new_err(
130 "index out of bounds",
131 ));
132 }
133 self.mask[idx] = masked;
134 Ok(())
135 }
136
137 /// Apply an element-wise unary operation to unmasked values.
138 ///
139 /// Masked positions receive `fill_value`. Supported operations:
140 /// - `"abs"` – absolute value
141 /// - `"sqrt"` – square root
142 /// - `"log"` – natural logarithm
143 pub fn apply_unmasked(&self, op: &str) -> PyResult<Vec<f64>> {
144 let fill = self.fill_value;
145 match op {
146 "abs" => Ok(self
147 .data
148 .iter()
149 .zip(self.mask.iter())
150 .map(|(&d, &m)| if m { fill } else { d.abs() })
151 .collect()),
152 "sqrt" => Ok(self
153 .data
154 .iter()
155 .zip(self.mask.iter())
156 .map(|(&d, &m)| if m { fill } else { d.sqrt() })
157 .collect()),
158 "log" => Ok(self
159 .data
160 .iter()
161 .zip(self.mask.iter())
162 .map(|(&d, &m)| if m { fill } else { d.ln() })
163 .collect()),
164 _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
165 "unknown operation '{op}'; supported: abs, sqrt, log"
166 ))),
167 }
168 }
169}
170
171/// Create a 1-D masked array from parallel data and mask vectors.
172///
173/// This mirrors `numpy.ma.array(data, mask=mask)`.
174#[pyfunction]
175pub fn masked_array(data: Vec<f64>, mask: Vec<bool>) -> PyResult<MaskedArray> {
176 let n = data.len();
177 MaskedArray::new(data, Some(mask), vec![n], None)
178}
179
180/// Create a 1-D masked array with all elements below `threshold` masked.
181///
182/// Mirrors `numpy.ma.masked_less(data, threshold)`.
183#[pyfunction]
184pub fn masked_less(data: Vec<f64>, threshold: f64) -> MaskedArray {
185 let n = data.len();
186 let mask: Vec<bool> = data.iter().map(|&d| d < threshold).collect();
187 MaskedArray {
188 data,
189 mask,
190 shape: vec![n],
191 fill_value: f64::NAN,
192 }
193}
194
195/// Register masked-array classes and functions into a PyO3 module.
196///
197/// Call this from your `#[pymodule]` init function to expose `MaskedArray`,
198/// `masked_array`, and `masked_less`.
199pub fn register_masked_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
200 m.add_class::<MaskedArray>()?;
201 m.add_function(wrap_pyfunction!(masked_array, m)?)?;
202 m.add_function(wrap_pyfunction!(masked_less, m)?)?;
203 Ok(())
204}