vortex_array/arrays/bool/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::get_bit;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11use vortex_scalar::Scalar;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::ToCanonical;
17use crate::arrays::BoolArray;
18use crate::arrays::BoolVTable;
19use crate::arrays::ConstantArray;
20use crate::compute::TakeKernel;
21use crate::compute::TakeKernelAdapter;
22use crate::compute::fill_null;
23use crate::register_kernel;
24use crate::vtable::ValidityHelper;
25
26impl TakeKernel for BoolVTable {
27    fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
28        let indices_nulls_zeroed = match indices.validity_mask() {
29            Mask::AllTrue(_) => indices.to_array(),
30            Mask::AllFalse(_) => {
31                return Ok(ConstantArray::new(
32                    Scalar::null(array.dtype().as_nullable()),
33                    indices.len(),
34                )
35                .into_array());
36            }
37            Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?,
38        };
39        let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive();
40        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
41            take_valid_indices(array.bit_buffer(), indices_nulls_zeroed.as_slice::<I>())
42        });
43
44        Ok(BoolArray::from_bit_buffer(buffer, array.validity().take(indices)?).to_array())
45    }
46}
47
48register_kernel!(TakeKernelAdapter(BoolVTable).lift());
49
50fn take_valid_indices<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
51    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
52    // the overhead to convert to a Vec<bool>.
53    if bools.len() <= 4096 {
54        let bools = bools.iter().collect_vec();
55        take_byte_bool(bools, indices)
56    } else {
57        take_bool(bools, indices)
58    }
59}
60
61fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
62    BitBuffer::collect_bool(indices.len(), |idx| {
63        bools[unsafe { indices.get_unchecked(idx).as_() }]
64    })
65}
66
67fn take_bool<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
68    // We dereference to underlying buffer to avoid access cost on every index.
69    let buffer = bools.inner().as_ref();
70    BitBuffer::collect_bool(indices.len(), |idx| {
71        // SAFETY: we can take from the indices unchecked since collect_bool just iterates len.
72        let idx = unsafe { indices.get_unchecked(idx).as_() };
73        get_bit(buffer, bools.offset() + idx)
74    })
75}
76
77#[cfg(test)]
78mod test {
79    use rstest::rstest;
80    use vortex_buffer::buffer;
81
82    use crate::Array;
83    use crate::IntoArray as _;
84    use crate::ToCanonical;
85    use crate::arrays::BoolArray;
86    use crate::arrays::primitive::PrimitiveArray;
87    use crate::assert_arrays_eq;
88    use crate::compute::conformance::take::test_take_conformance;
89    use crate::compute::take;
90    use crate::validity::Validity;
91
92    #[test]
93    fn take_nullable() {
94        let reference = BoolArray::from_iter(vec![
95            Some(false),
96            Some(true),
97            Some(false),
98            None,
99            Some(false),
100        ]);
101
102        let b = take(reference.as_ref(), buffer![0, 3, 4].into_array().as_ref())
103            .unwrap()
104            .to_bool();
105        assert_eq!(
106            b.bit_buffer(),
107            BoolArray::from_iter([Some(false), None, Some(false)]).bit_buffer()
108        );
109
110        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
111        let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap();
112        assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
113    }
114
115    #[test]
116    fn test_bool_array_take_with_null_out_of_bounds_indices() {
117        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
118        let indices = PrimitiveArray::new(
119            buffer![0, 3, 100],
120            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
121        );
122        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
123
124        // position 3 is null, the third index is null
125        assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
126    }
127
128    #[test]
129    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
130        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
131        let indices = PrimitiveArray::new(
132            buffer![0, 3, 100],
133            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
134        );
135        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
136        // the third index is null
137        assert_arrays_eq!(
138            actual,
139            BoolArray::from_iter([Some(false), Some(true), None])
140        );
141    }
142
143    #[test]
144    fn test_bool_array_take_all_null_indices() {
145        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
146        let indices = PrimitiveArray::new(
147            buffer![0, 3, 100],
148            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
149        );
150        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
151        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
152    }
153
154    #[test]
155    fn test_non_null_bool_array_take_all_null_indices() {
156        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
157        let indices = PrimitiveArray::new(
158            buffer![0, 3, 100],
159            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
160        );
161        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
162        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
163    }
164
165    #[rstest]
166    #[case(BoolArray::from_iter([true, false, true, true, false]))]
167    #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
168    #[case(BoolArray::from_iter([true, false]))]
169    #[case(BoolArray::from_iter([true]))]
170    fn test_take_bool_conformance(#[case] array: BoolArray) {
171        test_take_conformance(array.as_ref());
172    }
173}