vortex_dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod binary_numeric;
5mod cast;
6mod compare;
7mod fill_null;
8mod is_constant;
9mod is_sorted;
10mod like;
11mod min_max;
12
13use vortex_array::compute::{
14    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
15};
16use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
17use vortex_error::VortexResult;
18use vortex_mask::Mask;
19
20use crate::{DictArray, DictVTable};
21
22impl TakeKernel for DictVTable {
23    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
24        // TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
25        let codes = take(array.codes(), indices)?;
26        let values_dtype = array
27            .values()
28            .dtype()
29            .union_nullability(codes.dtype().nullability());
30        // SAFETY: selecting codes doesn't change the invariants of DictArray
31        unsafe {
32            Ok(DictArray::new_unchecked(codes, cast(array.values(), &values_dtype)?).into_array())
33        }
34    }
35}
36
37register_kernel!(TakeKernelAdapter(DictVTable).lift());
38
39impl FilterKernel for DictVTable {
40    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
41        let codes = filter(array.codes(), mask)?;
42
43        // SAFETY: filtering codes doesn't change invariants
44        unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
45    }
46}
47
48register_kernel!(FilterKernelAdapter(DictVTable).lift());
49
50#[cfg(test)]
51mod test {
52    use vortex_array::accessor::ArrayAccessor;
53    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
54    use vortex_array::compute::conformance::filter::test_filter_conformance;
55    use vortex_array::compute::conformance::mask::test_mask_conformance;
56    use vortex_array::compute::conformance::take::test_take_conformance;
57    use vortex_array::compute::{Operator, compare, take};
58    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
59    use vortex_dtype::PType::I32;
60    use vortex_dtype::{DType, Nullability};
61    use vortex_scalar::Scalar;
62
63    use crate::builders::dict_encode;
64
65    #[test]
66    fn canonicalise_nullable_primitive() {
67        let values: Vec<Option<i32>> = (0..65)
68            .map(|i| match i % 3 {
69                0 => Some(42),
70                1 => Some(-9),
71                2 => None,
72                _ => unreachable!(),
73            })
74            .collect();
75
76        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
77        let actual = dict.to_primitive();
78
79        let expected: Vec<i32> = (0..65)
80            .map(|i| match i % 3 {
81                // Compressor puts 0 as a code for invalid values which we end up using in take
82                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
83                0 | 2 => 42,
84                1 => -9,
85                _ => unreachable!(),
86            })
87            .collect();
88
89        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
90
91        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
92        assert_eq!(actual.validity_mask().true_count(), expected_valid_count);
93    }
94
95    #[test]
96    fn canonicalise_non_nullable_primitive_32_unique_values() {
97        let unique_values: Vec<i32> = (0..32).collect();
98        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
99
100        let dict =
101            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
102        let actual = dict.to_primitive();
103
104        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
105    }
106
107    #[test]
108    fn canonicalise_non_nullable_primitive_100_unique_values() {
109        let unique_values: Vec<i32> = (0..100).collect();
110        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
111
112        let dict =
113            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
114        let actual = dict.to_primitive();
115
116        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
117    }
118
119    #[test]
120    fn canonicalise_nullable_varbin() {
121        let reference = VarBinViewArray::from_iter(
122            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
123            DType::Utf8(Nullability::Nullable),
124        );
125        assert_eq!(reference.len(), 6);
126        let dict = dict_encode(reference.as_ref()).unwrap();
127        let flattened_dict = dict.to_varbinview();
128        assert_eq!(
129            flattened_dict
130                .with_iterator(|iter| iter
131                    .map(|slice| slice.map(|s| s.to_vec()))
132                    .collect::<Vec<_>>())
133                .unwrap(),
134            reference
135                .with_iterator(|iter| iter
136                    .map(|slice| slice.map(|s| s.to_vec()))
137                    .collect::<Vec<_>>())
138                .unwrap(),
139        );
140    }
141
142    fn sliced_dict_array() -> ArrayRef {
143        let reference = PrimitiveArray::from_option_iter([
144            Some(42),
145            Some(-9),
146            None,
147            Some(42),
148            Some(1),
149            Some(5),
150        ]);
151        let dict = dict_encode(reference.as_ref()).unwrap();
152        dict.slice(1..4)
153    }
154
155    #[test]
156    fn compare_sliced_dict() {
157        let sliced = sliced_dict_array();
158        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
159
160        assert_eq!(
161            compared.scalar_at(0),
162            Scalar::bool(false, Nullability::Nullable)
163        );
164        assert_eq!(
165            compared.scalar_at(1),
166            Scalar::null(DType::Bool(Nullability::Nullable))
167        );
168        assert_eq!(
169            compared.scalar_at(2),
170            Scalar::bool(true, Nullability::Nullable)
171        );
172    }
173
174    #[test]
175    fn test_mask_dict_array() {
176        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
177        test_mask_conformance(array.as_ref());
178
179        let array = dict_encode(
180            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
181        )
182        .unwrap();
183        test_mask_conformance(array.as_ref());
184
185        let array = dict_encode(
186            &VarBinArray::from_iter(
187                [
188                    Some("hello"),
189                    None,
190                    Some("hello"),
191                    Some("good"),
192                    Some("good"),
193                ],
194                DType::Utf8(Nullability::Nullable),
195            )
196            .into_array(),
197        )
198        .unwrap();
199        test_mask_conformance(array.as_ref());
200    }
201
202    #[test]
203    fn test_filter_dict_array() {
204        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
205        test_filter_conformance(array.as_ref());
206
207        let array = dict_encode(
208            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
209        )
210        .unwrap();
211        test_filter_conformance(array.as_ref());
212
213        let array = dict_encode(
214            &VarBinArray::from_iter(
215                [
216                    Some("hello"),
217                    None,
218                    Some("hello"),
219                    Some("good"),
220                    Some("good"),
221                ],
222                DType::Utf8(Nullability::Nullable),
223            )
224            .into_array(),
225        )
226        .unwrap();
227        test_filter_conformance(array.as_ref());
228    }
229
230    #[test]
231    fn test_take_dict() {
232        let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
233
234        assert_eq!(
235            take(
236                array.as_ref(),
237                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
238            )
239            .unwrap()
240            .dtype(),
241            &DType::Primitive(I32, Nullability::Nullable)
242        );
243    }
244
245    #[test]
246    fn test_take_dict_conformance() {
247        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
248        test_take_conformance(array.as_ref());
249
250        let array = dict_encode(
251            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
252        )
253        .unwrap();
254        test_take_conformance(array.as_ref());
255
256        let array = dict_encode(
257            &VarBinArray::from_iter(
258                [
259                    Some("hello"),
260                    None,
261                    Some("hello"),
262                    Some("good"),
263                    Some("good"),
264                ],
265                DType::Utf8(Nullability::Nullable),
266            )
267            .into_array(),
268        )
269        .unwrap();
270        test_take_conformance(array.as_ref());
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use rstest::rstest;
277    use vortex_array::IntoArray;
278    use vortex_array::arrays::{PrimitiveArray, VarBinArray};
279    use vortex_array::compute::conformance::consistency::test_array_consistency;
280    use vortex_dtype::{DType, Nullability};
281
282    use crate::DictArray;
283    use crate::builders::dict_encode;
284
285    #[rstest]
286    // Primitive arrays
287    #[case::dict_i32(dict_encode(&PrimitiveArray::from_iter([1i32, 2, 3, 2, 1]).into_array()).unwrap())]
288    #[case::dict_nullable_i32(dict_encode(
289        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
290    ).unwrap())]
291    #[case::dict_u64(dict_encode(&PrimitiveArray::from_iter([100u64, 200, 100, 300, 200]).into_array()).unwrap())]
292    // String arrays
293    #[case::dict_str(dict_encode(
294        &VarBinArray::from_iter(
295            ["hello", "world", "hello", "test", "world"].map(Some),
296            DType::Utf8(Nullability::NonNullable),
297        ).into_array()
298    ).unwrap())]
299    #[case::dict_nullable_str(dict_encode(
300        &VarBinArray::from_iter(
301            [Some("hello"), None, Some("world"), Some("hello"), None],
302            DType::Utf8(Nullability::Nullable),
303        ).into_array()
304    ).unwrap())]
305    // Edge cases
306    #[case::dict_single(dict_encode(&PrimitiveArray::from_iter([42i32]).into_array()).unwrap())]
307    #[case::dict_all_same(dict_encode(&PrimitiveArray::from_iter([5i32, 5, 5, 5, 5]).into_array()).unwrap())]
308    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
309
310    fn test_dict_consistency(#[case] array: DictArray) {
311        test_array_consistency(array.as_ref());
312    }
313}