vortex_dict/compute/
mod.rs

1mod binary_numeric;
2mod compare;
3mod fill_null;
4mod is_constant;
5mod is_sorted;
6mod like;
7mod min_max;
8mod optimize;
9
10use vortex_array::compute::{
11    FillNullFn, FilterKernel, FilterKernelAdapter, IsConstantFn, IsSortedFn, LikeFn, MinMaxFn,
12    OptimizeFn, ScalarAtFn, SliceFn, TakeFn, filter, scalar_at, slice, take,
13};
14use vortex_array::vtable::ComputeVTable;
15use vortex_array::{Array, ArrayRef, register_kernel};
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18use vortex_scalar::Scalar;
19
20use crate::{DictArray, DictEncoding};
21
22impl ComputeVTable for DictEncoding {
23    fn fill_null_fn(&self) -> Option<&dyn FillNullFn<&dyn Array>> {
24        Some(self)
25    }
26
27    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
28        Some(self)
29    }
30
31    fn is_sorted_fn(&self) -> Option<&dyn IsSortedFn<&dyn Array>> {
32        Some(self)
33    }
34
35    fn like_fn(&self) -> Option<&dyn LikeFn<&dyn Array>> {
36        Some(self)
37    }
38
39    fn optimize_fn(&self) -> Option<&dyn OptimizeFn<&dyn Array>> {
40        Some(self)
41    }
42
43    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
44        Some(self)
45    }
46
47    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
48        Some(self)
49    }
50
51    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
52        Some(self)
53    }
54
55    fn min_max_fn(&self) -> Option<&dyn MinMaxFn<&dyn Array>> {
56        Some(self)
57    }
58}
59
60impl ScalarAtFn<&DictArray> for DictEncoding {
61    fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
62        let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;
63        scalar_at(array.values(), dict_index)
64    }
65}
66
67impl TakeFn<&DictArray> for DictEncoding {
68    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
69        let codes = take(array.codes(), indices)?;
70        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
71    }
72}
73
74impl FilterKernel for DictEncoding {
75    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
76        let codes = filter(array.codes(), mask)?;
77        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
78    }
79}
80
81register_kernel!(FilterKernelAdapter(DictEncoding).lift());
82
83impl SliceFn<&DictArray> for DictEncoding {
84    // TODO(robert): Add function to trim the dictionary
85    fn slice(&self, array: &DictArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
86        DictArray::try_new(slice(array.codes(), start, stop)?, array.values().clone())
87            .map(|a| a.into_array())
88    }
89}
90
91#[cfg(test)]
92mod test {
93    use vortex_array::accessor::ArrayAccessor;
94    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
95    use vortex_array::compute::conformance::mask::test_mask;
96    use vortex_array::compute::{Operator, compare, scalar_at, slice};
97    use vortex_array::{Array, ArrayRef, ToCanonical};
98    use vortex_dtype::{DType, Nullability};
99    use vortex_scalar::Scalar;
100
101    use crate::builders::dict_encode;
102
103    #[test]
104    fn canonicalise_nullable_primitive() {
105        let values: Vec<Option<i32>> = (0..65)
106            .map(|i| match i % 3 {
107                0 => Some(42),
108                1 => Some(-9),
109                2 => None,
110                _ => unreachable!(),
111            })
112            .collect();
113
114        let dict = dict_encode(&PrimitiveArray::from_option_iter(values.clone())).unwrap();
115        let actual = dict.to_primitive().unwrap();
116
117        let expected: Vec<i32> = (0..65)
118            .map(|i| match i % 3 {
119                // Compressor puts 0 as a code for invalid values which we end up using in take
120                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
121                0 | 2 => 42,
122                1 => -9,
123                _ => unreachable!(),
124            })
125            .collect();
126
127        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
128
129        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
130        assert_eq!(
131            actual.validity_mask().unwrap().true_count(),
132            expected_valid_count
133        );
134    }
135
136    #[test]
137    fn canonicalise_non_nullable_primitive_32_unique_values() {
138        let unique_values: Vec<i32> = (0..32).collect();
139        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
140
141        let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
142        let actual = dict.to_primitive().unwrap();
143
144        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
145    }
146
147    #[test]
148    fn canonicalise_non_nullable_primitive_100_unique_values() {
149        let unique_values: Vec<i32> = (0..100).collect();
150        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
151
152        let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
153        let actual = dict.to_primitive().unwrap();
154
155        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
156    }
157
158    #[test]
159    fn canonicalise_nullable_varbin() {
160        let reference = VarBinViewArray::from_iter(
161            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
162            DType::Utf8(Nullability::Nullable),
163        );
164        assert_eq!(reference.len(), 6);
165        let dict = dict_encode(&reference).unwrap();
166        let flattened_dict = dict.to_varbinview().unwrap();
167        assert_eq!(
168            flattened_dict
169                .with_iterator(|iter| iter
170                    .map(|slice| slice.map(|s| s.to_vec()))
171                    .collect::<Vec<_>>())
172                .unwrap(),
173            reference
174                .with_iterator(|iter| iter
175                    .map(|slice| slice.map(|s| s.to_vec()))
176                    .collect::<Vec<_>>())
177                .unwrap(),
178        );
179    }
180
181    fn sliced_dict_array() -> ArrayRef {
182        let reference = PrimitiveArray::from_option_iter([
183            Some(42),
184            Some(-9),
185            None,
186            Some(42),
187            Some(1),
188            Some(5),
189        ]);
190        let dict = dict_encode(&reference).unwrap();
191        slice(&dict, 1, 4).unwrap()
192    }
193
194    #[test]
195    fn compare_sliced_dict() {
196        let sliced = sliced_dict_array();
197        let compared = compare(&sliced, &ConstantArray::new(42, 3), Operator::Eq).unwrap();
198
199        assert_eq!(
200            scalar_at(&compared, 0).unwrap(),
201            Scalar::bool(false, Nullability::Nullable)
202        );
203        assert_eq!(
204            scalar_at(&compared, 1).unwrap(),
205            Scalar::null(DType::Bool(Nullability::Nullable))
206        );
207        assert_eq!(
208            scalar_at(&compared, 2).unwrap(),
209            Scalar::bool(true, Nullability::Nullable)
210        );
211    }
212
213    #[test]
214    fn test_mask_dict_array() {
215        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
216        test_mask(&array);
217
218        let array = dict_encode(
219            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
220                .into_array(),
221        )
222        .unwrap();
223        test_mask(&array);
224
225        let array = dict_encode(
226            &VarBinArray::from_iter(
227                [
228                    Some("hello"),
229                    None,
230                    Some("hello"),
231                    Some("good"),
232                    Some("good"),
233                ],
234                DType::Utf8(Nullability::Nullable),
235            )
236            .into_array(),
237        )
238        .unwrap();
239        test_mask(&array);
240    }
241}