Skip to main content

vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::TakeExecute;
11use vortex_array::dtype::DType;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar_fn::fns::cast::CastReduce;
14use vortex_array::scalar_fn::fns::mask::MaskReduce;
15use vortex_array::validity::Validity;
16use vortex_array::vtable::ValidityHelper;
17use vortex_error::VortexResult;
18
19use super::ByteBoolArray;
20use super::ByteBoolVTable;
21
22impl CastReduce for ByteBoolVTable {
23    fn cast(array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
24        // ByteBool is essentially a bool array stored as bytes
25        // The main difference from BoolArray is the storage format
26        // For casting, we can decode to canonical (BoolArray) and let it handle the cast
27
28        // If just changing nullability, we can optimize
29        if array.dtype().eq_ignore_nullability(dtype) {
30            let new_validity = array
31                .validity()
32                .clone()
33                .cast_nullability(dtype.nullability(), array.len())?;
34
35            return Ok(Some(
36                ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
37            ));
38        }
39
40        // For other casts, decode to canonical and let BoolArray handle it
41        Ok(None)
42    }
43}
44
45impl MaskReduce for ByteBoolVTable {
46    fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
47        Ok(Some(
48            ByteBoolArray::new(
49                array.buffer().clone(),
50                array
51                    .validity()
52                    .clone()
53                    .and(Validity::Array(mask.clone()))?,
54            )
55            .into_array(),
56        ))
57    }
58}
59
60impl TakeExecute for ByteBoolVTable {
61    fn take(
62        array: &ByteBoolArray,
63        indices: &dyn Array,
64        _ctx: &mut ExecutionCtx,
65    ) -> VortexResult<Option<ArrayRef>> {
66        let indices = indices.to_primitive();
67        let bools = array.as_slice();
68
69        // This handles combining validity from both source array and nullable indices
70        let validity = array.validity().take(indices.as_ref())?;
71
72        let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
73            indices
74                .as_slice::<I>()
75                .iter()
76                .map(|&idx| {
77                    let idx: usize = idx.as_();
78                    bools[idx]
79                })
80                .collect::<Vec<bool>>()
81        });
82
83        Ok(Some(
84            ByteBoolArray::from_vec(taken_bools, validity).into_array(),
85        ))
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use rstest::rstest;
92    use vortex_array::assert_arrays_eq;
93    use vortex_array::builtins::ArrayBuiltins;
94    use vortex_array::compute::conformance::cast::test_cast_conformance;
95    use vortex_array::compute::conformance::consistency::test_array_consistency;
96    use vortex_array::compute::conformance::filter::test_filter_conformance;
97    use vortex_array::compute::conformance::mask::test_mask_conformance;
98    use vortex_array::compute::conformance::take::test_take_conformance;
99    use vortex_array::dtype::DType;
100    use vortex_array::dtype::Nullability;
101    use vortex_array::scalar_fn::fns::operators::Operator;
102
103    use super::*;
104
105    #[test]
106    fn test_slice() {
107        let original = vec![Some(true), Some(true), None, Some(false), None];
108        let vortex_arr = ByteBoolArray::from(original);
109
110        let sliced_arr = vortex_arr.slice(1..4).unwrap();
111
112        let expected = ByteBoolArray::from(vec![Some(true), None, Some(false)]);
113        assert_arrays_eq!(sliced_arr, expected.to_array());
114    }
115
116    #[test]
117    fn test_compare_all_equal() {
118        let lhs = ByteBoolArray::from(vec![true; 5]);
119        let rhs = ByteBoolArray::from(vec![true; 5]);
120
121        let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
122
123        let expected = ByteBoolArray::from(vec![true; 5]);
124        assert_arrays_eq!(arr, expected.to_array());
125    }
126
127    #[test]
128    fn test_compare_all_different() {
129        let lhs = ByteBoolArray::from(vec![false; 5]);
130        let rhs = ByteBoolArray::from(vec![true; 5]);
131
132        let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
133
134        let expected = ByteBoolArray::from(vec![false; 5]);
135        assert_arrays_eq!(arr, expected.to_array());
136    }
137
138    #[test]
139    fn test_compare_with_nulls() {
140        let lhs = ByteBoolArray::from(vec![true; 5]);
141        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
142
143        let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
144
145        let expected =
146            ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
147        assert_arrays_eq!(arr, expected.to_array());
148    }
149
150    #[test]
151    fn test_mask_byte_bool() {
152        test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
153        test_mask_conformance(
154            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
155        );
156    }
157
158    #[test]
159    fn test_filter_byte_bool() {
160        test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
161        test_filter_conformance(
162            ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
163        );
164    }
165
166    #[rstest]
167    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
168    #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
169    #[case(ByteBoolArray::from(vec![true, false]))]
170    #[case(ByteBoolArray::from(vec![true]))]
171    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
172        test_take_conformance(array.as_ref());
173    }
174
175    #[test]
176    fn test_cast_bytebool_to_nullable() {
177        let array = ByteBoolArray::from(vec![true, false, true, false]);
178        let casted = array
179            .to_array()
180            .cast(DType::Bool(Nullability::Nullable))
181            .unwrap();
182        assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
183        assert_eq!(casted.len(), 4);
184    }
185
186    #[rstest]
187    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
188    #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
189    #[case(ByteBoolArray::from(vec![false]))]
190    #[case(ByteBoolArray::from(vec![true]))]
191    #[case(ByteBoolArray::from(vec![Some(true), None]))]
192    fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
193        test_cast_conformance(array.as_ref());
194    }
195
196    #[rstest]
197    #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
198    #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
199    #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
200    #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
201    #[case::single_true(ByteBoolArray::from(vec![true]))]
202    #[case::single_false(ByteBoolArray::from(vec![false]))]
203    #[case::single_null(ByteBoolArray::from(vec![None]))]
204    #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
205    fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
206        test_array_consistency(array.as_ref());
207    }
208}