vortex_bytebool/
compute.rs1use num_traits::AsPrimitive;
2use vortex_array::compute::{MaskKernel, MaskKernelAdapter, ScalarAtFn, SliceFn, TakeFn};
3use vortex_array::variants::PrimitiveArrayTrait;
4use vortex_array::vtable::ComputeVTable;
5use vortex_array::{Array, ArrayRef, ToCanonical, register_kernel};
6use vortex_dtype::match_each_integer_ptype;
7use vortex_error::VortexResult;
8use vortex_mask::Mask;
9use vortex_scalar::Scalar;
10
11use super::{ByteBoolArray, ByteBoolEncoding};
12
13impl ComputeVTable for ByteBoolEncoding {
14 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
15 Some(self)
16 }
17
18 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
19 Some(self)
20 }
21
22 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
23 Some(self)
24 }
25}
26
27impl MaskKernel for ByteBoolEncoding {
28 fn mask(&self, array: &ByteBoolArray, mask: &Mask) -> VortexResult<ArrayRef> {
29 Ok(ByteBoolArray::new(array.buffer().clone(), array.validity().mask(mask)?).into_array())
30 }
31}
32
33register_kernel!(MaskKernelAdapter(ByteBoolEncoding).lift());
34
35impl ScalarAtFn<&ByteBoolArray> for ByteBoolEncoding {
36 fn scalar_at(&self, array: &ByteBoolArray, index: usize) -> VortexResult<Scalar> {
37 Ok(Scalar::bool(
38 array.buffer()[index] == 1,
39 array.dtype().nullability(),
40 ))
41 }
42}
43
44impl SliceFn<&ByteBoolArray> for ByteBoolEncoding {
45 fn slice(&self, array: &ByteBoolArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
46 Ok(ByteBoolArray::new(
47 array.buffer().slice(start..stop),
48 array.validity().slice(start, stop)?,
49 )
50 .into_array())
51 }
52}
53
54impl TakeFn<&ByteBoolArray> for ByteBoolEncoding {
55 fn take(&self, array: &ByteBoolArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
56 let validity = array.validity_mask()?;
57 let indices = indices.to_primitive()?;
58 let bools = array.as_slice();
59
60 let arr = match validity {
63 Mask::AllTrue(_) => {
64 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
65 indices.as_slice::<$I>()
66 .iter()
67 .map(|&idx| {
68 let idx: usize = idx.as_();
69 bools[idx]
70 })
71 .collect::<Vec<_>>()
72 });
73
74 ByteBoolArray::from(bools).into_array()
75 }
76 Mask::AllFalse(_) => ByteBoolArray::from(vec![None; indices.len()]).into_array(),
77 Mask::Values(values) => {
78 let bools = match_each_integer_ptype!(indices.ptype(), |$I| {
79 indices.as_slice::<$I>()
80 .iter()
81 .map(|&idx| {
82 let idx = idx.as_();
83 if values.value(idx) {
84 Some(bools[idx])
85 } else {
86 None
87 }
88 })
89 .collect::<Vec<Option<_>>>()
90 });
91
92 ByteBoolArray::from(bools).into_array()
93 }
94 };
95
96 Ok(arr)
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use vortex_array::compute::conformance::mask::test_mask;
103 use vortex_array::compute::{Operator, compare, scalar_at, slice};
104
105 use super::*;
106
107 #[test]
108 fn test_slice() {
109 let original = vec![Some(true), Some(true), None, Some(false), None];
110 let vortex_arr = ByteBoolArray::from(original);
111
112 let sliced_arr = slice(&vortex_arr, 1, 4).unwrap();
113 let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();
114
115 let s = scalar_at(&sliced_arr, 0).unwrap();
116 assert_eq!(s.as_bool().value(), Some(true));
117
118 let s = scalar_at(&sliced_arr, 1).unwrap();
119 assert!(!sliced_arr.is_valid(1).unwrap());
120 assert!(s.is_null());
121 assert_eq!(s.as_bool().value(), None);
122
123 let s = scalar_at(&sliced_arr, 2).unwrap();
124 assert_eq!(s.as_bool().value(), Some(false));
125 }
126
127 #[test]
128 fn test_compare_all_equal() {
129 let lhs = ByteBoolArray::from(vec![true; 5]);
130 let rhs = ByteBoolArray::from(vec![true; 5]);
131
132 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
133
134 for i in 0..arr.len() {
135 let s = scalar_at(&arr, i).unwrap();
136 assert!(s.is_valid());
137 assert_eq!(s.as_bool().value(), Some(true));
138 }
139 }
140
141 #[test]
142 fn test_compare_all_different() {
143 let lhs = ByteBoolArray::from(vec![false; 5]);
144 let rhs = ByteBoolArray::from(vec![true; 5]);
145
146 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
147
148 for i in 0..arr.len() {
149 let s = scalar_at(&arr, i).unwrap();
150 assert!(s.is_valid());
151 assert_eq!(s.as_bool().value(), Some(false));
152 }
153 }
154
155 #[test]
156 fn test_compare_with_nulls() {
157 let lhs = ByteBoolArray::from(vec![true; 5]);
158 let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
159
160 let arr = compare(&lhs, &rhs, Operator::Eq).unwrap();
161
162 for i in 0..3 {
163 let s = scalar_at(&arr, i).unwrap();
164 assert!(s.is_valid());
165 assert_eq!(s.as_bool().value(), Some(true));
166 }
167
168 let s = scalar_at(&arr, 3).unwrap();
169 assert!(s.is_valid());
170 assert_eq!(s.as_bool().value(), Some(false));
171
172 let s = scalar_at(&arr, 4).unwrap();
173 assert!(s.is_null());
174 }
175
176 #[test]
177 fn test_mask_byte_bool() {
178 test_mask(&ByteBoolArray::from(vec![true, false, true, true, false]));
179 test_mask(&ByteBoolArray::from(vec![
180 Some(true),
181 Some(true),
182 None,
183 Some(false),
184 None,
185 ]));
186 }
187}