vortex_bytebool/
compute.rs

1use num_traits::AsPrimitive;
2use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
3use vortex_array::vtable::ValidityHelper;
4use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use super::{ByteBoolArray, ByteBoolVTable};
10
11impl MaskKernel for ByteBoolVTable {
12    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
13        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
14    }
15}
16
17register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
18
19impl TakeKernel for ByteBoolVTable {
20    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21        let validity = array.validity_mask()?;
22        let indices = indices.to_primitive()?;
23        let bools = array.as_slice();
24
25        // FIXME(ngates): we should be operating over canonical validity, which doesn't
26        //  have fallible is_valid function.
27        let arr = match validity {
28            Mask::AllTrue(_) => {
29                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
30                    indices.as_slice::<$I>()
31                    .iter()
32                    .map(|&idx| {
33                        let idx: usize = idx.as_();
34                        bools[idx]
35                    })
36                    .collect::<Vec<_>>()
37                });
38
39                ByteBoolArray::from(bools).into_array()
40            }
41            Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
42            Mask::Values(values) => {
43                let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
44                    indices.as_slice::<$I>()
45                    .iter()
46                    .map(|&idx| {
47                        let idx = idx.as_();
48                        if values.value(idx) {
49                            Some(bools[idx])
50                        } else {
51                            None
52                        }
53                    })
54                    .collect::<Vec<Option<_>>>()
55                });
56
57                ByteBoolArray::from(bools).into_array()
58            }
59        };
60
61        Ok(arr)
62    }
63}
64
65register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
66
67#[cfg(test)]
68mod tests {
69    use vortex_array::compute::conformance::mask::test_mask;
70    use vortex_array::compute::{Operator, compare};
71
72    use super::*;
73
74    #[test]
75    fn test_slice() {
76        let original = vec![Some(true), Some(true), None, Some(false), None];
77        let vortex_arr = ByteBoolArray::from(original);
78
79        let sliced_arr = vortex_arr.slice(1, 4).unwrap();
80        let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
81
82        let s = sliced_arr.scalar_at(0).unwrap();
83        assert_eq!(s.as_bool().value(), Some(true));
84
85        let s = sliced_arr.scalar_at(1).unwrap();
86        assert!(!sliced_arr.is_valid(1).unwrap());
87        assert!(s.is_null());
88        assert_eq!(s.as_bool().value(), None);
89
90        let s = sliced_arr.scalar_at(2).unwrap();
91        assert_eq!(s.as_bool().value(), Some(false));
92    }
93
94    #[test]
95    fn test_compare_all_equal() {
96        let lhs = ByteBoolArray::from(vec![true; 5]);
97        let rhs = ByteBoolArray::from(vec![true; 5]);
98
99        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
100
101        for i in 0..arr.len() {
102            let s = arr.scalar_at(i).unwrap();
103            assert!(s.is_valid());
104            assert_eq!(s.as_bool().value(), Some(true));
105        }
106    }
107
108    #[test]
109    fn test_compare_all_different() {
110        let lhs = ByteBoolArray::from(vec![false; 5]);
111        let rhs = ByteBoolArray::from(vec![true; 5]);
112
113        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
114
115        for i in 0..arr.len() {
116            let s = arr.scalar_at(i).unwrap();
117            assert!(s.is_valid());
118            assert_eq!(s.as_bool().value(), Some(false));
119        }
120    }
121
122    #[test]
123    fn test_compare_with_nulls() {
124        let lhs = ByteBoolArray::from(vec![true; 5]);
125        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
126
127        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
128
129        for i in 0..3 {
130            let s = arr.scalar_at(i).unwrap();
131            assert!(s.is_valid());
132            assert_eq!(s.as_bool().value(), Some(true));
133        }
134
135        let s = arr.scalar_at(3).unwrap();
136        assert!(s.is_valid());
137        assert_eq!(s.as_bool().value(), Some(false));
138
139        let s = arr.scalar_at(4).unwrap();
140        assert!(s.is_null());
141    }
142
143    #[test]
144    fn test_mask_byte_bool() {
145        test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
146        test_mask(
147            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
148        );
149    }
150}