Skip to main content

vortex_array/arrays/bool/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::BitBufferView;
8use vortex_buffer::get_bit;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::Bool;
16use crate::arrays::BoolArray;
17use crate::arrays::ConstantArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::bool::BoolArrayExt;
20use crate::arrays::dict::TakeExecute;
21use crate::builtins::ArrayBuiltins;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::scalar::Scalar;
25
26impl TakeExecute for Bool {
27    fn take(
28        array: ArrayView<'_, Bool>,
29        indices: &ArrayRef,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let indices_nulls_zeroed = match indices.validity()?.execute_mask(indices.len(), ctx)? {
33            Mask::AllTrue(_) => indices.clone(),
34            Mask::AllFalse(_) => {
35                return Ok(Some(
36                    ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
37                        .into_array(),
38                ));
39            }
40            Mask::Values(_) => indices
41                .clone()
42                .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
43        };
44        let indices_nulls_zeroed = indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?;
45        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
46            take_valid_indices(
47                array.bit_buffer_view(),
48                indices_nulls_zeroed.as_slice::<I>(),
49            )
50        });
51
52        Ok(Some(
53            BoolArray::new(buffer, array.validity()?.take(indices)?).into_array(),
54        ))
55    }
56}
57
58fn take_valid_indices<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
59    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
60    // the overhead to convert to a Vec<bool>.
61    if bools.len() <= 4096 {
62        let bools = bools.iter().collect_vec();
63        take_byte_bool(bools, indices)
64    } else {
65        take_bool_impl(bools, indices)
66    }
67}
68
69fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
70    BitBuffer::collect_bool(indices.len(), |idx| {
71        bools[unsafe { indices.get_unchecked(idx).as_() }]
72    })
73}
74
75fn take_bool_impl<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
76    // We dereference to underlying buffer to avoid access cost on every index.
77    let buffer = bools.inner();
78    BitBuffer::collect_bool(indices.len(), |idx| {
79        // SAFETY: we can take from the indices unchecked since collect_bool just iterates len.
80        let idx = unsafe { indices.get_unchecked(idx).as_() };
81        get_bit(buffer, bools.offset() + idx)
82    })
83}
84
85#[cfg(test)]
86mod test {
87    use rstest::rstest;
88    use vortex_buffer::buffer;
89
90    use crate::IntoArray as _;
91    #[expect(deprecated)]
92    use crate::ToCanonical as _;
93    use crate::arrays::BoolArray;
94    use crate::arrays::PrimitiveArray;
95    use crate::arrays::bool::BoolArrayExt;
96    use crate::assert_arrays_eq;
97    use crate::compute::conformance::take::test_take_conformance;
98    use crate::validity::Validity;
99
100    #[test]
101    fn take_nullable() {
102        let reference = BoolArray::from_iter(vec![
103            Some(false),
104            Some(true),
105            Some(false),
106            None,
107            Some(false),
108        ]);
109
110        #[expect(deprecated)]
111        let b = reference
112            .take(buffer![0, 3, 4].into_array())
113            .unwrap()
114            .to_bool();
115        assert_eq!(
116            b.to_bit_buffer(),
117            BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
118        );
119
120        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
121        let b = reference.take(all_invalid_indices.into_array()).unwrap();
122        assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
123    }
124
125    #[test]
126    fn test_bool_array_take_with_null_out_of_bounds_indices() {
127        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
128        let indices = PrimitiveArray::new(
129            buffer![0, 3, 100],
130            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
131        );
132        let actual = values.take(indices.into_array()).unwrap();
133
134        // position 3 is null, the third index is null
135        assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
136    }
137
138    #[test]
139    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
140        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
141        let indices = PrimitiveArray::new(
142            buffer![0, 3, 100],
143            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
144        );
145        let actual = values.take(indices.into_array()).unwrap();
146        // the third index is null
147        assert_arrays_eq!(
148            actual,
149            BoolArray::from_iter([Some(false), Some(true), None])
150        );
151    }
152
153    #[test]
154    fn test_bool_array_take_all_null_indices() {
155        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
156        let indices = PrimitiveArray::new(
157            buffer![0, 3, 100],
158            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
159        );
160        let actual = values.take(indices.into_array()).unwrap();
161        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
162    }
163
164    #[test]
165    fn test_non_null_bool_array_take_all_null_indices() {
166        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
167        let indices = PrimitiveArray::new(
168            buffer![0, 3, 100],
169            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
170        );
171        let actual = values.take(indices.into_array()).unwrap();
172        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
173    }
174
175    #[rstest]
176    #[case(BoolArray::from_iter([true, false, true, true, false]))]
177    #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
178    #[case(BoolArray::from_iter([true, false]))]
179    #[case(BoolArray::from_iter([true]))]
180    fn test_take_bool_conformance(#[case] array: BoolArray) {
181        test_take_conformance(&array.into_array());
182    }
183}