vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
5use vortex_array::compute::{MaskKernel, MaskKernelAdapter, TakeKernel, TakeKernelAdapter};
6use vortex_array::vtable::ValidityHelper;
7use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::Mask;
11
12use super::{ByteBoolArray, ByteBoolVTable};
13
14impl MaskKernel for ByteBoolVTable {
15 fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
16 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
17 }
18}
19
20register_kernel!(MaskKernelAdapter(ByteBoolVTable).lift());
21
22impl TakeKernel for ByteBoolVTable {
23 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24 let validity = array.validity_mask()?;
25 let indices = indices.to_primitive()?;
26 let bools = array.as_slice();
27
28 let arr = match validity {
31 Mask::AllTrue(_) => {
32 let bools = match_each_integer_ptype!(indices.ptype(), |I| {
33 indices
34 .as_slice::<I>()
35 .iter()
36 .map(|&idx| {
37 let idx: usize = idx.as_();
38 bools[idx]
39 })
40 .collect::<Vec<_>>()
41 });
42
43 ByteBoolArray::from(bools).into_array()
44 }
45 Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
46 Mask::Values(values) => {
47 let bools = match_each_integer_ptype!(indices.ptype(), |I| {
48 indices
49 .as_slice::<I>()
50 .iter()
51 .map(|&idx| {
52 let idx = idx.as_();
53 values.value(idx).then(|| bools[idx])
54 })
55 .collect::<Vec<Option<_>>>()
56 });
57
58 ByteBoolArray::from(bools).into_array()
59 }
60 };
61
62 Ok(arr)
63 }
64}
65
66register_kernel!(TakeKernelAdapter(ByteBoolVTable).lift());
67
68#[cfg(test)]
69mod tests {
70 use vortex_array::compute::conformance::mask::test_mask;
71 use vortex_array::compute::{Operator, compare};
72
73 use super::*;
74
75 #[test]
76 fn test_slice() {
77 let original = vec![Some(true), Some(true), None, Some(false), None];
78 let vortex_arr = ByteBoolArray::from(original);
79
80 let sliced_arr = vortex_arr.slice(1, 4).unwrap();
81 let sliced_arr = sliced_arr.as_::<ByteBoolVTable>();
82
83 let s = sliced_arr.scalar_at(0).unwrap();
84 assert_eq!(s.as_bool().value(), Some(true));
85
86 let s = sliced_arr.scalar_at(1).unwrap();
87 assert!(!sliced_arr.is_valid(1).unwrap());
88 assert!(s.is_null());
89 assert_eq!(s.as_bool().value(), None);
90
91 let s = sliced_arr.scalar_at(2).unwrap();
92 assert_eq!(s.as_bool().value(), Some(false));
93 }
94
95 #[test]
96 fn test_compare_all_equal() {
97 let lhs = ByteBoolArray::from(vec![true; 5]);
98 let rhs = ByteBoolArray::from(vec![true; 5]);
99
100 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
101
102 for i in 0..arr.len() {
103 let s = arr.scalar_at(i).unwrap();
104 assert!(s.is_valid());
105 assert_eq!(s.as_bool().value(), Some(true));
106 }
107 }
108
109 #[test]
110 fn test_compare_all_different() {
111 let lhs = ByteBoolArray::from(vec![false; 5]);
112 let rhs = ByteBoolArray::from(vec![true; 5]);
113
114 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
115
116 for i in 0..arr.len() {
117 let s = arr.scalar_at(i).unwrap();
118 assert!(s.is_valid());
119 assert_eq!(s.as_bool().value(), Some(false));
120 }
121 }
122
123 #[test]
124 fn test_compare_with_nulls() {
125 let lhs = ByteBoolArray::from(vec![true; 5]);
126 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
127
128 let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
129
130 for i in 0..3 {
131 let s = arr.scalar_at(i).unwrap();
132 assert!(s.is_valid());
133 assert_eq!(s.as_bool().value(), Some(true));
134 }
135
136 let s = arr.scalar_at(3).unwrap();
137 assert!(s.is_valid());
138 assert_eq!(s.as_bool().value(), Some(false));
139
140 let s = arr.scalar_at(4).unwrap();
141 assert!(s.is_null());
142 }
143
144 #[test]
145 fn test_mask_byte_bool() {
146 test_mask(ByteBoolArray::from(vec![true, false, true, true, false]).as_ref());
147 test_mask(
148 ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]).as_ref(),
149 );
150 }
151}