vortex_array/arrays/bool/compute/
take.rs1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools as _;
3use num_traits::AsPrimitive;
4use vortex_dtype::match_each_integer_ptype;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::{BoolArray, BoolVTable, ConstantArray};
10use crate::compute::{TakeKernel, TakeKernelAdapter, fill_null};
11use crate::vtable::ValidityHelper;
12use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
13
14impl TakeKernel for BoolVTable {
15 fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16 let indices_nulls_zeroed = match indices.validity_mask()? {
17 Mask::AllTrue(_) => indices.to_array(),
18 Mask::AllFalse(_) => {
19 return Ok(ConstantArray::new(
20 Scalar::null(array.dtype().as_nullable()),
21 indices.len(),
22 )
23 .into_array());
24 }
25 Mask::Values(_) => fill_null(indices, &Scalar::from(0).cast(indices.dtype())?)?,
26 };
27 let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
28 let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
29 take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<I>())
30 });
31
32 Ok(BoolArray::new(buffer, array.validity().take(indices)?).to_array())
33 }
34}
35
36register_kernel!(TakeKernelAdapter(BoolVTable).lift());
37
38fn take_valid_indices<I: AsPrimitive<usize>>(
39 bools: &BooleanBuffer,
40 indices: &[I],
41) -> BooleanBuffer {
42 if bools.len() <= 4096 {
45 let bools = bools.into_iter().collect_vec();
46 take_byte_bool(bools, indices)
47 } else {
48 take_bool(bools, indices)
49 }
50}
51
52fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
53 BooleanBuffer::collect_bool(indices.len(), |idx| {
54 bools[unsafe { indices.get_unchecked(idx).as_() }]
55 })
56}
57
58fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
59 BooleanBuffer::collect_bool(indices.len(), |idx| {
60 bools.value(unsafe { indices.get_unchecked(idx).as_() })
62 })
63}
64
65#[cfg(test)]
66mod test {
67 use vortex_buffer::buffer;
68 use vortex_dtype::{DType, Nullability};
69 use vortex_scalar::Scalar;
70
71 use crate::arrays::BoolArray;
72 use crate::arrays::primitive::PrimitiveArray;
73 use crate::compute::take;
74 use crate::validity::Validity;
75 use crate::{Array, ToCanonical};
76
77 #[test]
78 fn take_nullable() {
79 let reference = BoolArray::from_iter(vec![
80 Some(false),
81 Some(true),
82 Some(false),
83 None,
84 Some(false),
85 ]);
86
87 let b = take(
88 reference.as_ref(),
89 PrimitiveArray::from_iter([0, 3, 4]).as_ref(),
90 )
91 .unwrap()
92 .to_bool()
93 .unwrap();
94 assert_eq!(
95 b.boolean_buffer(),
96 BoolArray::from_iter([Some(false), None, Some(false)]).boolean_buffer()
97 );
98
99 let nullable_bool_dtype = DType::Bool(Nullability::Nullable);
100 let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
101 let b = take(reference.as_ref(), all_invalid_indices.as_ref()).unwrap();
102 assert_eq!(b.dtype(), &nullable_bool_dtype);
103 assert_eq!(
104 b.scalar_at(0).unwrap(),
105 Scalar::null(nullable_bool_dtype.clone())
106 );
107 assert_eq!(
108 b.scalar_at(1).unwrap(),
109 Scalar::null(nullable_bool_dtype.clone())
110 );
111 assert_eq!(b.scalar_at(2).unwrap(), Scalar::null(nullable_bool_dtype));
112 }
113
114 #[test]
115 fn test_bool_array_take_with_null_out_of_bounds_indices() {
116 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
117 let indices = PrimitiveArray::new(
118 buffer![0, 3, 100],
119 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
120 );
121 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
122 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(false)));
123 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
125 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
127 }
128
129 #[test]
130 fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
131 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
132 let indices = PrimitiveArray::new(
133 buffer![0, 3, 100],
134 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
135 );
136 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
137 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::from(Some(false)));
138 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::from(Some(true)));
139 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
141 }
142
143 #[test]
144 fn test_bool_array_take_all_null_indices() {
145 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
146 let indices = PrimitiveArray::new(
147 buffer![0, 3, 100],
148 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
149 );
150 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
151 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::null_typed::<bool>());
152 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
153 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
154 }
155
156 #[test]
157 fn test_non_null_bool_array_take_all_null_indices() {
158 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
159 let indices = PrimitiveArray::new(
160 buffer![0, 3, 100],
161 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
162 );
163 let actual = take(values.as_ref(), indices.as_ref()).unwrap();
164 assert_eq!(actual.scalar_at(0).unwrap(), Scalar::null_typed::<bool>());
165 assert_eq!(actual.scalar_at(1).unwrap(), Scalar::null_typed::<bool>());
166 assert_eq!(actual.scalar_at(2).unwrap(), Scalar::null_typed::<bool>());
167 }
168}