vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod min_max;
11
12use vortex_error::VortexResult;
13use vortex_mask::Mask;
14
15use super::{DictArray, DictVTable};
16use crate::compute::{
17    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, filter, take,
18};
19use crate::{Array, ArrayRef, IntoArray, register_kernel};
20
21impl TakeKernel for DictVTable {
22    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23        let codes = take(array.codes(), indices)?;
24        // SAFETY: selecting codes doesn't change the invariants of DictArray
25        Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()) }.into_array())
26    }
27}
28
29register_kernel!(TakeKernelAdapter(DictVTable).lift());
30
31impl FilterKernel for DictVTable {
32    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
33        let codes = filter(array.codes(), mask)?;
34
35        // SAFETY: filtering codes doesn't change invariants
36        unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
37    }
38}
39
40register_kernel!(FilterKernelAdapter(DictVTable).lift());
41
42#[cfg(test)]
43mod test {
44    #[allow(unused_imports)]
45    use itertools::Itertools;
46    use vortex_buffer::buffer;
47    use vortex_dtype::PType::I32;
48    use vortex_dtype::{DType, Nullability};
49
50    use crate::accessor::ArrayAccessor;
51    use crate::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
52    use crate::builders::dict::dict_encode;
53    use crate::compute::conformance::filter::test_filter_conformance;
54    use crate::compute::conformance::mask::test_mask_conformance;
55    use crate::compute::conformance::take::test_take_conformance;
56    use crate::compute::{Operator, compare, take};
57    use crate::{Array, ArrayRef, IntoArray, ToCanonical, assert_arrays_eq};
58
59    #[test]
60    fn canonicalise_nullable_primitive() {
61        let values: Vec<Option<i32>> = (0..65)
62            .map(|i| match i % 3 {
63                0 => Some(42),
64                1 => Some(-9),
65                2 => None,
66                _ => unreachable!(),
67            })
68            .collect();
69
70        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
71        let actual = dict.to_primitive();
72
73        let expected = PrimitiveArray::from_option_iter(values);
74
75        assert_arrays_eq!(actual, expected);
76    }
77
78    #[test]
79    fn canonicalise_non_nullable_primitive_32_unique_values() {
80        let unique_values: Vec<i32> = (0..32).collect();
81        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
82
83        let dict = dict_encode(expected.as_ref()).unwrap();
84        let actual = dict.to_primitive();
85
86        assert_arrays_eq!(actual, expected);
87    }
88
89    #[test]
90    fn canonicalise_non_nullable_primitive_100_unique_values() {
91        let unique_values: Vec<i32> = (0..100).collect();
92        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
93
94        let dict = dict_encode(expected.as_ref()).unwrap();
95        let actual = dict.to_primitive();
96
97        assert_arrays_eq!(actual, expected);
98    }
99
100    #[test]
101    fn canonicalise_nullable_varbin() {
102        let reference = VarBinViewArray::from_iter(
103            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
104            DType::Utf8(Nullability::Nullable),
105        );
106        assert_eq!(reference.len(), 6);
107        let dict = dict_encode(reference.as_ref()).unwrap();
108        let flattened_dict = dict.to_varbinview();
109        assert_eq!(
110            flattened_dict.with_iterator(|iter| iter
111                .map(|slice| slice.map(|s| s.to_vec()))
112                .collect::<Vec<_>>()),
113            reference.with_iterator(|iter| iter
114                .map(|slice| slice.map(|s| s.to_vec()))
115                .collect::<Vec<_>>()),
116        );
117    }
118
119    fn sliced_dict_array() -> ArrayRef {
120        let reference = PrimitiveArray::from_option_iter([
121            Some(42),
122            Some(-9),
123            None,
124            Some(42),
125            Some(1),
126            Some(5),
127        ]);
128        let dict = dict_encode(reference.as_ref()).unwrap();
129        dict.slice(1..4)
130    }
131
132    #[test]
133    fn compare_sliced_dict() {
134        use crate::arrays::BoolArray;
135        let sliced = sliced_dict_array();
136        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
137
138        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
139        assert_arrays_eq!(compared, expected.to_array());
140    }
141
142    #[test]
143    fn test_mask_dict_array() {
144        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
145        test_mask_conformance(array.as_ref());
146
147        let array = dict_encode(
148            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
149        )
150        .unwrap();
151        test_mask_conformance(array.as_ref());
152
153        let array = dict_encode(
154            &VarBinArray::from_iter(
155                [
156                    Some("hello"),
157                    None,
158                    Some("hello"),
159                    Some("good"),
160                    Some("good"),
161                ],
162                DType::Utf8(Nullability::Nullable),
163            )
164            .into_array(),
165        )
166        .unwrap();
167        test_mask_conformance(array.as_ref());
168    }
169
170    #[test]
171    fn test_filter_dict_array() {
172        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
173        test_filter_conformance(array.as_ref());
174
175        let array = dict_encode(
176            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
177        )
178        .unwrap();
179        test_filter_conformance(array.as_ref());
180
181        let array = dict_encode(
182            &VarBinArray::from_iter(
183                [
184                    Some("hello"),
185                    None,
186                    Some("hello"),
187                    Some("good"),
188                    Some("good"),
189                ],
190                DType::Utf8(Nullability::Nullable),
191            )
192            .into_array(),
193        )
194        .unwrap();
195        test_filter_conformance(array.as_ref());
196    }
197
198    #[test]
199    fn test_take_dict() {
200        let array = dict_encode(buffer![1, 2].into_array().as_ref()).unwrap();
201
202        assert_eq!(
203            take(
204                array.as_ref(),
205                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
206            )
207            .unwrap()
208            .dtype(),
209            &DType::Primitive(I32, Nullability::Nullable)
210        );
211    }
212
213    #[test]
214    fn test_take_dict_conformance() {
215        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
216        test_take_conformance(array.as_ref());
217
218        let array = dict_encode(
219            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
220        )
221        .unwrap();
222        test_take_conformance(array.as_ref());
223
224        let array = dict_encode(
225            &VarBinArray::from_iter(
226                [
227                    Some("hello"),
228                    None,
229                    Some("hello"),
230                    Some("good"),
231                    Some("good"),
232                ],
233                DType::Utf8(Nullability::Nullable),
234            )
235            .into_array(),
236        )
237        .unwrap();
238        test_take_conformance(array.as_ref());
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use rstest::rstest;
245    use vortex_buffer::buffer;
246    use vortex_dtype::{DType, Nullability};
247
248    use crate::IntoArray;
249    use crate::arrays::dict::DictArray;
250    use crate::arrays::{PrimitiveArray, VarBinArray};
251    use crate::builders::dict::dict_encode;
252    use crate::compute::conformance::consistency::test_array_consistency;
253
254    #[rstest]
255    // Primitive arrays
256    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
257    #[case::dict_nullable_codes(DictArray::try_new(
258        buffer![0u32, 1, 2, 2, 0].into_array(),
259        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
260    ).unwrap())]
261    #[case::dict_nullable_values(dict_encode(
262        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
263    ).unwrap())]
264    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
265    // String arrays
266    #[case::dict_str(dict_encode(
267        &VarBinArray::from_iter(
268            ["hello", "world", "hello", "test", "world"].map(Some),
269            DType::Utf8(Nullability::NonNullable),
270        ).into_array()
271    ).unwrap())]
272    #[case::dict_nullable_str(dict_encode(
273        &VarBinArray::from_iter(
274            [Some("hello"), None, Some("world"), Some("hello"), None],
275            DType::Utf8(Nullability::Nullable),
276        ).into_array()
277    ).unwrap())]
278    // Edge cases
279    #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
280    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
281    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
282    fn test_dict_consistency(#[case] array: DictArray) {
283        test_array_consistency(array.as_ref());
284    }
285}