vortex_compute/take/
bit_buffer.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Take operation on [`BitBuffer`].
5//!
6//! NB: We do NOT implement `impl<I: UnsignedPType> Take<PVector<I>> for &BitBuffer`, specifically
7//! because there is a very similar implementation on `Mask` that has special logic for working with
8//! null indices. That logic could also be implemented on `BitBuffer`, but since it is not
9//! immediately clear what should happen in the case of a null index when taking a `BitBuffer` (do
10//! you set it to true or false?), we do not implement this at all.
11
12use vortex_buffer::BitBuffer;
13use vortex_buffer::get_bit;
14use vortex_dtype::UnsignedPType;
15
16use crate::take::LINUX_PAGE_SIZE;
17use crate::take::Take;
18
19impl<I: UnsignedPType> Take<[I]> for &BitBuffer {
20    type Output = BitBuffer;
21
22    fn take(self, indices: &[I]) -> BitBuffer {
23        // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
24        // the overhead to convert to a `Vec<bool>`.
25        if self.len() <= LINUX_PAGE_SIZE {
26            let bools = self.iter().collect();
27            take_byte_bool(bools, indices)
28        } else {
29            take_bool(self, indices)
30        }
31    }
32}
33
34/// # Panics
35///
36/// Panics if an index is out of bounds.
37fn take_byte_bool<I: UnsignedPType>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
38    BitBuffer::collect_bool(indices.len(), |idx| {
39        // SAFETY: We are iterating within the bounds of the `indices` array, so we are always
40        // within bounds of `indices`.
41        let bool_idx = unsafe { indices.get_unchecked(idx).as_() };
42        bools[bool_idx]
43    })
44}
45
46/// # Panics
47///
48/// Panics if an index is out of bounds.
49fn take_bool<I: UnsignedPType>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
50    // We dereference to the underlying buffer to avoid incurring an access cost on every index.
51    let buffer = bools.inner().as_ref();
52    let offset = bools.offset();
53
54    BitBuffer::collect_bool(indices.len(), |idx| {
55        // SAFETY: We are iterating within the bounds of the `indices` array, so we are always
56        // within bounds.
57        let bool_idx = unsafe { indices.get_unchecked(idx).as_() };
58        get_bit(buffer, offset + bool_idx)
59    })
60}
61
62#[cfg(test)]
63mod tests {
64    use crate::take::Take;
65
66    #[test]
67    fn test_take_bit_buffer_take_small_and_large() {
68        use vortex_buffer::BitBuffer;
69
70        // Small buffer (uses take_byte_bool path).
71        let small: BitBuffer = [true, false, true, true, false, true, false, false]
72            .into_iter()
73            .collect();
74        let result = (&small).take(&[7u32, 0, 2, 5, 1][..]);
75
76        let values: Vec<bool> = (0..result.len()).map(|i| result.value(i)).collect();
77        assert_eq!(values, vec![false, true, true, true, false]);
78
79        // Large buffer (uses take_bool path, len > 4096).
80        let large: BitBuffer = (0..5000).map(|i| i % 3 == 0).collect();
81        let result = (&large).take(&[4999u32, 0, 1, 2, 3, 4998][..]);
82
83        let values: Vec<bool> = (0..result.len()).map(|i| result.value(i)).collect();
84        assert_eq!(values, vec![false, true, false, false, true, true]);
85    }
86}