Skip to main content

vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod mask;
11mod min_max;
12pub(crate) mod rules;
13mod slice;
14
15use vortex_error::VortexResult;
16use vortex_mask::Mask;
17
18use super::DictArray;
19use super::DictVTable;
20use super::TakeExecute;
21use crate::Array;
22use crate::ArrayRef;
23use crate::ExecutionCtx;
24use crate::IntoArray;
25use crate::arrays::filter::FilterReduce;
26
27impl TakeExecute for DictVTable {
28    fn take(
29        array: &DictArray,
30        indices: &dyn Array,
31        _ctx: &mut ExecutionCtx,
32    ) -> VortexResult<Option<ArrayRef>> {
33        let codes = array
34            .codes()
35            .take(indices.to_array())?
36            .to_canonical()?
37            .into_array();
38        // SAFETY: selecting codes doesn't change the invariants of DictArray
39        // Preserve all_values_referenced since taking codes doesn't affect which values are referenced
40        Ok(Some(unsafe {
41            DictArray::new_unchecked(codes, array.values().clone()).into_array()
42        }))
43    }
44}
45
46impl FilterReduce for DictVTable {
47    fn filter(array: &DictArray, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
48        let codes = array.codes().filter(mask.clone())?;
49
50        // SAFETY: filtering codes doesn't change invariants
51        // Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
52        Ok(Some(unsafe {
53            DictArray::new_unchecked(codes, array.values().clone()).into_array()
54        }))
55    }
56}
57
58#[cfg(test)]
59mod test {
60    #[allow(unused_imports)]
61    use itertools::Itertools;
62    use vortex_buffer::buffer;
63    use vortex_dtype::DType;
64    use vortex_dtype::Nullability;
65    use vortex_dtype::PType::I32;
66
67    use crate::Array;
68    use crate::ArrayRef;
69    use crate::IntoArray;
70    use crate::ToCanonical;
71    use crate::accessor::ArrayAccessor;
72    use crate::arrays::ConstantArray;
73    use crate::arrays::PrimitiveArray;
74    use crate::arrays::VarBinArray;
75    use crate::arrays::VarBinViewArray;
76    use crate::assert_arrays_eq;
77    use crate::builders::dict::dict_encode;
78    use crate::compute::Operator;
79    use crate::compute::compare;
80    use crate::compute::conformance::filter::test_filter_conformance;
81    use crate::compute::conformance::mask::test_mask_conformance;
82    use crate::compute::conformance::take::test_take_conformance;
83    #[test]
84    fn canonicalise_nullable_primitive() {
85        let values: Vec<Option<i32>> = (0..65)
86            .map(|i| match i % 3 {
87                0 => Some(42),
88                1 => Some(-9),
89                2 => None,
90                _ => unreachable!(),
91            })
92            .collect();
93
94        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
95        let actual = dict.to_primitive();
96
97        let expected = PrimitiveArray::from_option_iter(values);
98
99        assert_arrays_eq!(actual, expected);
100    }
101
102    #[test]
103    fn canonicalise_non_nullable_primitive_32_unique_values() {
104        let unique_values: Vec<i32> = (0..32).collect();
105        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
106
107        let dict = dict_encode(expected.as_ref()).unwrap();
108        let actual = dict.to_primitive();
109
110        assert_arrays_eq!(actual, expected);
111    }
112
113    #[test]
114    fn canonicalise_non_nullable_primitive_100_unique_values() {
115        let unique_values: Vec<i32> = (0..100).collect();
116        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
117
118        let dict = dict_encode(expected.as_ref()).unwrap();
119        let actual = dict.to_primitive();
120
121        assert_arrays_eq!(actual, expected);
122    }
123
124    #[test]
125    fn canonicalise_nullable_varbin() {
126        let reference = VarBinViewArray::from_iter(
127            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
128            DType::Utf8(Nullability::Nullable),
129        );
130        assert_eq!(reference.len(), 6);
131        let dict = dict_encode(reference.as_ref()).unwrap();
132        let flattened_dict = dict.to_varbinview();
133        assert_eq!(
134            flattened_dict.with_iterator(|iter| iter
135                .map(|slice| slice.map(|s| s.to_vec()))
136                .collect::<Vec<_>>()),
137            reference.with_iterator(|iter| iter
138                .map(|slice| slice.map(|s| s.to_vec()))
139                .collect::<Vec<_>>()),
140        );
141    }
142
143    fn sliced_dict_array() -> ArrayRef {
144        let reference = PrimitiveArray::from_option_iter([
145            Some(42),
146            Some(-9),
147            None,
148            Some(42),
149            Some(1),
150            Some(5),
151        ]);
152        let dict = dict_encode(reference.as_ref()).unwrap();
153        dict.slice(1..4).unwrap()
154    }
155
156    #[test]
157    fn compare_sliced_dict() {
158        use crate::arrays::BoolArray;
159        let sliced = sliced_dict_array();
160        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
161
162        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
163        assert_arrays_eq!(compared, expected.to_array());
164    }
165
166    #[test]
167    fn test_mask_dict_array() {
168        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
169        test_mask_conformance(array.as_ref());
170
171        let array = dict_encode(
172            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
173        )
174        .unwrap();
175        test_mask_conformance(array.as_ref());
176
177        let array = dict_encode(
178            &VarBinArray::from_iter(
179                [
180                    Some("hello"),
181                    None,
182                    Some("hello"),
183                    Some("good"),
184                    Some("good"),
185                ],
186                DType::Utf8(Nullability::Nullable),
187            )
188            .into_array(),
189        )
190        .unwrap();
191        test_mask_conformance(array.as_ref());
192    }
193
194    #[test]
195    fn test_filter_dict_array() {
196        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
197        test_filter_conformance(array.as_ref());
198
199        let array = dict_encode(
200            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
201        )
202        .unwrap();
203        test_filter_conformance(array.as_ref());
204
205        let array = dict_encode(
206            &VarBinArray::from_iter(
207                [
208                    Some("hello"),
209                    None,
210                    Some("hello"),
211                    Some("good"),
212                    Some("good"),
213                ],
214                DType::Utf8(Nullability::Nullable),
215            )
216            .into_array(),
217        )
218        .unwrap();
219        test_filter_conformance(array.as_ref());
220    }
221
222    #[test]
223    fn test_take_dict() {
224        let array = dict_encode(buffer![1, 2].into_array().as_ref()).unwrap();
225
226        assert_eq!(
227            array
228                .take(PrimitiveArray::from_option_iter([Option::<i32>::None]).to_array())
229                .unwrap()
230                .dtype(),
231            &DType::Primitive(I32, Nullability::Nullable)
232        );
233    }
234
235    #[test]
236    fn test_take_dict_conformance() {
237        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
238        test_take_conformance(array.as_ref());
239
240        let array = dict_encode(
241            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
242        )
243        .unwrap();
244        test_take_conformance(array.as_ref());
245
246        let array = dict_encode(
247            &VarBinArray::from_iter(
248                [
249                    Some("hello"),
250                    None,
251                    Some("hello"),
252                    Some("good"),
253                    Some("good"),
254                ],
255                DType::Utf8(Nullability::Nullable),
256            )
257            .into_array(),
258        )
259        .unwrap();
260        test_take_conformance(array.as_ref());
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use rstest::rstest;
267    use vortex_buffer::buffer;
268    use vortex_dtype::DType;
269    use vortex_dtype::Nullability;
270
271    use crate::IntoArray;
272    use crate::arrays::PrimitiveArray;
273    use crate::arrays::VarBinArray;
274    use crate::arrays::dict::DictArray;
275    use crate::builders::dict::dict_encode;
276    use crate::compute::conformance::consistency::test_array_consistency;
277
278    #[rstest]
279    // Primitive arrays
280    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
281    #[case::dict_nullable_codes(DictArray::try_new(
282        buffer![0u32, 1, 2, 2, 0].into_array(),
283        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
284    ).unwrap())]
285    #[case::dict_nullable_values(dict_encode(
286        PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
287    ).unwrap())]
288    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
289    // String arrays
290    #[case::dict_str(dict_encode(
291        &VarBinArray::from_iter(
292            ["hello", "world", "hello", "test", "world"].map(Some),
293            DType::Utf8(Nullability::NonNullable),
294        ).into_array()
295    ).unwrap())]
296    #[case::dict_nullable_str(dict_encode(
297        &VarBinArray::from_iter(
298            [Some("hello"), None, Some("world"), Some("hello"), None],
299            DType::Utf8(Nullability::Nullable),
300        ).into_array()
301    ).unwrap())]
302    // Edge cases
303    #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
304    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
305    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
306    fn test_dict_consistency(#[case] array: DictArray) {
307        test_array_consistency(array.as_ref());
308    }
309}