vortex_dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod min_max;
11
12use vortex_array::compute::{
13    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, filter, take,
14};
15use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18
19use crate::{DictArray, DictVTable};
20
21impl TakeKernel for DictVTable {
22    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23        let codes = take(array.codes(), indices)?;
24        // SAFETY: selecting codes doesn't change the invariants of DictArray
25        Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()) }.into_array())
26    }
27}
28
29register_kernel!(TakeKernelAdapter(DictVTable).lift());
30
31impl FilterKernel for DictVTable {
32    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
33        let codes = filter(array.codes(), mask)?;
34
35        // SAFETY: filtering codes doesn't change invariants
36        unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
37    }
38}
39
40register_kernel!(FilterKernelAdapter(DictVTable).lift());
41
42#[cfg(test)]
43mod test {
44    use vortex_array::accessor::ArrayAccessor;
45    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
46    use vortex_array::compute::conformance::filter::test_filter_conformance;
47    use vortex_array::compute::conformance::mask::test_mask_conformance;
48    use vortex_array::compute::conformance::take::test_take_conformance;
49    use vortex_array::compute::{Operator, compare, take};
50    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
51    use vortex_buffer::buffer;
52    use vortex_dtype::PType::I32;
53    use vortex_dtype::{DType, Nullability};
54    use vortex_scalar::Scalar;
55
56    use crate::builders::dict_encode;
57
58    #[test]
59    fn canonicalise_nullable_primitive() {
60        let values: Vec<Option<i32>> = (0..65)
61            .map(|i| match i % 3 {
62                0 => Some(42),
63                1 => Some(-9),
64                2 => None,
65                _ => unreachable!(),
66            })
67            .collect();
68
69        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
70        let actual = dict.to_primitive();
71
72        let expected: Vec<i32> = (0..65)
73            .map(|i| match i % 3 {
74                // Compressor puts 0 as a code for invalid values
75                0 => 42,
76                1 => -9,
77                2 => 0,
78                _ => unreachable!(),
79            })
80            .collect();
81
82        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
83
84        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
85        assert_eq!(actual.validity_mask().true_count(), expected_valid_count);
86    }
87
88    #[test]
89    fn canonicalise_non_nullable_primitive_32_unique_values() {
90        let unique_values: Vec<i32> = (0..32).collect();
91        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
92
93        let dict =
94            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
95        let actual = dict.to_primitive();
96
97        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
98    }
99
100    #[test]
101    fn canonicalise_non_nullable_primitive_100_unique_values() {
102        let unique_values: Vec<i32> = (0..100).collect();
103        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
104
105        let dict =
106            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
107        let actual = dict.to_primitive();
108
109        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
110    }
111
112    #[test]
113    fn canonicalise_nullable_varbin() {
114        let reference = VarBinViewArray::from_iter(
115            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
116            DType::Utf8(Nullability::Nullable),
117        );
118        assert_eq!(reference.len(), 6);
119        let dict = dict_encode(reference.as_ref()).unwrap();
120        let flattened_dict = dict.to_varbinview();
121        assert_eq!(
122            flattened_dict
123                .with_iterator(|iter| iter
124                    .map(|slice| slice.map(|s| s.to_vec()))
125                    .collect::<Vec<_>>())
126                .unwrap(),
127            reference
128                .with_iterator(|iter| iter
129                    .map(|slice| slice.map(|s| s.to_vec()))
130                    .collect::<Vec<_>>())
131                .unwrap(),
132        );
133    }
134
135    fn sliced_dict_array() -> ArrayRef {
136        let reference = PrimitiveArray::from_option_iter([
137            Some(42),
138            Some(-9),
139            None,
140            Some(42),
141            Some(1),
142            Some(5),
143        ]);
144        let dict = dict_encode(reference.as_ref()).unwrap();
145        dict.slice(1..4)
146    }
147
148    #[test]
149    fn compare_sliced_dict() {
150        let sliced = sliced_dict_array();
151        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
152
153        assert_eq!(
154            compared.scalar_at(0),
155            Scalar::bool(false, Nullability::Nullable)
156        );
157        assert_eq!(
158            compared.scalar_at(1),
159            Scalar::null(DType::Bool(Nullability::Nullable))
160        );
161        assert_eq!(
162            compared.scalar_at(2),
163            Scalar::bool(true, Nullability::Nullable)
164        );
165    }
166
167    #[test]
168    fn test_mask_dict_array() {
169        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
170        test_mask_conformance(array.as_ref());
171
172        let array = dict_encode(
173            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
174        )
175        .unwrap();
176        test_mask_conformance(array.as_ref());
177
178        let array = dict_encode(
179            &VarBinArray::from_iter(
180                [
181                    Some("hello"),
182                    None,
183                    Some("hello"),
184                    Some("good"),
185                    Some("good"),
186                ],
187                DType::Utf8(Nullability::Nullable),
188            )
189            .into_array(),
190        )
191        .unwrap();
192        test_mask_conformance(array.as_ref());
193    }
194
195    #[test]
196    fn test_filter_dict_array() {
197        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
198        test_filter_conformance(array.as_ref());
199
200        let array = dict_encode(
201            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
202        )
203        .unwrap();
204        test_filter_conformance(array.as_ref());
205
206        let array = dict_encode(
207            &VarBinArray::from_iter(
208                [
209                    Some("hello"),
210                    None,
211                    Some("hello"),
212                    Some("good"),
213                    Some("good"),
214                ],
215                DType::Utf8(Nullability::Nullable),
216            )
217            .into_array(),
218        )
219        .unwrap();
220        test_filter_conformance(array.as_ref());
221    }
222
223    #[test]
224    fn test_take_dict() {
225        let array = dict_encode(buffer![1, 2].into_array().as_ref()).unwrap();
226
227        assert_eq!(
228            take(
229                array.as_ref(),
230                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
231            )
232            .unwrap()
233            .dtype(),
234            &DType::Primitive(I32, Nullability::Nullable)
235        );
236    }
237
238    #[test]
239    fn test_take_dict_conformance() {
240        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
241        test_take_conformance(array.as_ref());
242
243        let array = dict_encode(
244            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
245        )
246        .unwrap();
247        test_take_conformance(array.as_ref());
248
249        let array = dict_encode(
250            &VarBinArray::from_iter(
251                [
252                    Some("hello"),
253                    None,
254                    Some("hello"),
255                    Some("good"),
256                    Some("good"),
257                ],
258                DType::Utf8(Nullability::Nullable),
259            )
260            .into_array(),
261        )
262        .unwrap();
263        test_take_conformance(array.as_ref());
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use rstest::rstest;
270    use vortex_array::IntoArray;
271    use vortex_array::arrays::{PrimitiveArray, VarBinArray};
272    use vortex_array::compute::conformance::consistency::test_array_consistency;
273    use vortex_buffer::buffer;
274    use vortex_dtype::{DType, Nullability};
275
276    use crate::DictArray;
277    use crate::builders::dict_encode;
278
279    #[rstest]
280    // Primitive arrays
281    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
282    #[case::dict_nullable_codes(DictArray::try_new(
283        buffer![0u32, 1, 2, 2, 0].into_array(),
284        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
285    ).unwrap())]
286    #[case::dict_nullable_values(dict_encode(
287        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
288    ).unwrap())]
289    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
290    // String arrays
291    #[case::dict_str(dict_encode(
292        &VarBinArray::from_iter(
293            ["hello", "world", "hello", "test", "world"].map(Some),
294            DType::Utf8(Nullability::NonNullable),
295        ).into_array()
296    ).unwrap())]
297    #[case::dict_nullable_str(dict_encode(
298        &VarBinArray::from_iter(
299            [Some("hello"), None, Some("world"), Some("hello"), None],
300            DType::Utf8(Nullability::Nullable),
301        ).into_array()
302    ).unwrap())]
303    // Edge cases
304    #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
305    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
306    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
307    fn test_dict_consistency(#[case] array: DictArray) {
308        test_array_consistency(array.as_ref());
309    }
310}