1use num_traits::AsPrimitive;
5use vortex_array::ArrayRef;
6use vortex_array::ArrayView;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::PrimitiveArray;
10use vortex_array::arrays::dict::TakeExecute;
11use vortex_array::dtype::DType;
12use vortex_array::match_each_integer_ptype;
13use vortex_array::scalar_fn::fns::cast::CastReduce;
14use vortex_array::scalar_fn::fns::mask::MaskReduce;
15use vortex_array::validity::Validity;
16use vortex_error::VortexResult;
17
18use super::ByteBool;
19
20impl CastReduce for ByteBool {
21 fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
22 if array.dtype().eq_ignore_nullability(dtype) {
28 let new_validity = array
29 .validity()?
30 .cast_nullability(dtype.nullability(), array.len())?;
31
32 return Ok(Some(
33 ByteBool::new(array.buffer().clone(), new_validity).into_array(),
34 ));
35 }
36
37 Ok(None)
39 }
40}
41
42impl MaskReduce for ByteBool {
43 fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
44 Ok(Some(
45 ByteBool::new(
46 array.buffer().clone(),
47 array.validity()?.and(Validity::Array(mask.clone()))?,
48 )
49 .into_array(),
50 ))
51 }
52}
53
54impl TakeExecute for ByteBool {
55 fn take(
56 array: ArrayView<'_, Self>,
57 indices: &ArrayRef,
58 ctx: &mut ExecutionCtx,
59 ) -> VortexResult<Option<ArrayRef>> {
60 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
61 let bools = array.as_slice();
62
63 let validity = array.validity()?.take(&indices.clone().into_array())?;
65
66 let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
67 indices
68 .as_slice::<I>()
69 .iter()
70 .map(|&idx| {
71 let idx: usize = idx.as_();
72 bools[idx]
73 })
74 .collect::<Vec<bool>>()
75 });
76
77 Ok(Some(ByteBool::from_vec(taken_bools, validity).into_array()))
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use rstest::rstest;
84 use vortex_array::assert_arrays_eq;
85 use vortex_array::builtins::ArrayBuiltins;
86 use vortex_array::compute::conformance::cast::test_cast_conformance;
87 use vortex_array::compute::conformance::consistency::test_array_consistency;
88 use vortex_array::compute::conformance::filter::test_filter_conformance;
89 use vortex_array::compute::conformance::mask::test_mask_conformance;
90 use vortex_array::compute::conformance::take::test_take_conformance;
91 use vortex_array::dtype::DType;
92 use vortex_array::dtype::Nullability;
93 use vortex_array::scalar_fn::fns::operators::Operator;
94
95 use super::*;
96 use crate::ByteBoolArray;
97
98 fn bb(v: Vec<bool>) -> ByteBoolArray {
99 ByteBool::from_vec(v, Validity::AllValid)
100 }
101
102 fn bb_opt(v: Vec<Option<bool>>) -> ByteBoolArray {
103 ByteBool::from_option_vec(v)
104 }
105
106 #[test]
107 fn test_slice() {
108 let original = vec![Some(true), Some(true), None, Some(false), None];
109 let vortex_arr = bb_opt(original);
110
111 let sliced_arr = vortex_arr.slice(1..4).unwrap();
112
113 let expected = bb_opt(vec![Some(true), None, Some(false)]);
114 assert_arrays_eq!(sliced_arr, expected.into_array());
115 }
116
117 #[test]
118 fn test_compare_all_equal() {
119 let lhs = bb(vec![true; 5]);
120 let rhs = bb(vec![true; 5]);
121
122 let arr = lhs
123 .into_array()
124 .binary(rhs.into_array(), Operator::Eq)
125 .unwrap();
126
127 let expected = bb(vec![true; 5]);
128 assert_arrays_eq!(arr, expected.into_array());
129 }
130
131 #[test]
132 fn test_compare_all_different() {
133 let lhs = bb(vec![false; 5]);
134 let rhs = bb(vec![true; 5]);
135
136 let arr = lhs
137 .into_array()
138 .binary(rhs.into_array(), Operator::Eq)
139 .unwrap();
140
141 let expected = bb(vec![false; 5]);
142 assert_arrays_eq!(arr, expected.into_array());
143 }
144
145 #[test]
146 fn test_compare_with_nulls() {
147 let lhs = bb(vec![true; 5]);
148 let rhs = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
149
150 let arr = lhs
151 .into_array()
152 .binary(rhs.into_array(), Operator::Eq)
153 .unwrap();
154
155 let expected = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
156 assert_arrays_eq!(arr, expected.into_array());
157 }
158
159 #[test]
160 fn test_mask_byte_bool() {
161 test_mask_conformance(&bb(vec![true, false, true, true, false]).into_array());
162 test_mask_conformance(
163 &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
164 );
165 }
166
167 #[test]
168 fn test_filter_byte_bool() {
169 test_filter_conformance(&bb(vec![true, false, true, true, false]).into_array());
170 test_filter_conformance(
171 &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
172 );
173 }
174
175 #[rstest]
176 #[case(bb(vec![true, false, true, true, false]))]
177 #[case(bb_opt(vec![Some(true), Some(true), None, Some(false), None]))]
178 #[case(bb(vec![true, false]))]
179 #[case(bb(vec![true]))]
180 fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
181 test_take_conformance(&array.into_array());
182 }
183
184 #[test]
185 fn test_cast_bytebool_to_nullable() {
186 let array = bb(vec![true, false, true, false]);
187 let casted = array
188 .into_array()
189 .cast(DType::Bool(Nullability::Nullable))
190 .unwrap();
191 assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
192 assert_eq!(casted.len(), 4);
193 }
194
195 #[rstest]
196 #[case(bb(vec![true, false, true, true, false]))]
197 #[case(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
198 #[case(bb(vec![false]))]
199 #[case(bb(vec![true]))]
200 #[case(bb_opt(vec![Some(true), None]))]
201 fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
202 test_cast_conformance(&array.into_array());
203 }
204
205 #[rstest]
206 #[case::non_nullable(bb(vec![true, false, true, true, false]))]
207 #[case::nullable(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
208 #[case::all_true(bb(vec![true, true, true, true]))]
209 #[case::all_false(bb(vec![false, false, false, false]))]
210 #[case::single_true(bb(vec![true]))]
211 #[case::single_false(bb(vec![false]))]
212 #[case::single_null(bb_opt(vec![None]))]
213 #[case::mixed_with_nulls(bb_opt(vec![Some(true), None, Some(false), None, Some(true)]))]
214 fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
215 test_array_consistency(&array.into_array());
216 }
217}