vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::IntoArray;
8use vortex_array::ToCanonical;
9use vortex_array::compute::CastKernel;
10use vortex_array::compute::CastKernelAdapter;
11use vortex_array::compute::MaskKernel;
12use vortex_array::compute::MaskKernelAdapter;
13use vortex_array::compute::TakeKernel;
14use vortex_array::compute::TakeKernelAdapter;
15use vortex_array::register_kernel;
16use vortex_array::vtable::ValidityHelper;
17use vortex_dtype::DType;
18use vortex_dtype::match_each_integer_ptype;
19use vortex_error::VortexResult;
20use vortex_mask::Mask;
21
22use super::ByteBoolArray;
23use super::ByteBoolVTable;
24
25impl CastKernel for ByteBoolVTable {
26    fn cast(&self, array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
27        // ByteBool is essentially a bool array stored as bytes
28        // The main difference from BoolArray is the storage format
29        // For casting, we can decode to canonical (BoolArray) and let it handle the cast
30
31        // If just changing nullability, we can optimize
32        if array.dtype().eq_ignore_nullability(dtype) {
33            let new_validity = array
34                .validity()
35                .clone()
36                .cast_nullability(dtype.nullability(), array.len())?;
37
38            return Ok(Some(
39                ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
40            ));
41        }
42
43        // For other casts, decode to canonical and let BoolArray handle it
44        Ok(None)
45    }
46}
47
48register_kernel!(CastKernelAdapter(ByteBoolVTable).lift());
49
50impl MaskKernel for ByteBoolVTable {
51    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
52        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)).into_array())
53    }
54}
55
56register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
57
58impl TakeKernel for ByteBoolVTable {
59    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
60        let indices = indices.to_primitive();
61        let bools = array.as_slice();
62
63        // This handles combining validity from both source array and nullable indices
64        let validity = array.validity().take(indices.as_ref())?;
65
66        let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
67            indices
68                .as_slice::<I>()
69                .iter()
70                .map(|&idx| {
71                    let idx: usize = idx.as_();
72                    bools[idx]
73                })
74                .collect::<Vec<bool>>()
75        });
76
77        Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array())
78    }
79}
80
81register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
82
83#[cfg(test)]
84mod tests {
85    use rstest::rstest;
86    use vortex_array::assert_arrays_eq;
87    use vortex_array::compute::Operator;
88    use vortex_array::compute::compare;
89    use vortex_array::compute::conformance::filter::test_filter_conformance;
90    use vortex_array::compute::conformance::mask::test_mask_conformance;
91    use vortex_array::compute::conformance::take::test_take_conformance;
92
93    use super::*;
94
95    #[test]
96    fn test_slice() {
97        let original = vec![Some(true), Some(true), None, Some(false), None];
98        let vortex_arr = ByteBoolArray::from(original);
99
100        let sliced_arr = vortex_arr.slice(1..4);
101
102        let expected = ByteBoolArray::from(vec![Some(true), None, Some(false)]);
103        assert_arrays_eq!(sliced_arr, expected.to_array());
104    }
105
106    #[test]
107    fn test_compare_all_equal() {
108        let lhs = ByteBoolArray::from(vec![true; 5]);
109        let rhs = ByteBoolArray::from(vec![true; 5]);
110
111        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
112
113        let expected = ByteBoolArray::from(vec![true; 5]);
114        assert_arrays_eq!(arr, expected.to_array());
115    }
116
117    #[test]
118    fn test_compare_all_different() {
119        let lhs = ByteBoolArray::from(vec![false; 5]);
120        let rhs = ByteBoolArray::from(vec![true; 5]);
121
122        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
123
124        let expected = ByteBoolArray::from(vec![false; 5]);
125        assert_arrays_eq!(arr, expected.to_array());
126    }
127
128    #[test]
129    fn test_compare_with_nulls() {
130        let lhs = ByteBoolArray::from(vec![true; 5]);
131        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
132
133        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
134
135        let expected =
136            ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
137        assert_arrays_eq!(arr, expected.to_array());
138    }
139
140    #[test]
141    fn test_mask_byte_bool() {
142        test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
143        test_mask_conformance(
144            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
145        );
146    }
147
148    #[test]
149    fn test_filter_byte_bool() {
150        test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
151        test_filter_conformance(
152            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
153        );
154    }
155
156    #[rstest]
157    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
158    #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
159    #[case(ByteBoolArray::from(vec![true, false]))]
160    #[case(ByteBoolArray::from(vec![true]))]
161    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
162        test_take_conformance(array.as_ref());
163    }
164
165    use vortex_array::compute::cast;
166    use vortex_array::compute::conformance::cast::test_cast_conformance;
167    use vortex_array::compute::conformance::consistency::test_array_consistency;
168    use vortex_dtype::DType;
169    use vortex_dtype::Nullability;
170
171    #[test]
172    fn test_cast_bytebool_to_nullable() {
173        let array = ByteBoolArray::from(vec![true, false, true, false]);
174        let casted = cast(array.as_ref(), &DType::Bool(Nullability::Nullable)).unwrap();
175        assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
176        assert_eq!(casted.len(), 4);
177    }
178
179    #[rstest]
180    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
181    #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
182    #[case(ByteBoolArray::from(vec![false]))]
183    #[case(ByteBoolArray::from(vec![true]))]
184    #[case(ByteBoolArray::from(vec![Some(true), None]))]
185    fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
186        test_cast_conformance(array.as_ref());
187    }
188
189    #[rstest]
190    #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
191    #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
192    #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
193    #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
194    #[case::single_true(ByteBoolArray::from(vec![true]))]
195    #[case::single_false(ByteBoolArray::from(vec![false]))]
196    #[case::single_null(ByteBoolArray::from(vec![None]))]
197    #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
198    fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
199        test_array_consistency(array.as_ref());
200    }
201}