1use num_traits::AsPrimitive;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::PrimitiveArray;
9use vortex_array::arrays::dict::TakeExecute;
10use vortex_array::dtype::DType;
11use vortex_array::match_each_integer_ptype;
12use vortex_array::scalar_fn::fns::cast::CastReduce;
13use vortex_array::scalar_fn::fns::mask::MaskReduce;
14use vortex_array::validity::Validity;
15use vortex_array::vtable::ValidityHelper;
16use vortex_error::VortexResult;
17
18use super::ByteBoolArray;
19use super::ByteBoolVTable;
20
21impl CastReduce for ByteBoolVTable {
22 fn cast(array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
23 if array.dtype().eq_ignore_nullability(dtype) {
29 let new_validity = array
30 .validity()
31 .clone()
32 .cast_nullability(dtype.nullability(), array.len())?;
33
34 return Ok(Some(
35 ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
36 ));
37 }
38
39 Ok(None)
41 }
42}
43
44impl MaskReduce for ByteBoolVTable {
45 fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
46 Ok(Some(
47 ByteBoolArray::new(
48 array.buffer().clone(),
49 array
50 .validity()
51 .clone()
52 .and(Validity::Array(mask.clone()))?,
53 )
54 .into_array(),
55 ))
56 }
57}
58
59impl TakeExecute for ByteBoolVTable {
60 fn take(
61 array: &ByteBoolArray,
62 indices: &ArrayRef,
63 ctx: &mut ExecutionCtx,
64 ) -> VortexResult<Option<ArrayRef>> {
65 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
66 let bools = array.as_slice();
67
68 let validity = array.validity().take(&indices.clone().into_array())?;
70
71 let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
72 indices
73 .as_slice::<I>()
74 .iter()
75 .map(|&idx| {
76 let idx: usize = idx.as_();
77 bools[idx]
78 })
79 .collect::<Vec<bool>>()
80 });
81
82 Ok(Some(
83 ByteBoolArray::from_vec(taken_bools, validity).into_array(),
84 ))
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use rstest::rstest;
91 use vortex_array::assert_arrays_eq;
92 use vortex_array::builtins::ArrayBuiltins;
93 use vortex_array::compute::conformance::cast::test_cast_conformance;
94 use vortex_array::compute::conformance::consistency::test_array_consistency;
95 use vortex_array::compute::conformance::filter::test_filter_conformance;
96 use vortex_array::compute::conformance::mask::test_mask_conformance;
97 use vortex_array::compute::conformance::take::test_take_conformance;
98 use vortex_array::dtype::DType;
99 use vortex_array::dtype::Nullability;
100 use vortex_array::scalar_fn::fns::operators::Operator;
101
102 use super::*;
103
104 #[test]
105 fn test_slice() {
106 let original = vec![Some(true), Some(true), None, Some(false), None];
107 let vortex_arr = ByteBoolArray::from(original);
108
109 let sliced_arr = vortex_arr.slice(1..4).unwrap();
110
111 let expected = ByteBoolArray::from(vec![Some(true), None, Some(false)]);
112 assert_arrays_eq!(sliced_arr, expected.into_array());
113 }
114
115 #[test]
116 fn test_compare_all_equal() {
117 let lhs = ByteBoolArray::from(vec![true; 5]);
118 let rhs = ByteBoolArray::from(vec![true; 5]);
119
120 let arr = lhs
121 .into_array()
122 .binary(rhs.into_array(), Operator::Eq)
123 .unwrap();
124
125 let expected = ByteBoolArray::from(vec![true; 5]);
126 assert_arrays_eq!(arr, expected.into_array());
127 }
128
129 #[test]
130 fn test_compare_all_different() {
131 let lhs = ByteBoolArray::from(vec![false; 5]);
132 let rhs = ByteBoolArray::from(vec![true; 5]);
133
134 let arr = lhs
135 .into_array()
136 .binary(rhs.into_array(), Operator::Eq)
137 .unwrap();
138
139 let expected = ByteBoolArray::from(vec![false; 5]);
140 assert_arrays_eq!(arr, expected.into_array());
141 }
142
143 #[test]
144 fn test_compare_with_nulls() {
145 let lhs = ByteBoolArray::from(vec![true; 5]);
146 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
147
148 let arr = lhs
149 .into_array()
150 .binary(rhs.into_array(), Operator::Eq)
151 .unwrap();
152
153 let expected =
154 ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
155 assert_arrays_eq!(arr, expected.into_array());
156 }
157
158 #[test]
159 fn test_mask_byte_bool() {
160 test_mask_conformance(
161 &ByteBoolArray::from(vec![true, false, true, true, false]).into_array(),
162 );
163 test_mask_conformance(
164 &ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None])
165 .into_array(),
166 );
167 }
168
169 #[test]
170 fn test_filter_byte_bool() {
171 test_filter_conformance(
172 &ByteBoolArray::from(vec![true, false, true, true, false]).into_array(),
173 );
174 test_filter_conformance(
175 &ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None])
176 .into_array(),
177 );
178 }
179
180 #[rstest]
181 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
182 #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
183 #[case(ByteBoolArray::from(vec![true, false]))]
184 #[case(ByteBoolArray::from(vec![true]))]
185 fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
186 test_take_conformance(&array.into_array());
187 }
188
189 #[test]
190 fn test_cast_bytebool_to_nullable() {
191 let array = ByteBoolArray::from(vec![true, false, true, false]);
192 let casted = array
193 .into_array()
194 .cast(DType::Bool(Nullability::Nullable))
195 .unwrap();
196 assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
197 assert_eq!(casted.len(), 4);
198 }
199
200 #[rstest]
201 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
202 #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
203 #[case(ByteBoolArray::from(vec![false]))]
204 #[case(ByteBoolArray::from(vec![true]))]
205 #[case(ByteBoolArray::from(vec![Some(true), None]))]
206 fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
207 test_cast_conformance(&array.into_array());
208 }
209
210 #[rstest]
211 #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
212 #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
213 #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
214 #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
215 #[case::single_true(ByteBoolArray::from(vec![true]))]
216 #[case::single_false(ByteBoolArray::from(vec![false]))]
217 #[case::single_null(ByteBoolArray::from(vec![None]))]
218 #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
219 fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
220 test_array_consistency(&array.into_array());
221 }
222}