vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
2use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
3use vortex_array::validity::Validity;
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::vtable::ComputeVTable;
6use vortex_array::{Array, ArrayComputeImpl, ArrayRef, ToCanonical};
7use vortex_dtype::{Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use super::{ByteBoolArray, ByteBoolEncoding};
13
14impl ArrayComputeImpl for ByteBoolArray {}
15
16impl ComputeVTable for ByteBoolEncoding {
17 fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<&dyn Array>> {
18 None
19 }
20
21 fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
22 Some(self)
23 }
24
25 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
26 Some(self)
27 }
28
29 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
30 Some(self)
31 }
32
33 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
34 Some(self)
35 }
36}
37
38impl MaskFn<&ByteBoolArray> for ByteBoolEncoding {
39 fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult<ArrayRef> {
40 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(&mask)?).into_array())
41 }
42}
43
44impl ScalarAtFn<&ByteBoolArray> for ByteBoolEncoding {
45 fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
46 Ok(Scalar::bool(
47 array.buffer()[index] == 1,
48 array.dtype().nullability(),
49 ))
50 }
51}
52
53impl SliceFn<&ByteBoolArray> for ByteBoolEncoding {
54 fn slice(&self, array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
55 Ok(ByteBoolArray::new(
56 array.buffer().slice(start..stop),
57 array.validity().slice(start, stop)?,
58 )
59 .into_array())
60 }
61}
62
63impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
64 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
65 let validity = array.validity_mask()?;
66 let indices = indices.to_primitive()?;
67 let bools = array.as_slice();
68
69 let arr = match validity {
72 Mask::AllTrue(_) => {
73 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
74 indices.as_slice::<$I>()
75 .iter()
76 .map(|&idx| {
77 let idx: usize = idx.as_();
78 bools[idx]
79 })
80 .collect::<Vec<_>>()
81 });
82
83 ByteBoolArray::from(bools).into_array()
84 }
85 Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
86 Mask::Values(values) => {
87 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
88 indices.as_slice::<$I>()
89 .iter()
90 .map(|&idx| {
91 let idx = idx.as_();
92 if values.value(idx) {
93 Some(bools[idx])
94 } else {
95 None
96 }
97 })
98 .collect::<Vec<Option<_>>>()
99 });
100
101 ByteBoolArray::from(bools).into_array()
102 }
103 };
104
105 Ok(arr)
106 }
107}
108
109impl FillForwardFn<&ByteBoolArray> for ByteBoolEncoding {
110 fn fill_forward(&self, array: &ByteBoolArray) -> VortexResult<ArrayRef> {
111 let validity = array.validity_mask()?;
112 if array.dtype().nullability() == Nullability::NonNullable {
113 return Ok(array.to_array().into_array());
114 }
115 if validity.all_true() {
117 return Ok(ByteBoolArray::new(array.buffer().clone(), Validity::AllValid).into_array());
118 }
119 if validity.all_false() {
121 return Ok(
122 ByteBoolArray::from_vec(vec![false; array.len()], Validity::AllValid).into_array(),
123 );
124 }
125
126 let validity = validity
127 .to_null_buffer()
128 .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;
129
130 let bools = array.as_slice();
131 let mut last_value = bool::default();
132
133 let filled = bools
134 .iter()
135 .zip(validity.inner().iter())
136 .map(|(&v, is_valid)| {
137 if is_valid {
138 last_value = v
139 }
140
141 last_value
142 })
143 .collect::<Vec<_>>();
144
145 Ok(ByteBoolArray::from_vec(filled, Validity::AllValid).into_array())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use vortex_array::compute::conformance::mask::test_mask;
152 use vortex_array::compute::{Operator, compare, scalar_at, slice};
153
154 use super::*;
155
156 #[test]
157 fn test_slice() {
158 let original = vec![Some(true), Some(true), None, Some(false), None];
159 let vortex_arr = ByteBoolArray::from(original);
160
161 let sliced_arr = slice(&vortex_arr, 1, 4).unwrap();
162 let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();
163
164 let s = scalar_at(&sliced_arr, 0).unwrap();
165 assert_eq!(s.as_bool().value(), Some(true));
166
167 let s = scalar_at(&sliced_arr, 1).unwrap();
168 assert!(!sliced_arr.is_valid(1).unwrap());
169 assert!(s.is_null());
170 assert_eq!(s.as_bool().value(), None);
171
172 let s = scalar_at(&sliced_arr, 2).unwrap();
173 assert_eq!(s.as_bool().value(), Some(false));
174 }
175
176 #[test]
177 fn test_compare_all_equal() {
178 let lhs = ByteBoolArray::from(vec![true; 5]);
179 let rhs = ByteBoolArray::from(vec![true; 5]);
180
181 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
182
183 for i in 0..arr.len() {
184 let s = scalar_at(&arr, i).unwrap();
185 assert!(s.is_valid());
186 assert_eq!(s.as_bool().value(), Some(true));
187 }
188 }
189
190 #[test]
191 fn test_compare_all_different() {
192 let lhs = ByteBoolArray::from(vec![false; 5]);
193 let rhs = ByteBoolArray::from(vec![true; 5]);
194
195 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
196
197 for i in 0..arr.len() {
198 let s = scalar_at(&arr, i).unwrap();
199 assert!(s.is_valid());
200 assert_eq!(s.as_bool().value(), Some(false));
201 }
202 }
203
204 #[test]
205 fn test_compare_with_nulls() {
206 let lhs = ByteBoolArray::from(vec![true; 5]);
207 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
208
209 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
210
211 for i in 0..3 {
212 let s = scalar_at(&arr, i).unwrap();
213 assert!(s.is_valid());
214 assert_eq!(s.as_bool().value(), Some(true));
215 }
216
217 let s = scalar_at(&arr, 3).unwrap();
218 assert!(s.is_valid());
219 assert_eq!(s.as_bool().value(), Some(false));
220
221 let s = scalar_at(&arr, 4).unwrap();
222 assert!(s.is_null());
223 }
224
225 #[test]
226 fn test_mask_byte_bool() {
227 test_mask(&ByteBoolArray::from(vec![true, false, true, true, false]));
228 test_mask(&ByteBoolArray::from(vec![
229 Some(true),
230 Some(true),
231 None,
232 Some(false),
233 None,
234 ]));
235 }
236}