vortex_array/arrays/bool/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use arrow_buffer::BooleanBuffer;
5use itertools::Itertools as _;
6use num_traits::AsPrimitive;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, BoolVTable, ConstantArray};
13use crate::compute::{TakeKernel, TakeKernelAdapter, fill_null};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for BoolVTable {
18    fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        let indices_nulls_zeroed = match indices.validity_mask()? {
20            Mask::AllTrue(_) => indices.to_array(),
21            Mask::AllFalse(_) => {
22                return Ok(ConstantArray::new(
23                    Scalar::null(array.dtype().as_nullable()),
24                    indices.len(),
25                )
26                .into_array());
27            }
28            Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?,
29        };
30        let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
31        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
32            take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<I>())
33        });
34
35        Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array())
36    }
37}
38
39register_kernel!(TakeKernelAdapter(BoolVTable).lift());
40
41fn take_valid_indices<I: AsPrimitive<usize>>(
42    bools: &BooleanBuffer,
43    indices: &[I],
44) -> BooleanBuffer {
45    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
46    // the overhead to convert to a Vec<bool>.
47    if bools.len() <= 4096 {
48        let bools = bools.into_iter().collect_vec();
49        take_byte_bool(bools, indices)
50    } else {
51        take_bool(bools, indices)
52    }
53}
54
55fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
56    BooleanBuffer::collect_bool(indices.len(), |idx| {
57        bools[unsafe { indices.get_unchecked(idx).as_() }]
58    })
59}
60
61fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
62    BooleanBuffer::collect_bool(indices.len(), |idx| {
63        // We can always take from the indices unchecked since collect_bool just iterates len.
64        bools.value(unsafe { indices.get_unchecked(idx).as_() })
65    })
66}
67
68#[cfg(test)]
69mod test {
70    use rstest::rstest;
71    use vortex_buffer::buffer;
72    use vortex_dtype::{DType, Nullability};
73    use vortex_scalar::Scalar;
74
75    use crate::arrays::BoolArray;
76    use crate::arrays::primitive::PrimitiveArray;
77    use crate::compute::conformance::take::test_take_conformance;
78    use crate::compute::take;
79    use crate::validity::Validity;
80    use crate::{Array, ToCanonical};
81
82    #[test]
83    fn take_nullable() {
84        let reference = BoolArray::from_iter(vec![
85            Some(false),
86            Some(true),
87            Some(false),
88            None,
89            Some(false),
90        ]);
91
92        let b = take(
93            reference.as_ref(),
94            PrimitiveArray::from_iter([0, 3, 4]).as_ref(),
95        )
96        .unwrap()
97        .to_bool()
98        .unwrap();
99        assert_eq!(
100            b.boolean_buffer(),
101            BoolArray::from_iter([Some(false), None, Some(false)]).boolean_buffer()
102        );
103
104        let nullable_bool_dtype = DType::Bool(Nullability::Nullable);
105        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
106        let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap();
107        assert_eq!(b.dtype(), &nullable_bool_dtype);
108        assert_eq!(
109            b.scalar_at(0).unwrap(),
110            Scalar::null(nullable_bool_dtype.clone())
111        );
112        assert_eq!(
113            b.scalar_at(1).unwrap(),
114            Scalar::null(nullable_bool_dtype.clone())
115        );
116        assert_eq!(b.scalar_at(2).unwrap(), Scalar::null(nullable_bool_dtype));
117    }
118
119    #[test]
120    fn test_bool_array_take_with_null_out_of_bounds_indices() {
121        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
122        let indices = PrimitiveArray::new(
123            buffer![0, 3, 100],
124            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
125        );
126        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
127        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(false)));
128        // position 3 is null
129        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
130        // the third index is null
131        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
132    }
133
134    #[test]
135    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
136        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
137        let indices = PrimitiveArray::new(
138            buffer![0, 3, 100],
139            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
140        );
141        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
142        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(false)));
143        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::from(Some(true)));
144        // the third index is null
145        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
146    }
147
148    #[test]
149    fn test_bool_array_take_all_null_indices() {
150        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
151        let indices = PrimitiveArray::new(
152            buffer![0, 3, 100],
153            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
154        );
155        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
156        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::null_typed::<bool>());
157        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
158        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
159    }
160
161    #[test]
162    fn test_non_null_bool_array_take_all_null_indices() {
163        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
164        let indices = PrimitiveArray::new(
165            buffer![0, 3, 100],
166            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
167        );
168        let actual = take(values.as_ref(), indices.as_ref()).unwrap();
169        assert_eq!(actual.scalar_at(0).unwrap(), Scalar::null_typed::<bool>());
170        assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
171        assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
172    }
173
174    #[rstest]
175    #[case(BoolArray::from_iter([true, false, true, true, false]))]
176    #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
177    #[case(BoolArray::from_iter([true, false]))]
178    #[case(BoolArray::from_iter([true]))]
179    fn test_take_bool_conformance(#[case] array: BoolArray) {
180        test_take_conformance(array.as_ref());
181    }
182}