1mod binary_numeric;
2mod compare;
3mod is_constant;
4mod is_sorted;
5mod like;
6mod min_max;
7mod optimize;
8
9use vortex_array::compute::{
10 BinaryNumericFn, CompareFn, FilterKernel, FilterKernelAdapter, IsConstantFn, IsSortedFn,
11 KernelRef, LikeFn, MinMaxFn, OptimizeFn, ScalarAtFn, SliceFn, TakeFn, filter, scalar_at, slice,
12 take,
13};
14use vortex_array::vtable::ComputeVTable;
15use vortex_array::{Array, ArrayComputeImpl, ArrayRef};
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18use vortex_scalar::Scalar;
19
20use crate::{DictArray, DictEncoding};
21
22impl ArrayComputeImpl for DictArray {
23 const FILTER: Option<KernelRef> = FilterKernelAdapter(DictEncoding).some();
24}
25
26impl ComputeVTable for DictEncoding {
27 fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
28 Some(self)
29 }
30
31 fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
32 Some(self)
33 }
34
35 fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
36 Some(self)
37 }
38
39 fn is_sorted_fn(&self) -> Option<&dyn IsSortedFn<&dyn Array>> {
40 Some(self)
41 }
42
43 fn like_fn(&self) -> Option<&dyn LikeFn<&dyn Array>> {
44 Some(self)
45 }
46
47 fn optimize_fn(&self) -> Option<&dyn OptimizeFn<&dyn Array>> {
48 Some(self)
49 }
50
51 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
52 Some(self)
53 }
54
55 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
56 Some(self)
57 }
58
59 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
60 Some(self)
61 }
62
63 fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
64 Some(self)
65 }
66}
67
68impl ScalarAtFn<&DictArray> for DictEncoding {
69 fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
70 let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;
71 scalar_at(array.values(), dict_index)
72 }
73}
74
75impl TakeFn<&DictArray> for DictEncoding {
76 fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
77 let codes = take(array.codes(), indices)?;
78 DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
79 }
80}
81
82impl FilterKernel for DictEncoding {
83 fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
84 let codes = filter(array.codes(), mask)?;
85 DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
86 }
87}
88
89impl SliceFn<&DictArray> for DictEncoding {
90 fn slice(&self, array: &DictArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
92 DictArray::try_new(slice(array.codes(), start, stop)?, array.values().clone())
93 .map(|a| a.into_array())
94 }
95}
96
97#[cfg(test)]
98mod test {
99 use vortex_array::accessor::ArrayAccessor;
100 use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
101 use vortex_array::compute::conformance::mask::test_mask;
102 use vortex_array::compute::{Operator, compare, scalar_at, slice};
103 use vortex_array::{Array, ArrayRef, ToCanonical};
104 use vortex_dtype::{DType, Nullability};
105 use vortex_scalar::Scalar;
106
107 use crate::builders::dict_encode;
108
109 #[test]
110 fn canonicalise_nullable_primitive() {
111 let values: Vec<Option<i32>> = (0..65)
112 .map(|i| match i % 3 {
113 0 => Some(42),
114 1 => Some(-9),
115 2 => None,
116 _ => unreachable!(),
117 })
118 .collect();
119
120 let dict = dict_encode(&PrimitiveArray::from_option_iter(values.clone())).unwrap();
121 let actual = dict.to_primitive().unwrap();
122
123 let expected: Vec<i32> = (0..65)
124 .map(|i| match i % 3 {
125 0 | 2 => 42,
128 1 => -9,
129 _ => unreachable!(),
130 })
131 .collect();
132
133 assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
134
135 let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
136 assert_eq!(
137 actual.validity_mask().unwrap().true_count(),
138 expected_valid_count
139 );
140 }
141
142 #[test]
143 fn canonicalise_non_nullable_primitive_32_unique_values() {
144 let unique_values: Vec<i32> = (0..32).collect();
145 let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
146
147 let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
148 let actual = dict.to_primitive().unwrap();
149
150 assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
151 }
152
153 #[test]
154 fn canonicalise_non_nullable_primitive_100_unique_values() {
155 let unique_values: Vec<i32> = (0..100).collect();
156 let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
157
158 let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
159 let actual = dict.to_primitive().unwrap();
160
161 assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
162 }
163
164 #[test]
165 fn canonicalise_nullable_varbin() {
166 let reference = VarBinViewArray::from_iter(
167 vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
168 DType::Utf8(Nullability::Nullable),
169 );
170 assert_eq!(reference.len(), 6);
171 let dict = dict_encode(&reference).unwrap();
172 let flattened_dict = dict.to_varbinview().unwrap();
173 assert_eq!(
174 flattened_dict
175 .with_iterator(|iter| iter
176 .map(|slice| slice.map(|s| s.to_vec()))
177 .collect::<Vec<_>>())
178 .unwrap(),
179 reference
180 .with_iterator(|iter| iter
181 .map(|slice| slice.map(|s| s.to_vec()))
182 .collect::<Vec<_>>())
183 .unwrap(),
184 );
185 }
186
187 fn sliced_dict_array() -> ArrayRef {
188 let reference = PrimitiveArray::from_option_iter([
189 Some(42),
190 Some(-9),
191 None,
192 Some(42),
193 Some(1),
194 Some(5),
195 ]);
196 let dict = dict_encode(&reference).unwrap();
197 slice(&dict, 1, 4).unwrap()
198 }
199
200 #[test]
201 fn compare_sliced_dict() {
202 let sliced = sliced_dict_array();
203 let compared = compare(&sliced, &ConstantArray::new(42, 3), Operator::Eq).unwrap();
204
205 assert_eq!(
206 scalar_at(&compared, 0).unwrap(),
207 Scalar::bool(false, Nullability::Nullable)
208 );
209 assert_eq!(
210 scalar_at(&compared, 1).unwrap(),
211 Scalar::null(DType::Bool(Nullability::Nullable))
212 );
213 assert_eq!(
214 scalar_at(&compared, 2).unwrap(),
215 Scalar::bool(true, Nullability::Nullable)
216 );
217 }
218
219 #[test]
220 fn test_mask_dict_array() {
221 let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
222 test_mask(&array);
223
224 let array = dict_encode(
225 &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
226 .into_array(),
227 )
228 .unwrap();
229 test_mask(&array);
230
231 let array = dict_encode(
232 &VarBinArray::from_iter(
233 [
234 Some("hello"),
235 None,
236 Some("hello"),
237 Some("good"),
238 Some("good"),
239 ],
240 DType::Utf8(Nullability::Nullable),
241 )
242 .into_array(),
243 )
244 .unwrap();
245 test_mask(&array);
246 }
247}