vortex_array/arrays/bool/compute/
take.rs1use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::BitBufferView;
8use vortex_buffer::get_bit;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::Bool;
16use crate::arrays::BoolArray;
17use crate::arrays::ConstantArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::bool::BoolArrayExt;
20use crate::arrays::dict::TakeExecute;
21use crate::builtins::ArrayBuiltins;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::scalar::Scalar;
25
26impl TakeExecute for Bool {
27 fn take(
28 array: ArrayView<'_, Bool>,
29 indices: &ArrayRef,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let indices_nulls_zeroed = match indices.validity()?.execute_mask(indices.len(), ctx)? {
33 Mask::AllTrue(_) => indices.clone(),
34 Mask::AllFalse(_) => {
35 return Ok(Some(
36 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
37 .into_array(),
38 ));
39 }
40 Mask::Values(_) => indices
41 .clone()
42 .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
43 };
44 let indices_nulls_zeroed = indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?;
45 let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
46 take_valid_indices(
47 array.bit_buffer_view(),
48 indices_nulls_zeroed.as_slice::<I>(),
49 )
50 });
51
52 Ok(Some(
53 BoolArray::new(buffer, array.validity()?.take(indices)?).into_array(),
54 ))
55 }
56}
57
58fn take_valid_indices<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
59 if bools.len() <= 4096 {
62 let bools = bools.iter().collect_vec();
63 take_byte_bool(bools, indices)
64 } else {
65 take_bool_impl(bools, indices)
66 }
67}
68
69fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
70 BitBuffer::collect_bool(indices.len(), |idx| {
71 bools[unsafe { indices.get_unchecked(idx).as_() }]
72 })
73}
74
75fn take_bool_impl<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
76 let buffer = bools.inner();
78 BitBuffer::collect_bool(indices.len(), |idx| {
79 let idx = unsafe { indices.get_unchecked(idx).as_() };
81 get_bit(buffer, bools.offset() + idx)
82 })
83}
84
85#[cfg(test)]
86mod test {
87 use rstest::rstest;
88 use vortex_buffer::buffer;
89
90 use crate::IntoArray as _;
91 #[expect(deprecated)]
92 use crate::ToCanonical as _;
93 use crate::arrays::BoolArray;
94 use crate::arrays::PrimitiveArray;
95 use crate::arrays::bool::BoolArrayExt;
96 use crate::assert_arrays_eq;
97 use crate::compute::conformance::take::test_take_conformance;
98 use crate::validity::Validity;
99
100 #[test]
101 fn take_nullable() {
102 let reference = BoolArray::from_iter(vec![
103 Some(false),
104 Some(true),
105 Some(false),
106 None,
107 Some(false),
108 ]);
109
110 #[expect(deprecated)]
111 let b = reference
112 .take(buffer![0, 3, 4].into_array())
113 .unwrap()
114 .to_bool();
115 assert_eq!(
116 b.to_bit_buffer(),
117 BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
118 );
119
120 let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
121 let b = reference.take(all_invalid_indices.into_array()).unwrap();
122 assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
123 }
124
125 #[test]
126 fn test_bool_array_take_with_null_out_of_bounds_indices() {
127 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
128 let indices = PrimitiveArray::new(
129 buffer![0, 3, 100],
130 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
131 );
132 let actual = values.take(indices.into_array()).unwrap();
133
134 assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
136 }
137
138 #[test]
139 fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
140 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
141 let indices = PrimitiveArray::new(
142 buffer![0, 3, 100],
143 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
144 );
145 let actual = values.take(indices.into_array()).unwrap();
146 assert_arrays_eq!(
148 actual,
149 BoolArray::from_iter([Some(false), Some(true), None])
150 );
151 }
152
153 #[test]
154 fn test_bool_array_take_all_null_indices() {
155 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
156 let indices = PrimitiveArray::new(
157 buffer![0, 3, 100],
158 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
159 );
160 let actual = values.take(indices.into_array()).unwrap();
161 assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
162 }
163
164 #[test]
165 fn test_non_null_bool_array_take_all_null_indices() {
166 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
167 let indices = PrimitiveArray::new(
168 buffer![0, 3, 100],
169 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
170 );
171 let actual = values.take(indices.into_array()).unwrap();
172 assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
173 }
174
175 #[rstest]
176 #[case(BoolArray::from_iter([true, false, true, true, false]))]
177 #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
178 #[case(BoolArray::from_iter([true, false]))]
179 #[case(BoolArray::from_iter([true]))]
180 fn test_take_bool_conformance(#[case] array: BoolArray) {
181 test_take_conformance(&array.into_array());
182 }
183}