vortex_array/arrays/bool/compute/
take.rs

1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools as _;
3use num_traits::AsPrimitive;
4use vortex_dtype::match_each_integer_ptype;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::{BoolArray, BoolVTable, ConstantArray};
10use crate::compute::{TakeKernel, TakeKernelAdapter, fill_null};
11use crate::vtable::ValidityHelper;
12use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
13
14impl TakeKernel for BoolVTable {
15    fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16        let indices_nulls_zeroed = match indices.validity_mask()? {
17            Mask::AllTrue(_) => indices.to_array(),
18            Mask::AllFalse(_) => {
19                return Ok(ConstantArray::new(
20                    Scalar::null(array.dtype().as_nullable()),
21                    indices.len(),
22                )
23                .into_array());
24            }
25            Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?,
26        };
27        let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
28        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
29            take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<I>())
30        });
31
32        Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array())
33    }
34}
35
36register_kernel!(TakeKernelAdapter(BoolVTable).lift());
37
38fn take_valid_indices<I: AsPrimitive<usize>>(
39    bools: &BooleanBuffer,
40    indices: &[I],
41) -> BooleanBuffer {
42    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
43    // the overhead to convert to a Vec<bool>.
44    if bools.len() <= 4096 {
45        let bools = bools.into_iter().collect_vec();
46        take_byte_bool(bools, indices)
47    } else {
48        take_bool(bools, indices)
49    }
50}
51
52fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
53    BooleanBuffer::collect_bool(indices.len(), |idx| {
54        bools[unsafe { indices.get_unchecked(idx).as_() }]
55    })
56}
57
58fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
59    BooleanBuffer::collect_bool(indices.len(), |idx| {
60        // We can always take from the indices unchecked since collect_bool just iterates len.
61        bools.value(unsafe { indices.get_unchecked(idx).as_() })
62    })
63}
64
65#[cfg(test)]
66mod test {
67    use vortex_buffer::buffer;
68    use vortex_dtype::{DType, Nullability};
69    use vortex_scalar::Scalar;
70
71    use crate::arrays::BoolArray;
72    use crate::arrays::primitive::PrimitiveArray;
73    use crate::compute::take;
74    use crate::validity::Validity;
75    use crate::{Array, ToCanonical};
76
77    #[test]
78    fn take_nullable() {
79        let reference = BoolArray::from_iter(vec![
80            Some(false),
81            Some(true),
82            Some(false),
83            None,
84            Some(false),
85        ]);
86
87        let b = take(
88            reference.as_ref(),
89            PrimitiveArray::from_iter([0, 3, 4]).as_ref(),
90        )
91        .unwrap()
92        .to_bool()
93        .unwrap();
94        assert_eq!(
95            b.boolean_buffer(),
96            BoolArray::from_iter([Some(false), None, Some(false)]).boolean_buffer()
97        );
98
99        let nullable_bool_dtype = DType::Bool(Nullability::Nullable);
100        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
101        let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap();
102        assert_eq!(b.dtype(), &nullable_bool_dtype);
103        assert_eq!(
104            b.scalar_at(0).unwrap(),
105            Scalar::null(nullable_bool_dtype.clone())
106        );
107        assert_eq!(
108            b.scalar_at(1).unwrap(),
109            Scalar::null(nullable_bool_dtype.clone())
110        );
111        assert_eq!(b.scalar_at(2).unwrap(), Scalar::null(nullable_bool_dtype));
112    }
113
114    #[test]
115    fn test_bool_array_take_with_null_out_of_bounds_indices() {
116        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
117        let indices = PrimitiveArray::new(
118            buffer![0, 3, 100],
119            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
120        );
121        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
122        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(false)));
123        // position 3 is null
124        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
125        // the third index is null
126        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
127    }
128
129    #[test]
130    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
131        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
132        let indices = PrimitiveArray::new(
133            buffer![0, 3, 100],
134            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
135        );
136        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
137        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(false)));
138        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::from(Some(true)));
139        // the third index is null
140        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
141    }
142
143    #[test]
144    fn test_bool_array_take_all_null_indices() {
145        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
146        let indices = PrimitiveArray::new(
147            buffer![0, 3, 100],
148            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
149        );
150        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
151        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::null_typed::<bool>());
152        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
153        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
154    }
155
156    #[test]
157    fn test_non_null_bool_array_take_all_null_indices() {
158        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
159        let indices = PrimitiveArray::new(
160            buffer![0, 3, 100],
161            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
162        );
163        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
164        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::null_typed::<bool>());
165        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
166        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
167    }
168}