vortex_dict/compute/
mod.rs

1mod binary_numeric;
2mod compare;
3mod is_constant;
4mod like;
5
6use vortex_array::compute::{
7    BinaryNumericFn, CompareFn, FilterFn, IsConstantFn, LikeFn, ScalarAtFn, SliceFn, TakeFn,
8    filter, scalar_at, slice, take,
9};
10use vortex_array::vtable::ComputeVTable;
11use vortex_array::{Array, ArrayRef};
12use vortex_error::VortexResult;
13use vortex_mask::Mask;
14use vortex_scalar::Scalar;
15
16use crate::{DictArray, DictEncoding};
17
18impl ComputeVTable for DictEncoding {
19    fn binary_numeric_fn(&self) -> Option<&dyn BinaryNumericFn<&dyn Array>> {
20        Some(self)
21    }
22
23    fn compare_fn(&self) -> Option<&dyn CompareFn<&dyn Array>> {
24        Some(self)
25    }
26
27    fn is_constant_fn(&self) -> Option<&dyn IsConstantFn<&dyn Array>> {
28        Some(self)
29    }
30
31    fn filter_fn(&self) -> Option<&dyn FilterFn<&dyn Array>> {
32        Some(self)
33    }
34
35    fn like_fn(&self) -> Option<&dyn LikeFn<&dyn Array>> {
36        Some(self)
37    }
38
39    fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn<&dyn Array>> {
40        Some(self)
41    }
42
43    fn slice_fn(&self) -> Option<&dyn SliceFn<&dyn Array>> {
44        Some(self)
45    }
46
47    fn take_fn(&self) -> Option<&dyn TakeFn<&dyn Array>> {
48        Some(self)
49    }
50}
51
52impl ScalarAtFn<&DictArray> for DictEncoding {
53    fn scalar_at(&self, array: &DictArray, index: usize) -> VortexResult<Scalar> {
54        let dict_index: usize = scalar_at(array.codes(), index)?.as_ref().try_into()?;
55        scalar_at(array.values(), dict_index)
56    }
57}
58
59impl TakeFn<&DictArray> for DictEncoding {
60    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
61        let codes = take(array.codes(), indices)?;
62        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
63    }
64}
65
66impl FilterFn<&DictArray> for DictEncoding {
67    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
68        let codes = filter(array.codes(), mask)?;
69        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
70    }
71}
72
73impl SliceFn<&DictArray> for DictEncoding {
74    // TODO(robert): Add function to trim the dictionary
75    fn slice(&self, array: &DictArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
76        DictArray::try_new(slice(array.codes(), start, stop)?, array.values().clone())
77            .map(|a| a.into_array())
78    }
79}
80
81#[cfg(test)]
82mod test {
83    use vortex_array::accessor::ArrayAccessor;
84    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
85    use vortex_array::compute::test_harness::test_mask;
86    use vortex_array::compute::{Operator, compare, scalar_at, slice};
87    use vortex_array::{Array, ArrayRef, ToCanonical};
88    use vortex_dtype::{DType, Nullability};
89    use vortex_scalar::Scalar;
90
91    use crate::builders::dict_encode;
92
93    #[test]
94    fn canonicalise_nullable_primitive() {
95        let values: Vec<Option<i32>> = (0..65)
96            .map(|i| match i % 3 {
97                0 => Some(42),
98                1 => Some(-9),
99                2 => None,
100                _ => unreachable!(),
101            })
102            .collect();
103
104        let dict = dict_encode(&PrimitiveArray::from_option_iter(values.clone())).unwrap();
105        let actual = dict.to_primitive().unwrap();
106
107        let expected: Vec<i32> = (0..65)
108            .map(|i| match i % 3 {
109                // Compressor puts 0 as a code for invalid values which we end up using in take
110                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
111                0 | 2 => 42,
112                1 => -9,
113                _ => unreachable!(),
114            })
115            .collect();
116
117        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
118
119        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
120        assert_eq!(
121            actual.validity_mask().unwrap().true_count(),
122            expected_valid_count
123        );
124    }
125
126    #[test]
127    fn canonicalise_non_nullable_primitive_32_unique_values() {
128        let unique_values: Vec<i32> = (0..32).collect();
129        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
130
131        let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
132        let actual = dict.to_primitive().unwrap();
133
134        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
135    }
136
137    #[test]
138    fn canonicalise_non_nullable_primitive_100_unique_values() {
139        let unique_values: Vec<i32> = (0..100).collect();
140        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
141
142        let dict = dict_encode(&PrimitiveArray::from_iter(expected.iter().copied())).unwrap();
143        let actual = dict.to_primitive().unwrap();
144
145        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
146    }
147
148    #[test]
149    fn canonicalise_nullable_varbin() {
150        let reference = VarBinViewArray::from_iter(
151            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
152            DType::Utf8(Nullability::Nullable),
153        );
154        assert_eq!(reference.len(), 6);
155        let dict = dict_encode(&reference).unwrap();
156        let flattened_dict = dict.to_varbinview().unwrap();
157        assert_eq!(
158            flattened_dict
159                .with_iterator(|iter| iter
160                    .map(|slice| slice.map(|s| s.to_vec()))
161                    .collect::<Vec<_>>())
162                .unwrap(),
163            reference
164                .with_iterator(|iter| iter
165                    .map(|slice| slice.map(|s| s.to_vec()))
166                    .collect::<Vec<_>>())
167                .unwrap(),
168        );
169    }
170
171    fn sliced_dict_array() -> ArrayRef {
172        let reference = PrimitiveArray::from_option_iter([
173            Some(42),
174            Some(-9),
175            None,
176            Some(42),
177            Some(1),
178            Some(5),
179        ]);
180        let dict = dict_encode(&reference).unwrap();
181        slice(&dict, 1, 4).unwrap()
182    }
183
184    #[test]
185    fn compare_sliced_dict() {
186        let sliced = sliced_dict_array();
187        let compared = compare(&sliced, &ConstantArray::new(42, 3), Operator::Eq).unwrap();
188
189        assert_eq!(
190            scalar_at(&compared, 0).unwrap(),
191            Scalar::bool(false, Nullability::Nullable)
192        );
193        assert_eq!(
194            scalar_at(&compared, 1).unwrap(),
195            Scalar::null(DType::Bool(Nullability::Nullable))
196        );
197        assert_eq!(
198            scalar_at(&compared, 2).unwrap(),
199            Scalar::bool(true, Nullability::Nullable)
200        );
201    }
202
203    #[test]
204    fn test_mask_dict_array() {
205        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
206        test_mask(&array);
207
208        let array = dict_encode(
209            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
210                .into_array(),
211        )
212        .unwrap();
213        test_mask(&array);
214
215        let array = dict_encode(
216            &VarBinArray::from_iter(
217                [
218                    Some("hello"),
219                    None,
220                    Some("hello"),
221                    Some("good"),
222                    Some("good"),
223                ],
224                DType::Utf8(Nullability::Nullable),
225            )
226            .into_array(),
227        )
228        .unwrap();
229        test_mask(&array);
230    }
231}