vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
6use vortex_array::vtable::ValidityHelper;
7use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use super::{ByteBoolArray, ByteBoolVTable};
13
14impl MaskKernel for ByteBoolVTable {
15    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
16        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
17    }
18}
19
20register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
21
22impl TakeKernel for ByteBoolVTable {
23    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24        let validity = array.validity_mask()?;
25        let indices = indices.to_primitive()?;
26        let bools = array.as_slice();
27
28        // FIXME(ngates): we should be operating over canonical validity, which doesn't
29        //  have fallible is_valid function.
30        let arr = match validity {
31            Mask::AllTrue(_) => {
32                let bools = match_each_integer_ptype!(indices.ptype(), |I| {
33                    indices
34                        .as_slice::<I>()
35                        .iter()
36                        .map(|&idx| {
37                            let idx: usize = idx.as_();
38                            bools[idx]
39                        })
40                        .collect::<Vec<_>>()
41                });
42
43                ByteBoolArray::from(bools).into_array()
44            }
45            Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
46            Mask::Values(values) => {
47                let bools = match_each_integer_ptype!(indices.ptype(), |I| {
48                    indices
49                        .as_slice::<I>()
50                        .iter()
51                        .map(|&idx| {
52                            let idx = idx.as_();
53                            values.value(idx).then(|| bools[idx])
54                        })
55                        .collect::<Vec<Option<_>>>()
56                });
57
58                ByteBoolArray::from(bools).into_array()
59            }
60        };
61
62        Ok(arr)
63    }
64}
65
66register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
67
68#[cfg(test)]
69mod tests {
70    use vortex_array::compute::conformance::mask::test_mask;
71    use vortex_array::compute::{Operator, compare};
72
73    use super::*;
74
75    #[test]
76    fn test_slice() {
77        let original = vec![Some(true), Some(true), None, Some(false), None];
78        let vortex_arr = ByteBoolArray::from(original);
79
80        let sliced_arr = vortex_arr.slice(1, 4).unwrap();
81        let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
82
83        let s = sliced_arr.scalar_at(0).unwrap();
84        assert_eq!(s.as_bool().value(), Some(true));
85
86        let s = sliced_arr.scalar_at(1).unwrap();
87        assert!(!sliced_arr.is_valid(1).unwrap());
88        assert!(s.is_null());
89        assert_eq!(s.as_bool().value(), None);
90
91        let s = sliced_arr.scalar_at(2).unwrap();
92        assert_eq!(s.as_bool().value(), Some(false));
93    }
94
95    #[test]
96    fn test_compare_all_equal() {
97        let lhs = ByteBoolArray::from(vec![true; 5]);
98        let rhs = ByteBoolArray::from(vec![true; 5]);
99
100        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
101
102        for i in 0..arr.len() {
103            let s = arr.scalar_at(i).unwrap();
104            assert!(s.is_valid());
105            assert_eq!(s.as_bool().value(), Some(true));
106        }
107    }
108
109    #[test]
110    fn test_compare_all_different() {
111        let lhs = ByteBoolArray::from(vec![false; 5]);
112        let rhs = ByteBoolArray::from(vec![true; 5]);
113
114        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
115
116        for i in 0..arr.len() {
117            let s = arr.scalar_at(i).unwrap();
118            assert!(s.is_valid());
119            assert_eq!(s.as_bool().value(), Some(false));
120        }
121    }
122
123    #[test]
124    fn test_compare_with_nulls() {
125        let lhs = ByteBoolArray::from(vec![true; 5]);
126        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
127
128        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
129
130        for i in 0..3 {
131            let s = arr.scalar_at(i).unwrap();
132            assert!(s.is_valid());
133            assert_eq!(s.as_bool().value(), Some(true));
134        }
135
136        let s = arr.scalar_at(3).unwrap();
137        assert!(s.is_valid());
138        assert_eq!(s.as_bool().value(), Some(false));
139
140        let s = arr.scalar_at(4).unwrap();
141        assert!(s.is_null());
142    }
143
144    #[test]
145    fn test_mask_byte_bool() {
146        test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
147        test_mask(
148            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
149        );
150    }
151}