vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod binary_numeric;
5mod cast;
6mod compare;
7mod fill_null;
8mod is_constant;
9mod is_sorted;
10mod like;
11mod min_max;
12
13use vortex_error::VortexResult;
14use vortex_mask::Mask;
15
16use super::DictArray;
17use super::DictVTable;
18use crate::Array;
19use crate::ArrayRef;
20use crate::IntoArray;
21use crate::compute::FilterKernel;
22use crate::compute::FilterKernelAdapter;
23use crate::compute::TakeKernel;
24use crate::compute::TakeKernelAdapter;
25use crate::compute::filter;
26use crate::compute::take;
27use crate::register_kernel;
28
29impl TakeKernel for DictVTable {
30    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
31        let codes = take(array.codes(), indices)?;
32        // SAFETY: selecting codes doesn't change the invariants of DictArray
33        // Preserve all_values_referenced since taking codes doesn't affect which values are referenced
34        Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()).into_array() })
35    }
36}
37
38register_kernel!(TakeKernelAdapter(DictVTable).lift());
39
40impl FilterKernel for DictVTable {
41    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
42        let codes = filter(array.codes(), mask)?;
43
44        // SAFETY: filtering codes doesn't change invariants
45        // Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
46        unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
47    }
48}
49
50register_kernel!(FilterKernelAdapter(DictVTable).lift());
51
52#[cfg(test)]
53mod test {
54    #[allow(unused_imports)]
55    use itertools::Itertools;
56    use vortex_buffer::buffer;
57    use vortex_dtype::DType;
58    use vortex_dtype::Nullability;
59    use vortex_dtype::PType::I32;
60
61    use crate::Array;
62    use crate::ArrayRef;
63    use crate::IntoArray;
64    use crate::ToCanonical;
65    use crate::accessor::ArrayAccessor;
66    use crate::arrays::ConstantArray;
67    use crate::arrays::PrimitiveArray;
68    use crate::arrays::VarBinArray;
69    use crate::arrays::VarBinViewArray;
70    use crate::assert_arrays_eq;
71    use crate::builders::dict::dict_encode;
72    use crate::compute::Operator;
73    use crate::compute::compare;
74    use crate::compute::conformance::filter::test_filter_conformance;
75    use crate::compute::conformance::mask::test_mask_conformance;
76    use crate::compute::conformance::take::test_take_conformance;
77    use crate::compute::take;
78
79    #[test]
80    fn canonicalise_nullable_primitive() {
81        let values: Vec<Option<i32>> = (0..65)
82            .map(|i| match i % 3 {
83                0 => Some(42),
84                1 => Some(-9),
85                2 => None,
86                _ => unreachable!(),
87            })
88            .collect();
89
90        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
91        let actual = dict.to_primitive();
92
93        let expected = PrimitiveArray::from_option_iter(values);
94
95        assert_arrays_eq!(actual, expected);
96    }
97
98    #[test]
99    fn canonicalise_non_nullable_primitive_32_unique_values() {
100        let unique_values: Vec<i32> = (0..32).collect();
101        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
102
103        let dict = dict_encode(expected.as_ref()).unwrap();
104        let actual = dict.to_primitive();
105
106        assert_arrays_eq!(actual, expected);
107    }
108
109    #[test]
110    fn canonicalise_non_nullable_primitive_100_unique_values() {
111        let unique_values: Vec<i32> = (0..100).collect();
112        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
113
114        let dict = dict_encode(expected.as_ref()).unwrap();
115        let actual = dict.to_primitive();
116
117        assert_arrays_eq!(actual, expected);
118    }
119
120    #[test]
121    fn canonicalise_nullable_varbin() {
122        let reference = VarBinViewArray::from_iter(
123            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
124            DType::Utf8(Nullability::Nullable),
125        );
126        assert_eq!(reference.len(), 6);
127        let dict = dict_encode(reference.as_ref()).unwrap();
128        let flattened_dict = dict.to_varbinview();
129        assert_eq!(
130            flattened_dict.with_iterator(|iter| iter
131                .map(|slice| slice.map(|s| s.to_vec()))
132                .collect::<Vec<_>>()),
133            reference.with_iterator(|iter| iter
134                .map(|slice| slice.map(|s| s.to_vec()))
135                .collect::<Vec<_>>()),
136        );
137    }
138
139    fn sliced_dict_array() -> ArrayRef {
140        let reference = PrimitiveArray::from_option_iter([
141            Some(42),
142            Some(-9),
143            None,
144            Some(42),
145            Some(1),
146            Some(5),
147        ]);
148        let dict = dict_encode(reference.as_ref()).unwrap();
149        dict.slice(1..4)
150    }
151
152    #[test]
153    fn compare_sliced_dict() {
154        use crate::arrays::BoolArray;
155        let sliced = sliced_dict_array();
156        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
157
158        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
159        assert_arrays_eq!(compared, expected.to_array());
160    }
161
162    #[test]
163    fn test_mask_dict_array() {
164        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
165        test_mask_conformance(array.as_ref());
166
167        let array = dict_encode(
168            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
169        )
170        .unwrap();
171        test_mask_conformance(array.as_ref());
172
173        let array = dict_encode(
174            &VarBinArray::from_iter(
175                [
176                    Some("hello"),
177                    None,
178                    Some("hello"),
179                    Some("good"),
180                    Some("good"),
181                ],
182                DType::Utf8(Nullability::Nullable),
183            )
184            .into_array(),
185        )
186        .unwrap();
187        test_mask_conformance(array.as_ref());
188    }
189
190    #[test]
191    fn test_filter_dict_array() {
192        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
193        test_filter_conformance(array.as_ref());
194
195        let array = dict_encode(
196            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
197        )
198        .unwrap();
199        test_filter_conformance(array.as_ref());
200
201        let array = dict_encode(
202            &VarBinArray::from_iter(
203                [
204                    Some("hello"),
205                    None,
206                    Some("hello"),
207                    Some("good"),
208                    Some("good"),
209                ],
210                DType::Utf8(Nullability::Nullable),
211            )
212            .into_array(),
213        )
214        .unwrap();
215        test_filter_conformance(array.as_ref());
216    }
217
218    #[test]
219    fn test_take_dict() {
220        let array = dict_encode(buffer![1, 2].into_array().as_ref()).unwrap();
221
222        assert_eq!(
223            take(
224                array.as_ref(),
225                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
226            )
227            .unwrap()
228            .dtype(),
229            &DType::Primitive(I32, Nullability::Nullable)
230        );
231    }
232
233    #[test]
234    fn test_take_dict_conformance() {
235        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
236        test_take_conformance(array.as_ref());
237
238        let array = dict_encode(
239            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
240        )
241        .unwrap();
242        test_take_conformance(array.as_ref());
243
244        let array = dict_encode(
245            &VarBinArray::from_iter(
246                [
247                    Some("hello"),
248                    None,
249                    Some("hello"),
250                    Some("good"),
251                    Some("good"),
252                ],
253                DType::Utf8(Nullability::Nullable),
254            )
255            .into_array(),
256        )
257        .unwrap();
258        test_take_conformance(array.as_ref());
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use rstest::rstest;
265    use vortex_buffer::buffer;
266    use vortex_dtype::DType;
267    use vortex_dtype::Nullability;
268
269    use crate::IntoArray;
270    use crate::arrays::PrimitiveArray;
271    use crate::arrays::VarBinArray;
272    use crate::arrays::dict::DictArray;
273    use crate::builders::dict::dict_encode;
274    use crate::compute::conformance::consistency::test_array_consistency;
275
276    #[rstest]
277    // Primitive arrays
278    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
279    #[case::dict_nullable_codes(DictArray::try_new(
280        buffer![0u32, 1, 2, 2, 0].into_array(),
281        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
282    ).unwrap())]
283    #[case::dict_nullable_values(dict_encode(
284        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
285    ).unwrap())]
286    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
287    // String arrays
288    #[case::dict_str(dict_encode(
289        &VarBinArray::from_iter(
290            ["hello", "world", "hello", "test", "world"].map(Some),
291            DType::Utf8(Nullability::NonNullable),
292        ).into_array()
293    ).unwrap())]
294    #[case::dict_nullable_str(dict_encode(
295        &VarBinArray::from_iter(
296            [Some("hello"), None, Some("world"), Some("hello"), None],
297            DType::Utf8(Nullability::Nullable),
298        ).into_array()
299    ).unwrap())]
300    // Edge cases
301    #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
302    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
303    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
304    fn test_dict_consistency(#[case] array: DictArray) {
305        test_array_consistency(array.as_ref());
306    }
307}