1mod binary_numeric;
2mod compare;
3mod is_constant;
4mod like;
5
6use vortex_array::compute::{
7 BinaryNumericFn, CompareFn, FilterFn, IsConstantFn, LikeFn, ScalarAtFn, SliceFn, TakeFn,
8 filter, scalar_at, slice, take,
9};
10use vortex_array::vtable::ComputeVTable;
11use vortex_array::{Array, ArrayRef};
12use vortex_error::VortexResult;
13use vortex_mask::Mask;
14use vortex_scalar::Scalar;
15
16use crate::{DictArray, DictEncoding};
17
18impl ComputeVTable for DictEncoding {
19 fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
20 Some(self)
21 }
22
23 fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
24 Some(self)
25 }
26
27 fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
28 Some(self)
29 }
30
31 fn filter_fn(&self) -> Option<&dyn FilterFn<&dyn Array>> {
32 Some(self)
33 }
34
35 fn like_fn(&self) -> Option<&dyn LikeFn<&dyn Array>> {
36 Some(self)
37 }
38
39 fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
40 Some(self)
41 }
42
43 fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
44 Some(self)
45 }
46
47 fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
48 Some(self)
49 }
50}
51
52impl ScalarAtFn<&DictArray> for DictEncoding {
53 fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
54 let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;
55 scalar_at(array.values(), dict_index)
56 }
57}
58
59impl TakeFn<&DictArray> for DictEncoding {
60 fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
61 let codes = take(array.codes(), indices)?;
62 DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
63 }
64}
65
66impl FilterFn<&DictArray> for DictEncoding {
67 fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
68 let codes = filter(array.codes(), mask)?;
69 DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
70 }
71}
72
73impl SliceFn<&DictArray> for DictEncoding {
74 fn slice(&self, array: &DictArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
76 DictArray::try_new(slice(array.codes(), start, stop)?, array.values().clone())
77 .map(|a| a.into_array())
78 }
79}
80
81#[cfg(test)]
82mod test {
83 use vortex_array::accessor::ArrayAccessor;
84 use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
85 use vortex_array::compute::test_harness::test_mask;
86 use vortex_array::compute::{Operator, compare, scalar_at, slice};
87 use vortex_array::{Array, ArrayRef, ToCanonical};
88 use vortex_dtype::{DType, Nullability};
89 use vortex_scalar::Scalar;
90
91 use crate::builders::dict_encode;
92
93 #[test]
94 fn canonicalise_nullable_primitive() {
95 let values: Vec<Option<i32>> = (0..65)
96 .map(|i| match i % 3 {
97 0 => Some(42),
98 1 => Some(-9),
99 2 => None,
100 _ => unreachable!(),
101 })
102 .collect();
103
104 let dict = dict_encode(&PrimitiveArray::from_option_iter(values.clone())).unwrap();
105 let actual = dict.to_primitive().unwrap();
106
107 let expected: Vec<i32> = (0..65)
108 .map(|i| match i % 3 {
109 0 | 2 => 42,
112 1 => -9,
113 _ => unreachable!(),
114 })
115 .collect();
116
117 assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
118
119 let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
120 assert_eq!(
121 actual.validity_mask().unwrap().true_count(),
122 expected_valid_count
123 );
124 }
125
126 #[test]
127 fn canonicalise_non_nullable_primitive_32_unique_values() {
128 let unique_values: Vec<i32> = (0..32).collect();
129 let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
130
131 let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
132 let actual = dict.to_primitive().unwrap();
133
134 assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
135 }
136
137 #[test]
138 fn canonicalise_non_nullable_primitive_100_unique_values() {
139 let unique_values: Vec<i32> = (0..100).collect();
140 let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
141
142 let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
143 let actual = dict.to_primitive().unwrap();
144
145 assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
146 }
147
148 #[test]
149 fn canonicalise_nullable_varbin() {
150 let reference = VarBinViewArray::from_iter(
151 vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
152 DType::Utf8(Nullability::Nullable),
153 );
154 assert_eq!(reference.len(), 6);
155 let dict = dict_encode(&reference).unwrap();
156 let flattened_dict = dict.to_varbinview().unwrap();
157 assert_eq!(
158 flattened_dict
159 .with_iterator(|iter| iter
160 .map(|slice| slice.map(|s| s.to_vec()))
161 .collect::<Vec<_>>())
162 .unwrap(),
163 reference
164 .with_iterator(|iter| iter
165 .map(|slice| slice.map(|s| s.to_vec()))
166 .collect::<Vec<_>>())
167 .unwrap(),
168 );
169 }
170
171 fn sliced_dict_array() -> ArrayRef {
172 let reference = PrimitiveArray::from_option_iter([
173 Some(42),
174 Some(-9),
175 None,
176 Some(42),
177 Some(1),
178 Some(5),
179 ]);
180 let dict = dict_encode(&reference).unwrap();
181 slice(&dict, 1, 4).unwrap()
182 }
183
184 #[test]
185 fn compare_sliced_dict() {
186 let sliced = sliced_dict_array();
187 let compared = compare(&sliced, &ConstantArray::new(42, 3), Operator::Eq).unwrap();
188
189 assert_eq!(
190 scalar_at(&compared, 0).unwrap(),
191 Scalar::bool(false, Nullability::Nullable)
192 );
193 assert_eq!(
194 scalar_at(&compared, 1).unwrap(),
195 Scalar::null(DType::Bool(Nullability::Nullable))
196 );
197 assert_eq!(
198 scalar_at(&compared, 2).unwrap(),
199 Scalar::bool(true, Nullability::Nullable)
200 );
201 }
202
203 #[test]
204 fn test_mask_dict_array() {
205 let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
206 test_mask(&array);
207
208 let array = dict_encode(
209 &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
210 .into_array(),
211 )
212 .unwrap();
213 test_mask(&array);
214
215 let array = dict_encode(
216 &VarBinArray::from_iter(
217 [
218 Some("hello"),
219 None,
220 Some("hello"),
221 Some("good"),
222 Some("good"),
223 ],
224 DType::Utf8(Nullability::Nullable),
225 )
226 .into_array(),
227 )
228 .unwrap();
229 test_mask(&array);
230 }
231}