vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
6use vortex_array::vtable::ValidityHelper;
7use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use super::{ByteBoolArray, ByteBoolVTable};
13
14impl MaskKernel for ByteBoolVTable {
15    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
16        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
17    }
18}
19
20register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
21
22impl TakeKernel for ByteBoolVTable {
23    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24        let indices = indices.to_primitive()?;
25        let bools = array.as_slice();
26
27        // This handles combining validity from both source array and nullable indices
28        let validity = array.validity().take(indices.as_ref())?;
29
30        let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
31            indices
32                .as_slice::<I>()
33                .iter()
34                .map(|&idx| {
35                    let idx: usize = idx.as_();
36                    bools[idx]
37                })
38                .collect::<Vec<bool>>()
39        });
40
41        Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array())
42    }
43}
44
45register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
46
47#[cfg(test)]
48mod tests {
49    use rstest::rstest;
50    use vortex_array::compute::conformance::filter::test_filter_conformance;
51    use vortex_array::compute::conformance::mask::test_mask_conformance;
52    use vortex_array::compute::conformance::take::test_take_conformance;
53    use vortex_array::compute::{Operator, compare};
54
55    use super::*;
56
57    #[test]
58    fn test_slice() {
59        let original = vec![Some(true), Some(true), None, Some(false), None];
60        let vortex_arr = ByteBoolArray::from(original);
61
62        let sliced_arr = vortex_arr.slice(1, 4).unwrap();
63        let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
64
65        let s = sliced_arr.scalar_at(0).unwrap();
66        assert_eq!(s.as_bool().value(), Some(true));
67
68        let s = sliced_arr.scalar_at(1).unwrap();
69        assert!(!sliced_arr.is_valid(1).unwrap());
70        assert!(s.is_null());
71        assert_eq!(s.as_bool().value(), None);
72
73        let s = sliced_arr.scalar_at(2).unwrap();
74        assert_eq!(s.as_bool().value(), Some(false));
75    }
76
77    #[test]
78    fn test_compare_all_equal() {
79        let lhs = ByteBoolArray::from(vec![true; 5]);
80        let rhs = ByteBoolArray::from(vec![true; 5]);
81
82        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
83
84        for i in 0..arr.len() {
85            let s = arr.scalar_at(i).unwrap();
86            assert!(s.is_valid());
87            assert_eq!(s.as_bool().value(), Some(true));
88        }
89    }
90
91    #[test]
92    fn test_compare_all_different() {
93        let lhs = ByteBoolArray::from(vec![false; 5]);
94        let rhs = ByteBoolArray::from(vec![true; 5]);
95
96        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
97
98        for i in 0..arr.len() {
99            let s = arr.scalar_at(i).unwrap();
100            assert!(s.is_valid());
101            assert_eq!(s.as_bool().value(), Some(false));
102        }
103    }
104
105    #[test]
106    fn test_compare_with_nulls() {
107        let lhs = ByteBoolArray::from(vec![true; 5]);
108        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
109
110        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
111
112        for i in 0..3 {
113            let s = arr.scalar_at(i).unwrap();
114            assert!(s.is_valid());
115            assert_eq!(s.as_bool().value(), Some(true));
116        }
117
118        let s = arr.scalar_at(3).unwrap();
119        assert!(s.is_valid());
120        assert_eq!(s.as_bool().value(), Some(false));
121
122        let s = arr.scalar_at(4).unwrap();
123        assert!(s.is_null());
124    }
125
126    #[test]
127    fn test_mask_byte_bool() {
128        test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
129        test_mask_conformance(
130            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
131        );
132    }
133
134    #[test]
135    fn test_filter_byte_bool() {
136        test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
137        test_filter_conformance(
138            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
139        );
140    }
141
142    #[rstest]
143    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
144    #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
145    #[case(ByteBoolArray::from(vec![true, false]))]
146    #[case(ByteBoolArray::from(vec![true]))]
147    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
148        test_take_conformance(array.as_ref());
149    }
150}