vortex_compute/mask/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Compute function for masking the validity of vectors.
5
6use std::ops::BitAnd;
7
8use vortex_dtype::{NativeDecimalType, NativePType};
9use vortex_mask::Mask;
10use vortex_vector::binaryview::{BinaryViewType, BinaryViewVector};
11use vortex_vector::bool::BoolVector;
12use vortex_vector::decimal::{DVector, DecimalVector};
13use vortex_vector::fixed_size_list::FixedSizeListVector;
14use vortex_vector::listview::ListViewVector;
15use vortex_vector::null::NullVector;
16use vortex_vector::primitive::{PVector, PrimitiveVector};
17use vortex_vector::struct_::StructVector;
18use vortex_vector::{Vector, match_each_dvector, match_each_pvector, match_each_vector};
19
20/// Trait for masking the validity of an array or vector.
21pub trait MaskValidity {
22    /// Masks the validity of the object using the provided mask.
23    ///
24    /// The output has its validity intersected with the given mask, resulting in a new validity
25    /// with equal or fewer valid entries.
26    fn mask_validity(self, mask: &Mask) -> Self;
27}
28
29impl MaskValidity for Vector {
30    fn mask_validity(self, mask: &Mask) -> Self {
31        match_each_vector!(self, |v| { MaskValidity::mask_validity(v, mask).into() })
32    }
33}
34
35impl MaskValidity for NullVector {
36    fn mask_validity(self, _mask: &Mask) -> Self {
37        // Null vectors have no validity to mask; they are always fully null.
38        self
39    }
40}
41
42impl MaskValidity for BoolVector {
43    fn mask_validity(self, mask: &Mask) -> Self {
44        let (bits, validity) = self.into_parts();
45        // SAFETY: we are preserving the original bits buffer and only modifying the validity.
46        unsafe { Self::new_unchecked(bits, validity.bitand(mask)) }
47    }
48}
49
50impl MaskValidity for PrimitiveVector {
51    fn mask_validity(self, mask: &Mask) -> Self {
52        match_each_pvector!(self, |v| { MaskValidity::mask_validity(v, mask).into() })
53    }
54}
55
56impl<T: NativePType> MaskValidity for PVector<T> {
57    fn mask_validity(self, mask: &Mask) -> Self {
58        let (data, validity) = self.into_parts();
59        // SAFETY: we are preserving the original data buffer and only modifying the validity.
60        unsafe { Self::new_unchecked(data, validity.bitand(mask)) }
61    }
62}
63
64impl MaskValidity for DecimalVector {
65    fn mask_validity(self, mask: &Mask) -> Self {
66        match_each_dvector!(self, |v| { MaskValidity::mask_validity(v, mask).into() })
67    }
68}
69
70impl<D: NativeDecimalType> MaskValidity for DVector<D> {
71    fn mask_validity(self, mask: &Mask) -> Self {
72        let (ps, elements, validity) = self.into_parts();
73        // SAFETY: we are preserving the original elements buffer and only modifying the validity.
74        unsafe { Self::new_unchecked(ps, elements, validity.bitand(mask)) }
75    }
76}
77
78impl<T: BinaryViewType> MaskValidity for BinaryViewVector<T> {
79    fn mask_validity(self, mask: &Mask) -> Self {
80        let (views, buffers, validity) = self.into_parts();
81        // SAFETY: we are preserving the original views and buffers, only modifying the validity.
82        unsafe { Self::new_unchecked(views, buffers, validity.bitand(mask)) }
83    }
84}
85
86impl MaskValidity for ListViewVector {
87    fn mask_validity(self, mask: &Mask) -> Self {
88        let (elements, offsets, sizes, validity) = self.into_parts();
89        // SAFETY: we are preserving the original elements and `list_size`, only modifying the
90        // validity.
91        unsafe { Self::new_unchecked(elements, offsets, sizes, validity.bitand(mask)) }
92    }
93}
94
95impl MaskValidity for FixedSizeListVector {
96    fn mask_validity(self, mask: &Mask) -> Self {
97        let (elements, list_size, validity) = self.into_parts();
98        // SAFETY: we are preserving the original elements and `list_size`, only modifying the
99        // validity.
100        unsafe { Self::new_unchecked(elements, list_size, validity.bitand(mask)) }
101    }
102}
103
104impl MaskValidity for StructVector {
105    fn mask_validity(self, mask: &Mask) -> Self {
106        let (fields, validity) = self.into_parts();
107        // SAFETY: we are preserving the original fields and only modifying the validity.
108        unsafe { StructVector::new_unchecked(fields, validity.bitand(mask)) }
109    }
110}