vortex_dict/compute/
mod.rs

1mod binary_numeric;
2mod compare;
3mod is_constant;
4mod is_sorted;
5mod like;
6mod min_max;
7mod optimize;
8
9use vortex_array::compute::{
10    BinaryNumericFn, CompareFn, FilterKernel, FilterKernelAdapter, IsConstantFn, IsSortedFn,
11    KernelRef, LikeFn, MinMaxFn, OptimizeFn, ScalarAtFn, SliceFn, TakeFn, filter, scalar_at, slice,
12    take,
13};
14use vortex_array::vtable::ComputeVTable;
15use vortex_array::{Array, ArrayComputeImpl, ArrayRef};
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18use vortex_scalar::Scalar;
19
20use crate::{DictArray, DictEncoding};
21
22impl ArrayComputeImpl for DictArray {
23    const FILTER: Option<KernelRef> = FilterKernelAdapter(DictEncoding).some();
24}
25
26impl ComputeVTable for DictEncoding {
27    fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
28        Some(self)
29    }
30
31    fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
32        Some(self)
33    }
34
35    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
36        Some(self)
37    }
38
39    fn is_sorted_fn(&self) -> Option<&dyn IsSortedFn<&dyn Array>> {
40        Some(self)
41    }
42
43    fn like_fn(&self) -> Option<&dyn LikeFn<&dyn Array>> {
44        Some(self)
45    }
46
47    fn optimize_fn(&self) -> Option<&dyn OptimizeFn<&dyn Array>> {
48        Some(self)
49    }
50
51    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
52        Some(self)
53    }
54
55    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
56        Some(self)
57    }
58
59    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
60        Some(self)
61    }
62
63    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
64        Some(self)
65    }
66}
67
68impl ScalarAtFn<&DictArray> for DictEncoding {
69    fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
70        let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;
71        scalar_at(array.values(), dict_index)
72    }
73}
74
75impl TakeFn<&DictArray> for DictEncoding {
76    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
77        let codes = take(array.codes(), indices)?;
78        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
79    }
80}
81
82impl FilterKernel for DictEncoding {
83    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
84        let codes = filter(array.codes(), mask)?;
85        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
86    }
87}
88
89impl SliceFn<&DictArray> for DictEncoding {
90    // TODO(robert): Add function to trim the dictionary
91    fn slice(&self, array: &DictArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
92        DictArray::try_new(slice(array.codes(), start, stop)?, array.values().clone())
93            .map(|a| a.into_array())
94    }
95}
96
97#[cfg(test)]
98mod test {
99    use vortex_array::accessor::ArrayAccessor;
100    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
101    use vortex_array::compute::conformance::mask::test_mask;
102    use vortex_array::compute::{Operator, compare, scalar_at, slice};
103    use vortex_array::{Array, ArrayRef, ToCanonical};
104    use vortex_dtype::{DType, Nullability};
105    use vortex_scalar::Scalar;
106
107    use crate::builders::dict_encode;
108
109    #[test]
110    fn canonicalise_nullable_primitive() {
111        let values: Vec<Option<i32>> = (0..65)
112            .map(|i| match i % 3 {
113                0 => Some(42),
114                1 => Some(-9),
115                2 => None,
116                _ => unreachable!(),
117            })
118            .collect();
119
120        let dict = dict_encode(&PrimitiveArray::from_option_iter(values.clone())).unwrap();
121        let actual = dict.to_primitive().unwrap();
122
123        let expected: Vec<i32> = (0..65)
124            .map(|i| match i % 3 {
125                // Compressor puts 0 as a code for invalid values which we end up using in take
126                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
127                0 | 2 => 42,
128                1 => -9,
129                _ => unreachable!(),
130            })
131            .collect();
132
133        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
134
135        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
136        assert_eq!(
137            actual.validity_mask().unwrap().true_count(),
138            expected_valid_count
139        );
140    }
141
142    #[test]
143    fn canonicalise_non_nullable_primitive_32_unique_values() {
144        let unique_values: Vec<i32> = (0..32).collect();
145        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
146
147        let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
148        let actual = dict.to_primitive().unwrap();
149
150        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
151    }
152
153    #[test]
154    fn canonicalise_non_nullable_primitive_100_unique_values() {
155        let unique_values: Vec<i32> = (0..100).collect();
156        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
157
158        let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
159        let actual = dict.to_primitive().unwrap();
160
161        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
162    }
163
164    #[test]
165    fn canonicalise_nullable_varbin() {
166        let reference = VarBinViewArray::from_iter(
167            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
168            DType::Utf8(Nullability::Nullable),
169        );
170        assert_eq!(reference.len(), 6);
171        let dict = dict_encode(&reference).unwrap();
172        let flattened_dict = dict.to_varbinview().unwrap();
173        assert_eq!(
174            flattened_dict
175                .with_iterator(|iter| iter
176                    .map(|slice| slice.map(|s| s.to_vec()))
177                    .collect::<Vec<_>>())
178                .unwrap(),
179            reference
180                .with_iterator(|iter| iter
181                    .map(|slice| slice.map(|s| s.to_vec()))
182                    .collect::<Vec<_>>())
183                .unwrap(),
184        );
185    }
186
187    fn sliced_dict_array() -> ArrayRef {
188        let reference = PrimitiveArray::from_option_iter([
189            Some(42),
190            Some(-9),
191            None,
192            Some(42),
193            Some(1),
194            Some(5),
195        ]);
196        let dict = dict_encode(&reference).unwrap();
197        slice(&dict, 1, 4).unwrap()
198    }
199
200    #[test]
201    fn compare_sliced_dict() {
202        let sliced = sliced_dict_array();
203        let compared = compare(&sliced, &ConstantArray::new(42, 3), Operator::Eq).unwrap();
204
205        assert_eq!(
206            scalar_at(&compared, 0).unwrap(),
207            Scalar::bool(false, Nullability::Nullable)
208        );
209        assert_eq!(
210            scalar_at(&compared, 1).unwrap(),
211            Scalar::null(DType::Bool(Nullability::Nullable))
212        );
213        assert_eq!(
214            scalar_at(&compared, 2).unwrap(),
215            Scalar::bool(true, Nullability::Nullable)
216        );
217    }
218
219    #[test]
220    fn test_mask_dict_array() {
221        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
222        test_mask(&array);
223
224        let array = dict_encode(
225            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
226                .into_array(),
227        )
228        .unwrap();
229        test_mask(&array);
230
231        let array = dict_encode(
232            &VarBinArray::from_iter(
233                [
234                    Some("hello"),
235                    None,
236                    Some("hello"),
237                    Some("good"),
238                    Some("good"),
239                ],
240                DType::Utf8(Nullability::Nullable),
241            )
242            .into_array(),
243        )
244        .unwrap();
245        test_mask(&array);
246    }
247}