1use num_traits::AsPrimitive;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::TakeExecute;
11use vortex_array::dtype::DType;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar_fn::fns::cast::CastReduce;
14use vortex_array::scalar_fn::fns::mask::MaskReduce;
15use vortex_array::validity::Validity;
16use vortex_array::vtable::ValidityHelper;
17use vortex_error::VortexResult;
18
19use super::ByteBoolArray;
20use super::ByteBoolVTable;
21
22impl CastReduce for ByteBoolVTable {
23 fn cast(array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
24 if array.dtype().eq_ignore_nullability(dtype) {
30 let new_validity = array
31 .validity()
32 .clone()
33 .cast_nullability(dtype.nullability(), array.len())?;
34
35 return Ok(Some(
36 ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
37 ));
38 }
39
40 Ok(None)
42 }
43}
44
45impl MaskReduce for ByteBoolVTable {
46 fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
47 Ok(Some(
48 ByteBoolArray::new(
49 array.buffer().clone(),
50 array
51 .validity()
52 .clone()
53 .and(Validity::Array(mask.clone()))?,
54 )
55 .into_array(),
56 ))
57 }
58}
59
60impl TakeExecute for ByteBoolVTable {
61 fn take(
62 array: &ByteBoolArray,
63 indices: &dyn Array,
64 _ctx: &mut ExecutionCtx,
65 ) -> VortexResult<Option<ArrayRef>> {
66 let indices = indices.to_primitive();
67 let bools = array.as_slice();
68
69 let validity = array.validity().take(indices.as_ref())?;
71
72 let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
73 indices
74 .as_slice::<I>()
75 .iter()
76 .map(|&idx| {
77 let idx: usize = idx.as_();
78 bools[idx]
79 })
80 .collect::<Vec<bool>>()
81 });
82
83 Ok(Some(
84 ByteBoolArray::from_vec(taken_bools, validity).into_array(),
85 ))
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use rstest::rstest;
92 use vortex_array::assert_arrays_eq;
93 use vortex_array::builtins::ArrayBuiltins;
94 use vortex_array::compute::conformance::cast::test_cast_conformance;
95 use vortex_array::compute::conformance::consistency::test_array_consistency;
96 use vortex_array::compute::conformance::filter::test_filter_conformance;
97 use vortex_array::compute::conformance::mask::test_mask_conformance;
98 use vortex_array::compute::conformance::take::test_take_conformance;
99 use vortex_array::dtype::DType;
100 use vortex_array::dtype::Nullability;
101 use vortex_array::scalar_fn::fns::operators::Operator;
102
103 use super::*;
104
105 #[test]
106 fn test_slice() {
107 let original = vec![Some(true), Some(true), None, Some(false), None];
108 let vortex_arr = ByteBoolArray::from(original);
109
110 let sliced_arr = vortex_arr.slice(1..4).unwrap();
111
112 let expected = ByteBoolArray::from(vec![Some(true), None, Some(false)]);
113 assert_arrays_eq!(sliced_arr, expected.to_array());
114 }
115
116 #[test]
117 fn test_compare_all_equal() {
118 let lhs = ByteBoolArray::from(vec![true; 5]);
119 let rhs = ByteBoolArray::from(vec![true; 5]);
120
121 let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
122
123 let expected = ByteBoolArray::from(vec![true; 5]);
124 assert_arrays_eq!(arr, expected.to_array());
125 }
126
127 #[test]
128 fn test_compare_all_different() {
129 let lhs = ByteBoolArray::from(vec![false; 5]);
130 let rhs = ByteBoolArray::from(vec![true; 5]);
131
132 let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
133
134 let expected = ByteBoolArray::from(vec![false; 5]);
135 assert_arrays_eq!(arr, expected.to_array());
136 }
137
138 #[test]
139 fn test_compare_with_nulls() {
140 let lhs = ByteBoolArray::from(vec![true; 5]);
141 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
142
143 let arr = lhs.to_array().binary(rhs.to_array(), Operator::Eq).unwrap();
144
145 let expected =
146 ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
147 assert_arrays_eq!(arr, expected.to_array());
148 }
149
150 #[test]
151 fn test_mask_byte_bool() {
152 test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
153 test_mask_conformance(
154 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
155 );
156 }
157
158 #[test]
159 fn test_filter_byte_bool() {
160 test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
161 test_filter_conformance(
162 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
163 );
164 }
165
166 #[rstest]
167 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
168 #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
169 #[case(ByteBoolArray::from(vec![true, false]))]
170 #[case(ByteBoolArray::from(vec![true]))]
171 fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
172 test_take_conformance(array.as_ref());
173 }
174
175 #[test]
176 fn test_cast_bytebool_to_nullable() {
177 let array = ByteBoolArray::from(vec![true, false, true, false]);
178 let casted = array
179 .to_array()
180 .cast(DType::Bool(Nullability::Nullable))
181 .unwrap();
182 assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
183 assert_eq!(casted.len(), 4);
184 }
185
186 #[rstest]
187 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
188 #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
189 #[case(ByteBoolArray::from(vec![false]))]
190 #[case(ByteBoolArray::from(vec![true]))]
191 #[case(ByteBoolArray::from(vec![Some(true), None]))]
192 fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
193 test_cast_conformance(array.as_ref());
194 }
195
196 #[rstest]
197 #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
198 #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
199 #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
200 #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
201 #[case::single_true(ByteBoolArray::from(vec![true]))]
202 #[case::single_false(ByteBoolArray::from(vec![false]))]
203 #[case::single_null(ByteBoolArray::from(vec![None]))]
204 #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
205 fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
206 test_array_consistency(array.as_ref());
207 }
208}