vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
2use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
3use vortex_array::validity::Validity;
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::vtable::ComputeVTable;
6use vortex_array::{Array, ArrayRef, ToCanonical};
7use vortex_dtype::{Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use super::{ByteBoolArray, ByteBoolEncoding};
13
14impl ComputeVTable for ByteBoolEncoding {
15 fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<&dyn Array>> {
16 None
17 }
18
19 fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
20 Some(self)
21 }
22
23 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
24 Some(self)
25 }
26
27 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
28 Some(self)
29 }
30
31 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
32 Some(self)
33 }
34}
35
36impl MaskFn<&ByteBoolArray> for ByteBoolEncoding {
37 fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult<ArrayRef> {
38 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(&mask)?).into_array())
39 }
40}
41
42impl ScalarAtFn<&ByteBoolArray> for ByteBoolEncoding {
43 fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
44 Ok(Scalar::bool(
45 array.buffer()[index] == 1,
46 array.dtype().nullability(),
47 ))
48 }
49}
50
51impl SliceFn<&ByteBoolArray> for ByteBoolEncoding {
52 fn slice(&self, array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
53 Ok(ByteBoolArray::new(
54 array.buffer().slice(start..stop),
55 array.validity().slice(start, stop)?,
56 )
57 .into_array())
58 }
59}
60
61impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
62 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
63 let validity = array.validity_mask()?;
64 let indices = indices.to_primitive()?;
65 let bools = array.as_slice();
66
67 let arr = match validity {
70 Mask::AllTrue(_) => {
71 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
72 indices.as_slice::<$I>()
73 .iter()
74 .map(|&idx| {
75 let idx: usize = idx.as_();
76 bools[idx]
77 })
78 .collect::<Vec<_>>()
79 });
80
81 ByteBoolArray::from(bools).into_array()
82 }
83 Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
84 Mask::Values(values) => {
85 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
86 indices.as_slice::<$I>()
87 .iter()
88 .map(|&idx| {
89 let idx = idx.as_();
90 if values.value(idx) {
91 Some(bools[idx])
92 } else {
93 None
94 }
95 })
96 .collect::<Vec<Option<_>>>()
97 });
98
99 ByteBoolArray::from(bools).into_array()
100 }
101 };
102
103 Ok(arr)
104 }
105}
106
107impl FillForwardFn<&ByteBoolArray> for ByteBoolEncoding {
108 fn fill_forward(&self, array: &ByteBoolArray) -> VortexResult<ArrayRef> {
109 let validity = array.validity_mask()?;
110 if array.dtype().nullability() == Nullability::NonNullable {
111 return Ok(array.to_array().into_array());
112 }
113 if validity.all_true() {
115 return Ok(ByteBoolArray::new(array.buffer().clone(), Validity::AllValid).into_array());
116 }
117 if validity.all_false() {
119 return Ok(
120 ByteBoolArray::from_vec(vec![false; array.len()], Validity::AllValid).into_array(),
121 );
122 }
123
124 let validity = validity
125 .to_null_buffer()
126 .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;
127
128 let bools = array.as_slice();
129 let mut last_value = bool::default();
130
131 let filled = bools
132 .iter()
133 .zip(validity.inner().iter())
134 .map(|(&v, is_valid)| {
135 if is_valid {
136 last_value = v
137 }
138
139 last_value
140 })
141 .collect::<Vec<_>>();
142
143 Ok(ByteBoolArray::from_vec(filled, Validity::AllValid).into_array())
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use vortex_array::compute::test_harness::test_mask;
150 use vortex_array::compute::{Operator, compare, scalar_at, slice};
151
152 use super::*;
153
154 #[test]
155 fn test_slice() {
156 let original = vec![Some(true), Some(true), None, Some(false), None];
157 let vortex_arr = ByteBoolArray::from(original);
158
159 let sliced_arr = slice(&vortex_arr, 1, 4).unwrap();
160 let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();
161
162 let s = scalar_at(&sliced_arr, 0).unwrap();
163 assert_eq!(s.as_bool().value(), Some(true));
164
165 let s = scalar_at(&sliced_arr, 1).unwrap();
166 assert!(!sliced_arr.is_valid(1).unwrap());
167 assert!(s.is_null());
168 assert_eq!(s.as_bool().value(), None);
169
170 let s = scalar_at(&sliced_arr, 2).unwrap();
171 assert_eq!(s.as_bool().value(), Some(false));
172 }
173
174 #[test]
175 fn test_compare_all_equal() {
176 let lhs = ByteBoolArray::from(vec![true; 5]);
177 let rhs = ByteBoolArray::from(vec![true; 5]);
178
179 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
180
181 for i in 0..arr.len() {
182 let s = scalar_at(&arr, i).unwrap();
183 assert!(s.is_valid());
184 assert_eq!(s.as_bool().value(), Some(true));
185 }
186 }
187
188 #[test]
189 fn test_compare_all_different() {
190 let lhs = ByteBoolArray::from(vec![false; 5]);
191 let rhs = ByteBoolArray::from(vec![true; 5]);
192
193 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
194
195 for i in 0..arr.len() {
196 let s = scalar_at(&arr, i).unwrap();
197 assert!(s.is_valid());
198 assert_eq!(s.as_bool().value(), Some(false));
199 }
200 }
201
202 #[test]
203 fn test_compare_with_nulls() {
204 let lhs = ByteBoolArray::from(vec![true; 5]);
205 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
206
207 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
208
209 for i in 0..3 {
210 let s = scalar_at(&arr, i).unwrap();
211 assert!(s.is_valid());
212 assert_eq!(s.as_bool().value(), Some(true));
213 }
214
215 let s = scalar_at(&arr, 3).unwrap();
216 assert!(s.is_valid());
217 assert_eq!(s.as_bool().value(), Some(false));
218
219 let s = scalar_at(&arr, 4).unwrap();
220 assert!(s.is_null());
221 }
222
223 #[test]
224 fn test_mask_byte_bool() {
225 test_mask(&ByteBoolArray::from(vec![true, false, true, true, false]));
226 test_mask(&ByteBoolArray::from(vec![
227 Some(true),
228 Some(true),
229 None,
230 Some(false),
231 None,
232 ]));
233 }
234}