Skip to main content

vortex_array/arrays/bool/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::get_bit;
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10
11use crate::ArrayRef;
12use crate::IntoArray;
13use crate::array::ArrayView;
14use crate::arrays::Bool;
15use crate::arrays::BoolArray;
16use crate::arrays::ConstantArray;
17use crate::arrays::PrimitiveArray;
18use crate::arrays::bool::BoolArrayExt;
19use crate::arrays::dict::TakeExecute;
20use crate::builtins::ArrayBuiltins;
21use crate::executor::ExecutionCtx;
22use crate::match_each_integer_ptype;
23use crate::scalar::Scalar;
24
25impl TakeExecute for Bool {
26    fn take(
27        array: ArrayView<'_, Bool>,
28        indices: &ArrayRef,
29        ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let indices_nulls_zeroed = match indices.validity()?.execute_mask(indices.len(), ctx)? {
32            Mask::AllTrue(_) => indices.clone(),
33            Mask::AllFalse(_) => {
34                return Ok(Some(
35                    ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
36                        .into_array(),
37                ));
38            }
39            Mask::Values(_) => indices
40                .clone()
41                .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
42        };
43        let indices_nulls_zeroed = indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?;
44        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
45            take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::<I>())
46        });
47
48        Ok(Some(
49            BoolArray::new(buffer, array.validity()?.take(indices)?).into_array(),
50        ))
51    }
52}
53
54fn take_valid_indices<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
55    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
56    // the overhead to convert to a Vec<bool>.
57    if bools.len() <= 4096 {
58        let bools = bools.iter().collect_vec();
59        take_byte_bool(bools, indices)
60    } else {
61        take_bool_impl(bools, indices)
62    }
63}
64
65fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
66    BitBuffer::collect_bool(indices.len(), |idx| {
67        bools[unsafe { indices.get_unchecked(idx).as_() }]
68    })
69}
70
71fn take_bool_impl<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
72    // We dereference to underlying buffer to avoid access cost on every index.
73    let buffer = bools.inner().as_ref();
74    BitBuffer::collect_bool(indices.len(), |idx| {
75        // SAFETY: we can take from the indices unchecked since collect_bool just iterates len.
76        let idx = unsafe { indices.get_unchecked(idx).as_() };
77        get_bit(buffer, bools.offset() + idx)
78    })
79}
80
81#[cfg(test)]
82mod test {
83    use rstest::rstest;
84    use vortex_buffer::buffer;
85
86    use crate::IntoArray as _;
87    #[expect(deprecated)]
88    use crate::ToCanonical as _;
89    use crate::arrays::BoolArray;
90    use crate::arrays::PrimitiveArray;
91    use crate::arrays::bool::BoolArrayExt;
92    use crate::assert_arrays_eq;
93    use crate::compute::conformance::take::test_take_conformance;
94    use crate::validity::Validity;
95
96    #[test]
97    fn take_nullable() {
98        let reference = BoolArray::from_iter(vec![
99            Some(false),
100            Some(true),
101            Some(false),
102            None,
103            Some(false),
104        ]);
105
106        #[expect(deprecated)]
107        let b = reference
108            .take(buffer![0, 3, 4].into_array())
109            .unwrap()
110            .to_bool();
111        assert_eq!(
112            b.to_bit_buffer(),
113            BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
114        );
115
116        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
117        let b = reference.take(all_invalid_indices.into_array()).unwrap();
118        assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
119    }
120
121    #[test]
122    fn test_bool_array_take_with_null_out_of_bounds_indices() {
123        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
124        let indices = PrimitiveArray::new(
125            buffer![0, 3, 100],
126            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
127        );
128        let actual = values.take(indices.into_array()).unwrap();
129
130        // position 3 is null, the third index is null
131        assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
132    }
133
134    #[test]
135    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
136        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
137        let indices = PrimitiveArray::new(
138            buffer![0, 3, 100],
139            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
140        );
141        let actual = values.take(indices.into_array()).unwrap();
142        // the third index is null
143        assert_arrays_eq!(
144            actual,
145            BoolArray::from_iter([Some(false), Some(true), None])
146        );
147    }
148
149    #[test]
150    fn test_bool_array_take_all_null_indices() {
151        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
152        let indices = PrimitiveArray::new(
153            buffer![0, 3, 100],
154            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
155        );
156        let actual = values.take(indices.into_array()).unwrap();
157        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
158    }
159
160    #[test]
161    fn test_non_null_bool_array_take_all_null_indices() {
162        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
163        let indices = PrimitiveArray::new(
164            buffer![0, 3, 100],
165            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
166        );
167        let actual = values.take(indices.into_array()).unwrap();
168        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
169    }
170
171    #[rstest]
172    #[case(BoolArray::from_iter([true, false, true, true, false]))]
173    #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
174    #[case(BoolArray::from_iter([true, false]))]
175    #[case(BoolArray::from_iter([true]))]
176    fn test_take_bool_conformance(#[case] array: BoolArray) {
177        test_take_conformance(&array.into_array());
178    }
179}