Skip to main content

vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7pub(crate) mod is_constant;
8pub(crate) mod is_sorted;
9mod like;
10mod mask;
11pub(crate) mod min_max;
12pub(crate) mod rules;
13mod slice;
14
15use vortex_error::VortexResult;
16use vortex_mask::Mask;
17
18use super::Dict;
19use super::DictArray;
20use super::TakeExecute;
21use crate::ArrayRef;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::array::ArrayView;
25use crate::arrays::dict::DictArraySlotsExt;
26use crate::arrays::filter::FilterReduce;
27
28impl TakeExecute for Dict {
29    fn take(
30        array: ArrayView<'_, Dict>,
31        indices: &ArrayRef,
32        _ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let codes = array.codes().take(indices.clone())?;
35        // SAFETY: selecting codes doesn't change the invariants of DictArray
36        // Preserve all_values_referenced since taking codes doesn't affect which values are referenced
37        Ok(Some(unsafe {
38            DictArray::new_unchecked(codes, array.values().clone()).into_array()
39        }))
40    }
41}
42
43impl FilterReduce for Dict {
44    fn filter(array: ArrayView<'_, Dict>, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
45        let codes = array.codes().filter(mask.clone())?;
46
47        // SAFETY: filtering codes doesn't change invariants
48        // Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
49        Ok(Some(unsafe {
50            DictArray::new_unchecked(codes, array.values().clone()).into_array()
51        }))
52    }
53}
54
55#[cfg(test)]
56mod test {
57    #[expect(unused_imports)]
58    use itertools::Itertools;
59    use vortex_buffer::buffer;
60
61    use crate::ArrayRef;
62    use crate::IntoArray;
63    #[expect(deprecated)]
64    use crate::ToCanonical as _;
65    use crate::accessor::ArrayAccessor;
66    use crate::arrays::ConstantArray;
67    use crate::arrays::PrimitiveArray;
68    use crate::arrays::VarBinArray;
69    use crate::arrays::VarBinViewArray;
70    use crate::assert_arrays_eq;
71    use crate::builders::dict::dict_encode;
72    use crate::builtins::ArrayBuiltins;
73    use crate::compute::conformance::filter::test_filter_conformance;
74    use crate::compute::conformance::mask::test_mask_conformance;
75    use crate::compute::conformance::take::test_take_conformance;
76    use crate::dtype::DType;
77    use crate::dtype::Nullability;
78    use crate::dtype::PType::I32;
79    use crate::scalar_fn::fns::operators::Operator;
80    #[test]
81    fn canonicalise_nullable_primitive() {
82        let values: Vec<Option<i32>> = (0..65)
83            .map(|i| match i % 3 {
84                0 => Some(42),
85                1 => Some(-9),
86                2 => None,
87                _ => unreachable!(),
88            })
89            .collect();
90
91        let dict =
92            dict_encode(&PrimitiveArray::from_option_iter(values.clone()).into_array()).unwrap();
93        #[expect(deprecated)]
94        let actual = dict.as_array().to_primitive();
95
96        let expected = PrimitiveArray::from_option_iter(values);
97
98        assert_arrays_eq!(actual, expected);
99    }
100
101    #[test]
102    fn canonicalise_non_nullable_primitive_32_unique_values() {
103        let unique_values: Vec<i32> = (0..32).collect();
104        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
105
106        let dict = dict_encode(&expected.clone().into_array()).unwrap();
107        #[expect(deprecated)]
108        let actual = dict.as_array().to_primitive();
109
110        assert_arrays_eq!(actual, expected);
111    }
112
113    #[test]
114    fn canonicalise_non_nullable_primitive_100_unique_values() {
115        let unique_values: Vec<i32> = (0..100).collect();
116        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
117
118        let dict = dict_encode(&expected.clone().into_array()).unwrap();
119        #[expect(deprecated)]
120        let actual = dict.as_array().to_primitive();
121
122        assert_arrays_eq!(actual, expected);
123    }
124
125    #[test]
126    fn canonicalise_nullable_varbin() {
127        let reference = VarBinViewArray::from_iter(
128            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
129            DType::Utf8(Nullability::Nullable),
130        );
131        assert_eq!(reference.len(), 6);
132        let dict = dict_encode(&reference.clone().into_array()).unwrap();
133        #[expect(deprecated)]
134        let flattened_dict = dict.as_array().to_varbinview();
135        assert_eq!(
136            flattened_dict.with_iterator(|iter| iter
137                .map(|slice| slice.map(|s| s.to_vec()))
138                .collect::<Vec<_>>()),
139            reference.with_iterator(|iter| iter
140                .map(|slice| slice.map(|s| s.to_vec()))
141                .collect::<Vec<_>>()),
142        );
143    }
144
145    fn sliced_dict_array() -> ArrayRef {
146        let reference = PrimitiveArray::from_option_iter([
147            Some(42),
148            Some(-9),
149            None,
150            Some(42),
151            Some(1),
152            Some(5),
153        ]);
154        let dict = dict_encode(&reference.into_array()).unwrap();
155        dict.slice(1..4).unwrap()
156    }
157
158    #[test]
159    fn compare_sliced_dict() {
160        use crate::arrays::BoolArray;
161        let sliced = sliced_dict_array();
162        let compared = sliced
163            .binary(ConstantArray::new(42, 3).into_array(), Operator::Eq)
164            .unwrap();
165
166        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
167        assert_arrays_eq!(compared, expected.into_array());
168    }
169
170    #[test]
171    fn test_mask_dict_array() {
172        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
173        test_mask_conformance(&array.into_array());
174
175        let array = dict_encode(
176            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
177                .into_array(),
178        )
179        .unwrap();
180        test_mask_conformance(&array.into_array());
181
182        let array = dict_encode(
183            &VarBinArray::from_iter(
184                [
185                    Some("hello"),
186                    None,
187                    Some("hello"),
188                    Some("good"),
189                    Some("good"),
190                ],
191                DType::Utf8(Nullability::Nullable),
192            )
193            .into_array(),
194        )
195        .unwrap();
196        test_mask_conformance(&array.into_array());
197    }
198
199    #[test]
200    fn test_filter_dict_array() {
201        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
202        test_filter_conformance(&array.into_array());
203
204        let array = dict_encode(
205            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
206                .into_array(),
207        )
208        .unwrap();
209        test_filter_conformance(&array.into_array());
210
211        let array = dict_encode(
212            &VarBinArray::from_iter(
213                [
214                    Some("hello"),
215                    None,
216                    Some("hello"),
217                    Some("good"),
218                    Some("good"),
219                ],
220                DType::Utf8(Nullability::Nullable),
221            )
222            .into_array(),
223        )
224        .unwrap();
225        test_filter_conformance(&array.into_array());
226    }
227
228    #[test]
229    fn test_take_dict() {
230        let array = dict_encode(&buffer![1, 2].into_array()).unwrap();
231
232        assert_eq!(
233            array
234                .take(PrimitiveArray::from_option_iter([Option::<i32>::None]).into_array())
235                .unwrap()
236                .dtype(),
237            &DType::Primitive(I32, Nullability::Nullable)
238        );
239    }
240
241    #[test]
242    fn test_take_dict_conformance() {
243        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
244        test_take_conformance(&array.into_array());
245
246        let array = dict_encode(
247            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
248                .into_array(),
249        )
250        .unwrap();
251        test_take_conformance(&array.into_array());
252
253        let array = dict_encode(
254            &VarBinArray::from_iter(
255                [
256                    Some("hello"),
257                    None,
258                    Some("hello"),
259                    Some("good"),
260                    Some("good"),
261                ],
262                DType::Utf8(Nullability::Nullable),
263            )
264            .into_array(),
265        )
266        .unwrap();
267        test_take_conformance(&array.into_array());
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use rstest::rstest;
274    use vortex_buffer::buffer;
275
276    use crate::IntoArray;
277    use crate::arrays::DictArray;
278    use crate::arrays::PrimitiveArray;
279    use crate::arrays::VarBinArray;
280    use crate::builders::dict::dict_encode;
281    use crate::compute::conformance::consistency::test_array_consistency;
282    use crate::dtype::DType;
283    use crate::dtype::Nullability;
284
285    #[rstest]
286    // Primitive arrays
287    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
288    #[case::dict_nullable_codes(DictArray::try_new(
289        buffer![0u32, 1, 2, 2, 0].into_array(),
290        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
291    ).unwrap())]
292    #[case::dict_nullable_values(dict_encode(
293        &PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()
294    ).unwrap())]
295    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
296    // String arrays
297    #[case::dict_str(dict_encode(
298        &VarBinArray::from_iter(
299            ["hello", "world", "hello", "test", "world"].map(Some),
300            DType::Utf8(Nullability::NonNullable),
301        ).into_array()
302    ).unwrap())]
303    #[case::dict_nullable_str(dict_encode(
304        &VarBinArray::from_iter(
305            [Some("hello"), None, Some("world"), Some("hello"), None],
306            DType::Utf8(Nullability::Nullable),
307        ).into_array()
308    ).unwrap())]
309    // Edge cases
310    #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
311    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
312    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
313    fn test_dict_consistency(#[case] array: DictArray) {
314        test_array_consistency(&array.into_array());
315    }
316}