vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::compute::{
6    CastKernel, CastKernelAdapter, MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter,
7};
8use vortex_array::vtable::ValidityHelper;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
10use vortex_dtype::{DType, match_each_integer_ptype};
11use vortex_error::VortexResult;
12use vortex_mask::Mask;
13
14use super::{ByteBoolArray, ByteBoolVTable};
15
16impl CastKernel for ByteBoolVTable {
17    fn cast(&self, array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        // ByteBool is essentially a bool array stored as bytes
19        // The main difference from BoolArray is the storage format
20        // For casting, we can decode to canonical (BoolArray) and let it handle the cast
21
22        // If just changing nullability, we can optimize
23        if array.dtype().eq_ignore_nullability(dtype) {
24            let new_validity = array
25                .validity()
26                .clone()
27                .cast_nullability(dtype.nullability())?;
28
29            return Ok(Some(
30                ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
31            ));
32        }
33
34        // For other casts, decode to canonical and let BoolArray handle it
35        Ok(None)
36    }
37}
38
39register_kernel!(CastKernelAdapter(ByteBoolVTable).lift());
40
41impl MaskKernel for ByteBoolVTable {
42    fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
43        Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)).into_array())
44    }
45}
46
47register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
48
49impl TakeKernel for ByteBoolVTable {
50    fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
51        let indices = indices.to_primitive();
52        let bools = array.as_slice();
53
54        // This handles combining validity from both source array and nullable indices
55        let validity = array.validity().take(indices.as_ref())?;
56
57        let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
58            indices
59                .as_slice::<I>()
60                .iter()
61                .map(|&idx| {
62                    let idx: usize = idx.as_();
63                    bools[idx]
64                })
65                .collect::<Vec<bool>>()
66        });
67
68        Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array())
69    }
70}
71
72register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
73
74#[cfg(test)]
75mod tests {
76    use rstest::rstest;
77    use vortex_array::compute::conformance::filter::test_filter_conformance;
78    use vortex_array::compute::conformance::mask::test_mask_conformance;
79    use vortex_array::compute::conformance::take::test_take_conformance;
80    use vortex_array::compute::{Operator, compare};
81
82    use super::*;
83
84    #[test]
85    fn test_slice() {
86        let original = vec![Some(true), Some(true), None, Some(false), None];
87        let vortex_arr = ByteBoolArray::from(original);
88
89        let sliced_arr = vortex_arr.slice(1..4);
90        let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
91
92        let s = sliced_arr.scalar_at(0);
93        assert_eq!(s.as_bool().value(), Some(true));
94
95        let s = sliced_arr.scalar_at(1);
96        assert!(!sliced_arr.is_valid(1));
97        assert!(s.is_null());
98        assert_eq!(s.as_bool().value(), None);
99
100        let s = sliced_arr.scalar_at(2);
101        assert_eq!(s.as_bool().value(), Some(false));
102    }
103
104    #[test]
105    fn test_compare_all_equal() {
106        let lhs = ByteBoolArray::from(vec![true; 5]);
107        let rhs = ByteBoolArray::from(vec![true; 5]);
108
109        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
110
111        for i in 0..arr.len() {
112            let s = arr.scalar_at(i);
113            assert!(s.is_valid());
114            assert_eq!(s.as_bool().value(), Some(true));
115        }
116    }
117
118    #[test]
119    fn test_compare_all_different() {
120        let lhs = ByteBoolArray::from(vec![false; 5]);
121        let rhs = ByteBoolArray::from(vec![true; 5]);
122
123        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
124
125        for i in 0..arr.len() {
126            let s = arr.scalar_at(i);
127            assert!(s.is_valid());
128            assert_eq!(s.as_bool().value(), Some(false));
129        }
130    }
131
132    #[test]
133    fn test_compare_with_nulls() {
134        let lhs = ByteBoolArray::from(vec![true; 5]);
135        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
136
137        let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
138
139        for i in 0..3 {
140            let s = arr.scalar_at(i);
141            assert!(s.is_valid());
142            assert_eq!(s.as_bool().value(), Some(true));
143        }
144
145        let s = arr.scalar_at(3);
146        assert!(s.is_valid());
147        assert_eq!(s.as_bool().value(), Some(false));
148
149        let s = arr.scalar_at(4);
150        assert!(s.is_null());
151    }
152
153    #[test]
154    fn test_mask_byte_bool() {
155        test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
156        test_mask_conformance(
157            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
158        );
159    }
160
161    #[test]
162    fn test_filter_byte_bool() {
163        test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
164        test_filter_conformance(
165            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
166        );
167    }
168
169    #[rstest]
170    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
171    #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
172    #[case(ByteBoolArray::from(vec![true, false]))]
173    #[case(ByteBoolArray::from(vec![true]))]
174    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
175        test_take_conformance(array.as_ref());
176    }
177
178    use vortex_array::compute::cast;
179    use vortex_array::compute::conformance::cast::test_cast_conformance;
180    use vortex_array::compute::conformance::consistency::test_array_consistency;
181    use vortex_dtype::{DType, Nullability};
182
183    #[test]
184    fn test_cast_bytebool_to_nullable() {
185        let array = ByteBoolArray::from(vec![true, false, true, false]);
186        let casted = cast(array.as_ref(), &DType::Bool(Nullability::Nullable)).unwrap();
187        assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
188        assert_eq!(casted.len(), 4);
189    }
190
191    #[rstest]
192    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
193    #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
194    #[case(ByteBoolArray::from(vec![false]))]
195    #[case(ByteBoolArray::from(vec![true]))]
196    #[case(ByteBoolArray::from(vec![Some(true), None]))]
197    fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
198        test_cast_conformance(array.as_ref());
199    }
200
201    #[rstest]
202    #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
203    #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
204    #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
205    #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
206    #[case::single_true(ByteBoolArray::from(vec![true]))]
207    #[case::single_false(ByteBoolArray::from(vec![false]))]
208    #[case::single_null(ByteBoolArray::from(vec![None]))]
209    #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
210    fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
211        test_array_consistency(array.as_ref());
212    }
213}