Skip to main content

vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::ArrayRef;
6use vortex_array::ArrayView;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::PrimitiveArray;
10use vortex_array::arrays::dict::TakeExecute;
11use vortex_array::dtype::DType;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar_fn::fns::cast::CastReduce;
14use vortex_array::scalar_fn::fns::mask::MaskReduce;
15use vortex_array::validity::Validity;
16use vortex_error::VortexResult;
17
18use super::ByteBool;
19
20impl CastReduce for ByteBool {
21    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
22        // ByteBool is essentially a bool array stored as bytes
23        // The main difference from BoolArray is the storage format
24        // For casting, we can decode to canonical (BoolArray) and let it handle the cast
25
26        // If just changing nullability, we can optimize
27        if array.dtype().eq_ignore_nullability(dtype) {
28            let new_validity = array
29                .validity()?
30                .cast_nullability(dtype.nullability(), array.len())?;
31
32            return Ok(Some(
33                ByteBool::new(array.buffer().clone(), new_validity).into_array(),
34            ));
35        }
36
37        // For other casts, decode to canonical and let BoolArray handle it
38        Ok(None)
39    }
40}
41
42impl MaskReduce for ByteBool {
43    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
44        Ok(Some(
45            ByteBool::new(
46                array.buffer().clone(),
47                array.validity()?.and(Validity::Array(mask.clone()))?,
48            )
49            .into_array(),
50        ))
51    }
52}
53
54impl TakeExecute for ByteBool {
55    fn take(
56        array: ArrayView<'_, Self>,
57        indices: &ArrayRef,
58        ctx: &mut ExecutionCtx,
59    ) -> VortexResult<Option<ArrayRef>> {
60        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
61        let bools = array.as_slice();
62
63        // This handles combining validity from both source array and nullable indices
64        let validity = array.validity()?.take(&indices.clone().into_array())?;
65
66        let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
67            indices
68                .as_slice::<I>()
69                .iter()
70                .map(|&idx| {
71                    let idx: usize = idx.as_();
72                    bools[idx]
73                })
74                .collect::<Vec<bool>>()
75        });
76
77        Ok(Some(ByteBool::from_vec(taken_bools, validity).into_array()))
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use rstest::rstest;
84    use vortex_array::assert_arrays_eq;
85    use vortex_array::builtins::ArrayBuiltins;
86    use vortex_array::compute::conformance::cast::test_cast_conformance;
87    use vortex_array::compute::conformance::consistency::test_array_consistency;
88    use vortex_array::compute::conformance::filter::test_filter_conformance;
89    use vortex_array::compute::conformance::mask::test_mask_conformance;
90    use vortex_array::compute::conformance::take::test_take_conformance;
91    use vortex_array::dtype::DType;
92    use vortex_array::dtype::Nullability;
93    use vortex_array::scalar_fn::fns::operators::Operator;
94
95    use super::*;
96    use crate::ByteBoolArray;
97
98    fn bb(v: Vec<bool>) -> ByteBoolArray {
99        ByteBool::from_vec(v, Validity::AllValid)
100    }
101
102    fn bb_opt(v: Vec<Option<bool>>) -> ByteBoolArray {
103        ByteBool::from_option_vec(v)
104    }
105
106    #[test]
107    fn test_slice() {
108        let original = vec![Some(true), Some(true), None, Some(false), None];
109        let vortex_arr = bb_opt(original);
110
111        let sliced_arr = vortex_arr.slice(1..4).unwrap();
112
113        let expected = bb_opt(vec![Some(true), None, Some(false)]);
114        assert_arrays_eq!(sliced_arr, expected.into_array());
115    }
116
117    #[test]
118    fn test_compare_all_equal() {
119        let lhs = bb(vec![true; 5]);
120        let rhs = bb(vec![true; 5]);
121
122        let arr = lhs
123            .into_array()
124            .binary(rhs.into_array(), Operator::Eq)
125            .unwrap();
126
127        let expected = bb(vec![true; 5]);
128        assert_arrays_eq!(arr, expected.into_array());
129    }
130
131    #[test]
132    fn test_compare_all_different() {
133        let lhs = bb(vec![false; 5]);
134        let rhs = bb(vec![true; 5]);
135
136        let arr = lhs
137            .into_array()
138            .binary(rhs.into_array(), Operator::Eq)
139            .unwrap();
140
141        let expected = bb(vec![false; 5]);
142        assert_arrays_eq!(arr, expected.into_array());
143    }
144
145    #[test]
146    fn test_compare_with_nulls() {
147        let lhs = bb(vec![true; 5]);
148        let rhs = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
149
150        let arr = lhs
151            .into_array()
152            .binary(rhs.into_array(), Operator::Eq)
153            .unwrap();
154
155        let expected = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
156        assert_arrays_eq!(arr, expected.into_array());
157    }
158
159    #[test]
160    fn test_mask_byte_bool() {
161        test_mask_conformance(&bb(vec![true, false, true, true, false]).into_array());
162        test_mask_conformance(
163            &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
164        );
165    }
166
167    #[test]
168    fn test_filter_byte_bool() {
169        test_filter_conformance(&bb(vec![true, false, true, true, false]).into_array());
170        test_filter_conformance(
171            &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
172        );
173    }
174
175    #[rstest]
176    #[case(bb(vec![true, false, true, true, false]))]
177    #[case(bb_opt(vec![Some(true), Some(true), None, Some(false), None]))]
178    #[case(bb(vec![true, false]))]
179    #[case(bb(vec![true]))]
180    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
181        test_take_conformance(&array.into_array());
182    }
183
184    #[test]
185    fn test_cast_bytebool_to_nullable() {
186        let array = bb(vec![true, false, true, false]);
187        let casted = array
188            .into_array()
189            .cast(DType::Bool(Nullability::Nullable))
190            .unwrap();
191        assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
192        assert_eq!(casted.len(), 4);
193    }
194
195    #[rstest]
196    #[case(bb(vec![true, false, true, true, false]))]
197    #[case(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
198    #[case(bb(vec![false]))]
199    #[case(bb(vec![true]))]
200    #[case(bb_opt(vec![Some(true), None]))]
201    fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
202        test_cast_conformance(&array.into_array());
203    }
204
205    #[rstest]
206    #[case::non_nullable(bb(vec![true, false, true, true, false]))]
207    #[case::nullable(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
208    #[case::all_true(bb(vec![true, true, true, true]))]
209    #[case::all_false(bb(vec![false, false, false, false]))]
210    #[case::single_true(bb(vec![true]))]
211    #[case::single_false(bb(vec![false]))]
212    #[case::single_null(bb_opt(vec![None]))]
213    #[case::mixed_with_nulls(bb_opt(vec![Some(true), None, Some(false), None, Some(true)]))]
214    fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
215        test_array_consistency(&array.into_array());
216    }
217}