vortex_bytebool/
compute.rs

1use num_traits::AsPrimitive;
2use vortex_array::compute::{FillForwardFn, MaskFn, ScalarAtFn, SliceFn, TakeFn};
3use vortex_array::validity::Validity;
4use vortex_array::variants::PrimitiveArrayTrait;
5use vortex_array::vtable::ComputeVTable;
6use vortex_array::{Array, ArrayComputeImpl, ArrayRef, ToCanonical};
7use vortex_dtype::{Nullability, match_each_integer_ptype};
8use vortex_error::{VortexResult, vortex_err};
9use vortex_mask::Mask;
10use vortex_scalar::Scalar;
11
12use super::{ByteBoolArray, ByteBoolEncoding};
13
14impl ArrayComputeImpl for ByteBoolArray {}
15
16impl ComputeVTable for ByteBoolEncoding {
17    fn fill_forward_fn(&self) -> Option<&dyn FillForwardFn<&dyn Array>> {
18        None
19    }
20
21    fn mask_fn(&self) -> Option<&dyn MaskFn<&dyn Array>> {
22        Some(self)
23    }
24
25    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
26        Some(self)
27    }
28
29    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
30        Some(self)
31    }
32
33    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
34        Some(self)
35    }
36}
37
38impl MaskFn<&ByteBoolArray> for ByteBoolEncoding {
39    fn mask(&self, array: &ByteBoolArray, mask: Mask) -> VortexResult<ArrayRef> {
40        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(&mask)?).into_array())
41    }
42}
43
44impl ScalarAtFn<&ByteBoolArray> for ByteBoolEncoding {
45    fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
46        Ok(Scalar::bool(
47            array.buffer()[index] == 1,
48            array.dtype().nullability(),
49        ))
50    }
51}
52
53impl SliceFn<&ByteBoolArray> for ByteBoolEncoding {
54    fn slice(&self, array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
55        Ok(ByteBoolArray::new(
56            array.buffer().slice(start..stop),
57            array.validity().slice(start, stop)?,
58        )
59        .into_array())
60    }
61}
62
63impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
64    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
65        let validity = array.validity_mask()?;
66        let indices = indices.to_primitive()?;
67        let bools = array.as_slice();
68
69        // FIXME(ngates): we should be operating over canonical validity, which doesn't
70        //  have fallible is_valid function.
71        let arr = match validity {
72            Mask::AllTrue(_) => {
73                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
74                    indices.as_slice::<$I>()
75                    .iter()
76                    .map(|&idx| {
77                        let idx: usize = idx.as_();
78                        bools[idx]
79                    })
80                    .collect::<Vec<_>>()
81                });
82
83                ByteBoolArray::from(bools).into_array()
84            }
85            Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
86            Mask::Values(values) => {
87                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
88                    indices.as_slice::<$I>()
89                    .iter()
90                    .map(|&idx| {
91                        let idx = idx.as_();
92                        if values.value(idx) {
93                            Some(bools[idx])
94                        } else {
95                            None
96                        }
97                    })
98                    .collect::<Vec<Option<_>>>()
99                });
100
101                ByteBoolArray::from(bools).into_array()
102            }
103        };
104
105        Ok(arr)
106    }
107}
108
109impl FillForwardFn<&ByteBoolArray> for ByteBoolEncoding {
110    fn fill_forward(&self, array: &ByteBoolArray) -> VortexResult<ArrayRef> {
111        let validity = array.validity_mask()?;
112        if array.dtype().nullability() == Nullability::NonNullable {
113            return Ok(array.to_array().into_array());
114        }
115        // all valid, but we need to convert to non-nullable
116        if validity.all_true() {
117            return Ok(ByteBoolArray::new(array.buffer().clone(), Validity::AllValid).into_array());
118        }
119        // all invalid => fill with default value (false)
120        if validity.all_false() {
121            return Ok(
122                ByteBoolArray::from_vec(vec![false; array.len()], Validity::AllValid).into_array(),
123            );
124        }
125
126        let validity = validity
127            .to_null_buffer()
128            .ok_or_else(|| vortex_err!("Failed to convert array validity to null buffer"))?;
129
130        let bools = array.as_slice();
131        let mut last_value = bool::default();
132
133        let filled = bools
134            .iter()
135            .zip(validity.inner().iter())
136            .map(|(&v, is_valid)| {
137                if is_valid {
138                    last_value = v
139                }
140
141                last_value
142            })
143            .collect::<Vec<_>>();
144
145        Ok(ByteBoolArray::from_vec(filled, Validity::AllValid).into_array())
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use vortex_array::compute::conformance::mask::test_mask;
152    use vortex_array::compute::{Operator, compare, scalar_at, slice};
153
154    use super::*;
155
156    #[test]
157    fn test_slice() {
158        let original = vec![Some(true), Some(true), None, Some(false), None];
159        let vortex_arr = ByteBoolArray::from(original);
160
161        let sliced_arr = slice(&vortex_arr, 1, 4).unwrap();
162        let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();
163
164        let s = scalar_at(&sliced_arr, 0).unwrap();
165        assert_eq!(s.as_bool().value(), Some(true));
166
167        let s = scalar_at(&sliced_arr, 1).unwrap();
168        assert!(!sliced_arr.is_valid(1).unwrap());
169        assert!(s.is_null());
170        assert_eq!(s.as_bool().value(), None);
171
172        let s = scalar_at(&sliced_arr, 2).unwrap();
173        assert_eq!(s.as_bool().value(), Some(false));
174    }
175
176    #[test]
177    fn test_compare_all_equal() {
178        let lhs = ByteBoolArray::from(vec![true; 5]);
179        let rhs = ByteBoolArray::from(vec![true; 5]);
180
181        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
182
183        for i in 0..arr.len() {
184            let s = scalar_at(&arr, i).unwrap();
185            assert!(s.is_valid());
186            assert_eq!(s.as_bool().value(), Some(true));
187        }
188    }
189
190    #[test]
191    fn test_compare_all_different() {
192        let lhs = ByteBoolArray::from(vec![false; 5]);
193        let rhs = ByteBoolArray::from(vec![true; 5]);
194
195        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
196
197        for i in 0..arr.len() {
198            let s = scalar_at(&arr, i).unwrap();
199            assert!(s.is_valid());
200            assert_eq!(s.as_bool().value(), Some(false));
201        }
202    }
203
204    #[test]
205    fn test_compare_with_nulls() {
206        let lhs = ByteBoolArray::from(vec![true; 5]);
207        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
208
209        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
210
211        for i in 0..3 {
212            let s = scalar_at(&arr, i).unwrap();
213            assert!(s.is_valid());
214            assert_eq!(s.as_bool().value(), Some(true));
215        }
216
217        let s = scalar_at(&arr, 3).unwrap();
218        assert!(s.is_valid());
219        assert_eq!(s.as_bool().value(), Some(false));
220
221        let s = scalar_at(&arr, 4).unwrap();
222        assert!(s.is_null());
223    }
224
225    #[test]
226    fn test_mask_byte_bool() {
227        test_mask(&ByteBoolArray::from(vec![true, false, true, true, false]));
228        test_mask(&ByteBoolArray::from(vec![
229            Some(true),
230            Some(true),
231            None,
232            Some(false),
233            None,
234        ]));
235    }
236}