Skip to main content

vortex_array/arrays/bool/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::BitBufferView;
8use vortex_buffer::get_bit;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::Bool;
16use crate::arrays::BoolArray;
17use crate::arrays::ConstantArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::bool::BoolArrayExt;
20use crate::arrays::dict::TakeExecute;
21use crate::builtins::ArrayBuiltins;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::scalar::Scalar;
25
26impl TakeExecute for Bool {
27    fn take(
28        array: ArrayView<'_, Bool>,
29        indices: &ArrayRef,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let indices_nulls_zeroed = match indices.validity()?.execute_mask(indices.len(), ctx)? {
33            Mask::AllTrue(_) => indices.clone(),
34            Mask::AllFalse(_) => {
35                return Ok(Some(
36                    ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
37                        .into_array(),
38                ));
39            }
40            Mask::Values(_) => indices
41                .clone()
42                .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
43        };
44        let indices_nulls_zeroed = indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?;
45        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
46            take_valid_indices(
47                array.bit_buffer_view(),
48                indices_nulls_zeroed.as_slice::<I>(),
49            )
50        });
51
52        Ok(Some(
53            BoolArray::new(buffer, array.validity()?.take(indices)?).into_array(),
54        ))
55    }
56}
57
58fn take_valid_indices<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
59    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
60    // the overhead to convert to a Vec<bool>.
61    if bools.len() <= 4096 {
62        let bools = bools.iter().collect_vec();
63        take_byte_bool(bools, indices)
64    } else {
65        take_bool_impl(bools, indices)
66    }
67}
68
69fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
70    BitBuffer::collect_bool(indices.len(), |idx| {
71        bools[unsafe { indices.get_unchecked(idx).as_() }]
72    })
73}
74
75fn take_bool_impl<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
76    // We dereference to underlying buffer to avoid access cost on every index.
77    let buffer = bools.inner();
78    BitBuffer::collect_bool(indices.len(), |idx| {
79        // SAFETY: we can take from the indices unchecked since collect_bool just iterates len.
80        let idx = unsafe { indices.get_unchecked(idx).as_() };
81        get_bit(buffer, bools.offset() + idx)
82    })
83}
84
85#[cfg(test)]
86mod test {
87    use rstest::rstest;
88    use vortex_buffer::buffer;
89
90    use crate::IntoArray as _;
91    #[expect(deprecated)]
92    use crate::ToCanonical as _;
93    use crate::VortexSessionExecute;
94    use crate::array_session;
95    use crate::arrays::BoolArray;
96    use crate::arrays::PrimitiveArray;
97    use crate::arrays::bool::BoolArrayExt;
98    use crate::assert_arrays_eq;
99    use crate::compute::conformance::take::test_take_conformance;
100    use crate::validity::Validity;
101
102    #[test]
103    fn take_nullable() {
104        let mut ctx = array_session().create_execution_ctx();
105        let reference = BoolArray::from_iter(vec![
106            Some(false),
107            Some(true),
108            Some(false),
109            None,
110            Some(false),
111        ]);
112
113        #[expect(deprecated)]
114        let b = reference
115            .take(buffer![0, 3, 4].into_array())
116            .unwrap()
117            .to_bool();
118        assert_eq!(
119            b.to_bit_buffer(),
120            BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
121        );
122
123        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
124        let b = reference.take(all_invalid_indices.into_array()).unwrap();
125        assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]), &mut ctx);
126    }
127
128    #[test]
129    fn test_bool_array_take_with_null_out_of_bounds_indices() {
130        let mut ctx = array_session().create_execution_ctx();
131        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
132        let indices = PrimitiveArray::new(
133            buffer![0, 3, 100],
134            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
135        );
136        let actual = values.take(indices.into_array()).unwrap();
137
138        // position 3 is null, the third index is null
139        assert_arrays_eq!(
140            actual,
141            BoolArray::from_iter([Some(false), None, None]),
142            &mut ctx
143        );
144    }
145
146    #[test]
147    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
148        let mut ctx = array_session().create_execution_ctx();
149        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
150        let indices = PrimitiveArray::new(
151            buffer![0, 3, 100],
152            Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
153        );
154        let actual = values.take(indices.into_array()).unwrap();
155        // the third index is null
156        assert_arrays_eq!(
157            actual,
158            BoolArray::from_iter([Some(false), Some(true), None]),
159            &mut ctx
160        );
161    }
162
163    #[test]
164    fn test_bool_array_take_all_null_indices() {
165        let mut ctx = array_session().create_execution_ctx();
166        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
167        let indices = PrimitiveArray::new(
168            buffer![0, 3, 100],
169            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
170        );
171        let actual = values.take(indices.into_array()).unwrap();
172        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]), &mut ctx);
173    }
174
175    #[test]
176    fn test_non_null_bool_array_take_all_null_indices() {
177        let mut ctx = array_session().create_execution_ctx();
178        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
179        let indices = PrimitiveArray::new(
180            buffer![0, 3, 100],
181            Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
182        );
183        let actual = values.take(indices.into_array()).unwrap();
184        assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]), &mut ctx);
185    }
186
187    #[rstest]
188    #[case(BoolArray::from_iter([true, false, true, true, false]))]
189    #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
190    #[case(BoolArray::from_iter([true, false]))]
191    #[case(BoolArray::from_iter([true]))]
192    fn test_take_bool_conformance(#[case] array: BoolArray) {
193        test_take_conformance(&array.into_array());
194    }
195}