vortex_array/arrays/bool/compute/
take.rs1use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::get_bit;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::ToCanonical;
16use crate::arrays::BoolArray;
17use crate::arrays::BoolVTable;
18use crate::arrays::ConstantArray;
19use crate::arrays::TakeExecute;
20use crate::builtins::ArrayBuiltins;
21use crate::executor::ExecutionCtx;
22use crate::scalar::Scalar;
23use crate::vtable::ValidityHelper;
24
25impl TakeExecute for BoolVTable {
26 fn take(
27 array: &BoolArray,
28 indices: &dyn Array,
29 _ctx: &mut ExecutionCtx,
30 ) -> VortexResult<Option<ArrayRef>> {
31 let indices_nulls_zeroed = match indices.validity_mask()? {
32 Mask::AllTrue(_) => indices.to_array(),
33 Mask::AllFalse(_) => {
34 return Ok(Some(
35 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
36 .into_array(),
37 ));
38 }
39 Mask::Values(_) => indices
40 .to_array()
41 .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
42 };
43 let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive();
44 let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
45 take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::<I>())
46 });
47
48 Ok(Some(
49 BoolArray::new(buffer, array.validity().take(indices)?).to_array(),
50 ))
51 }
52}
53
54fn take_valid_indices<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
55 if bools.len() <= 4096 {
58 let bools = bools.iter().collect_vec();
59 take_byte_bool(bools, indices)
60 } else {
61 take_bool_impl(bools, indices)
62 }
63}
64
65fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
66 BitBuffer::collect_bool(indices.len(), |idx| {
67 bools[unsafe { indices.get_unchecked(idx).as_() }]
68 })
69}
70
71fn take_bool_impl<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
72 let buffer = bools.inner().as_ref();
74 BitBuffer::collect_bool(indices.len(), |idx| {
75 let idx = unsafe { indices.get_unchecked(idx).as_() };
77 get_bit(buffer, bools.offset() + idx)
78 })
79}
80
81#[cfg(test)]
82mod test {
83 use rstest::rstest;
84 use vortex_buffer::buffer;
85
86 use crate::Array;
87 use crate::IntoArray as _;
88 use crate::ToCanonical;
89 use crate::arrays::BoolArray;
90 use crate::arrays::primitive::PrimitiveArray;
91 use crate::assert_arrays_eq;
92 use crate::compute::conformance::take::test_take_conformance;
93 use crate::validity::Validity;
94
95 #[test]
96 fn take_nullable() {
97 let reference = BoolArray::from_iter(vec![
98 Some(false),
99 Some(true),
100 Some(false),
101 None,
102 Some(false),
103 ]);
104
105 let b = reference
106 .take(buffer![0, 3, 4].into_array())
107 .unwrap()
108 .to_bool();
109 assert_eq!(
110 b.to_bit_buffer(),
111 BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
112 );
113
114 let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
115 let b = reference.take(all_invalid_indices.to_array()).unwrap();
116 assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
117 }
118
119 #[test]
120 fn test_bool_array_take_with_null_out_of_bounds_indices() {
121 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
122 let indices = PrimitiveArray::new(
123 buffer![0, 3, 100],
124 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
125 );
126 let actual = values.take(indices.to_array()).unwrap();
127
128 assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
130 }
131
132 #[test]
133 fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
134 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
135 let indices = PrimitiveArray::new(
136 buffer![0, 3, 100],
137 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
138 );
139 let actual = values.take(indices.to_array()).unwrap();
140 assert_arrays_eq!(
142 actual,
143 BoolArray::from_iter([Some(false), Some(true), None])
144 );
145 }
146
147 #[test]
148 fn test_bool_array_take_all_null_indices() {
149 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
150 let indices = PrimitiveArray::new(
151 buffer![0, 3, 100],
152 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
153 );
154 let actual = values.take(indices.to_array()).unwrap();
155 assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
156 }
157
158 #[test]
159 fn test_non_null_bool_array_take_all_null_indices() {
160 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
161 let indices = PrimitiveArray::new(
162 buffer![0, 3, 100],
163 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
164 );
165 let actual = values.take(indices.to_array()).unwrap();
166 assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
167 }
168
169 #[rstest]
170 #[case(BoolArray::from_iter([true, false, true, true, false]))]
171 #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
172 #[case(BoolArray::from_iter([true, false]))]
173 #[case(BoolArray::from_iter([true]))]
174 fn test_take_bool_conformance(#[case] array: BoolArray) {
175 test_take_conformance(array.as_ref());
176 }
177}