vortex_bytebool/
compute.rs

1use num_traits::AsPrimitive;
2use vortex_array::compute::{MaskKernel, MaskKernelAdapter, ScalarAtFn, SliceFn, TakeFn};
3use vortex_array::variants::PrimitiveArrayTrait;
4use vortex_array::vtable::ComputeVTable;
5use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
6use vortex_dtype::match_each_integer_ptype;
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9use vortex_scalar::Scalar;
10
11use super::{ByteBoolArray, ByteBoolEncoding};
12
13impl ComputeVTable for ByteBoolEncoding {
14    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
15        Some(self)
16    }
17
18    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
19        Some(self)
20    }
21
22    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
23        Some(self)
24    }
25}
26
27impl MaskKernel for ByteBoolEncoding {
28    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
29        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
30    }
31}
32
33register_kernel!(MaskKernelAdapter(ByteBoolEncoding).lift());
34
35impl ScalarAtFn<&ByteBoolArray> for ByteBoolEncoding {
36    fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
37        Ok(Scalar::bool(
38            array.buffer()[index] == 1,
39            array.dtype().nullability(),
40        ))
41    }
42}
43
44impl SliceFn<&ByteBoolArray> for ByteBoolEncoding {
45    fn slice(&self, array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
46        Ok(ByteBoolArray::new(
47            array.buffer().slice(start..stop),
48            array.validity().slice(start, stop)?,
49        )
50        .into_array())
51    }
52}
53
54impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
55    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
56        let validity = array.validity_mask()?;
57        let indices = indices.to_primitive()?;
58        let bools = array.as_slice();
59
60        // FIXME(ngates): we should be operating over canonical validity, which doesn't
61        //  have fallible is_valid function.
62        let arr = match validity {
63            Mask::AllTrue(_) => {
64                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
65                    indices.as_slice::<$I>()
66                    .iter()
67                    .map(|&idx| {
68                        let idx: usize = idx.as_();
69                        bools[idx]
70                    })
71                    .collect::<Vec<_>>()
72                });
73
74                ByteBoolArray::from(bools).into_array()
75            }
76            Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
77            Mask::Values(values) => {
78                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
79                    indices.as_slice::<$I>()
80                    .iter()
81                    .map(|&idx| {
82                        let idx = idx.as_();
83                        if values.value(idx) {
84                            Some(bools[idx])
85                        } else {
86                            None
87                        }
88                    })
89                    .collect::<Vec<Option<_>>>()
90                });
91
92                ByteBoolArray::from(bools).into_array()
93            }
94        };
95
96        Ok(arr)
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use vortex_array::compute::conformance::mask::test_mask;
103    use vortex_array::compute::{Operator, compare, scalar_at, slice};
104
105    use super::*;
106
107    #[test]
108    fn test_slice() {
109        let original = vec![Some(true), Some(true), None, Some(false), None];
110        let vortex_arr = ByteBoolArray::from(original);
111
112        let sliced_arr = slice(&vortex_arr, 1, 4).unwrap();
113        let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();
114
115        let s = scalar_at(&sliced_arr, 0).unwrap();
116        assert_eq!(s.as_bool().value(), Some(true));
117
118        let s = scalar_at(&sliced_arr, 1).unwrap();
119        assert!(!sliced_arr.is_valid(1).unwrap());
120        assert!(s.is_null());
121        assert_eq!(s.as_bool().value(), None);
122
123        let s = scalar_at(&sliced_arr, 2).unwrap();
124        assert_eq!(s.as_bool().value(), Some(false));
125    }
126
127    #[test]
128    fn test_compare_all_equal() {
129        let lhs = ByteBoolArray::from(vec![true; 5]);
130        let rhs = ByteBoolArray::from(vec![true; 5]);
131
132        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
133
134        for i in 0..arr.len() {
135            let s = scalar_at(&arr, i).unwrap();
136            assert!(s.is_valid());
137            assert_eq!(s.as_bool().value(), Some(true));
138        }
139    }
140
141    #[test]
142    fn test_compare_all_different() {
143        let lhs = ByteBoolArray::from(vec![false; 5]);
144        let rhs = ByteBoolArray::from(vec![true; 5]);
145
146        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
147
148        for i in 0..arr.len() {
149            let s = scalar_at(&arr, i).unwrap();
150            assert!(s.is_valid());
151            assert_eq!(s.as_bool().value(), Some(false));
152        }
153    }
154
155    #[test]
156    fn test_compare_with_nulls() {
157        let lhs = ByteBoolArray::from(vec![true; 5]);
158        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
159
160        let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
161
162        for i in 0..3 {
163            let s = scalar_at(&arr, i).unwrap();
164            assert!(s.is_valid());
165            assert_eq!(s.as_bool().value(), Some(true));
166        }
167
168        let s = scalar_at(&arr, 3).unwrap();
169        assert!(s.is_valid());
170        assert_eq!(s.as_bool().value(), Some(false));
171
172        let s = scalar_at(&arr, 4).unwrap();
173        assert!(s.is_null());
174    }
175
176    #[test]
177    fn test_mask_byte_bool() {
178        test_mask(&ByteBoolArray::from(vec![true, false, true, true, false]));
179        test_mask(&ByteBoolArray::from(vec![
180            Some(true),
181            Some(true),
182            None,
183            Some(false),
184            None,
185        ]));
186    }
187}