1use num_traits::AsPrimitive;
5use vortex_array::ArrayRef;
6use vortex_array::ArrayView;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::PrimitiveArray;
10use vortex_array::arrays::dict::TakeExecute;
11use vortex_array::buffer::BufferHandle;
12use vortex_array::dtype::DType;
13use vortex_array::match_each_integer_ptype;
14use vortex_array::scalar_fn::fns::cast::CastKernel;
15use vortex_array::scalar_fn::fns::cast::CastReduce;
16use vortex_array::scalar_fn::fns::mask::MaskReduce;
17use vortex_array::validity::Validity;
18use vortex_buffer::ByteBuffer;
19use vortex_error::VortexResult;
20
21use super::ByteBool;
22
23impl CastReduce for ByteBool {
24 fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
25 if !dtype.is_boolean() {
30 return Ok(None);
31 }
32
33 let Some(new_validity) = array
34 .validity()?
35 .trivial_cast_nullability(dtype.nullability(), array.len())?
36 else {
37 return Ok(None);
38 };
39
40 Ok(Some(
41 ByteBool::new(array.buffer().clone(), new_validity).into_array(),
42 ))
43 }
44}
45
46impl CastKernel for ByteBool {
47 fn cast(
48 array: ArrayView<'_, Self>,
49 dtype: &DType,
50 ctx: &mut ExecutionCtx,
51 ) -> VortexResult<Option<ArrayRef>> {
52 if !dtype.is_boolean() {
54 return Ok(None);
55 }
56
57 let new_validity =
58 array
59 .validity()?
60 .cast_nullability(dtype.nullability(), array.len(), ctx)?;
61
62 Ok(Some(
63 ByteBool::new(array.buffer().clone(), new_validity).into_array(),
64 ))
65 }
66}
67
68impl MaskReduce for ByteBool {
69 fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
70 Ok(Some(
71 ByteBool::new(
72 array.buffer().clone(),
73 array.validity()?.and(Validity::Array(mask.clone()))?,
74 )
75 .into_array(),
76 ))
77 }
78}
79
80impl TakeExecute for ByteBool {
81 fn take(
82 array: ArrayView<'_, Self>,
83 indices: &ArrayRef,
84 ctx: &mut ExecutionCtx,
85 ) -> VortexResult<Option<ArrayRef>> {
86 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
87 let values = array.truthy_bytes();
88
89 let validity = array.validity()?.take(&indices.clone().into_array())?;
91
92 let taken = match_each_integer_ptype!(indices.ptype(), |I| {
93 indices
94 .as_slice::<I>()
95 .iter()
96 .map(|&idx| {
97 let idx: usize = idx.as_();
98 values[idx]
99 })
100 .collect::<ByteBuffer>()
101 });
102
103 Ok(Some(
104 ByteBool::new(BufferHandle::new_host(taken), validity).into_array(),
105 ))
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use rstest::rstest;
112 use vortex_array::assert_arrays_eq;
113 use vortex_array::builtins::ArrayBuiltins;
114 use vortex_array::compute::conformance::cast::test_cast_conformance;
115 use vortex_array::compute::conformance::consistency::test_array_consistency;
116 use vortex_array::compute::conformance::filter::test_filter_conformance;
117 use vortex_array::compute::conformance::mask::test_mask_conformance;
118 use vortex_array::compute::conformance::take::test_take_conformance;
119 use vortex_array::dtype::DType;
120 use vortex_array::dtype::Nullability;
121 use vortex_array::scalar_fn::fns::operators::Operator;
122
123 use super::*;
124 use crate::ByteBoolArray;
125
126 fn bb(v: Vec<bool>) -> ByteBoolArray {
127 ByteBool::from_vec(v, Validity::AllValid)
128 }
129
130 fn bb_opt(v: Vec<Option<bool>>) -> ByteBoolArray {
131 ByteBool::from_option_vec(v)
132 }
133
134 #[test]
135 fn test_slice() {
136 let original = vec![Some(true), Some(true), None, Some(false), None];
137 let vortex_arr = bb_opt(original);
138
139 let sliced_arr = vortex_arr.slice(1..4).unwrap();
140
141 let expected = bb_opt(vec![Some(true), None, Some(false)]);
142 assert_arrays_eq!(sliced_arr, expected.into_array());
143 }
144
145 #[test]
146 fn test_compare_all_equal() {
147 let lhs = bb(vec![true; 5]);
148 let rhs = bb(vec![true; 5]);
149
150 let arr = lhs
151 .into_array()
152 .binary(rhs.into_array(), Operator::Eq)
153 .unwrap();
154
155 let expected = bb(vec![true; 5]);
156 assert_arrays_eq!(arr, expected.into_array());
157 }
158
159 #[test]
160 fn test_compare_all_different() {
161 let lhs = bb(vec![false; 5]);
162 let rhs = bb(vec![true; 5]);
163
164 let arr = lhs
165 .into_array()
166 .binary(rhs.into_array(), Operator::Eq)
167 .unwrap();
168
169 let expected = bb(vec![false; 5]);
170 assert_arrays_eq!(arr, expected.into_array());
171 }
172
173 #[test]
174 fn test_compare_with_nulls() {
175 let lhs = bb(vec![true; 5]);
176 let rhs = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
177
178 let arr = lhs
179 .into_array()
180 .binary(rhs.into_array(), Operator::Eq)
181 .unwrap();
182
183 let expected = bb_opt(vec![Some(true), Some(true), Some(true), Some(false), None]);
184 assert_arrays_eq!(arr, expected.into_array());
185 }
186
187 #[test]
188 fn test_mask_byte_bool() {
189 test_mask_conformance(&bb(vec![true, false, true, true, false]).into_array());
190 test_mask_conformance(
191 &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
192 );
193 }
194
195 #[test]
196 fn test_filter_byte_bool() {
197 test_filter_conformance(&bb(vec![true, false, true, true, false]).into_array());
198 test_filter_conformance(
199 &bb_opt(vec![Some(true), Some(true), None, Some(false), None]).into_array(),
200 );
201 }
202
203 #[rstest]
204 #[case(bb(vec![true, false, true, true, false]))]
205 #[case(bb_opt(vec![Some(true), Some(true), None, Some(false), None]))]
206 #[case(bb(vec![true, false]))]
207 #[case(bb(vec![true]))]
208 fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
209 test_take_conformance(&array.into_array());
210 }
211
212 #[test]
213 fn test_cast_bytebool_to_nullable() {
214 let array = bb(vec![true, false, true, false]);
215 let casted = array
216 .into_array()
217 .cast(DType::Bool(Nullability::Nullable))
218 .unwrap();
219 assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
220 assert_eq!(casted.len(), 4);
221 }
222
223 #[rstest]
224 #[case(bb(vec![true, false, true, true, false]))]
225 #[case(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
226 #[case(bb(vec![false]))]
227 #[case(bb(vec![true]))]
228 #[case(bb_opt(vec![Some(true), None]))]
229 fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
230 test_cast_conformance(&array.into_array());
231 }
232
233 #[rstest]
234 #[case::non_nullable(bb(vec![true, false, true, true, false]))]
235 #[case::nullable(bb_opt(vec![Some(true), Some(false), None, Some(true), None]))]
236 #[case::all_true(bb(vec![true, true, true, true]))]
237 #[case::all_false(bb(vec![false, false, false, false]))]
238 #[case::single_true(bb(vec![true]))]
239 #[case::single_false(bb(vec![false]))]
240 #[case::single_null(bb_opt(vec![None]))]
241 #[case::mixed_with_nulls(bb_opt(vec![Some(true), None, Some(false), None, Some(true)]))]
242 fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
243 test_array_consistency(&array.into_array());
244 }
245}