1use num_traits::AsPrimitive;
5use vortex_array::compute::{
6 CastKernel, CastKernelAdapter, MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter,
7};
8use vortex_array::vtable::ValidityHelper;
9use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
10use vortex_dtype::{DType, match_each_integer_ptype};
11use vortex_error::VortexResult;
12use vortex_mask::Mask;
13
14use super::{ByteBoolArray, ByteBoolVTable};
15
16impl CastKernel for ByteBoolVTable {
17 fn cast(&self, array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18 if array.dtype().eq_ignore_nullability(dtype) {
24 let new_validity = array
25 .validity()
26 .clone()
27 .cast_nullability(dtype.nullability())?;
28
29 return Ok(Some(
30 ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
31 ));
32 }
33
34 Ok(None)
36 }
37}
38
39register_kernel!(CastKernelAdapter(ByteBoolVTable).lift());
40
41impl MaskKernel for ByteBoolVTable {
42 fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
43 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)).into_array())
44 }
45}
46
47register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
48
49impl TakeKernel for ByteBoolVTable {
50 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
51 let indices = indices.to_primitive();
52 let bools = array.as_slice();
53
54 let validity = array.validity().take(indices.as_ref())?;
56
57 let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
58 indices
59 .as_slice::<I>()
60 .iter()
61 .map(|&idx| {
62 let idx: usize = idx.as_();
63 bools[idx]
64 })
65 .collect::<Vec<bool>>()
66 });
67
68 Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array())
69 }
70}
71
72register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
73
74#[cfg(test)]
75mod tests {
76 use rstest::rstest;
77 use vortex_array::compute::conformance::filter::test_filter_conformance;
78 use vortex_array::compute::conformance::mask::test_mask_conformance;
79 use vortex_array::compute::conformance::take::test_take_conformance;
80 use vortex_array::compute::{Operator, compare};
81
82 use super::*;
83
84 #[test]
85 fn test_slice() {
86 let original = vec![Some(true), Some(true), None, Some(false), None];
87 let vortex_arr = ByteBoolArray::from(original);
88
89 let sliced_arr = vortex_arr.slice(1..4);
90 let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
91
92 let s = sliced_arr.scalar_at(0);
93 assert_eq!(s.as_bool().value(), Some(true));
94
95 let s = sliced_arr.scalar_at(1);
96 assert!(!sliced_arr.is_valid(1));
97 assert!(s.is_null());
98 assert_eq!(s.as_bool().value(), None);
99
100 let s = sliced_arr.scalar_at(2);
101 assert_eq!(s.as_bool().value(), Some(false));
102 }
103
104 #[test]
105 fn test_compare_all_equal() {
106 let lhs = ByteBoolArray::from(vec![true; 5]);
107 let rhs = ByteBoolArray::from(vec![true; 5]);
108
109 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
110
111 for i in 0..arr.len() {
112 let s = arr.scalar_at(i);
113 assert!(s.is_valid());
114 assert_eq!(s.as_bool().value(), Some(true));
115 }
116 }
117
118 #[test]
119 fn test_compare_all_different() {
120 let lhs = ByteBoolArray::from(vec![false; 5]);
121 let rhs = ByteBoolArray::from(vec![true; 5]);
122
123 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
124
125 for i in 0..arr.len() {
126 let s = arr.scalar_at(i);
127 assert!(s.is_valid());
128 assert_eq!(s.as_bool().value(), Some(false));
129 }
130 }
131
132 #[test]
133 fn test_compare_with_nulls() {
134 let lhs = ByteBoolArray::from(vec![true; 5]);
135 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
136
137 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
138
139 for i in 0..3 {
140 let s = arr.scalar_at(i);
141 assert!(s.is_valid());
142 assert_eq!(s.as_bool().value(), Some(true));
143 }
144
145 let s = arr.scalar_at(3);
146 assert!(s.is_valid());
147 assert_eq!(s.as_bool().value(), Some(false));
148
149 let s = arr.scalar_at(4);
150 assert!(s.is_null());
151 }
152
153 #[test]
154 fn test_mask_byte_bool() {
155 test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
156 test_mask_conformance(
157 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
158 );
159 }
160
161 #[test]
162 fn test_filter_byte_bool() {
163 test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
164 test_filter_conformance(
165 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
166 );
167 }
168
169 #[rstest]
170 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
171 #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
172 #[case(ByteBoolArray::from(vec![true, false]))]
173 #[case(ByteBoolArray::from(vec![true]))]
174 fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
175 test_take_conformance(array.as_ref());
176 }
177
178 use vortex_array::compute::cast;
179 use vortex_array::compute::conformance::cast::test_cast_conformance;
180 use vortex_array::compute::conformance::consistency::test_array_consistency;
181 use vortex_dtype::{DType, Nullability};
182
183 #[test]
184 fn test_cast_bytebool_to_nullable() {
185 let array = ByteBoolArray::from(vec![true, false, true, false]);
186 let casted = cast(array.as_ref(), &DType::Bool(Nullability::Nullable)).unwrap();
187 assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
188 assert_eq!(casted.len(), 4);
189 }
190
191 #[rstest]
192 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
193 #[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
194 #[case(ByteBoolArray::from(vec![false]))]
195 #[case(ByteBoolArray::from(vec![true]))]
196 #[case(ByteBoolArray::from(vec![Some(true), None]))]
197 fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
198 test_cast_conformance(array.as_ref());
199 }
200
201 #[rstest]
202 #[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
203 #[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
204 #[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
205 #[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
206 #[case::single_true(ByteBoolArray::from(vec![true]))]
207 #[case::single_false(ByteBoolArray::from(vec![false]))]
208 #[case::single_null(ByteBoolArray::from(vec![None]))]
209 #[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
210 fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
211 test_array_consistency(array.as_ref());
212 }
213}