vortex_array/arrays/bool/compute/
take.rs

1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools;
3use num_traits::AsPrimitive;
4use vortex_dtype::match_each_integer_ptype;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::{BoolArray, BoolEncoding, ConstantArray};
10use crate::builders::ArrayBuilder;
11use crate::compute::{TakeFn, fill_null};
12use crate::variants::PrimitiveArrayTrait;
13use crate::{Array, ArrayRef, ToCanonical};
14
15impl TakeFn<&BoolArray> for BoolEncoding {
16    fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17        let indices_nulls_zeroed = match indices.validity_mask()? {
18            Mask::AllTrue(_) => indices.to_array(),
19            Mask::AllFalse(_) => {
20                return Ok(ConstantArray::new(
21                    Scalar::null(array.dtype().as_nullable()),
22                    indices.len(),
23                )
24                .into_array());
25            }
26            Mask::Values(_) => fill_null(indices, Scalar::from(0).cast(indices.dtype())?)?,
27        };
28        let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
29        let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |$I| {
30            take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<$I>())
31        });
32
33        Ok(BoolArray::new(buffer, array.validity().take(indices)?).into_array())
34    }
35
36    fn take_into(
37        &self,
38        array: &BoolArray,
39        indices: &dyn Array,
40        builder: &mut dyn ArrayBuilder,
41    ) -> VortexResult<()> {
42        builder.extend_from_array(&self.take(array, indices)?)
43    }
44}
45
46fn take_valid_indices<I: AsPrimitive<usize>>(
47    bools: &BooleanBuffer,
48    indices: &[I],
49) -> BooleanBuffer {
50    // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
51    // the overhead to convert to a Vec<bool>.
52    if bools.len() <= 4096 {
53        let bools = bools.into_iter().collect_vec();
54        take_byte_bool(bools, indices)
55    } else {
56        take_bool(bools, indices)
57    }
58}
59
60fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
61    BooleanBuffer::collect_bool(indices.len(), |idx| {
62        bools[unsafe { indices.get_unchecked(idx).as_() }]
63    })
64}
65
66fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
67    BooleanBuffer::collect_bool(indices.len(), |idx| {
68        // We can always take from the indices unchecked since collect_bool just iterates len.
69        bools.value(unsafe { indices.get_unchecked(idx).as_() })
70    })
71}
72
73#[cfg(test)]
74mod test {
75    use vortex_buffer::buffer;
76    use vortex_dtype::{DType, Nullability};
77    use vortex_scalar::Scalar;
78
79    use crate::arrays::BoolArray;
80    use crate::arrays::primitive::PrimitiveArray;
81    use crate::compute::{scalar_at, take};
82    use crate::validity::Validity;
83    use crate::{Array, ToCanonical};
84
85    #[test]
86    fn take_nullable() {
87        let reference = BoolArray::from_iter(vec![
88            Some(false),
89            Some(true),
90            Some(false),
91            None,
92            Some(false),
93        ]);
94
95        let b = take(&reference, &PrimitiveArray::from_iter([0, 3, 4]))
96            .unwrap()
97            .to_bool()
98            .unwrap();
99        assert_eq!(
100            b.boolean_buffer(),
101            BoolArray::from_iter([Some(false), None, Some(false)]).boolean_buffer()
102        );
103
104        let nullable_bool_dtype = DType::Bool(Nullability::Nullable);
105        let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
106        let b = take(&reference, &all_invalid_indices).unwrap();
107        assert_eq!(b.dtype(), &nullable_bool_dtype);
108        assert_eq!(
109            scalar_at(&b, 0).unwrap(),
110            Scalar::null(nullable_bool_dtype.clone())
111        );
112        assert_eq!(
113            scalar_at(&b, 1).unwrap(),
114            Scalar::null(nullable_bool_dtype.clone())
115        );
116        assert_eq!(scalar_at(&b, 2).unwrap(), Scalar::null(nullable_bool_dtype));
117    }
118
119    #[test]
120    fn test_bool_array_take_with_null_out_of_bounds_indices() {
121        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
122        let indices = PrimitiveArray::new(
123            buffer![0, 3, 100],
124            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
125        );
126        let actual = take(&values, &indices).unwrap();
127        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(false)));
128        // position 3 is null
129        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<bool>());
130        // the third index is null
131        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
132    }
133
134    #[test]
135    fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
136        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
137        let indices = PrimitiveArray::new(
138            buffer![0, 3, 100],
139            Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
140        );
141        let actual = take(&values, &indices).unwrap();
142        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(false)));
143        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(true)));
144        // the third index is null
145        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
146    }
147
148    #[test]
149    fn test_bool_array_take_all_null_indices() {
150        let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
151        let indices = PrimitiveArray::new(
152            buffer![0, 3, 100],
153            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
154        );
155        let actual = take(&values, &indices).unwrap();
156        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<bool>());
157        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<bool>());
158        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
159    }
160
161    #[test]
162    fn test_non_null_bool_array_take_all_null_indices() {
163        let values = BoolArray::from_iter(vec![false, true, false, true, false]);
164        let indices = PrimitiveArray::new(
165            buffer![0, 3, 100],
166            Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
167        );
168        let actual = take(&values, &indices).unwrap();
169        assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<bool>());
170        assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<bool>());
171        assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
172    }
173}