vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
5use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
6use vortex_array::vtable::ValidityHelper;
7use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use super::{ByteBoolArray, ByteBoolVTable};
13
14impl MaskKernel for ByteBoolVTable {
15 fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
16 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
17 }
18}
19
20register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
21
22impl TakeKernel for ByteBoolVTable {
23 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24 let indices = indices.to_primitive()?;
25 let bools = array.as_slice();
26
27 let validity = array.validity().take(indices.as_ref())?;
29
30 let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
31 indices
32 .as_slice::<I>()
33 .iter()
34 .map(|&idx| {
35 let idx: usize = idx.as_();
36 bools[idx]
37 })
38 .collect::<Vec<bool>>()
39 });
40
41 Ok(ByteBoolArray::from_vec(taken_bools, validity).into_array())
42 }
43}
44
45register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
46
47#[cfg(test)]
48mod tests {
49 use rstest::rstest;
50 use vortex_array::compute::conformance::filter::test_filter_conformance;
51 use vortex_array::compute::conformance::mask::test_mask_conformance;
52 use vortex_array::compute::conformance::take::test_take_conformance;
53 use vortex_array::compute::{Operator, compare};
54
55 use super::*;
56
57 #[test]
58 fn test_slice() {
59 let original = vec![Some(true), Some(true), None, Some(false), None];
60 let vortex_arr = ByteBoolArray::from(original);
61
62 let sliced_arr = vortex_arr.slice(1, 4).unwrap();
63 let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
64
65 let s = sliced_arr.scalar_at(0).unwrap();
66 assert_eq!(s.as_bool().value(), Some(true));
67
68 let s = sliced_arr.scalar_at(1).unwrap();
69 assert!(!sliced_arr.is_valid(1).unwrap());
70 assert!(s.is_null());
71 assert_eq!(s.as_bool().value(), None);
72
73 let s = sliced_arr.scalar_at(2).unwrap();
74 assert_eq!(s.as_bool().value(), Some(false));
75 }
76
77 #[test]
78 fn test_compare_all_equal() {
79 let lhs = ByteBoolArray::from(vec![true; 5]);
80 let rhs = ByteBoolArray::from(vec![true; 5]);
81
82 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
83
84 for i in 0..arr.len() {
85 let s = arr.scalar_at(i).unwrap();
86 assert!(s.is_valid());
87 assert_eq!(s.as_bool().value(), Some(true));
88 }
89 }
90
91 #[test]
92 fn test_compare_all_different() {
93 let lhs = ByteBoolArray::from(vec![false; 5]);
94 let rhs = ByteBoolArray::from(vec![true; 5]);
95
96 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
97
98 for i in 0..arr.len() {
99 let s = arr.scalar_at(i).unwrap();
100 assert!(s.is_valid());
101 assert_eq!(s.as_bool().value(), Some(false));
102 }
103 }
104
105 #[test]
106 fn test_compare_with_nulls() {
107 let lhs = ByteBoolArray::from(vec![true; 5]);
108 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
109
110 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
111
112 for i in 0..3 {
113 let s = arr.scalar_at(i).unwrap();
114 assert!(s.is_valid());
115 assert_eq!(s.as_bool().value(), Some(true));
116 }
117
118 let s = arr.scalar_at(3).unwrap();
119 assert!(s.is_valid());
120 assert_eq!(s.as_bool().value(), Some(false));
121
122 let s = arr.scalar_at(4).unwrap();
123 assert!(s.is_null());
124 }
125
126 #[test]
127 fn test_mask_byte_bool() {
128 test_mask_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
129 test_mask_conformance(
130 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
131 );
132 }
133
134 #[test]
135 fn test_filter_byte_bool() {
136 test_filter_conformance(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
137 test_filter_conformance(
138 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
139 );
140 }
141
142 #[rstest]
143 #[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
144 #[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
145 #[case(ByteBoolArray::from(vec![true, false]))]
146 #[case(ByteBoolArray::from(vec![true]))]
147 fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
148 test_take_conformance(array.as_ref());
149 }
150}