vortex_compute/take/
mask.rs1use vortex_buffer::BitBuffer;
5use vortex_buffer::get_bit;
6use vortex_dtype::UnsignedPType;
7use vortex_mask::Mask;
8use vortex_vector::VectorOps;
9use vortex_vector::primitive::PVector;
10
11use crate::take::LINUX_PAGE_SIZE;
12use crate::take::Take;
13
14impl<I: UnsignedPType> Take<[I]> for &Mask {
15 type Output = Mask;
16
17 fn take(self, indices: &[I]) -> Mask {
18 match self {
19 Mask::AllTrue(_) => Mask::AllTrue(indices.len()),
20 Mask::AllFalse(_) => Mask::AllFalse(indices.len()),
21 Mask::Values(mask_values) => {
22 let taken_bit_buffer = mask_values.bit_buffer().take(indices);
23 Mask::from_buffer(taken_bit_buffer)
24 }
25 }
26 }
27}
28
29impl<I: UnsignedPType> Take<PVector<I>> for &Mask {
30 type Output = Mask;
31
32 fn take(self, indices: &PVector<I>) -> Mask {
39 let indices_validity = indices.validity();
40 let indices_len = indices.len();
41
42 let indices_validity_values = match indices_validity {
43 Mask::AllTrue(_) => return self.take(indices.elements().as_slice()),
44 Mask::AllFalse(_) => return Mask::AllFalse(indices_len),
45 Mask::Values(indices_validity_values) => indices_validity_values,
46 };
47
48 match self {
49 Mask::AllTrue(_) => Mask::Values(indices_validity_values.clone()),
51 Mask::AllFalse(_) => Mask::AllFalse(indices_len),
53 Mask::Values(mask_values) => {
54 if self.len() <= LINUX_PAGE_SIZE {
57 let bools = mask_values.bit_buffer().iter().collect();
58 Mask::from_buffer(take_byte_bool_nullable(bools, indices))
59 } else {
60 Mask::from_buffer(take_bool_nullable(mask_values.bit_buffer(), indices))
61 }
62 }
63 }
64 }
65}
66
67fn take_byte_bool_nullable<I: UnsignedPType>(bools: Vec<bool>, indices: &PVector<I>) -> BitBuffer {
68 BitBuffer::collect_bool(indices.len(), |idx| {
69 indices
70 .get(idx)
71 .is_some_and(|bool_idx| bools[bool_idx.as_()])
72 })
73}
74
75fn take_bool_nullable<I: UnsignedPType>(bools: &BitBuffer, indices: &PVector<I>) -> BitBuffer {
76 let buffer = bools.inner().as_ref();
78 let offset = bools.offset();
79
80 BitBuffer::collect_bool(indices.len(), |idx| {
81 indices
82 .get(idx)
83 .is_some_and(|bool_idx| get_bit(buffer, offset + bool_idx.as_()))
84 })
85}