vortex_bytebool/
compute.rs

1use num_traits::AsPrimitive;
2use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
3use vortex_array::validity::Validity;
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::vtable::ComputeVTable;
6use vortex_array::{Array, ArrayRef, ToCanonical};
7use vortex_dtype::{Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use super::{ByteBoolArray, ByteBoolEncoding};
13
14impl ComputeVTable for ByteBoolEncoding {
15    fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<&dyn Array>> {
16        None
17    }
18
19    fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
20        Some(self)
21    }
22
23    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
24        Some(self)
25    }
26
27    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
28        Some(self)
29    }
30
31    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
32        Some(self)
33    }
34}
35
36impl MaskFn<&ByteBoolArray> for ByteBoolEncoding {
37    fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult<ArrayRef> {
38        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(&mask)?).into_array())
39    }
40}
41
42impl ScalarAtFn<&ByteBoolArray> for ByteBoolEncoding {
43    fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
44        Ok(Scalar::bool(
45            array.buffer()[index] == 1,
46            array.dtype().nullability(),
47        ))
48    }
49}
50
51impl SliceFn<&ByteBoolArray> for ByteBoolEncoding {
52    fn slice(&self, array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
53        Ok(ByteBoolArray::new(
54            array.buffer().slice(start..stop),
55            array.validity().slice(start, stop)?,
56        )
57        .into_array())
58    }
59}
60
61impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
62    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
63        let validity = array.validity_mask()?;
64        let indices = indices.to_primitive()?;
65        let bools = array.as_slice();
66
67        // FIXME(ngates): we should be operating over canonical validity, which doesn't
68        //  have fallible is_valid function.
69        let arr = match validity {
70            Mask::AllTrue(_) => {
71                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
72                    indices.as_slice::<$I>()
73                    .iter()
74                    .map(|&idx| {
75                        let idx: usize = idx.as_();
76                        bools[idx]
77                    })
78                    .collect::<Vec<_>>()
79                });
80
81                ByteBoolArray::from(bools).into_array()
82            }
83            Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
84            Mask::Values(values) => {
85                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
86                    indices.as_slice::<$I>()
87                    .iter()
88                    .map(|&idx| {
89                        let idx = idx.as_();
90                        if values.value(idx) {
91                            Some(bools[idx])
92                        } else {
93                            None
94                        }
95                    })
96                    .collect::<Vec<Option<_>>>()
97                });
98
99                ByteBoolArray::from(bools).into_array()
100            }
101        };
102
103        Ok(arr)
104    }
105}
106
107impl FillForwardFn<&ByteBoolArray> for ByteBoolEncoding {
108    fn fill_forward(&self, array: &ByteBoolArray) -> VortexResult<ArrayRef> {
109        let validity = array.validity_mask()?;
110        if array.dtype().nullability() == Nullability::NonNullable {
111            return Ok(array.to_array().into_array());
112        }
113        // all valid, but we need to convert to non-nullable
114        if validity.all_true() {
115            return Ok(ByteBoolArray::new(array.buffer().clone(), Validity::AllValid).into_array());
116        }
117        // all invalid => fill with default value (false)
118        if validity.all_false() {
119            return Ok(
120                ByteBoolArray::from_vec(vec![false; array.len()], Validity::AllValid).into_array(),
121            );
122        }
123
124        let validity = validity
125            .to_null_buffer()
126            .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;
127
128        let bools = array.as_slice();
129        let mut last_value = bool::default();
130
131        let filled = bools
132            .iter()
133            .zip(validity.inner().iter())
134            .map(|(&v, is_valid)| {
135                if is_valid {
136                    last_value = v
137                }
138
139                last_value
140            })
141            .collect::<Vec<_>>();
142
143        Ok(ByteBoolArray::from_vec(filled, Validity::AllValid).into_array())
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use vortex_array::compute::test_harness::test_mask;
150    use vortex_array::compute::{Operator, compare, scalar_at, slice};
151
152    use super::*;
153
154    #[test]
155    fn test_slice() {
156        let original = vec![Some(true), Some(true), None, Some(false), None];
157        let vortex_arr = ByteBoolArray::from(original);
158
159        let sliced_arr = slice(&vortex_arr, 1, 4).unwrap();
160        let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();
161
162        let s = scalar_at(&sliced_arr, 0).unwrap();
163        assert_eq!(s.as_bool().value(), Some(true));
164
165        let s = scalar_at(&sliced_arr, 1).unwrap();
166        assert!(!sliced_arr.is_valid(1).unwrap());
167        assert!(s.is_null());
168        assert_eq!(s.as_bool().value(), None);
169
170        let s = scalar_at(&sliced_arr, 2).unwrap();
171        assert_eq!(s.as_bool().value(), Some(false));
172    }
173
174    #[test]
175    fn test_compare_all_equal() {
176        let lhs = ByteBoolArray::from(vec![true; 5]);
177        let rhs = ByteBoolArray::from(vec![true; 5]);
178
179        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
180
181        for i in 0..arr.len() {
182            let s = scalar_at(&arr, i).unwrap();
183            assert!(s.is_valid());
184            assert_eq!(s.as_bool().value(), Some(true));
185        }
186    }
187
188    #[test]
189    fn test_compare_all_different() {
190        let lhs = ByteBoolArray::from(vec![false; 5]);
191        let rhs = ByteBoolArray::from(vec![true; 5]);
192
193        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
194
195        for i in 0..arr.len() {
196            let s = scalar_at(&arr, i).unwrap();
197            assert!(s.is_valid());
198            assert_eq!(s.as_bool().value(), Some(false));
199        }
200    }
201
202    #[test]
203    fn test_compare_with_nulls() {
204        let lhs = ByteBoolArray::from(vec![true; 5]);
205        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
206
207        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
208
209        for i in 0..3 {
210            let s = scalar_at(&arr, i).unwrap();
211            assert!(s.is_valid());
212            assert_eq!(s.as_bool().value(), Some(true));
213        }
214
215        let s = scalar_at(&arr, 3).unwrap();
216        assert!(s.is_valid());
217        assert_eq!(s.as_bool().value(), Some(false));
218
219        let s = scalar_at(&arr, 4).unwrap();
220        assert!(s.is_null());
221    }
222
223    #[test]
224    fn test_mask_byte_bool() {
225        test_mask(&ByteBoolArray::from(vec![true, false, true, true, false]));
226        test_mask(&ByteBoolArray::from(vec![
227            Some(true),
228            Some(true),
229            None,
230            Some(false),
231            None,
232        ]));
233    }
234}