vortex_array/arrays/bool/compute/
take.rs1use itertools::Itertools as _;
5use num_traits::AsPrimitive;
6use vortex_buffer::BitBuffer;
7use vortex_buffer::BitBufferView;
8use vortex_buffer::get_bit;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::array::ArrayView;
15use crate::arrays::Bool;
16use crate::arrays::BoolArray;
17use crate::arrays::ConstantArray;
18use crate::arrays::PrimitiveArray;
19use crate::arrays::bool::BoolArrayExt;
20use crate::arrays::dict::TakeExecute;
21use crate::builtins::ArrayBuiltins;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24use crate::scalar::Scalar;
25
26impl TakeExecute for Bool {
27 fn take(
28 array: ArrayView<'_, Bool>,
29 indices: &ArrayRef,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let indices_nulls_zeroed = match indices.validity()?.execute_mask(indices.len(), ctx)? {
33 Mask::AllTrue(_) => indices.clone(),
34 Mask::AllFalse(_) => {
35 return Ok(Some(
36 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
37 .into_array(),
38 ));
39 }
40 Mask::Values(_) => indices
41 .clone()
42 .fill_null(Scalar::from(0).cast(indices.dtype())?)?,
43 };
44 let indices_nulls_zeroed = indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?;
45 let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
46 take_valid_indices(
47 array.bit_buffer_view(),
48 indices_nulls_zeroed.as_slice::<I>(),
49 )
50 });
51
52 Ok(Some(
53 BoolArray::new(buffer, array.validity()?.take(indices)?).into_array(),
54 ))
55 }
56}
57
58fn take_valid_indices<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
59 if bools.len() <= 4096 {
62 let bools = bools.iter().collect_vec();
63 take_byte_bool(bools, indices)
64 } else {
65 take_bool_impl(bools, indices)
66 }
67}
68
69fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
70 BitBuffer::collect_bool(indices.len(), |idx| {
71 bools[unsafe { indices.get_unchecked(idx).as_() }]
72 })
73}
74
75fn take_bool_impl<I: AsPrimitive<usize>>(bools: BitBufferView<'_>, indices: &[I]) -> BitBuffer {
76 let buffer = bools.inner();
78 BitBuffer::collect_bool(indices.len(), |idx| {
79 let idx = unsafe { indices.get_unchecked(idx).as_() };
81 get_bit(buffer, bools.offset() + idx)
82 })
83}
84
85#[cfg(test)]
86mod test {
87 use rstest::rstest;
88 use vortex_buffer::buffer;
89
90 use crate::IntoArray as _;
91 #[expect(deprecated)]
92 use crate::ToCanonical as _;
93 use crate::VortexSessionExecute;
94 use crate::array_session;
95 use crate::arrays::BoolArray;
96 use crate::arrays::PrimitiveArray;
97 use crate::arrays::bool::BoolArrayExt;
98 use crate::assert_arrays_eq;
99 use crate::compute::conformance::take::test_take_conformance;
100 use crate::validity::Validity;
101
102 #[test]
103 fn take_nullable() {
104 let mut ctx = array_session().create_execution_ctx();
105 let reference = BoolArray::from_iter(vec![
106 Some(false),
107 Some(true),
108 Some(false),
109 None,
110 Some(false),
111 ]);
112
113 #[expect(deprecated)]
114 let b = reference
115 .take(buffer![0, 3, 4].into_array())
116 .unwrap()
117 .to_bool();
118 assert_eq!(
119 b.to_bit_buffer(),
120 BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
121 );
122
123 let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
124 let b = reference.take(all_invalid_indices.into_array()).unwrap();
125 assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]), &mut ctx);
126 }
127
128 #[test]
129 fn test_bool_array_take_with_null_out_of_bounds_indices() {
130 let mut ctx = array_session().create_execution_ctx();
131 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
132 let indices = PrimitiveArray::new(
133 buffer![0, 3, 100],
134 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
135 );
136 let actual = values.take(indices.into_array()).unwrap();
137
138 assert_arrays_eq!(
140 actual,
141 BoolArray::from_iter([Some(false), None, None]),
142 &mut ctx
143 );
144 }
145
146 #[test]
147 fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
148 let mut ctx = array_session().create_execution_ctx();
149 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
150 let indices = PrimitiveArray::new(
151 buffer![0, 3, 100],
152 Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
153 );
154 let actual = values.take(indices.into_array()).unwrap();
155 assert_arrays_eq!(
157 actual,
158 BoolArray::from_iter([Some(false), Some(true), None]),
159 &mut ctx
160 );
161 }
162
163 #[test]
164 fn test_bool_array_take_all_null_indices() {
165 let mut ctx = array_session().create_execution_ctx();
166 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
167 let indices = PrimitiveArray::new(
168 buffer![0, 3, 100],
169 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
170 );
171 let actual = values.take(indices.into_array()).unwrap();
172 assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]), &mut ctx);
173 }
174
175 #[test]
176 fn test_non_null_bool_array_take_all_null_indices() {
177 let mut ctx = array_session().create_execution_ctx();
178 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
179 let indices = PrimitiveArray::new(
180 buffer![0, 3, 100],
181 Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
182 );
183 let actual = values.take(indices.into_array()).unwrap();
184 assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]), &mut ctx);
185 }
186
187 #[rstest]
188 #[case(BoolArray::from_iter([true, false, true, true, false]))]
189 #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
190 #[case(BoolArray::from_iter([true, false]))]
191 #[case(BoolArray::from_iter([true]))]
192 fn test_take_bool_conformance(#[case] array: BoolArray) {
193 test_take_conformance(&array.into_array());
194 }
195}