vortex_array/arrays/bool/compute/
take.rs1use arrow_buffer::BooleanBuffer;
2use itertools::Itertools;
3use num_traits::AsPrimitive;
4use vortex_dtype::match_each_integer_ptype;
5use vortex_error::VortexResult;
6use vortex_mask::Mask;
7use vortex_scalar::Scalar;
8
9use crate::arrays::{BoolArray, BoolEncoding, ConstantArray};
10use crate::builders::ArrayBuilder;
11use crate::compute::{TakeFn, fill_null};
12use crate::variants::PrimitiveArrayTrait;
13use crate::{Array, ArrayRef, ToCanonical};
14
15impl TakeFn<&BoolArray> for BoolEncoding {
16 fn take(&self, array: &BoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
17 let indices_nulls_zeroed = match indices.validity_mask()? {
18 Mask::AllTrue(_) => indices.to_array(),
19 Mask::AllFalse(_) => {
20 return Ok(ConstantArray::new(
21 Scalar::null(array.dtype().as_nullable()),
22 indices.len(),
23 )
24 .into_array());
25 }
26 Mask::Values(_) => fill_null(indices, Scalar::from(0).cast(indices.dtype())?)?,
27 };
28 let indices_nulls_zeroed = indices_nulls_zeroed.to_primitive()?;
29 let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |$I| {
30 take_valid_indices(array.boolean_buffer(), indices_nulls_zeroed.as_slice::<$I>())
31 });
32
33 Ok(BoolArray::new(buffer, array.validity().take(indices)?).into_array())
34 }
35
36 fn take_into(
37 &self,
38 array: &BoolArray,
39 indices: &dyn Array,
40 builder: &mut dyn ArrayBuilder,
41 ) -> VortexResult<()> {
42 builder.extend_from_array(&self.take(array, indices)?)
43 }
44}
45
46fn take_valid_indices<I: AsPrimitive<usize>>(
47 bools: &BooleanBuffer,
48 indices: &[I],
49) -> BooleanBuffer {
50 if bools.len() <= 4096 {
53 let bools = bools.into_iter().collect_vec();
54 take_byte_bool(bools, indices)
55 } else {
56 take_bool(bools, indices)
57 }
58}
59
60fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
61 BooleanBuffer::collect_bool(indices.len(), |idx| {
62 bools[unsafe { indices.get_unchecked(idx).as_() }]
63 })
64}
65
66fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
67 BooleanBuffer::collect_bool(indices.len(), |idx| {
68 bools.value(unsafe { indices.get_unchecked(idx).as_() })
70 })
71}
72
73#[cfg(test)]
74mod test {
75 use vortex_buffer::buffer;
76 use vortex_dtype::{DType, Nullability};
77 use vortex_scalar::Scalar;
78
79 use crate::arrays::BoolArray;
80 use crate::arrays::primitive::PrimitiveArray;
81 use crate::compute::{scalar_at, take};
82 use crate::validity::Validity;
83 use crate::{Array, ToCanonical};
84
85 #[test]
86 fn take_nullable() {
87 let reference = BoolArray::from_iter(vec![
88 Some(false),
89 Some(true),
90 Some(false),
91 None,
92 Some(false),
93 ]);
94
95 let b = take(&reference, &PrimitiveArray::from_iter([0, 3, 4]))
96 .unwrap()
97 .to_bool()
98 .unwrap();
99 assert_eq!(
100 b.boolean_buffer(),
101 BoolArray::from_iter([Some(false), None, Some(false)]).boolean_buffer()
102 );
103
104 let nullable_bool_dtype = DType::Bool(Nullability::Nullable);
105 let all_invalid_indices = PrimitiveArray::from_option_iter([None::<u32>, None, None]);
106 let b = take(&reference, &all_invalid_indices).unwrap();
107 assert_eq!(b.dtype(), &nullable_bool_dtype);
108 assert_eq!(
109 scalar_at(&b, 0).unwrap(),
110 Scalar::null(nullable_bool_dtype.clone())
111 );
112 assert_eq!(
113 scalar_at(&b, 1).unwrap(),
114 Scalar::null(nullable_bool_dtype.clone())
115 );
116 assert_eq!(scalar_at(&b, 2).unwrap(), Scalar::null(nullable_bool_dtype));
117 }
118
119 #[test]
120 fn test_bool_array_take_with_null_out_of_bounds_indices() {
121 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
122 let indices = PrimitiveArray::new(
123 buffer![0, 3, 100],
124 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
125 );
126 let actual = take(&values, &indices).unwrap();
127 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(false)));
128 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<bool>());
130 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
132 }
133
134 #[test]
135 fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
136 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
137 let indices = PrimitiveArray::new(
138 buffer![0, 3, 100],
139 Validity::Array(BoolArray::from_iter([true, true, false]).to_array()),
140 );
141 let actual = take(&values, &indices).unwrap();
142 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::from(Some(false)));
143 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::from(Some(true)));
144 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
146 }
147
148 #[test]
149 fn test_bool_array_take_all_null_indices() {
150 let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
151 let indices = PrimitiveArray::new(
152 buffer![0, 3, 100],
153 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
154 );
155 let actual = take(&values, &indices).unwrap();
156 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<bool>());
157 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<bool>());
158 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
159 }
160
161 #[test]
162 fn test_non_null_bool_array_take_all_null_indices() {
163 let values = BoolArray::from_iter(vec![false, true, false, true, false]);
164 let indices = PrimitiveArray::new(
165 buffer![0, 3, 100],
166 Validity::Array(BoolArray::from_iter([false, false, false]).to_array()),
167 );
168 let actual = take(&values, &indices).unwrap();
169 assert_eq!(scalar_at(&actual, 0).unwrap(), Scalar::null_typed::<bool>());
170 assert_eq!(scalar_at(&actual, 1).unwrap(), Scalar::null_typed::<bool>());
171 assert_eq!(scalar_at(&actual, 2).unwrap(), Scalar::null_typed::<bool>());
172 }
173}