Skip to main content

vortex_bytebool/
compute.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::ArrayRef;
6use vortex_array::ArrayView;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::PrimitiveArray;
10use vortex_array::arrays::dict::TakeExecute;
11use vortex_array::buffer::BufferHandle;
12use vortex_array::dtype::DType;
13use vortex_array::match_each_integer_ptype;
14use vortex_array::scalar_fn::fns::cast::CastKernel;
15use vortex_array::scalar_fn::fns::cast::CastReduce;
16use vortex_array::scalar_fn::fns::mask::MaskReduce;
17use vortex_array::validity::Validity;
18use vortex_buffer::ByteBuffer;
19use vortex_error::VortexResult;
20
21use super::ByteBool;
22
23impl CastReduce for ByteBool {
24    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
25        // ByteBool is essentially a bool array stored as bytes
26        // The main difference from BoolArray is the storage format
27        // For casting, we can decode to canonical (BoolArray) and let it handle the cast
28        // If just changing nullability, we can optimize
29        if !dtype.is_boolean() {
30            return Ok(None);
31        }
32
33        let Some(new_validity) = array
34            .validity()?
35            .trivially_cast_nullability(dtype.nullability(), array.len())?
36        else {
37            return Ok(None);
38        };
39
40        Ok(Some(
41            ByteBool::new(array.buffer().clone(), new_validity).into_array(),
42        ))
43    }
44}
45
46impl CastKernel for ByteBool {
47    fn cast(
48        array: ArrayView<'_, Self>,
49        dtype: &DType,
50        ctx: &mut ExecutionCtx,
51    ) -> VortexResult<Option<ArrayRef>> {
52        // Only handle nullability changes here; non-bool targets fall through to canonicalization.
53        if !dtype.is_boolean() {
54            return Ok(None);
55        }
56
57        let new_validity =
58            array
59                .validity()?
60                .cast_nullability(dtype.nullability(), array.len(), ctx)?;
61
62        Ok(Some(
63            ByteBool::new(array.buffer().clone(), new_validity).into_array(),
64        ))
65    }
66}
67
68impl MaskReduce for ByteBool {
69    fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
70        Ok(Some(
71            ByteBool::new(
72                array.buffer().clone(),
73                array.validity()?.and(Validity::Array(mask.clone()))?,
74            )
75            .into_array(),
76        ))
77    }
78}
79
80impl TakeExecute for ByteBool {
81    fn take(
82        array: ArrayView<'_, Self>,
83        indices: &ArrayRef,
84        ctx: &mut ExecutionCtx,
85    ) -> VortexResult<Option<ArrayRef>> {
86        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
87        let values = array.truthy_bytes();
88
89        // This handles combining validity from both source array and nullable indices
90        let validity = array.validity()?.take(&indices.clone().into_array())?;
91
92        let taken = match_each_integer_ptype!(indices.ptype(), |I| {
93            indices
94                .as_slice::<I>()
95                .iter()
96                .map(|&idx| {
97                    let idx: usize = idx.as_();
98                    values[idx]
99                })
100                .collect::<ByteBuffer>()
101        });
102
103        Ok(Some(
104            ByteBool::new(BufferHandle::new_host(taken), validity).into_array(),
105        ))
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use rstest::rstest;
112    use vortex_array::assert_arrays_eq;
113    use vortex_array::builtins::ArrayBuiltins;
114    use vortex_array::compute::conformance::cast::test_cast_conformance;
115    use vortex_array::compute::conformance::consistency::test_array_consistency;
116    use vortex_array::compute::conformance::filter::test_filter_conformance;
117    use vortex_array::compute::conformance::mask::test_mask_conformance;
118    use vortex_array::compute::conformance::take::test_take_conformance;
119    use vortex_array::dtype::DType;
120    use vortex_array::dtype::Nullability;
121    use vortex_array::scalar_fn::fns::operators::Operator;
122
123    use super::*;
124    use crate::ByteBoolArray;
125
126    fn bb(v: Vec<bool>) -> ByteBoolArray {
127        ByteBool::from_vec(v, Validity::AllValid)
128    }
129
130    fn bb_opt(v: Vec<Option<bool>>) -> ByteBoolArray {
131        ByteBool::from_option_vec(v)
132    }
133
134    #[test]
135    fn test_slice() {
136        let original = vec![Some(true), Some(true), None, Some(false), None];
137        let vortex_arr = bb_opt(original);
138
139        let sliced_arr = vortex_arr.slice(1..4).unwrap();
140
141        let expected = bb_opt(vec![Some(true), None, Some(false)]);
142        assert_arrays_eq!(sliced_arr, expected.into_array());
143    }
144
145    #[test]
146    fn test_compare_all_equal() {
147        let lhs = bb(vec![true; 5]);
148        let rhs = bb(vec![true; 5]);
149
150        let arr = lhs
151            .into_array()
152            .binary(rhs.into_array(), Operator::Eq)
153            .unwrap();
154
155        let expected = bb(vec![true; 5]);
156        assert_arrays_eq!(arr, expected.into_array());
157    }
158
159    #[test]
160    fn test_compare_all_different() {
161        let lhs = bb(vec![false; 5]);
162        let rhs = bb(vec![true; 5]);
163
164        let arr = lhs
165            .into_array()
166            .binary(rhs.into_array(), Operator::Eq)
167            .unwrap();
168
169        let expected = bb(vec![false; 5]);
170        assert_arrays_eq!(arr, expected.into_array());
171    }
172
173    #[test]
174    fn test_compare_with_nulls() {
175        let lhs = bb(vec![true; 5]);
176        let rhs = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
177
178        let arr = lhs
179            .into_array()
180            .binary(rhs.into_array(), Operator::Eq)
181            .unwrap();
182
183        let expected = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
184        assert_arrays_eq!(arr, expected.into_array());
185    }
186
187    #[test]
188    fn test_mask_byte_bool() {
189        test_mask_conformance(&bb(vec![true, false, true, true, false]).into_array());
190        test_mask_conformance(
191            &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
192        );
193    }
194
195    #[test]
196    fn test_filter_byte_bool() {
197        test_filter_conformance(&bb(vec![true, false, true, true, false]).into_array());
198        test_filter_conformance(
199            &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
200        );
201    }
202
203    #[rstest]
204    #[case(bb(vec![true, false, true, true, false]))]
205    #[case(bb_opt(vec![Some(true), Some(true), None, Some(false), None]))]
206    #[case(bb(vec![true, false]))]
207    #[case(bb(vec![true]))]
208    fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
209        test_take_conformance(&array.into_array());
210    }
211
212    #[test]
213    fn test_cast_bytebool_to_nullable() {
214        let array = bb(vec![true, false, true, false]);
215        let casted = array
216            .into_array()
217            .cast(DType::Bool(Nullability::Nullable))
218            .unwrap();
219        assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
220        assert_eq!(casted.len(), 4);
221    }
222
223    #[rstest]
224    #[case(bb(vec![true, false, true, true, false]))]
225    #[case(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
226    #[case(bb(vec![false]))]
227    #[case(bb(vec![true]))]
228    #[case(bb_opt(vec![Some(true), None]))]
229    fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
230        test_cast_conformance(&array.into_array());
231    }
232
233    #[rstest]
234    #[case::non_nullable(bb(vec![true, false, true, true, false]))]
235    #[case::nullable(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
236    #[case::all_true(bb(vec![true, true, true, true]))]
237    #[case::all_false(bb(vec![false, false, false, false]))]
238    #[case::single_true(bb(vec![true]))]
239    #[case::single_false(bb(vec![false]))]
240    #[case::single_null(bb_opt(vec![None]))]
241    #[case::mixed_with_nulls(bb_opt(vec![Some(true), None, Some(false), None, Some(true)]))]
242    fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
243        test_array_consistency(&array.into_array());
244    }
245}