Skip to main content

vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::PrimitiveArray;
9use vortex_array::arrays::TakeExecute;
10use vortex_array::dtype::DType;
11use vortex_array::match_each_integer_ptype;
12use vortex_array::scalar_fn::fns::cast::CastReduce;
13use vortex_array::scalar_fn::fns::mask::MaskReduce;
14use vortex_array::validity::Validity;
15use vortex_array::vtable::ValidityHelper;
16use vortex_error::VortexResult;
17
18use super::ByteBoolArray;
19use super::ByteBoolVTable;
20
21impl CastReduce for ByteBoolVTable {
22    fn cast(array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
23        // ByteBool is essentially a bool array stored as bytes
24        // The main difference from BoolArray is the storage format
25        // For casting, we can decode to canonical (BoolArray) and let it handle the cast
26
27        // If just changing nullability, we can optimize
28        if array.dtype().eq_ignore_nullability(dtype) {
29            let new_validity = array
30                .validity()
31                .clone()
32                .cast_nullability(dtype.nullability(), array.len())?;
33
34            return Ok(Some(
35                ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
36            ));
37        }
38
39        // For other casts, decode to canonical and let BoolArray handle it
40        Ok(None)
41    }
42}
43
44impl MaskReduce for ByteBoolVTable {
45    fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
46        Ok(Some(
47            ByteBoolArray::new(
48                array.buffer().clone(),
49                array
50                    .validity()
51                    .clone()
52                    .and(Validity::Array(mask.clone()))?,
53            )
54            .into_array(),
55        ))
56    }
57}
58
59impl TakeExecute for ByteBoolVTable {
60    fn take(
61        array: &ByteBoolArray,
62        indices: &ArrayRef,
63        ctx: &mut ExecutionCtx,
64    ) -> VortexResult<Option<ArrayRef>> {
65        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
66        let bools = array.as_slice();
67
68        // This handles combining validity from both source array and nullable indices
69        let validity = array.validity().take(&indices.to_array())?;
70
71        let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
72            indices
73                .as_slice::<I>()
74                .iter()
75                .map(|&idx| {
76                    let idx: usize = idx.as_();
77                    bools[idx]
78                })
79                .collect::<Vec<bool>>()
80        });
81
82        Ok(Some(
83            ByteBoolArray::from_vec(taken_bools, validity).into_array(),
84        ))
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use rstest::rstest;
91    use vortex_array::assert_arrays_eq;
92    use vortex_array::builtins::ArrayBuiltins;
93    use vortex_array::compute::conformance::cast::test_cast_conformance;
94    use vortex_array::compute::conformance::consistency::test_array_consistency;
95    use vortex_array::compute::conformance::filter::test_filter_conformance;
96    use vortex_array::compute::conformance::mask::test_mask_conformance;
97    use vortex_array::compute::conformance::take::test_take_conformance;
98    use vortex_array::dtype::DType;
99    use vortex_array::dtype::Nullability;
100    use vortex_array::scalar_fn::fns::operators::Operator;
101
102    use super::*;
103
104    #[test]
105    fn test_slice() {
106        let original = vec![Some(true), Some(true), None, Some(false), None];
107        let vortex_arr = ByteBoolArray::from(original);
108
109        let sliced_arr = vortex_arr.slice(1..4).unwrap();
110
111        let expected = ByteBoolArray::from(vec![Some(true), None, Some(false)]);
112        assert_arrays_eq!(sliced_arr, expected.to_array());
113    }
114
115    #[test]
116    fn test_compare_all_equal() {
117        let lhs = ByteBoolArray::from(vec![true; 5]);
118        let rhs = ByteBoolArray::from(vec![true; 5]);
119
120        let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
121
122        let expected = ByteBoolArray::from(vec![true; 5]);
123        assert_arrays_eq!(arr, expected.to_array());
124    }
125
126    #[test]
127    fn test_compare_all_different() {
128        let lhs = ByteBoolArray::from(vec![false; 5]);
129        let rhs = ByteBoolArray::from(vec![true; 5]);
130
131        let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
132
133        let expected = ByteBoolArray::from(vec![false; 5]);
134        assert_arrays_eq!(arr, expected.to_array());
135    }
136
137    #[test]
138    fn test_compare_with_nulls() {
139        let lhs = ByteBoolArray::from(vec![true; 5]);
140        let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
141
142        let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
143
144        let expected =
145            ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
146        assert_arrays_eq!(arr, expected.to_array());
147    }
148
149    #[test]
150    fn test_mask_byte_bool() {
151        test_mask_conformance(
152            &ByteBoolArray::from(vec![true, false, true, true, false]).to_array(),
153        );
154        test_mask_conformance(
155            &ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).to_array(),
156        );
157    }
158
159    #[test]
160    fn test_filter_byte_bool() {
161        test_filter_conformance(
162            &ByteBoolArray::from(vec![true, false, true, true, false]).to_array(),
163        );
164        test_filter_conformance(
165            &ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).to_array(),
166        );
167    }
168
169    #[rstest]
170    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
171    #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
172    #[case(ByteBoolArray::from(vec![true, false]))]
173    #[case(ByteBoolArray::from(vec![true]))]
174    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
175        test_take_conformance(&array.to_array());
176    }
177
178    #[test]
179    fn test_cast_bytebool_to_nullable() {
180        let array = ByteBoolArray::from(vec![true, false, true, false]);
181        let casted = array
182            .to_array()
183            .cast(DType::Bool(Nullability::Nullable))
184            .unwrap();
185        assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
186        assert_eq!(casted.len(), 4);
187    }
188
189    #[rstest]
190    #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
191    #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
192    #[case(ByteBoolArray::from(vec![false]))]
193    #[case(ByteBoolArray::from(vec![true]))]
194    #[case(ByteBoolArray::from(vec![Some(true), None]))]
195    fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
196        test_cast_conformance(&array.to_array());
197    }
198
199    #[rstest]
200    #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
201    #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
202    #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
203    #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
204    #[case::single_true(ByteBoolArray::from(vec![true]))]
205    #[case::single_false(ByteBoolArray::from(vec![false]))]
206    #[case::single_null(ByteBoolArray::from(vec![None]))]
207    #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
208    fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
209        test_array_consistency(&array.to_array());
210    }
211}