vortex_dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod binary_numeric;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod min_max;
11
12use vortex_array::compute::{
13    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
14};
15use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18
19use crate::{DictArray, DictVTable};
20
21impl TakeKernel for DictVTable {
22    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23        // TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
24        let codes = take(array.codes(), indices)?;
25        let values_dtype = array
26            .values()
27            .dtype()
28            .union_nullability(codes.dtype().nullability());
29        DictArray::try_new(codes, cast(array.values(), &values_dtype)?).map(|a| a.into_array())
30    }
31}
32
33register_kernel!(TakeKernelAdapter(DictVTable).lift());
34
35impl FilterKernel for DictVTable {
36    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
37        let codes = filter(array.codes(), mask)?;
38        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
39    }
40}
41
42register_kernel!(FilterKernelAdapter(DictVTable).lift());
43
44#[cfg(test)]
45mod test {
46    use vortex_array::accessor::ArrayAccessor;
47    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
48    use vortex_array::compute::conformance::filter::test_filter_conformance;
49    use vortex_array::compute::conformance::mask::test_mask_conformance;
50    use vortex_array::compute::conformance::take::test_take_conformance;
51    use vortex_array::compute::{Operator, compare, take};
52    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
53    use vortex_dtype::PType::I32;
54    use vortex_dtype::{DType, Nullability};
55    use vortex_scalar::Scalar;
56
57    use crate::builders::dict_encode;
58
59    #[test]
60    fn canonicalise_nullable_primitive() {
61        let values: Vec<Option<i32>> = (0..65)
62            .map(|i| match i % 3 {
63                0 => Some(42),
64                1 => Some(-9),
65                2 => None,
66                _ => unreachable!(),
67            })
68            .collect();
69
70        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
71        let actual = dict.to_primitive().unwrap();
72
73        let expected: Vec<i32> = (0..65)
74            .map(|i| match i % 3 {
75                // Compressor puts 0 as a code for invalid values which we end up using in take
76                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
77                0 | 2 => 42,
78                1 => -9,
79                _ => unreachable!(),
80            })
81            .collect();
82
83        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
84
85        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
86        assert_eq!(
87            actual.validity_mask().unwrap().true_count(),
88            expected_valid_count
89        );
90    }
91
92    #[test]
93    fn canonicalise_non_nullable_primitive_32_unique_values() {
94        let unique_values: Vec<i32> = (0..32).collect();
95        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
96
97        let dict =
98            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
99        let actual = dict.to_primitive().unwrap();
100
101        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
102    }
103
104    #[test]
105    fn canonicalise_non_nullable_primitive_100_unique_values() {
106        let unique_values: Vec<i32> = (0..100).collect();
107        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
108
109        let dict =
110            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
111        let actual = dict.to_primitive().unwrap();
112
113        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
114    }
115
116    #[test]
117    fn canonicalise_nullable_varbin() {
118        let reference = VarBinViewArray::from_iter(
119            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
120            DType::Utf8(Nullability::Nullable),
121        );
122        assert_eq!(reference.len(), 6);
123        let dict = dict_encode(reference.as_ref()).unwrap();
124        let flattened_dict = dict.to_varbinview().unwrap();
125        assert_eq!(
126            flattened_dict
127                .with_iterator(|iter| iter
128                    .map(|slice| slice.map(|s| s.to_vec()))
129                    .collect::<Vec<_>>())
130                .unwrap(),
131            reference
132                .with_iterator(|iter| iter
133                    .map(|slice| slice.map(|s| s.to_vec()))
134                    .collect::<Vec<_>>())
135                .unwrap(),
136        );
137    }
138
139    fn sliced_dict_array() -> ArrayRef {
140        let reference = PrimitiveArray::from_option_iter([
141            Some(42),
142            Some(-9),
143            None,
144            Some(42),
145            Some(1),
146            Some(5),
147        ]);
148        let dict = dict_encode(reference.as_ref()).unwrap();
149        dict.slice(1, 4).unwrap()
150    }
151
152    #[test]
153    fn compare_sliced_dict() {
154        let sliced = sliced_dict_array();
155        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
156
157        assert_eq!(
158            compared.scalar_at(0).unwrap(),
159            Scalar::bool(false, Nullability::Nullable)
160        );
161        assert_eq!(
162            compared.scalar_at(1).unwrap(),
163            Scalar::null(DType::Bool(Nullability::Nullable))
164        );
165        assert_eq!(
166            compared.scalar_at(2).unwrap(),
167            Scalar::bool(true, Nullability::Nullable)
168        );
169    }
170
171    #[test]
172    fn test_mask_dict_array() {
173        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
174        test_mask_conformance(array.as_ref());
175
176        let array = dict_encode(
177            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
178        )
179        .unwrap();
180        test_mask_conformance(array.as_ref());
181
182        let array = dict_encode(
183            &VarBinArray::from_iter(
184                [
185                    Some("hello"),
186                    None,
187                    Some("hello"),
188                    Some("good"),
189                    Some("good"),
190                ],
191                DType::Utf8(Nullability::Nullable),
192            )
193            .into_array(),
194        )
195        .unwrap();
196        test_mask_conformance(array.as_ref());
197    }
198
199    #[test]
200    fn test_filter_dict_array() {
201        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
202        test_filter_conformance(array.as_ref());
203
204        let array = dict_encode(
205            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
206        )
207        .unwrap();
208        test_filter_conformance(array.as_ref());
209
210        let array = dict_encode(
211            &VarBinArray::from_iter(
212                [
213                    Some("hello"),
214                    None,
215                    Some("hello"),
216                    Some("good"),
217                    Some("good"),
218                ],
219                DType::Utf8(Nullability::Nullable),
220            )
221            .into_array(),
222        )
223        .unwrap();
224        test_filter_conformance(array.as_ref());
225    }
226
227    #[test]
228    fn test_take_dict() {
229        let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
230
231        assert_eq!(
232            take(
233                array.as_ref(),
234                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
235            )
236            .unwrap()
237            .dtype(),
238            &DType::Primitive(I32, Nullability::Nullable)
239        );
240    }
241
242    #[test]
243    fn test_take_dict_conformance() {
244        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
245        test_take_conformance(array.as_ref());
246
247        let array = dict_encode(
248            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
249        )
250        .unwrap();
251        test_take_conformance(array.as_ref());
252
253        let array = dict_encode(
254            &VarBinArray::from_iter(
255                [
256                    Some("hello"),
257                    None,
258                    Some("hello"),
259                    Some("good"),
260                    Some("good"),
261                ],
262                DType::Utf8(Nullability::Nullable),
263            )
264            .into_array(),
265        )
266        .unwrap();
267        test_take_conformance(array.as_ref());
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use rstest::rstest;
274    use vortex_array::IntoArray;
275    use vortex_array::arrays::{PrimitiveArray, VarBinArray};
276    use vortex_array::compute::conformance::consistency::test_array_consistency;
277    use vortex_dtype::{DType, Nullability};
278
279    use crate::DictArray;
280    use crate::builders::dict_encode;
281
282    #[rstest]
283    // Primitive arrays
284    #[case::dict_i32(dict_encode(&PrimitiveArray::from_iter([1i32, 2, 3, 2, 1]).into_array()).unwrap())]
285    #[case::dict_nullable_i32(dict_encode(
286        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
287    ).unwrap())]
288    #[case::dict_u64(dict_encode(&PrimitiveArray::from_iter([100u64, 200, 100, 300, 200]).into_array()).unwrap())]
289    // String arrays
290    #[case::dict_str(dict_encode(
291        &VarBinArray::from_iter(
292            ["hello", "world", "hello", "test", "world"].map(Some),
293            DType::Utf8(Nullability::NonNullable),
294        ).into_array()
295    ).unwrap())]
296    #[case::dict_nullable_str(dict_encode(
297        &VarBinArray::from_iter(
298            [Some("hello"), None, Some("world"), Some("hello"), None],
299            DType::Utf8(Nullability::Nullable),
300        ).into_array()
301    ).unwrap())]
302    // Edge cases
303    #[case::dict_single(dict_encode(&PrimitiveArray::from_iter([42i32]).into_array()).unwrap())]
304    #[case::dict_all_same(dict_encode(&PrimitiveArray::from_iter([5i32, 5, 5, 5, 5]).into_array()).unwrap())]
305    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
306
307    fn test_dict_consistency(#[case] array: DictArray) {
308        test_array_consistency(array.as_ref());
309    }
310}