vortex_compute/take/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_buffer::get_bit;
6use vortex_dtype::UnsignedPType;
7use vortex_mask::Mask;
8use vortex_vector::VectorOps;
9use vortex_vector::primitive::PVector;
10
11use crate::take::LINUX_PAGE_SIZE;
12use crate::take::Take;
13
14impl<I: UnsignedPType> Take<[I]> for &Mask {
15    type Output = Mask;
16
17    fn take(self, indices: &[I]) -> Mask {
18        match self {
19            Mask::AllTrue(_) => Mask::AllTrue(indices.len()),
20            Mask::AllFalse(_) => Mask::AllFalse(indices.len()),
21            Mask::Values(mask_values) => {
22                let taken_bit_buffer = mask_values.bit_buffer().take(indices);
23                Mask::from_buffer(taken_bit_buffer)
24            }
25        }
26    }
27}
28
29impl<I: UnsignedPType> Take<PVector<I>> for &Mask {
30    type Output = Mask;
31
32    /// Implementation of take on [`Mask`] that is null-aware.
33    ///
34    /// If an index is specified as null by the [`PVector`], then the taken mask value is set to
35    /// `false`.
36    ///
37    /// This is useful for many of the `take` implementations for vectors.
38    fn take(self, indices: &PVector<I>) -> Mask {
39        let indices_validity = indices.validity();
40        let indices_len = indices.len();
41
42        let indices_validity_values = match indices_validity {
43            Mask::AllTrue(_) => return self.take(indices.elements().as_slice()),
44            Mask::AllFalse(_) => return Mask::AllFalse(indices_len),
45            Mask::Values(indices_validity_values) => indices_validity_values,
46        };
47
48        match self {
49            // Since all the values are true, the only false values will be from the indices.
50            Mask::AllTrue(_) => Mask::Values(indices_validity_values.clone()),
51            // Since all the values are already false, the indices nullability wont change anything.
52            Mask::AllFalse(_) => Mask::AllFalse(indices_len),
53            Mask::Values(mask_values) => {
54                // For boolean arrays that roughly fit into a single page (at least, on Linux), it's
55                // worth the overhead to convert to a `Vec<bool>`.
56                if self.len() <= LINUX_PAGE_SIZE {
57                    let bools = mask_values.bit_buffer().iter().collect();
58                    Mask::from_buffer(take_byte_bool_nullable(bools, indices))
59                } else {
60                    Mask::from_buffer(take_bool_nullable(mask_values.bit_buffer(), indices))
61                }
62            }
63        }
64    }
65}
66
67fn take_byte_bool_nullable<I: UnsignedPType>(bools: Vec<bool>, indices: &PVector<I>) -> BitBuffer {
68    BitBuffer::collect_bool(indices.len(), |idx| {
69        indices
70            .get(idx)
71            .is_some_and(|bool_idx| bools[bool_idx.as_()])
72    })
73}
74
75fn take_bool_nullable<I: UnsignedPType>(bools: &BitBuffer, indices: &PVector<I>) -> BitBuffer {
76    // We dereference to the underlying buffer to avoid incurring an access cost on every index.
77    let buffer = bools.inner().as_ref();
78    let offset = bools.offset();
79
80    BitBuffer::collect_bool(indices.len(), |idx| {
81        indices
82            .get(idx)
83            .is_some_and(|bool_idx| get_bit(buffer, offset + bool_idx.as_()))
84    })
85}