scirs2_core/array/
masked_array.rs

1//! Implementation of masked arrays for handling missing or invalid data
2//!
3//! Masked arrays are useful in scientific computing when working with:
4//! - Data that contains missing values
5//! - Operations that produce invalid results (NaN, Inf)
6//! - Operations that should be applied only to a subset of data
7//! - Statistical computations that should ignore certain values
8//!
9//! The implementation is inspired by ``NumPy``'s `MaskedArray` and provides similar functionality
10//! in a Rust-native way.
11
12use ndarray::{Array, ArrayBase, Data, Dimension, Ix1};
13use num_traits::{Float, Zero};
14use std::cmp::PartialEq;
15use std::fmt;
16use std::ops::{Add, Div, Mul, Sub};
17
18/// Error type for array operations
19#[derive(Debug, Clone)]
20pub enum ArrayError {
21    /// Shape mismatch error
22    ShapeMismatch {
23        expected: Vec<usize>,
24        found: Vec<usize>,
25        msg: String,
26    },
27    /// Value error
28    ValueError(String),
29}
30
31impl std::fmt::Display for ArrayError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::ShapeMismatch {
35                expected,
36                found,
37                msg,
38            } => {
39                write!(
40                    f,
41                    "Shape mismatch: expected {expected:?}, found {found:?}: {msg}"
42                )
43            }
44            Self::ValueError(msg) => write!(f, "Value error: {msg}"),
45        }
46    }
47}
48
49impl std::error::Error for ArrayError {}
50
51/// Represents an array with a mask to identify invalid or missing values
52#[derive(Clone)]
53pub struct MaskedArray<A, S, D>
54where
55    A: Clone,
56    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
57    D: Dimension,
58{
59    /// The underlying data array
60    pub data: ArrayBase<S, D>,
61
62    /// The mask array: true = masked (invalid), false = valid
63    pub mask: Array<bool, D>,
64
65    /// The fill value used when creating a filled array
66    pub fill_value: A,
67}
68
69/// Represents a "no mask" indicator, which is equivalent to an array of all false
70pub struct NoMask;
71
72/// The global "no mask" constant
73pub const NOMASK: NoMask = NoMask;
74
75impl<A, S, D> MaskedArray<A, S, D>
76where
77    A: Clone + PartialEq,
78    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
79    D: Dimension,
80{
81    /// Create a new `MaskedArray` from data and mask
82    ///
83    /// # Errors
84    /// Returns `ArrayError::ShapeMismatch` if the mask shape doesn't match the data shape.
85    pub fn new(
86        data: ArrayBase<S, D>,
87        mask: Option<Array<bool, D>>,
88        fill_value: Option<A>,
89    ) -> Result<Self, ArrayError> {
90        let mask = match mask {
91            Some(m) => {
92                // Validate mask shape matches data shape
93                if m.shape() != data.shape() {
94                    return Err(ArrayError::ShapeMismatch {
95                        expected: data.shape().to_vec(),
96                        found: m.shape().to_vec(),
97                        msg: "Mask shape must match data shape".to_string(),
98                    });
99                }
100                m
101            }
102            None => Array::<bool, D>::from_elem(data.raw_dim(), false),
103        };
104
105        // Use provided fill value or create a default
106        let fill_value = fill_value.map_or_else(|| default_fill_value(&data), |v| v);
107
108        Ok(Self {
109            data,
110            mask,
111            fill_value,
112        })
113    }
114
115    /// Get a view of the data with masked values replaced by `fill_value`
116    pub fn value_2(&self, fillvalue: Option<A>) -> Array<A, D>
117    where
118        <D as Dimension>::Pattern: ndarray::NdIndex<D>,
119    {
120        let fill = fillvalue.map_or_else(|| self.fill_value.clone(), |v| v);
121
122        // Create new array with same shape as data
123        let mut result = Array::from_elem(self.data.raw_dim(), fill);
124
125        // Copy unmasked values from original data
126        for (i, val) in self.data.iter().enumerate() {
127            if !*self.mask.iter().nth(i).unwrap_or(&true) {
128                // Only copy if not masked
129                if let Some(v) = result.iter_mut().nth(i) {
130                    *v = val.clone();
131                }
132            }
133        }
134
135        result
136    }
137
138    /// Returns true if the array has at least one masked element
139    pub fn has_masked(&self) -> bool {
140        self.mask.iter().any(|&x| x)
141    }
142
143    /// Returns the count of non-masked elements
144    pub fn count(&self) -> usize {
145        self.mask.iter().filter(|&&x| !x).count()
146    }
147
148    /// Get a copy of the current mask
149    pub fn get_mask(&self) -> Array<bool, D> {
150        self.mask.clone()
151    }
152
153    /// Set a new mask for the array
154    ///
155    /// # Errors
156    /// Returns `ArrayError::ShapeMismatch` if the mask shape doesn't match the data shape.
157    pub fn set_mask(&mut self, mask: Array<bool, D>) -> Result<(), ArrayError> {
158        // Validate mask shape
159        if mask.shape() != self.data.shape() {
160            return Err(ArrayError::ShapeMismatch {
161                expected: self.data.shape().to_vec(),
162                found: mask.shape().to_vec(),
163                msg: "Mask shape must match data shape".to_string(),
164            });
165        }
166
167        self.mask = mask;
168        Ok(())
169    }
170
171    /// Set the fill value for the array
172    pub fn value_3(&mut self, fillvalue: A) {
173        self.fill_value = fillvalue;
174    }
175
176    /// Returns a new array containing only unmasked values
177    pub fn compressed(&self) -> Array<A, Ix1> {
178        // Count non-masked elements
179        let count = self.count();
180
181        // Create output array
182        let mut result = Vec::with_capacity(count);
183
184        // Fill output array with non-masked elements
185        for (i, val) in self.data.iter().enumerate() {
186            if !self.mask.iter().nth(i).unwrap_or(&true) {
187                result.push(val.clone());
188            }
189        }
190
191        // Convert to ndarray
192        Array::from_vec(result)
193    }
194
195    /// Returns the shape of the array
196    pub fn shape(&self) -> &[usize] {
197        self.data.shape()
198    }
199
200    /// Returns the number of dimensions of the array
201    pub fn ndim(&self) -> usize {
202        self.data.ndim()
203    }
204
205    /// Returns the number of elements in the array
206    pub fn size(&self) -> usize {
207        self.data.len()
208    }
209
210    /// Returns a tuple of (data, mask)
211    pub const fn data_and_mask(&self) -> (&ArrayBase<S, D>, &Array<bool, D>) {
212        (&self.data, &self.mask)
213    }
214
215    /// Creates a new masked array with the given mask operation applied
216    #[must_use]
217    pub fn mask_op<F>(&self, op: F) -> Self
218    where
219        F: Fn(&Array<bool, D>) -> Array<bool, D>,
220    {
221        let new_mask = op(&self.mask);
222
223        Self {
224            data: self.data.clone(),
225            mask: new_mask,
226            fill_value: self.fill_value.clone(),
227        }
228    }
229
230    /// Creates a new masked array with a hardened mask (copy)
231    #[must_use]
232    pub fn harden_mask(&self) -> Self {
233        // Create a copy with the same mask
234        Self {
235            data: self.data.clone(),
236            mask: self.mask.clone(),
237            fill_value: self.fill_value.clone(),
238        }
239    }
240
241    /// Create a new masked array with a softened mask (copy)
242    #[must_use]
243    pub fn soften_mask(&self) -> Self {
244        // Create a copy with the same mask
245        Self {
246            data: self.data.clone(),
247            mask: self.mask.clone(),
248            fill_value: self.fill_value.clone(),
249        }
250    }
251
252    /// Create a new masked array where the result of applying the function to each element is masked
253    #[must_use]
254    pub fn mask_where<F>(&self, condition: F) -> Self
255    where
256        F: Fn(&A) -> bool,
257    {
258        // Apply condition to each element of data
259        let new_mask = self.data.mapv(|x| condition(&x));
260
261        // Combine with existing mask
262        let combined_mask = &self.mask | &new_mask;
263
264        Self {
265            data: self.data.clone(),
266            mask: combined_mask,
267            fill_value: self.fill_value.clone(),
268        }
269    }
270
271    /// Create a logical OR of the mask with another mask
272    ///
273    /// # Errors
274    /// Returns `ArrayError::ShapeMismatch` if the mask shapes don't match.
275    pub fn mask_or(&self, othermask: &Array<bool, D>) -> Result<Self, ArrayError> {
276        // Check that shapes match
277        if self.mask.shape() != othermask.shape() {
278            return Err(ArrayError::ShapeMismatch {
279                expected: self.mask.shape().to_vec(),
280                found: othermask.shape().to_vec(),
281                msg: "Mask shapes must match for mask_or operation".to_string(),
282            });
283        }
284
285        // Combine masks
286        let combined_mask = &self.mask | othermask;
287
288        Ok(Self {
289            data: self.data.clone(),
290            mask: combined_mask,
291            fill_value: self.fill_value.clone(),
292        })
293    }
294
295    /// Create a logical AND of the mask with another mask
296    ///
297    /// # Errors
298    /// Returns `ArrayError::ShapeMismatch` if the mask shapes don't match.
299    pub fn mask_and(&self, othermask: &Array<bool, D>) -> Result<Self, ArrayError> {
300        // Check that shapes match
301        if self.mask.shape() != othermask.shape() {
302            return Err(ArrayError::ShapeMismatch {
303                expected: self.mask.shape().to_vec(),
304                found: othermask.shape().to_vec(),
305                msg: "Mask shapes must match for mask_and operation".to_string(),
306            });
307        }
308
309        // Combine masks
310        let combined_mask = &self.mask & othermask;
311
312        Ok(Self {
313            data: self.data.clone(),
314            mask: combined_mask,
315            fill_value: self.fill_value.clone(),
316        })
317    }
318
319    /// Reshape the masked array
320    ///
321    /// # Errors
322    /// Returns `ArrayError::ValueError` if the reshape operation fails.
323    pub fn reshape<E>(
324        &self,
325        shape: E,
326    ) -> Result<MaskedArray<A, ndarray::OwnedRepr<A>, E>, ArrayError>
327    where
328        E: Dimension,
329        D: Dimension,
330    {
331        // Try to reshape the data and mask
332        let reshaped_data = match self.data.clone().into_shape_with_order(shape.clone()) {
333            Ok(d) => d,
334            Err(e) => {
335                return Err(ArrayError::ValueError(format!(
336                    "Failed to reshape data: {e}"
337                )))
338            }
339        };
340
341        let reshaped_mask = match self.mask.clone().into_shape_with_order(shape) {
342            Ok(m) => m,
343            Err(e) => {
344                return Err(ArrayError::ValueError(format!(
345                    "Failed to reshape mask: {e}"
346                )))
347            }
348        };
349
350        Ok(MaskedArray {
351            data: reshaped_data.into_owned(),
352            mask: reshaped_mask,
353            fill_value: self.fill_value.clone(),
354        })
355    }
356
357    /// Convert to a different type
358    ///
359    /// # Errors
360    /// Currently this method doesn't return errors, but the signature is kept for future compatibility.
361    pub fn astype<B>(&self) -> Result<MaskedArray<B, ndarray::OwnedRepr<B>, D>, ArrayError>
362    where
363        A: Into<B> + Clone,
364        B: Clone + PartialEq + 'static,
365    {
366        // Convert each element
367        let converted_data = self.data.mapv(std::convert::Into::into);
368
369        Ok(MaskedArray {
370            data: converted_data,
371            mask: self.mask.clone(),
372            fill_value: self.fill_value.clone().into(),
373        })
374    }
375}
376
377/// Implementation of statistical methods
378impl<A, S, D> MaskedArray<A, S, D>
379where
380    A: Clone + PartialEq + num_traits::NumAssign + num_traits::Zero + num_traits::One + PartialOrd,
381    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
382    D: Dimension,
383{
384    /// Compute the sum of all unmasked elements
385    pub fn sum(&self) -> A {
386        let mut sum = A::zero();
387
388        for (i, val) in self.data.iter().enumerate() {
389            if !*self.mask.iter().nth(i).unwrap_or(&true) {
390                sum += val.clone();
391            }
392        }
393
394        sum
395    }
396
397    /// Compute the product of all unmasked elements
398    pub fn product(&self) -> A {
399        let mut product = A::one();
400
401        for (i, val) in self.data.iter().enumerate() {
402            if !*self.mask.iter().nth(i).unwrap_or(&true) {
403                product *= val.clone();
404            }
405        }
406
407        product
408    }
409
410    /// Find the minimum value among unmasked elements
411    pub fn min(&self) -> Option<A> {
412        let mut min_val = None;
413
414        for (i, val) in self.data.iter().enumerate() {
415            if !*self.mask.iter().nth(i).unwrap_or(&true) {
416                if let Some(ref current_min) = min_val {
417                    if val < current_min {
418                        min_val = Some(val.clone());
419                    }
420                } else {
421                    min_val = Some(val.clone());
422                }
423            }
424        }
425
426        min_val
427    }
428
429    /// Find the maximum value among unmasked elements
430    pub fn max(&self) -> Option<A> {
431        let mut max_val = None;
432
433        for (i, val) in self.data.iter().enumerate() {
434            if !*self.mask.iter().nth(i).unwrap_or(&true) {
435                if let Some(ref current_max) = max_val {
436                    if val > current_max {
437                        max_val = Some(val.clone());
438                    }
439                } else {
440                    max_val = Some(val.clone());
441                }
442            }
443        }
444
445        max_val
446    }
447
448    /// Find the index of the minimum value among unmasked elements
449    pub fn argmin(&self) -> Option<usize> {
450        let mut min_idx = None;
451        let mut min_val = None;
452
453        for (i, val) in self.data.iter().enumerate() {
454            if !*self.mask.iter().nth(i).unwrap_or(&true) {
455                if let Some(ref current_min) = min_val {
456                    if val < current_min {
457                        min_val = Some(val.clone());
458                        min_idx = Some(i);
459                    }
460                } else {
461                    min_val = Some(val.clone());
462                    min_idx = Some(0);
463                }
464            }
465        }
466
467        min_idx
468    }
469
470    /// Find the index of the maximum value among unmasked elements
471    pub fn argmax(&self) -> Option<usize> {
472        let mut max_idx = None;
473        let mut max_val = None;
474
475        for (i, val) in self.data.iter().enumerate() {
476            if !*self.mask.iter().nth(i).unwrap_or(&true) {
477                if let Some(ref current_max) = max_val {
478                    if val > current_max {
479                        max_val = Some(val.clone());
480                        max_idx = Some(i);
481                    }
482                } else {
483                    max_val = Some(val.clone());
484                    max_idx = Some(0);
485                }
486            }
487        }
488
489        max_idx
490    }
491}
492
493/// Implementation of statistical methods for floating point types
494impl<A, S, D> MaskedArray<A, S, D>
495where
496    A: Clone + PartialEq + num_traits::Float + std::iter::Sum<A>,
497    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
498    D: Dimension,
499{
500    /// Compute the mean of all unmasked elements
501    ///
502    /// Returns `None` if there are no unmasked elements or if count conversion fails.
503    pub fn mean(&self) -> Option<A> {
504        let count = self.count();
505
506        if count == 0 {
507            return None;
508        }
509
510        let sum: A = self
511            .data
512            .iter()
513            .enumerate()
514            .filter(|(i, _)| !*self.mask.iter().nth(*i).unwrap_or(&true))
515            .map(|(_, val)| *val)
516            .sum();
517
518        // Safe conversion with proper error handling
519        A::from(count).map(|count_val| sum / count_val)
520    }
521
522    /// Compute the variance of all unmasked elements
523    ///
524    /// Returns `None` if there are insufficient unmasked elements or if count conversion fails.
525    pub fn var(&self, ddof: usize) -> Option<A> {
526        let count = self.count();
527
528        if count <= ddof {
529            return None;
530        }
531
532        // Calculate mean
533        let mean = self.mean()?;
534
535        // Calculate sum of squared differences
536        let sum_sq_diff: A = self
537            .data
538            .iter()
539            .enumerate()
540            .filter(|(i, _)| !*self.mask.iter().nth(*i).unwrap_or(&true))
541            .map(|(_, val)| (*val - mean) * (*val - mean))
542            .sum();
543
544        // Apply degrees of freedom correction with safe conversion
545        A::from(count - ddof).map(|denom| sum_sq_diff / denom)
546    }
547
548    /// Compute the standard deviation of all unmasked elements
549    pub fn std(&self, ddof: usize) -> Option<A> {
550        self.var(ddof).map(num_traits::Float::sqrt)
551    }
552
553    /// Check if all unmasked elements are finite
554    pub fn all_finite(&self) -> bool {
555        self.data
556            .iter()
557            .enumerate()
558            .filter(|(i, _)| !*self.mask.iter().nth(*i).unwrap_or(&true))
559            .all(|(_, val)| val.is_finite())
560    }
561}
562
563/// Helper function to create a default fill value for a given type
564#[allow(dead_code)]
565fn default_fill_value<A, S, D>(data: &ArrayBase<S, D>) -> A
566where
567    A: Clone,
568    S: Data<Elem = A>,
569    D: Dimension,
570{
571    // In a real implementation, this would use type traits to determine
572    // appropriate default values based on the type (like `NumPy` does)
573    data.iter().next().map_or_else(
574        || {
575            // This is a placeholder - in reality you'd need to handle this by type
576            panic!("Cannot determine default fill value for empty array");
577        },
578        std::clone::Clone::clone,
579    )
580}
581
582/// Function to check if a value is masked
583pub const fn is_masked<A>(value: &A) -> bool
584where
585    A: PartialEq,
586{
587    // In `NumPy` this would check against the masked singleton
588    // Here we just return false as a placeholder
589    false
590}
591
592/// Create a masked array with elements equal to a given value masked
593#[allow(dead_code)]
594pub fn masked_equal<A, S, D>(data: ArrayBase<S, D>, value: A) -> MaskedArray<A, S, D>
595where
596    A: Clone + PartialEq,
597    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
598    D: Dimension,
599{
600    // Create a mask indicating where elements equal the value
601    let mask = data.mapv(|x| x == value);
602
603    MaskedArray {
604        data,
605        mask,
606        fill_value: value,
607    }
608}
609
610/// Create a masked array with NaN and infinite values masked
611#[allow(dead_code)]
612pub fn masked_invalid<A, S, D>(data: ArrayBase<S, D>) -> MaskedArray<A, S, D>
613where
614    A: Clone + PartialEq + Float,
615    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
616    D: Dimension,
617{
618    // Create a mask indicating where elements are NaN or infinite
619    let mask = data.mapv(|val| val.is_nan() || val.is_infinite());
620
621    let fill_value = A::nan();
622
623    MaskedArray {
624        data,
625        mask,
626        fill_value,
627    }
628}
629
630/// Create a masked array
631///
632/// # Errors
633/// Returns `ArrayError::ShapeMismatch` if the mask shape doesn't match the data shape.
634#[allow(dead_code)]
635pub fn mask_array<A, S, D>(
636    data: ArrayBase<S, D>,
637    mask: Option<Array<bool, D>>,
638    fill_value: Option<A>,
639) -> Result<MaskedArray<A, S, D>, ArrayError>
640where
641    A: Clone + PartialEq,
642    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
643    D: Dimension,
644{
645    MaskedArray::new(data, mask, fill_value)
646}
647
648/// Create a masked array with values outside a range masked
649#[allow(dead_code)]
650pub fn masked_outside<A, S, D>(
651    data: ArrayBase<S, D>,
652    min_val: &A,
653    max_val: &A,
654) -> MaskedArray<A, S, D>
655where
656    A: Clone + PartialEq + PartialOrd,
657    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
658    D: Dimension,
659{
660    // Create a mask indicating where elements are outside the range
661    let mask = data.mapv(|x| x < *min_val || x > *max_val);
662
663    // Choose a fill value (using min_val as default)
664    let fill_value = min_val.clone();
665
666    MaskedArray {
667        data,
668        mask,
669        fill_value,
670    }
671}
672
673/// Create a masked array with values inside a range masked
674#[allow(dead_code)]
675pub fn masked_inside<A, S, D>(
676    data: ArrayBase<S, D>,
677    min_val: &A,
678    max_val: &A,
679) -> MaskedArray<A, S, D>
680where
681    A: Clone + PartialEq + PartialOrd,
682    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
683    D: Dimension,
684{
685    // Create a mask indicating where elements are inside the range
686    let mask = data.mapv(|x| x >= *min_val && x <= *max_val);
687
688    // Choose a fill value (using min_val as default)
689    let fill_value = min_val.clone();
690
691    MaskedArray {
692        data,
693        mask,
694        fill_value,
695    }
696}
697
698/// Create a masked array with values greater than a given value masked
699#[allow(dead_code)]
700pub fn masked_greater<A, S, D>(data: ArrayBase<S, D>, value: &A) -> MaskedArray<A, S, D>
701where
702    A: Clone + PartialEq + PartialOrd,
703    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
704    D: Dimension,
705{
706    // Create a mask indicating where elements are greater than the value
707    let mask = data.mapv(|x| x > *value);
708
709    // Use the specified value as fill_value
710    let fill_value = value.clone();
711
712    MaskedArray {
713        data,
714        mask,
715        fill_value,
716    }
717}
718
719/// Create a masked array with values less than a given value masked
720#[allow(dead_code)]
721pub fn masked_less<A, S, D>(data: ArrayBase<S, D>, value: &A) -> MaskedArray<A, S, D>
722where
723    A: Clone + PartialEq + PartialOrd,
724    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
725    D: Dimension,
726{
727    // Create a mask indicating where elements are less than the value
728    let mask = data.mapv(|x| x < *value);
729
730    // Use the specified value as fill_value
731    let fill_value = value.clone();
732
733    MaskedArray {
734        data,
735        mask,
736        fill_value,
737    }
738}
739
740/// Create a masked array with values where a condition is true
741#[allow(dead_code)]
742pub fn masked_where<A, S, D, F>(
743    condition: F,
744    data: ArrayBase<S, D>,
745    fill_value: Option<A>,
746) -> MaskedArray<A, S, D>
747where
748    A: Clone + PartialEq,
749    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
750    D: Dimension,
751    F: Fn(&A) -> bool,
752{
753    // Create a mask by applying the condition function to each element
754    let mask = data.map(condition);
755
756    // Use provided fill value or create a default
757    let fill_value = fill_value.map_or_else(|| default_fill_value(&data), |v| v);
758
759    MaskedArray {
760        data,
761        mask,
762        fill_value,
763    }
764}
765
766/// Implementation of Display for `MaskedArray`
767impl<A, S, D> fmt::Display for MaskedArray<A, S, D>
768where
769    A: Clone + PartialEq + fmt::Display,
770    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
771    D: Dimension,
772{
773    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
774        writeln!(f, "MaskedArray(")?;
775
776        writeln!(f, "  data=[")?;
777        for (i, elem) in self.data.iter().enumerate() {
778            if i > 0 && i % 10 == 0 {
779                writeln!(f)?;
780            }
781            if *self.mask.iter().nth(0).unwrap_or(&false) {
782                write!(f, " --,")?;
783            } else {
784                write!(f, " {elem},")?;
785            }
786        }
787        writeln!(f, "\n  ],")?;
788
789        writeln!(f, "  mask=[")?;
790        for (i, &elem) in self.mask.iter().enumerate() {
791            if i > 0 && i % 10 == 0 {
792                writeln!(f)?;
793            }
794            write!(f, " {elem},")?;
795        }
796        writeln!(f, "\n  ],")?;
797
798        writeln!(f, "  fill_value={}", self.fill_value)?;
799        write!(f, ")")
800    }
801}
802
803/// Implementation of Debug for `MaskedArray`
804impl<A, S, D> fmt::Debug for MaskedArray<A, S, D>
805where
806    A: Clone + PartialEq + fmt::Debug,
807    S: Data<Elem = A> + Clone + ndarray::RawDataClone,
808    D: Dimension,
809{
810    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
811        f.debug_struct("MaskedArray")
812            .field("data", &self.data)
813            .field("mask", &self.mask)
814            .field("fill_value", &self.fill_value)
815            .finish()
816    }
817}
818
819// Arithmetic operations for MaskedArray
820// These could be expanded to handle more operations and types
821
822impl<A, S1, S2, D> Add<&MaskedArray<A, S2, D>> for &MaskedArray<A, S1, D>
823where
824    A: Clone + Add<Output = A> + PartialEq,
825    S1: Data<Elem = A> + Clone + ndarray::RawDataClone,
826    S2: Data<Elem = A> + Clone + ndarray::RawDataClone,
827    D: Dimension,
828{
829    type Output = MaskedArray<A, ndarray::OwnedRepr<A>, D>;
830
831    fn add(self, rhs: &MaskedArray<A, S2, D>) -> Self::Output {
832        // Create combined mask: true if either input is masked
833        let combined_mask = &self.mask | &rhs.mask;
834
835        // Create output data by adding input data
836        let data = &self.data + &rhs.data;
837
838        MaskedArray {
839            data,
840            mask: combined_mask,
841            fill_value: self.fill_value.clone(),
842        }
843    }
844}
845
846impl<A, S1, S2, D> Sub<&MaskedArray<A, S2, D>> for &MaskedArray<A, S1, D>
847where
848    A: Clone + Sub<Output = A> + PartialEq,
849    S1: Data<Elem = A> + Clone + ndarray::RawDataClone,
850    S2: Data<Elem = A> + Clone + ndarray::RawDataClone,
851    D: Dimension,
852{
853    type Output = MaskedArray<A, ndarray::OwnedRepr<A>, D>;
854
855    fn sub(self, rhs: &MaskedArray<A, S2, D>) -> Self::Output {
856        // Create combined mask: true if either input is masked
857        let combined_mask = &self.mask | &rhs.mask;
858
859        // Create output data by subtracting input data
860        let data = &self.data - &rhs.data;
861
862        MaskedArray {
863            data,
864            mask: combined_mask,
865            fill_value: self.fill_value.clone(),
866        }
867    }
868}
869
870impl<A, S1, S2, D> Mul<&MaskedArray<A, S2, D>> for &MaskedArray<A, S1, D>
871where
872    A: Clone + Mul<Output = A> + PartialEq,
873    S1: Data<Elem = A> + Clone + ndarray::RawDataClone,
874    S2: Data<Elem = A> + Clone + ndarray::RawDataClone,
875    D: Dimension,
876{
877    type Output = MaskedArray<A, ndarray::OwnedRepr<A>, D>;
878
879    fn mul(self, rhs: &MaskedArray<A, S2, D>) -> Self::Output {
880        // Create combined mask: true if either input is masked
881        let combined_mask = &self.mask | &rhs.mask;
882
883        // Create output data by multiplying input data
884        let data = &self.data * &rhs.data;
885
886        MaskedArray {
887            data,
888            mask: combined_mask,
889            fill_value: self.fill_value.clone(),
890        }
891    }
892}
893
894impl<A, S1, S2, D> Div<&MaskedArray<A, S2, D>> for &MaskedArray<A, S1, D>
895where
896    A: Clone + Div<Output = A> + PartialEq + Zero,
897    S1: Data<Elem = A> + Clone + ndarray::RawDataClone,
898    S2: Data<Elem = A> + Clone + ndarray::RawDataClone,
899    D: Dimension,
900{
901    type Output = MaskedArray<A, ndarray::OwnedRepr<A>, D>;
902
903    fn div(self, rhs: &MaskedArray<A, S2, D>) -> Self::Output {
904        // Create combined mask: true if either input is masked or rhs is zero
905        let mut combined_mask = &self.mask | &rhs.mask;
906
907        // Also mask division by zero
908        let zero = A::zero();
909        let additional_mask = rhs.data.mapv(|x| x == zero);
910
911        // Update combined mask to also mask division by zero
912        combined_mask = combined_mask | additional_mask;
913
914        // Create output data by dividing input data
915        let data = &self.data / &rhs.data;
916
917        MaskedArray {
918            data,
919            mask: combined_mask,
920            fill_value: self.fill_value.clone(),
921        }
922    }
923}
924
925// Add more arithmetic operations as needed
926
927#[cfg(test)]
928mod tests {
929    use super::*;
930    use ndarray::array;
931
932    #[test]
933    fn test_masked_array_creation() {
934        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
935        let mask = array![false, true, false, true, false];
936
937        let ma = MaskedArray::new(data.clone(), Some(mask.clone()), None)
938            .expect("Failed to create MaskedArray in test");
939
940        assert_eq!(ma.data, data);
941        assert_eq!(ma.mask, mask);
942        assert_eq!(ma.count(), 3);
943        assert!(ma.has_masked());
944    }
945
946    #[test]
947    fn test_masked_array_filled() {
948        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
949        let mask = array![false, true, false, true, false];
950
951        let ma = MaskedArray::new(data, Some(mask), Some(999.0))
952            .expect("Failed to create MaskedArray in test");
953
954        let filled = ma.value_2(None);
955        assert_eq!(filled, array![1.0, 999.0, 3.0, 999.0, 5.0]);
956
957        let filled_custom = ma.value_2(Some(-1.0));
958        assert_eq!(filled_custom, array![1.0, -1.0, 3.0, -1.0, 5.0]);
959    }
960
961    #[test]
962    fn test_masked_equal() {
963        let data = array![1.0, 2.0, 3.0, 2.0, 5.0];
964
965        let ma = masked_equal(data, 2.0);
966
967        assert_eq!(ma.mask, array![false, true, false, true, false]);
968        assert_eq!(ma.count(), 3);
969    }
970
971    #[test]
972    fn test_masked_invalid() {
973        let data = array![1.0, f64::NAN, 3.0, f64::INFINITY, 5.0];
974
975        let ma = masked_invalid(data);
976
977        // Cannot directly compare masks with NaN values using assert_eq
978        // So we check each element individually
979        assert!(!ma.mask[0]); // 1.0 is valid
980        assert!(ma.mask[1]); // NaN is invalid
981        assert!(!ma.mask[2]); // 3.0 is valid
982        assert!(ma.mask[3]); // INFINITY is invalid
983        assert!(!ma.mask[4]); // 5.0 is valid
984
985        assert_eq!(ma.count(), 3);
986    }
987
988    #[test]
989    fn test_masked_array_arithmetic() {
990        let a = MaskedArray::new(
991            array![1.0, 2.0, 3.0, 4.0, 5.0],
992            Some(array![false, true, false, false, false]),
993            Some(0.0),
994        )
995        .expect("Failed to create MaskedArray in test");
996
997        let b = MaskedArray::new(
998            array![5.0, 4.0, 3.0, 2.0, 1.0],
999            Some(array![false, false, false, true, false]),
1000            Some(0.0),
1001        )
1002        .expect("Failed to create MaskedArray in test");
1003
1004        // Addition
1005        let c = &a + &b;
1006        assert_eq!(c.data, array![6.0, 6.0, 6.0, 6.0, 6.0]);
1007        assert_eq!(c.mask, array![false, true, false, true, false]);
1008
1009        // Subtraction
1010        let d = &a - &b;
1011        assert_eq!(d.data, array![-4.0, -2.0, 0.0, 2.0, 4.0]);
1012        assert_eq!(d.mask, array![false, true, false, true, false]);
1013
1014        // Multiplication
1015        let e = &a * &b;
1016        assert_eq!(e.data, array![5.0, 8.0, 9.0, 8.0, 5.0]);
1017        assert_eq!(e.mask, array![false, true, false, true, false]);
1018
1019        // Division
1020        let f = &a / &b;
1021        assert_eq!(f.data, array![0.2, 0.5, 1.0, 2.0, 5.0]);
1022        assert_eq!(f.mask, array![false, true, false, true, false]);
1023
1024        // Division by zero
1025        let g = MaskedArray::new(
1026            array![1.0, 2.0, 3.0],
1027            Some(array![false, false, false]),
1028            Some(0.0),
1029        )
1030        .expect("Failed to create MaskedArray in test");
1031
1032        let h = MaskedArray::new(
1033            array![1.0, 0.0, 3.0],
1034            Some(array![false, false, false]),
1035            Some(0.0),
1036        )
1037        .expect("Failed to create MaskedArray in test");
1038
1039        let i = &g / &h;
1040        assert_eq!(i.mask, array![false, true, false]);
1041    }
1042
1043    #[test]
1044    fn test_compressed() {
1045        let ma = MaskedArray::new(
1046            array![1.0, 2.0, 3.0, 4.0, 5.0],
1047            Some(array![false, true, false, true, false]),
1048            Some(0.0),
1049        )
1050        .expect("Failed to create MaskedArray in test");
1051
1052        let compressed = ma.compressed();
1053        assert_eq!(compressed, array![1.0, 3.0, 5.0]);
1054    }
1055}