vortex_dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod binary_numeric;
5mod cast;
6mod compare;
7mod fill_null;
8mod is_constant;
9mod is_sorted;
10mod like;
11mod min_max;
12
13use vortex_array::compute::{
14    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
15};
16use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
17use vortex_error::VortexResult;
18use vortex_mask::Mask;
19
20use crate::{DictArray, DictVTable};
21
22impl TakeKernel for DictVTable {
23    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24        // TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
25        let codes = take(array.codes(), indices)?;
26        let values_dtype = array
27            .values()
28            .dtype()
29            .union_nullability(codes.dtype().nullability());
30        // SAFETY: selecting codes doesn't change the invariants of DictArray
31        unsafe {
32            Ok(DictArray::new_unchecked(codes, cast(array.values(), &values_dtype)?).into_array())
33        }
34    }
35}
36
37register_kernel!(TakeKernelAdapter(DictVTable).lift());
38
39impl FilterKernel for DictVTable {
40    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
41        let codes = filter(array.codes(), mask)?;
42
43        // SAFETY: filtering codes doesn't change invariants
44        unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
45    }
46}
47
48register_kernel!(FilterKernelAdapter(DictVTable).lift());
49
50#[cfg(test)]
51mod test {
52    use vortex_array::accessor::ArrayAccessor;
53    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
54    use vortex_array::compute::conformance::filter::test_filter_conformance;
55    use vortex_array::compute::conformance::mask::test_mask_conformance;
56    use vortex_array::compute::conformance::take::test_take_conformance;
57    use vortex_array::compute::{Operator, compare, take};
58    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
59    use vortex_dtype::PType::I32;
60    use vortex_dtype::{DType, Nullability};
61    use vortex_scalar::Scalar;
62
63    use crate::builders::dict_encode;
64
65    #[test]
66    fn canonicalise_nullable_primitive() {
67        let values: Vec<Option<i32>> = (0..65)
68            .map(|i| match i % 3 {
69                0 => Some(42),
70                1 => Some(-9),
71                2 => None,
72                _ => unreachable!(),
73            })
74            .collect();
75
76        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
77        let actual = dict.to_primitive().unwrap();
78
79        let expected: Vec<i32> = (0..65)
80            .map(|i| match i % 3 {
81                // Compressor puts 0 as a code for invalid values which we end up using in take
82                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
83                0 | 2 => 42,
84                1 => -9,
85                _ => unreachable!(),
86            })
87            .collect();
88
89        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
90
91        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
92        assert_eq!(
93            actual.validity_mask().unwrap().true_count(),
94            expected_valid_count
95        );
96    }
97
98    #[test]
99    fn canonicalise_non_nullable_primitive_32_unique_values() {
100        let unique_values: Vec<i32> = (0..32).collect();
101        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
102
103        let dict =
104            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
105        let actual = dict.to_primitive().unwrap();
106
107        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
108    }
109
110    #[test]
111    fn canonicalise_non_nullable_primitive_100_unique_values() {
112        let unique_values: Vec<i32> = (0..100).collect();
113        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
114
115        let dict =
116            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
117        let actual = dict.to_primitive().unwrap();
118
119        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
120    }
121
122    #[test]
123    fn canonicalise_nullable_varbin() {
124        let reference = VarBinViewArray::from_iter(
125            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
126            DType::Utf8(Nullability::Nullable),
127        );
128        assert_eq!(reference.len(), 6);
129        let dict = dict_encode(reference.as_ref()).unwrap();
130        let flattened_dict = dict.to_varbinview().unwrap();
131        assert_eq!(
132            flattened_dict
133                .with_iterator(|iter| iter
134                    .map(|slice| slice.map(|s| s.to_vec()))
135                    .collect::<Vec<_>>())
136                .unwrap(),
137            reference
138                .with_iterator(|iter| iter
139                    .map(|slice| slice.map(|s| s.to_vec()))
140                    .collect::<Vec<_>>())
141                .unwrap(),
142        );
143    }
144
145    fn sliced_dict_array() -> ArrayRef {
146        let reference = PrimitiveArray::from_option_iter([
147            Some(42),
148            Some(-9),
149            None,
150            Some(42),
151            Some(1),
152            Some(5),
153        ]);
154        let dict = dict_encode(reference.as_ref()).unwrap();
155        dict.slice(1, 4)
156    }
157
158    #[test]
159    fn compare_sliced_dict() {
160        let sliced = sliced_dict_array();
161        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
162
163        assert_eq!(
164            compared.scalar_at(0),
165            Scalar::bool(false, Nullability::Nullable)
166        );
167        assert_eq!(
168            compared.scalar_at(1),
169            Scalar::null(DType::Bool(Nullability::Nullable))
170        );
171        assert_eq!(
172            compared.scalar_at(2),
173            Scalar::bool(true, Nullability::Nullable)
174        );
175    }
176
177    #[test]
178    fn test_mask_dict_array() {
179        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
180        test_mask_conformance(array.as_ref());
181
182        let array = dict_encode(
183            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
184        )
185        .unwrap();
186        test_mask_conformance(array.as_ref());
187
188        let array = dict_encode(
189            &VarBinArray::from_iter(
190                [
191                    Some("hello"),
192                    None,
193                    Some("hello"),
194                    Some("good"),
195                    Some("good"),
196                ],
197                DType::Utf8(Nullability::Nullable),
198            )
199            .into_array(),
200        )
201        .unwrap();
202        test_mask_conformance(array.as_ref());
203    }
204
205    #[test]
206    fn test_filter_dict_array() {
207        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
208        test_filter_conformance(array.as_ref());
209
210        let array = dict_encode(
211            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
212        )
213        .unwrap();
214        test_filter_conformance(array.as_ref());
215
216        let array = dict_encode(
217            &VarBinArray::from_iter(
218                [
219                    Some("hello"),
220                    None,
221                    Some("hello"),
222                    Some("good"),
223                    Some("good"),
224                ],
225                DType::Utf8(Nullability::Nullable),
226            )
227            .into_array(),
228        )
229        .unwrap();
230        test_filter_conformance(array.as_ref());
231    }
232
233    #[test]
234    fn test_take_dict() {
235        let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
236
237        assert_eq!(
238            take(
239                array.as_ref(),
240                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
241            )
242            .unwrap()
243            .dtype(),
244            &DType::Primitive(I32, Nullability::Nullable)
245        );
246    }
247
248    #[test]
249    fn test_take_dict_conformance() {
250        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
251        test_take_conformance(array.as_ref());
252
253        let array = dict_encode(
254            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
255        )
256        .unwrap();
257        test_take_conformance(array.as_ref());
258
259        let array = dict_encode(
260            &VarBinArray::from_iter(
261                [
262                    Some("hello"),
263                    None,
264                    Some("hello"),
265                    Some("good"),
266                    Some("good"),
267                ],
268                DType::Utf8(Nullability::Nullable),
269            )
270            .into_array(),
271        )
272        .unwrap();
273        test_take_conformance(array.as_ref());
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use rstest::rstest;
280    use vortex_array::IntoArray;
281    use vortex_array::arrays::{PrimitiveArray, VarBinArray};
282    use vortex_array::compute::conformance::consistency::test_array_consistency;
283    use vortex_dtype::{DType, Nullability};
284
285    use crate::DictArray;
286    use crate::builders::dict_encode;
287
288    #[rstest]
289    // Primitive arrays
290    #[case::dict_i32(dict_encode(&PrimitiveArray::from_iter([1i32, 2, 3, 2, 1]).into_array()).unwrap())]
291    #[case::dict_nullable_i32(dict_encode(
292        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
293    ).unwrap())]
294    #[case::dict_u64(dict_encode(&PrimitiveArray::from_iter([100u64, 200, 100, 300, 200]).into_array()).unwrap())]
295    // String arrays
296    #[case::dict_str(dict_encode(
297        &VarBinArray::from_iter(
298            ["hello", "world", "hello", "test", "world"].map(Some),
299            DType::Utf8(Nullability::NonNullable),
300        ).into_array()
301    ).unwrap())]
302    #[case::dict_nullable_str(dict_encode(
303        &VarBinArray::from_iter(
304            [Some("hello"), None, Some("world"), Some("hello"), None],
305            DType::Utf8(Nullability::Nullable),
306        ).into_array()
307    ).unwrap())]
308    // Edge cases
309    #[case::dict_single(dict_encode(&PrimitiveArray::from_iter([42i32]).into_array()).unwrap())]
310    #[case::dict_all_same(dict_encode(&PrimitiveArray::from_iter([5i32, 5, 5, 5, 5]).into_array()).unwrap())]
311    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
312
313    fn test_dict_consistency(#[case] array: DictArray) {
314        test_array_consistency(array.as_ref());
315    }
316}