vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
2use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
3use vortex_array::vtable::ValidityHelper;
4use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexResult;
7use vortex_mask::Mask;
8
9use super::{ByteBoolArray, ByteBoolVTable};
10
11impl MaskKernel for ByteBoolVTable {
12 fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
13 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
14 }
15}
16
17register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
18
19impl TakeKernel for ByteBoolVTable {
20 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
21 let validity = array.validity_mask()?;
22 let indices = indices.to_primitive()?;
23 let bools = array.as_slice();
24
25 let arr = match validity {
28 Mask::AllTrue(_) => {
29 let bools = match_each_integer_ptype!(indices.ptype(), |I| {
30 indices
31 .as_slice::<I>()
32 .iter()
33 .map(|&idx| {
34 let idx: usize = idx.as_();
35 bools[idx]
36 })
37 .collect::<Vec<_>>()
38 });
39
40 ByteBoolArray::from(bools).into_array()
41 }
42 Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
43 Mask::Values(values) => {
44 let bools = match_each_integer_ptype!(indices.ptype(), |I| {
45 indices
46 .as_slice::<I>()
47 .iter()
48 .map(|&idx| {
49 let idx = idx.as_();
50 values.value(idx).then(|| bools[idx])
51 })
52 .collect::<Vec<Option<_>>>()
53 });
54
55 ByteBoolArray::from(bools).into_array()
56 }
57 };
58
59 Ok(arr)
60 }
61}
62
63register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
64
65#[cfg(test)]
66mod tests {
67 use vortex_array::compute::conformance::mask::test_mask;
68 use vortex_array::compute::{Operator, compare};
69
70 use super::*;
71
72 #[test]
73 fn test_slice() {
74 let original = vec![Some(true), Some(true), None, Some(false), None];
75 let vortex_arr = ByteBoolArray::from(original);
76
77 let sliced_arr = vortex_arr.slice(1, 4).unwrap();
78 let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
79
80 let s = sliced_arr.scalar_at(0).unwrap();
81 assert_eq!(s.as_bool().value(), Some(true));
82
83 let s = sliced_arr.scalar_at(1).unwrap();
84 assert!(!sliced_arr.is_valid(1).unwrap());
85 assert!(s.is_null());
86 assert_eq!(s.as_bool().value(), None);
87
88 let s = sliced_arr.scalar_at(2).unwrap();
89 assert_eq!(s.as_bool().value(), Some(false));
90 }
91
92 #[test]
93 fn test_compare_all_equal() {
94 let lhs = ByteBoolArray::from(vec![true; 5]);
95 let rhs = ByteBoolArray::from(vec![true; 5]);
96
97 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
98
99 for i in 0..arr.len() {
100 let s = arr.scalar_at(i).unwrap();
101 assert!(s.is_valid());
102 assert_eq!(s.as_bool().value(), Some(true));
103 }
104 }
105
106 #[test]
107 fn test_compare_all_different() {
108 let lhs = ByteBoolArray::from(vec![false; 5]);
109 let rhs = ByteBoolArray::from(vec![true; 5]);
110
111 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
112
113 for i in 0..arr.len() {
114 let s = arr.scalar_at(i).unwrap();
115 assert!(s.is_valid());
116 assert_eq!(s.as_bool().value(), Some(false));
117 }
118 }
119
120 #[test]
121 fn test_compare_with_nulls() {
122 let lhs = ByteBoolArray::from(vec![true; 5]);
123 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
124
125 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
126
127 for i in 0..3 {
128 let s = arr.scalar_at(i).unwrap();
129 assert!(s.is_valid());
130 assert_eq!(s.as_bool().value(), Some(true));
131 }
132
133 let s = arr.scalar_at(3).unwrap();
134 assert!(s.is_valid());
135 assert_eq!(s.as_bool().value(), Some(false));
136
137 let s = arr.scalar_at(4).unwrap();
138 assert!(s.is_null());
139 }
140
141 #[test]
142 fn test_mask_byte_bool() {
143 test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
144 test_mask(
145 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
146 );
147 }
148}