vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
2use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
3use vortex_array::vtable::ValidityHelper;
4use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use super::{ByteBoolArray, ByteBoolVTable};
10
11impl MaskKernel for ByteBoolVTable {
12 fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
13 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
14 }
15}
16
17register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
18
19impl TakeKernel for ByteBoolVTable {
20 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21 let validity = array.validity_mask()?;
22 let indices = indices.to_primitive()?;
23 let bools = array.as_slice();
24
25 let arr = match validity {
28 Mask::AllTrue(_) => {
29 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
30 indices.as_slice::<$I>()
31 .iter()
32 .map(|&idx| {
33 let idx: usize = idx.as_();
34 bools[idx]
35 })
36 .collect::<Vec<_>>()
37 });
38
39 ByteBoolArray::from(bools).into_array()
40 }
41 Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
42 Mask::Values(values) => {
43 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
44 indices.as_slice::<$I>()
45 .iter()
46 .map(|&idx| {
47 let idx = idx.as_();
48 if values.value(idx) {
49 Some(bools[idx])
50 } else {
51 None
52 }
53 })
54 .collect::<Vec<Option<_>>>()
55 });
56
57 ByteBoolArray::from(bools).into_array()
58 }
59 };
60
61 Ok(arr)
62 }
63}
64
65register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
66
67#[cfg(test)]
68mod tests {
69 use vortex_array::compute::conformance::mask::test_mask;
70 use vortex_array::compute::{Operator, compare};
71
72 use super::*;
73
74 #[test]
75 fn test_slice() {
76 let original = vec![Some(true), Some(true), None, Some(false), None];
77 let vortex_arr = ByteBoolArray::from(original);
78
79 let sliced_arr = vortex_arr.slice(1, 4).unwrap();
80 let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
81
82 let s = sliced_arr.scalar_at(0).unwrap();
83 assert_eq!(s.as_bool().value(), Some(true));
84
85 let s = sliced_arr.scalar_at(1).unwrap();
86 assert!(!sliced_arr.is_valid(1).unwrap());
87 assert!(s.is_null());
88 assert_eq!(s.as_bool().value(), None);
89
90 let s = sliced_arr.scalar_at(2).unwrap();
91 assert_eq!(s.as_bool().value(), Some(false));
92 }
93
94 #[test]
95 fn test_compare_all_equal() {
96 let lhs = ByteBoolArray::from(vec![true; 5]);
97 let rhs = ByteBoolArray::from(vec![true; 5]);
98
99 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
100
101 for i in 0..arr.len() {
102 let s = arr.scalar_at(i).unwrap();
103 assert!(s.is_valid());
104 assert_eq!(s.as_bool().value(), Some(true));
105 }
106 }
107
108 #[test]
109 fn test_compare_all_different() {
110 let lhs = ByteBoolArray::from(vec![false; 5]);
111 let rhs = ByteBoolArray::from(vec![true; 5]);
112
113 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
114
115 for i in 0..arr.len() {
116 let s = arr.scalar_at(i).unwrap();
117 assert!(s.is_valid());
118 assert_eq!(s.as_bool().value(), Some(false));
119 }
120 }
121
122 #[test]
123 fn test_compare_with_nulls() {
124 let lhs = ByteBoolArray::from(vec![true; 5]);
125 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
126
127 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
128
129 for i in 0..3 {
130 let s = arr.scalar_at(i).unwrap();
131 assert!(s.is_valid());
132 assert_eq!(s.as_bool().value(), Some(true));
133 }
134
135 let s = arr.scalar_at(3).unwrap();
136 assert!(s.is_valid());
137 assert_eq!(s.as_bool().value(), Some(false));
138
139 let s = arr.scalar_at(4).unwrap();
140 assert!(s.is_null());
141 }
142
143 #[test]
144 fn test_mask_byte_bool() {
145 test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
146 test_mask(
147 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
148 );
149 }
150}