vortex_dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod binary_numeric;
5mod cast;
6mod compare;
7mod fill_null;
8mod is_constant;
9mod is_sorted;
10mod like;
11mod min_max;
12
13use vortex_array::compute::{
14    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
15};
16use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
17use vortex_error::VortexResult;
18use vortex_mask::Mask;
19
20use crate::{DictArray, DictVTable};
21
22impl TakeKernel for DictVTable {
23    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24        // TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
25        let codes = take(array.codes(), indices)?;
26        let values_dtype = array
27            .values()
28            .dtype()
29            .union_nullability(codes.dtype().nullability());
30        DictArray::try_new(codes, cast(array.values(), &values_dtype)?).map(|a| a.into_array())
31    }
32}
33
34register_kernel!(TakeKernelAdapter(DictVTable).lift());
35
36impl FilterKernel for DictVTable {
37    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
38        let codes = filter(array.codes(), mask)?;
39        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
40    }
41}
42
43register_kernel!(FilterKernelAdapter(DictVTable).lift());
44
45#[cfg(test)]
46mod test {
47    use vortex_array::accessor::ArrayAccessor;
48    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
49    use vortex_array::compute::conformance::filter::test_filter_conformance;
50    use vortex_array::compute::conformance::mask::test_mask_conformance;
51    use vortex_array::compute::conformance::take::test_take_conformance;
52    use vortex_array::compute::{Operator, compare, take};
53    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
54    use vortex_dtype::PType::I32;
55    use vortex_dtype::{DType, Nullability};
56    use vortex_scalar::Scalar;
57
58    use crate::builders::dict_encode;
59
60    #[test]
61    fn canonicalise_nullable_primitive() {
62        let values: Vec<Option<i32>> = (0..65)
63            .map(|i| match i % 3 {
64                0 => Some(42),
65                1 => Some(-9),
66                2 => None,
67                _ => unreachable!(),
68            })
69            .collect();
70
71        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
72        let actual = dict.to_primitive().unwrap();
73
74        let expected: Vec<i32> = (0..65)
75            .map(|i| match i % 3 {
76                // Compressor puts 0 as a code for invalid values which we end up using in take
77                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
78                0 | 2 => 42,
79                1 => -9,
80                _ => unreachable!(),
81            })
82            .collect();
83
84        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
85
86        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
87        assert_eq!(
88            actual.validity_mask().unwrap().true_count(),
89            expected_valid_count
90        );
91    }
92
93    #[test]
94    fn canonicalise_non_nullable_primitive_32_unique_values() {
95        let unique_values: Vec<i32> = (0..32).collect();
96        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
97
98        let dict =
99            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
100        let actual = dict.to_primitive().unwrap();
101
102        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
103    }
104
105    #[test]
106    fn canonicalise_non_nullable_primitive_100_unique_values() {
107        let unique_values: Vec<i32> = (0..100).collect();
108        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
109
110        let dict =
111            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
112        let actual = dict.to_primitive().unwrap();
113
114        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
115    }
116
117    #[test]
118    fn canonicalise_nullable_varbin() {
119        let reference = VarBinViewArray::from_iter(
120            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
121            DType::Utf8(Nullability::Nullable),
122        );
123        assert_eq!(reference.len(), 6);
124        let dict = dict_encode(reference.as_ref()).unwrap();
125        let flattened_dict = dict.to_varbinview().unwrap();
126        assert_eq!(
127            flattened_dict
128                .with_iterator(|iter| iter
129                    .map(|slice| slice.map(|s| s.to_vec()))
130                    .collect::<Vec<_>>())
131                .unwrap(),
132            reference
133                .with_iterator(|iter| iter
134                    .map(|slice| slice.map(|s| s.to_vec()))
135                    .collect::<Vec<_>>())
136                .unwrap(),
137        );
138    }
139
140    fn sliced_dict_array() -> ArrayRef {
141        let reference = PrimitiveArray::from_option_iter([
142            Some(42),
143            Some(-9),
144            None,
145            Some(42),
146            Some(1),
147            Some(5),
148        ]);
149        let dict = dict_encode(reference.as_ref()).unwrap();
150        dict.slice(1, 4).unwrap()
151    }
152
153    #[test]
154    fn compare_sliced_dict() {
155        let sliced = sliced_dict_array();
156        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
157
158        assert_eq!(
159            compared.scalar_at(0).unwrap(),
160            Scalar::bool(false, Nullability::Nullable)
161        );
162        assert_eq!(
163            compared.scalar_at(1).unwrap(),
164            Scalar::null(DType::Bool(Nullability::Nullable))
165        );
166        assert_eq!(
167            compared.scalar_at(2).unwrap(),
168            Scalar::bool(true, Nullability::Nullable)
169        );
170    }
171
172    #[test]
173    fn test_mask_dict_array() {
174        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
175        test_mask_conformance(array.as_ref());
176
177        let array = dict_encode(
178            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
179        )
180        .unwrap();
181        test_mask_conformance(array.as_ref());
182
183        let array = dict_encode(
184            &VarBinArray::from_iter(
185                [
186                    Some("hello"),
187                    None,
188                    Some("hello"),
189                    Some("good"),
190                    Some("good"),
191                ],
192                DType::Utf8(Nullability::Nullable),
193            )
194            .into_array(),
195        )
196        .unwrap();
197        test_mask_conformance(array.as_ref());
198    }
199
200    #[test]
201    fn test_filter_dict_array() {
202        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
203        test_filter_conformance(array.as_ref());
204
205        let array = dict_encode(
206            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
207        )
208        .unwrap();
209        test_filter_conformance(array.as_ref());
210
211        let array = dict_encode(
212            &VarBinArray::from_iter(
213                [
214                    Some("hello"),
215                    None,
216                    Some("hello"),
217                    Some("good"),
218                    Some("good"),
219                ],
220                DType::Utf8(Nullability::Nullable),
221            )
222            .into_array(),
223        )
224        .unwrap();
225        test_filter_conformance(array.as_ref());
226    }
227
228    #[test]
229    fn test_take_dict() {
230        let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
231
232        assert_eq!(
233            take(
234                array.as_ref(),
235                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
236            )
237            .unwrap()
238            .dtype(),
239            &DType::Primitive(I32, Nullability::Nullable)
240        );
241    }
242
243    #[test]
244    fn test_take_dict_conformance() {
245        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
246        test_take_conformance(array.as_ref());
247
248        let array = dict_encode(
249            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
250        )
251        .unwrap();
252        test_take_conformance(array.as_ref());
253
254        let array = dict_encode(
255            &VarBinArray::from_iter(
256                [
257                    Some("hello"),
258                    None,
259                    Some("hello"),
260                    Some("good"),
261                    Some("good"),
262                ],
263                DType::Utf8(Nullability::Nullable),
264            )
265            .into_array(),
266        )
267        .unwrap();
268        test_take_conformance(array.as_ref());
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use rstest::rstest;
275    use vortex_array::IntoArray;
276    use vortex_array::arrays::{PrimitiveArray, VarBinArray};
277    use vortex_array::compute::conformance::consistency::test_array_consistency;
278    use vortex_dtype::{DType, Nullability};
279
280    use crate::DictArray;
281    use crate::builders::dict_encode;
282
283    #[rstest]
284    // Primitive arrays
285    #[case::dict_i32(dict_encode(&PrimitiveArray::from_iter([1i32, 2, 3, 2, 1]).into_array()).unwrap())]
286    #[case::dict_nullable_i32(dict_encode(
287        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
288    ).unwrap())]
289    #[case::dict_u64(dict_encode(&PrimitiveArray::from_iter([100u64, 200, 100, 300, 200]).into_array()).unwrap())]
290    // String arrays
291    #[case::dict_str(dict_encode(
292        &VarBinArray::from_iter(
293            ["hello", "world", "hello", "test", "world"].map(Some),
294            DType::Utf8(Nullability::NonNullable),
295        ).into_array()
296    ).unwrap())]
297    #[case::dict_nullable_str(dict_encode(
298        &VarBinArray::from_iter(
299            [Some("hello"), None, Some("world"), Some("hello"), None],
300            DType::Utf8(Nullability::Nullable),
301        ).into_array()
302    ).unwrap())]
303    // Edge cases
304    #[case::dict_single(dict_encode(&PrimitiveArray::from_iter([42i32]).into_array()).unwrap())]
305    #[case::dict_all_same(dict_encode(&PrimitiveArray::from_iter([5i32, 5, 5, 5, 5]).into_array()).unwrap())]
306    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
307
308    fn test_dict_consistency(#[case] array: DictArray) {
309        test_array_consistency(array.as_ref());
310    }
311}