vortex_array/arrays/bool/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::{BitBuffer, get_bit};
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, BoolVTable, ConstantArray};
13use crate::compute::{TakeKernel, TakeKernelAdapter, fill_null};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for BoolVTable {
18    fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        let indices_nulls_zeroed = match indices.validity_mask() {
20            Mask::AllTrue(_) => indices.to_array(),
21            Mask::AllFalse(_) => {
22                return Ok(ConstantArray::new(
23                    Scalar::null(array.dtype().as_nullable()),
24                    indices.len(),
25                )
26                .into_array());
27            }
28            Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?,
29        };
30        let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive();
31        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
32            take_valid_indices(array.bit_buffer(), indices_nulls_zeroed.as_slice::<I>())
33        });
34
35        Ok(BoolArray::from_bit_buffer(buffer, array.validity().take(indices)?).to_array())
36    }
37}
38
39register_kernel!(TakeKernelAdapter(BoolVTable).lift());
40
41fn take_valid_indices<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
42    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
43    // the overhead to convert to a Vec<bool>.
44    if bools.len() <= 4096 {
45        let bools = bools.iter().collect_vec();
46        take_byte_bool(bools, indices)
47    } else {
48        take_bool(bools, indices)
49    }
50}
51
52fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
53    BitBuffer::collect_bool(indices.len(), |idx| {
54        bools[unsafe { indices.get_unchecked(idx).as_() }]
55    })
56}
57
58fn take_bool<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
59    // We dereference to underlying buffer to avoid access cost on every index.
60    let buffer = bools.inner().as_ref();
61    BitBuffer::collect_bool(indices.len(), |idx| {
62        // SAFETY: we can take from the indices unchecked since collect_bool just iterates len.
63        let idx = unsafe { indices.get_unchecked(idx).as_() };
64        get_bit(buffer, bools.offset() + idx)
65    })
66}
67
68#[cfg(test)]
69mod test {
70    use rstest::rstest;
71    use vortex_buffer::buffer;
72
73    use crate::arrays::BoolArray;
74    use crate::arrays::primitive::PrimitiveArray;
75    use crate::compute::conformance::take::test_take_conformance;
76    use crate::compute::take;
77    use crate::validity::Validity;
78    use crate::{Array, IntoArray as _, ToCanonical, assert_arrays_eq};
79
80    #[test]
81    fn take_nullable() {
82        let reference = BoolArray::from_iter(vec![
83            Some(false),
84            Some(true),
85            Some(false),
86            None,
87            Some(false),
88        ]);
89
90        let b = take(reference.as_ref(), buffer![0, 3, 4].into_array().as_ref())
91            .unwrap()
92            .to_bool();
93        assert_eq!(
94            b.bit_buffer(),
95            BoolArray::from_iter([Some(false), None, Some(false)]).bit_buffer()
96        );
97
98        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
99        let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap();
100        assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
101    }
102
103    #[test]
104    fn test_bool_array_take_with_null_out_of_bounds_indices() {
105        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
106        let indices = PrimitiveArray::new(
107            buffer![0, 3, 100],
108            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
109        );
110        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
111
112        // position 3 is null, the third index is null
113        assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
114    }
115
116    #[test]
117    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
118        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
119        let indices = PrimitiveArray::new(
120            buffer![0, 3, 100],
121            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
122        );
123        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
124        // the third index is null
125        assert_arrays_eq!(
126            actual,
127            BoolArray::from_iter([Some(false), Some(true), None])
128        );
129    }
130
131    #[test]
132    fn test_bool_array_take_all_null_indices() {
133        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
134        let indices = PrimitiveArray::new(
135            buffer![0, 3, 100],
136            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
137        );
138        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
139        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
140    }
141
142    #[test]
143    fn test_non_null_bool_array_take_all_null_indices() {
144        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
145        let indices = PrimitiveArray::new(
146            buffer![0, 3, 100],
147            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
148        );
149        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
150        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
151    }
152
153    #[rstest]
154    #[case(BoolArray::from_iter([true, false, true, true, false]))]
155    #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
156    #[case(BoolArray::from_iter([true, false]))]
157    #[case(BoolArray::from_iter([true]))]
158    fn test_take_bool_conformance(#[case] array: BoolArray) {
159        test_take_conformance(array.as_ref());
160    }
161}