Skip to main content

vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7pub(crate) mod is_constant;
8pub(crate) mod is_sorted;
9mod like;
10mod mask;
11pub(crate) mod min_max;
12pub(crate) mod rules;
13mod slice;
14
15use vortex_error::VortexResult;
16use vortex_mask::Mask;
17
18use super::Dict;
19use super::DictArray;
20use super::TakeExecute;
21use crate::ArrayRef;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::array::ArrayView;
25use crate::arrays::dict::DictArraySlotsExt;
26use crate::arrays::filter::FilterReduce;
27
28impl TakeExecute for Dict {
29    fn take(
30        array: ArrayView<'_, Dict>,
31        indices: &ArrayRef,
32        _ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let codes = array.codes().take(indices.clone())?;
35        // SAFETY: selecting codes doesn't change the invariants of DictArray
36        // Preserve all_values_referenced since taking codes doesn't affect which values are referenced
37        Ok(Some(unsafe {
38            DictArray::new_unchecked(codes, array.values().clone()).into_array()
39        }))
40    }
41}
42
43impl FilterReduce for Dict {
44    fn filter(array: ArrayView<'_, Dict>, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
45        let codes = array.codes().filter(mask.clone())?;
46
47        // SAFETY: filtering codes doesn't change invariants
48        // Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
49        Ok(Some(unsafe {
50            DictArray::new_unchecked(codes, array.values().clone()).into_array()
51        }))
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use std::sync::LazyLock;
58
59    #[expect(unused_imports)]
60    use itertools::Itertools;
61    use vortex_buffer::buffer;
62    use vortex_session::VortexSession;
63
64    use crate::ArrayRef;
65    use crate::IntoArray;
66    #[expect(deprecated)]
67    use crate::ToCanonical as _;
68    use crate::VortexSessionExecute;
69    use crate::accessor::ArrayAccessor;
70    use crate::arrays::ConstantArray;
71    use crate::arrays::PrimitiveArray;
72    use crate::arrays::VarBinArray;
73    use crate::arrays::VarBinViewArray;
74    use crate::assert_arrays_eq;
75    use crate::builders::dict::dict_encode;
76    use crate::builtins::ArrayBuiltins;
77    use crate::compute::conformance::filter::test_filter_conformance;
78    use crate::compute::conformance::mask::test_mask_conformance;
79    use crate::compute::conformance::take::test_take_conformance;
80    use crate::dtype::DType;
81    use crate::dtype::Nullability;
82    use crate::dtype::PType::I32;
83    use crate::scalar_fn::fns::operators::Operator;
84    use crate::session::ArraySession;
85
86    static SESSION: LazyLock<VortexSession> =
87        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
88
89    #[test]
90    fn canonicalise_nullable_primitive() {
91        let values: Vec<Option<i32>> = (0..65)
92            .map(|i| match i % 3 {
93                0 => Some(42),
94                1 => Some(-9),
95                2 => None,
96                _ => unreachable!(),
97            })
98            .collect();
99
100        let dict = dict_encode(
101            &PrimitiveArray::from_option_iter(values.clone()).into_array(),
102            &mut SESSION.create_execution_ctx(),
103        )
104        .unwrap();
105        #[expect(deprecated)]
106        let actual = dict.as_array().to_primitive();
107
108        let expected = PrimitiveArray::from_option_iter(values);
109
110        assert_arrays_eq!(actual, expected);
111    }
112
113    #[test]
114    fn canonicalise_non_nullable_primitive_32_unique_values() {
115        let unique_values: Vec<i32> = (0..32).collect();
116        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
117
118        let dict = dict_encode(
119            &expected.clone().into_array(),
120            &mut SESSION.create_execution_ctx(),
121        )
122        .unwrap();
123        #[expect(deprecated)]
124        let actual = dict.as_array().to_primitive();
125
126        assert_arrays_eq!(actual, expected);
127    }
128
129    #[test]
130    fn canonicalise_non_nullable_primitive_100_unique_values() {
131        let unique_values: Vec<i32> = (0..100).collect();
132        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
133
134        let dict = dict_encode(
135            &expected.clone().into_array(),
136            &mut SESSION.create_execution_ctx(),
137        )
138        .unwrap();
139        #[expect(deprecated)]
140        let actual = dict.as_array().to_primitive();
141
142        assert_arrays_eq!(actual, expected);
143    }
144
145    #[test]
146    fn canonicalise_nullable_varbin() {
147        let reference = VarBinViewArray::from_iter(
148            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
149            DType::Utf8(Nullability::Nullable),
150        );
151        assert_eq!(reference.len(), 6);
152        let dict = dict_encode(
153            &reference.clone().into_array(),
154            &mut SESSION.create_execution_ctx(),
155        )
156        .unwrap();
157        #[expect(deprecated)]
158        let flattened_dict = dict.as_array().to_varbinview();
159        assert_eq!(
160            flattened_dict.with_iterator(|iter| iter
161                .map(|slice| slice.map(|s| s.to_vec()))
162                .collect::<Vec<_>>()),
163            reference.with_iterator(|iter| iter
164                .map(|slice| slice.map(|s| s.to_vec()))
165                .collect::<Vec<_>>()),
166        );
167    }
168
169    fn sliced_dict_array() -> ArrayRef {
170        let reference = PrimitiveArray::from_option_iter([
171            Some(42),
172            Some(-9),
173            None,
174            Some(42),
175            Some(1),
176            Some(5),
177        ]);
178        let dict =
179            dict_encode(&reference.into_array(), &mut SESSION.create_execution_ctx()).unwrap();
180        dict.slice(1..4).unwrap()
181    }
182
183    #[test]
184    fn compare_sliced_dict() {
185        use crate::arrays::BoolArray;
186        let sliced = sliced_dict_array();
187        let compared = sliced
188            .binary(ConstantArray::new(42, 3).into_array(), Operator::Eq)
189            .unwrap();
190
191        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
192        assert_arrays_eq!(compared, expected.into_array());
193    }
194
195    #[test]
196    fn test_mask_dict_array() {
197        let array = dict_encode(
198            &buffer![2, 0, 2, 0, 10].into_array(),
199            &mut SESSION.create_execution_ctx(),
200        )
201        .unwrap();
202        test_mask_conformance(&array.into_array());
203
204        let array = dict_encode(
205            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
206                .into_array(),
207            &mut SESSION.create_execution_ctx(),
208        )
209        .unwrap();
210        test_mask_conformance(&array.into_array());
211
212        let array = dict_encode(
213            &VarBinArray::from_iter(
214                [
215                    Some("hello"),
216                    None,
217                    Some("hello"),
218                    Some("good"),
219                    Some("good"),
220                ],
221                DType::Utf8(Nullability::Nullable),
222            )
223            .into_array(),
224            &mut SESSION.create_execution_ctx(),
225        )
226        .unwrap();
227        test_mask_conformance(&array.into_array());
228    }
229
230    #[test]
231    fn test_filter_dict_array() {
232        let array = dict_encode(
233            &buffer![2, 0, 2, 0, 10].into_array(),
234            &mut SESSION.create_execution_ctx(),
235        )
236        .unwrap();
237        test_filter_conformance(&array.into_array());
238
239        let array = dict_encode(
240            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
241                .into_array(),
242            &mut SESSION.create_execution_ctx(),
243        )
244        .unwrap();
245        test_filter_conformance(&array.into_array());
246
247        let array = dict_encode(
248            &VarBinArray::from_iter(
249                [
250                    Some("hello"),
251                    None,
252                    Some("hello"),
253                    Some("good"),
254                    Some("good"),
255                ],
256                DType::Utf8(Nullability::Nullable),
257            )
258            .into_array(),
259            &mut SESSION.create_execution_ctx(),
260        )
261        .unwrap();
262        test_filter_conformance(&array.into_array());
263    }
264
265    #[test]
266    fn test_take_dict() {
267        let array = dict_encode(
268            &buffer![1, 2].into_array(),
269            &mut SESSION.create_execution_ctx(),
270        )
271        .unwrap();
272
273        assert_eq!(
274            array
275                .take(PrimitiveArray::from_option_iter([Option::<i32>::None]).into_array())
276                .unwrap()
277                .dtype(),
278            &DType::Primitive(I32, Nullability::Nullable)
279        );
280    }
281
282    #[test]
283    fn test_take_dict_conformance() {
284        let array = dict_encode(
285            &buffer![2, 0, 2, 0, 10].into_array(),
286            &mut SESSION.create_execution_ctx(),
287        )
288        .unwrap();
289        test_take_conformance(&array.into_array());
290
291        let array = dict_encode(
292            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
293                .into_array(),
294            &mut SESSION.create_execution_ctx(),
295        )
296        .unwrap();
297        test_take_conformance(&array.into_array());
298
299        let array = dict_encode(
300            &VarBinArray::from_iter(
301                [
302                    Some("hello"),
303                    None,
304                    Some("hello"),
305                    Some("good"),
306                    Some("good"),
307                ],
308                DType::Utf8(Nullability::Nullable),
309            )
310            .into_array(),
311            &mut SESSION.create_execution_ctx(),
312        )
313        .unwrap();
314        test_take_conformance(&array.into_array());
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use std::sync::LazyLock;
321
322    use rstest::rstest;
323    use vortex_buffer::buffer;
324    use vortex_session::VortexSession;
325
326    use crate::IntoArray;
327    use crate::VortexSessionExecute;
328    use crate::arrays::DictArray;
329    use crate::arrays::PrimitiveArray;
330    use crate::arrays::VarBinArray;
331    use crate::builders::dict::dict_encode;
332    use crate::compute::conformance::consistency::test_array_consistency;
333    use crate::dtype::DType;
334    use crate::dtype::Nullability;
335    use crate::session::ArraySession;
336
337    static SESSION: LazyLock<VortexSession> =
338        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
339
340    #[rstest]
341    // Primitive arrays
342    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
343    #[case::dict_nullable_codes(DictArray::try_new(
344        buffer![0u32, 1, 2, 2, 0].into_array(),
345        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
346    ).unwrap())]
347    #[case::dict_nullable_values(dict_encode(
348        &PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()
349    , &mut SESSION.create_execution_ctx()).unwrap())]
350    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
351    // String arrays
352    #[case::dict_str(dict_encode(
353        &VarBinArray::from_iter(
354            ["hello", "world", "hello", "test", "world"].map(Some),
355            DType::Utf8(Nullability::NonNullable),
356        ).into_array()
357    , &mut SESSION.create_execution_ctx()).unwrap())]
358    #[case::dict_nullable_str(dict_encode(
359        &VarBinArray::from_iter(
360            [Some("hello"), None, Some("world"), Some("hello"), None],
361            DType::Utf8(Nullability::Nullable),
362        ).into_array()
363    , &mut SESSION.create_execution_ctx()).unwrap())]
364    // Edge cases
365    #[case::dict_single(dict_encode(&buffer![42i32].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
366    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
367    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
368    fn test_dict_consistency(#[case] array: DictArray) {
369        test_array_consistency(&array.into_array());
370    }
371}