vortex_array/arrays/bool/compute/
take.rs1use arrow_buffer::BooleanBuffer;
5use itertools::Itertools as _;
6use num_traits::AsPrimitive;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexResult;
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use crate::arrays::{BoolArray, BoolVTable, ConstantArray};
13use crate::compute::{TakeKernel, TakeKernelAdapter, fill_null};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for BoolVTable {
18 fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 let indices_nulls_zeroed = match indices.validity_mask()? {
20 Mask::AllTrue(_) => indices.to_array(),
21 Mask::AllFalse(_) => {
22 return Ok(ConstantArray::new(
23 Scalar::null(array.dtype().as_nullable()),
24 indices.len(),
25 )
26 .into_array());
27 }
28 Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?,
29 };
30 let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
31 let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
32 take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<I>())
33 });
34
35 Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array())
36 }
37}
38
39register_kernel!(TakeKernelAdapter(BoolVTable).lift());
40
41fn take_valid_indices<I: AsPrimitive<usize>>(
42 bools: &BooleanBuffer,
43 indices: &[I],
44) -> BooleanBuffer {
45 if bools.len() <= 4096 {
48 let bools = bools.into_iter().collect_vec();
49 take_byte_bool(bools, indices)
50 } else {
51 take_bool(bools, indices)
52 }
53}
54
55fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
56 BooleanBuffer::collect_bool(indices.len(), |idx| {
57 bools[unsafe { indices.get_unchecked(idx).as_() }]
58 })
59}
60
61fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
62 BooleanBuffer::collect_bool(indices.len(), |idx| {
63 bools.value(unsafe { indices.get_unchecked(idx).as_() })
65 })
66}
67
68#[cfg(test)]
69mod test {
70 use rstest::rstest;
71 use vortex_buffer::buffer;
72 use vortex_dtype::{DType, Nullability};
73 use vortex_scalar::Scalar;
74
75 use crate::arrays::BoolArray;
76 use crate::arrays::primitive::PrimitiveArray;
77 use crate::compute::conformance::take::test_take_conformance;
78 use crate::compute::take;
79 use crate::validity::Validity;
80 use crate::{Array, ToCanonical};
81
82 #[test]
83 fn take_nullable() {
84 let reference = BoolArray::from_iter(vec![
85 Some(false),
86 Some(true),
87 Some(false),
88 None,
89 Some(false),
90 ]);
91
92 let b = take(
93 reference.as_ref(),
94 PrimitiveArray::from_iter([0, 3, 4]).as_ref(),
95 )
96 .unwrap()
97 .to_bool()
98 .unwrap();
99 assert_eq!(
100 b.boolean_buffer(),
101 BoolArray::from_iter([Some(false), None, Some(false)]).boolean_buffer()
102 );
103
104 let nullable_bool_dtype = DType::Bool(Nullability::Nullable);
105 let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
106 let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap();
107 assert_eq!(b.dtype(), &nullable_bool_dtype);
108 assert_eq!(b.scalar_at(0), Scalar::null(nullable_bool_dtype.clone()));
109 assert_eq!(b.scalar_at(1), Scalar::null(nullable_bool_dtype.clone()));
110 assert_eq!(b.scalar_at(2), Scalar::null(nullable_bool_dtype));
111 }
112
113 #[test]
114 fn test_bool_array_take_with_null_out_of_bounds_indices() {
115 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
116 let indices = PrimitiveArray::new(
117 buffer![0, 3, 100],
118 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
119 );
120 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
121 assert_eq!(actual.scalar_at(0), Scalar::from(Some(false)));
122 assert_eq!(actual.scalar_at(1), Scalar::null_typed::<bool>());
124 assert_eq!(actual.scalar_at(2), Scalar::null_typed::<bool>());
126 }
127
128 #[test]
129 fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
130 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
131 let indices = PrimitiveArray::new(
132 buffer![0, 3, 100],
133 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
134 );
135 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
136 assert_eq!(actual.scalar_at(0), Scalar::from(Some(false)));
137 assert_eq!(actual.scalar_at(1), Scalar::from(Some(true)));
138 assert_eq!(actual.scalar_at(2), Scalar::null_typed::<bool>());
140 }
141
142 #[test]
143 fn test_bool_array_take_all_null_indices() {
144 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
145 let indices = PrimitiveArray::new(
146 buffer![0, 3, 100],
147 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
148 );
149 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
150 assert_eq!(actual.scalar_at(0), Scalar::null_typed::<bool>());
151 assert_eq!(actual.scalar_at(1), Scalar::null_typed::<bool>());
152 assert_eq!(actual.scalar_at(2), Scalar::null_typed::<bool>());
153 }
154
155 #[test]
156 fn test_non_null_bool_array_take_all_null_indices() {
157 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
158 let indices = PrimitiveArray::new(
159 buffer![0, 3, 100],
160 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
161 );
162 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
163 assert_eq!(actual.scalar_at(0), Scalar::null_typed::<bool>());
164 assert_eq!(actual.scalar_at(1), Scalar::null_typed::<bool>());
165 assert_eq!(actual.scalar_at(2), Scalar::null_typed::<bool>());
166 }
167
168 #[rstest]
169 #[case(BoolArray::from_iter([true, false, true, true, false]))]
170 #[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
171 #[case(BoolArray::from_iter([true, false]))]
172 #[case(BoolArray::from_iter([true]))]
173 fn test_take_bool_conformance(#[case] array: BoolArray) {
174 test_take_conformance(array.as_ref());
175 }
176}