vortex_bytebool/
compute.rs

1use num_traits::AsPrimitive;
2use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
3use vortex_array::vtable::ValidityHelper;
4use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use super::{ByteBoolArray, ByteBoolVTable};
10
11impl MaskKernel for ByteBoolVTable {
12    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
13        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
14    }
15}
16
17register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
18
19impl TakeKernel for ByteBoolVTable {
20    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21        let validity = array.validity_mask()?;
22        let indices = indices.to_primitive()?;
23        let bools = array.as_slice();
24
25        // FIXME(ngates): we should be operating over canonical validity, which doesn't
26        //  have fallible is_valid function.
27        let arr = match validity {
28            Mask::AllTrue(_) => {
29                let bools = match_each_integer_ptype!(indices.ptype(), |I| {
30                    indices
31                        .as_slice::<I>()
32                        .iter()
33                        .map(|&idx| {
34                            let idx: usize = idx.as_();
35                            bools[idx]
36                        })
37                        .collect::<Vec<_>>()
38                });
39
40                ByteBoolArray::from(bools).into_array()
41            }
42            Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
43            Mask::Values(values) => {
44                let bools = match_each_integer_ptype!(indices.ptype(), |I| {
45                    indices
46                        .as_slice::<I>()
47                        .iter()
48                        .map(|&idx| {
49                            let idx = idx.as_();
50                            values.value(idx).then(|| bools[idx])
51                        })
52                        .collect::<Vec<Option<_>>>()
53                });
54
55                ByteBoolArray::from(bools).into_array()
56            }
57        };
58
59        Ok(arr)
60    }
61}
62
63register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
64
65#[cfg(test)]
66mod tests {
67    use vortex_array::compute::conformance::mask::test_mask;
68    use vortex_array::compute::{Operator, compare};
69
70    use super::*;
71
72    #[test]
73    fn test_slice() {
74        let original = vec![Some(true), Some(true), None, Some(false), None];
75        let vortex_arr = ByteBoolArray::from(original);
76
77        let sliced_arr = vortex_arr.slice(1, 4).unwrap();
78        let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
79
80        let s = sliced_arr.scalar_at(0).unwrap();
81        assert_eq!(s.as_bool().value(), Some(true));
82
83        let s = sliced_arr.scalar_at(1).unwrap();
84        assert!(!sliced_arr.is_valid(1).unwrap());
85        assert!(s.is_null());
86        assert_eq!(s.as_bool().value(), None);
87
88        let s = sliced_arr.scalar_at(2).unwrap();
89        assert_eq!(s.as_bool().value(), Some(false));
90    }
91
92    #[test]
93    fn test_compare_all_equal() {
94        let lhs = ByteBoolArray::from(vec![true; 5]);
95        let rhs = ByteBoolArray::from(vec![true; 5]);
96
97        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
98
99        for i in 0..arr.len() {
100            let s = arr.scalar_at(i).unwrap();
101            assert!(s.is_valid());
102            assert_eq!(s.as_bool().value(), Some(true));
103        }
104    }
105
106    #[test]
107    fn test_compare_all_different() {
108        let lhs = ByteBoolArray::from(vec![false; 5]);
109        let rhs = ByteBoolArray::from(vec![true; 5]);
110
111        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
112
113        for i in 0..arr.len() {
114            let s = arr.scalar_at(i).unwrap();
115            assert!(s.is_valid());
116            assert_eq!(s.as_bool().value(), Some(false));
117        }
118    }
119
120    #[test]
121    fn test_compare_with_nulls() {
122        let lhs = ByteBoolArray::from(vec![true; 5]);
123        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
124
125        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126
127        for i in 0..3 {
128            let s = arr.scalar_at(i).unwrap();
129            assert!(s.is_valid());
130            assert_eq!(s.as_bool().value(), Some(true));
131        }
132
133        let s = arr.scalar_at(3).unwrap();
134        assert!(s.is_valid());
135        assert_eq!(s.as_bool().value(), Some(false));
136
137        let s = arr.scalar_at(4).unwrap();
138        assert!(s.is_null());
139    }
140
141    #[test]
142    fn test_mask_byte_bool() {
143        test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
144        test_mask(
145            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
146        );
147    }
148}