Skip to main content

vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod mask;
11mod min_max;
12pub(crate) mod rules;
13mod slice;
14
15use vortex_error::VortexResult;
16use vortex_mask::Mask;
17
18use super::DictArray;
19use super::DictVTable;
20use super::TakeExecute;
21use crate::Array;
22use crate::ArrayRef;
23use crate::ExecutionCtx;
24use crate::IntoArray;
25use crate::arrays::filter::FilterReduce;
26
27impl TakeExecute for DictVTable {
28    fn take(
29        array: &DictArray,
30        indices: &ArrayRef,
31        _ctx: &mut ExecutionCtx,
32    ) -> VortexResult<Option<ArrayRef>> {
33        let codes = array.codes().take(indices.to_array())?;
34        // SAFETY: selecting codes doesn't change the invariants of DictArray
35        // Preserve all_values_referenced since taking codes doesn't affect which values are referenced
36        Ok(Some(unsafe {
37            DictArray::new_unchecked(codes, array.values().clone()).into_array()
38        }))
39    }
40}
41
42impl FilterReduce for DictVTable {
43    fn filter(array: &DictArray, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
44        let codes = array.codes().filter(mask.clone())?;
45
46        // SAFETY: filtering codes doesn't change invariants
47        // Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
48        Ok(Some(unsafe {
49            DictArray::new_unchecked(codes, array.values().clone()).into_array()
50        }))
51    }
52}
53
54#[cfg(test)]
55mod test {
56    #[allow(unused_imports)]
57    use itertools::Itertools;
58    use vortex_buffer::buffer;
59
60    use crate::Array;
61    use crate::ArrayRef;
62    use crate::IntoArray;
63    use crate::ToCanonical;
64    use crate::accessor::ArrayAccessor;
65    use crate::arrays::ConstantArray;
66    use crate::arrays::PrimitiveArray;
67    use crate::arrays::VarBinArray;
68    use crate::arrays::VarBinViewArray;
69    use crate::assert_arrays_eq;
70    use crate::builders::dict::dict_encode;
71    use crate::builtins::ArrayBuiltins;
72    use crate::compute::conformance::filter::test_filter_conformance;
73    use crate::compute::conformance::mask::test_mask_conformance;
74    use crate::compute::conformance::take::test_take_conformance;
75    use crate::dtype::DType;
76    use crate::dtype::Nullability;
77    use crate::dtype::PType::I32;
78    use crate::scalar_fn::fns::operators::Operator;
79    #[test]
80    fn canonicalise_nullable_primitive() {
81        let values: Vec<Option<i32>> = (0..65)
82            .map(|i| match i % 3 {
83                0 => Some(42),
84                1 => Some(-9),
85                2 => None,
86                _ => unreachable!(),
87            })
88            .collect();
89
90        let dict =
91            dict_encode(&PrimitiveArray::from_option_iter(values.clone()).to_array()).unwrap();
92        let actual = dict.to_primitive();
93
94        let expected = PrimitiveArray::from_option_iter(values);
95
96        assert_arrays_eq!(actual, expected);
97    }
98
99    #[test]
100    fn canonicalise_non_nullable_primitive_32_unique_values() {
101        let unique_values: Vec<i32> = (0..32).collect();
102        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
103
104        let dict = dict_encode(&expected.to_array()).unwrap();
105        let actual = dict.to_primitive();
106
107        assert_arrays_eq!(actual, expected);
108    }
109
110    #[test]
111    fn canonicalise_non_nullable_primitive_100_unique_values() {
112        let unique_values: Vec<i32> = (0..100).collect();
113        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
114
115        let dict = dict_encode(&expected.to_array()).unwrap();
116        let actual = dict.to_primitive();
117
118        assert_arrays_eq!(actual, expected);
119    }
120
121    #[test]
122    fn canonicalise_nullable_varbin() {
123        let reference = VarBinViewArray::from_iter(
124            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
125            DType::Utf8(Nullability::Nullable),
126        );
127        assert_eq!(reference.len(), 6);
128        let dict = dict_encode(&reference.to_array()).unwrap();
129        let flattened_dict = dict.to_varbinview();
130        assert_eq!(
131            flattened_dict.with_iterator(|iter| iter
132                .map(|slice| slice.map(|s| s.to_vec()))
133                .collect::<Vec<_>>()),
134            reference.with_iterator(|iter| iter
135                .map(|slice| slice.map(|s| s.to_vec()))
136                .collect::<Vec<_>>()),
137        );
138    }
139
140    fn sliced_dict_array() -> ArrayRef {
141        let reference = PrimitiveArray::from_option_iter([
142            Some(42),
143            Some(-9),
144            None,
145            Some(42),
146            Some(1),
147            Some(5),
148        ]);
149        let dict = dict_encode(&reference.to_array()).unwrap();
150        dict.slice(1..4).unwrap()
151    }
152
153    #[test]
154    fn compare_sliced_dict() {
155        use crate::arrays::BoolArray;
156        let sliced = sliced_dict_array();
157        let compared = sliced
158            .binary(ConstantArray::new(42, 3).to_array(), Operator::Eq)
159            .unwrap();
160
161        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
162        assert_arrays_eq!(compared, expected.to_array());
163    }
164
165    #[test]
166    fn test_mask_dict_array() {
167        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
168        test_mask_conformance(&array.to_array());
169
170        let array = dict_encode(
171            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
172                .to_array(),
173        )
174        .unwrap();
175        test_mask_conformance(&array.to_array());
176
177        let array = dict_encode(
178            &VarBinArray::from_iter(
179                [
180                    Some("hello"),
181                    None,
182                    Some("hello"),
183                    Some("good"),
184                    Some("good"),
185                ],
186                DType::Utf8(Nullability::Nullable),
187            )
188            .into_array(),
189        )
190        .unwrap();
191        test_mask_conformance(&array.to_array());
192    }
193
194    #[test]
195    fn test_filter_dict_array() {
196        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
197        test_filter_conformance(&array.to_array());
198
199        let array = dict_encode(
200            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
201                .to_array(),
202        )
203        .unwrap();
204        test_filter_conformance(&array.to_array());
205
206        let array = dict_encode(
207            &VarBinArray::from_iter(
208                [
209                    Some("hello"),
210                    None,
211                    Some("hello"),
212                    Some("good"),
213                    Some("good"),
214                ],
215                DType::Utf8(Nullability::Nullable),
216            )
217            .into_array(),
218        )
219        .unwrap();
220        test_filter_conformance(&array.to_array());
221    }
222
223    #[test]
224    fn test_take_dict() {
225        let array = dict_encode(&buffer![1, 2].into_array()).unwrap();
226
227        assert_eq!(
228            array
229                .take(PrimitiveArray::from_option_iter([Option::<i32>::None]).to_array())
230                .unwrap()
231                .dtype(),
232            &DType::Primitive(I32, Nullability::Nullable)
233        );
234    }
235
236    #[test]
237    fn test_take_dict_conformance() {
238        let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
239        test_take_conformance(&array.to_array());
240
241        let array = dict_encode(
242            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
243                .to_array(),
244        )
245        .unwrap();
246        test_take_conformance(&array.to_array());
247
248        let array = dict_encode(
249            &VarBinArray::from_iter(
250                [
251                    Some("hello"),
252                    None,
253                    Some("hello"),
254                    Some("good"),
255                    Some("good"),
256                ],
257                DType::Utf8(Nullability::Nullable),
258            )
259            .into_array(),
260        )
261        .unwrap();
262        test_take_conformance(&array.to_array());
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use rstest::rstest;
269    use vortex_buffer::buffer;
270
271    use crate::IntoArray;
272    use crate::arrays::PrimitiveArray;
273    use crate::arrays::VarBinArray;
274    use crate::arrays::dict::DictArray;
275    use crate::builders::dict::dict_encode;
276    use crate::compute::conformance::consistency::test_array_consistency;
277    use crate::dtype::DType;
278    use crate::dtype::Nullability;
279
280    #[rstest]
281    // Primitive arrays
282    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
283    #[case::dict_nullable_codes(DictArray::try_new(
284        buffer![0u32, 1, 2, 2, 0].into_array(),
285        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
286    ).unwrap())]
287    #[case::dict_nullable_values(dict_encode(
288        &PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).to_array()
289    ).unwrap())]
290    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
291    // String arrays
292    #[case::dict_str(dict_encode(
293        &VarBinArray::from_iter(
294            ["hello", "world", "hello", "test", "world"].map(Some),
295            DType::Utf8(Nullability::NonNullable),
296        ).into_array()
297    ).unwrap())]
298    #[case::dict_nullable_str(dict_encode(
299        &VarBinArray::from_iter(
300            [Some("hello"), None, Some("world"), Some("hello"), None],
301            DType::Utf8(Nullability::Nullable),
302        ).into_array()
303    ).unwrap())]
304    // Edge cases
305    #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
306    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
307    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
308    fn test_dict_consistency(#[case] array: DictArray) {
309        test_array_consistency(&array.to_array());
310    }
311}