Skip to main content

vortex_array/arrays/dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod compare;
6mod fill_null;
7pub(crate) mod is_constant;
8pub(crate) mod is_sorted;
9mod like;
10mod mask;
11pub(crate) mod min_max;
12pub(crate) mod rules;
13mod slice;
14
15use vortex_error::VortexResult;
16use vortex_mask::Mask;
17
18use super::Dict;
19use super::DictArray;
20use super::TakeExecute;
21use crate::ArrayRef;
22use crate::ExecutionCtx;
23use crate::IntoArray;
24use crate::array::ArrayView;
25use crate::arrays::dict::DictArraySlotsExt;
26use crate::arrays::filter::FilterReduce;
27
28impl TakeExecute for Dict {
29    fn take(
30        array: ArrayView<'_, Dict>,
31        indices: &ArrayRef,
32        _ctx: &mut ExecutionCtx,
33    ) -> VortexResult<Option<ArrayRef>> {
34        let codes = array.codes().take(indices.clone())?;
35        // SAFETY: selecting codes doesn't change the invariants of DictArray
36        // Preserve all_values_referenced since taking codes doesn't affect which values are referenced
37        Ok(Some(unsafe {
38            DictArray::new_unchecked(codes, array.values().clone()).into_array()
39        }))
40    }
41}
42
43impl FilterReduce for Dict {
44    fn filter(array: ArrayView<'_, Dict>, mask: &Mask) -> VortexResult<Option<ArrayRef>> {
45        let codes = array.codes().filter(mask.clone())?;
46
47        // SAFETY: filtering codes doesn't change invariants
48        // Preserve all_values_referenced since filtering codes doesn't affect which values are referenced
49        Ok(Some(unsafe {
50            DictArray::new_unchecked(codes, array.values().clone()).into_array()
51        }))
52    }
53}
54
55#[cfg(test)]
56mod test {
57    use std::sync::LazyLock;
58
59    use vortex_buffer::buffer;
60    use vortex_error::VortexResult;
61    use vortex_session::VortexSession;
62
63    use crate::ArrayRef;
64    use crate::IntoArray;
65    use crate::VortexSessionExecute;
66    use crate::arrays::ConstantArray;
67    use crate::arrays::PrimitiveArray;
68    use crate::arrays::VarBinArray;
69    use crate::arrays::VarBinViewArray;
70    use crate::assert_arrays_eq;
71    use crate::builders::dict::dict_encode;
72    use crate::builtins::ArrayBuiltins;
73    use crate::compute::conformance::filter::test_filter_conformance;
74    use crate::compute::conformance::mask::test_mask_conformance;
75    use crate::compute::conformance::take::test_take_conformance;
76    use crate::dtype::DType;
77    use crate::dtype::Nullability;
78    use crate::dtype::PType::I32;
79    use crate::scalar_fn::fns::operators::Operator;
80
81    static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
82
83    #[test]
84    fn canonicalise_nullable_primitive() {
85        let mut ctx = SESSION.create_execution_ctx();
86        let values: Vec<Option<i32>> = (0..65)
87            .map(|i| match i % 3 {
88                0 => Some(42),
89                1 => Some(-9),
90                2 => None,
91                _ => unreachable!(),
92            })
93            .collect();
94
95        let dict = dict_encode(
96            &PrimitiveArray::from_option_iter(values.clone()).into_array(),
97            &mut SESSION.create_execution_ctx(),
98        )
99        .unwrap();
100        let actual = dict
101            .into_array()
102            .execute::<PrimitiveArray>(&mut SESSION.create_execution_ctx())
103            .unwrap();
104
105        let expected = PrimitiveArray::from_option_iter(values);
106
107        assert_arrays_eq!(actual, expected, &mut ctx);
108    }
109
110    #[test]
111    fn canonicalise_non_nullable_primitive_32_unique_values() {
112        let mut ctx = SESSION.create_execution_ctx();
113        let unique_values: Vec<i32> = (0..32).collect();
114        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
115
116        let dict = dict_encode(
117            &expected.clone().into_array(),
118            &mut SESSION.create_execution_ctx(),
119        )
120        .unwrap()
121        .into_array();
122
123        let actual = dict
124            .execute::<PrimitiveArray>(&mut SESSION.create_execution_ctx())
125            .unwrap();
126
127        assert_arrays_eq!(actual, expected, &mut ctx);
128    }
129
130    #[test]
131    fn canonicalise_non_nullable_primitive_100_unique_values() {
132        let mut ctx = SESSION.create_execution_ctx();
133        let unique_values: Vec<i32> = (0..100).collect();
134        let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
135
136        let dict = dict_encode(
137            &expected.clone().into_array(),
138            &mut SESSION.create_execution_ctx(),
139        )
140        .unwrap()
141        .into_array();
142
143        let actual = dict
144            .execute::<PrimitiveArray>(&mut SESSION.create_execution_ctx())
145            .unwrap();
146
147        assert_arrays_eq!(actual, expected, &mut ctx);
148    }
149
150    #[test]
151    fn canonicalise_nullable_varbin() -> VortexResult<()> {
152        let reference = VarBinViewArray::from_iter(
153            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
154            DType::Utf8(Nullability::Nullable),
155        );
156        assert_eq!(reference.len(), 6);
157        let mut ctx = SESSION.create_execution_ctx();
158        let dict = dict_encode(&reference.clone().into_array(), &mut ctx)?;
159        let flattened_dict = dict.into_array().execute::<VarBinViewArray>(&mut ctx)?;
160        let flattened_mask = flattened_dict
161            .validity()?
162            .execute_mask(flattened_dict.len(), &mut ctx)?;
163        let flattened_values = (0..flattened_dict.len())
164            .map(|i| {
165                flattened_mask
166                    .value(i)
167                    .then(|| flattened_dict.bytes_at(i).to_vec())
168            })
169            .collect::<Vec<_>>();
170        let reference_mask = reference
171            .validity()?
172            .execute_mask(reference.len(), &mut ctx)?;
173        let reference_values = (0..reference.len())
174            .map(|i| {
175                reference_mask
176                    .value(i)
177                    .then(|| reference.bytes_at(i).to_vec())
178            })
179            .collect::<Vec<_>>();
180        assert_eq!(flattened_values, reference_values);
181        Ok(())
182    }
183
184    fn sliced_dict_array() -> ArrayRef {
185        let reference = PrimitiveArray::from_option_iter([
186            Some(42),
187            Some(-9),
188            None,
189            Some(42),
190            Some(1),
191            Some(5),
192        ]);
193        let dict =
194            dict_encode(&reference.into_array(), &mut SESSION.create_execution_ctx()).unwrap();
195        dict.slice(1..4).unwrap()
196    }
197
198    #[test]
199    fn compare_sliced_dict() {
200        let mut ctx = SESSION.create_execution_ctx();
201        use crate::arrays::BoolArray;
202        let sliced = sliced_dict_array();
203        let compared = sliced
204            .binary(ConstantArray::new(42, 3).into_array(), Operator::Eq)
205            .unwrap();
206
207        let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
208        assert_arrays_eq!(compared, expected.into_array(), &mut ctx);
209    }
210
211    #[test]
212    fn test_mask_dict_array() {
213        let array = dict_encode(
214            &buffer![2, 0, 2, 0, 10].into_array(),
215            &mut SESSION.create_execution_ctx(),
216        )
217        .unwrap();
218        test_mask_conformance(&array.into_array());
219
220        let array = dict_encode(
221            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
222                .into_array(),
223            &mut SESSION.create_execution_ctx(),
224        )
225        .unwrap();
226        test_mask_conformance(&array.into_array());
227
228        let array = dict_encode(
229            &VarBinArray::from_iter(
230                [
231                    Some("hello"),
232                    None,
233                    Some("hello"),
234                    Some("good"),
235                    Some("good"),
236                ],
237                DType::Utf8(Nullability::Nullable),
238            )
239            .into_array(),
240            &mut SESSION.create_execution_ctx(),
241        )
242        .unwrap();
243        test_mask_conformance(&array.into_array());
244    }
245
246    #[test]
247    fn test_filter_dict_array() {
248        let array = dict_encode(
249            &buffer![2, 0, 2, 0, 10].into_array(),
250            &mut SESSION.create_execution_ctx(),
251        )
252        .unwrap();
253        test_filter_conformance(&array.into_array());
254
255        let array = dict_encode(
256            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
257                .into_array(),
258            &mut SESSION.create_execution_ctx(),
259        )
260        .unwrap();
261        test_filter_conformance(&array.into_array());
262
263        let array = dict_encode(
264            &VarBinArray::from_iter(
265                [
266                    Some("hello"),
267                    None,
268                    Some("hello"),
269                    Some("good"),
270                    Some("good"),
271                ],
272                DType::Utf8(Nullability::Nullable),
273            )
274            .into_array(),
275            &mut SESSION.create_execution_ctx(),
276        )
277        .unwrap();
278        test_filter_conformance(&array.into_array());
279    }
280
281    #[test]
282    fn test_take_dict() {
283        let array = dict_encode(
284            &buffer![1, 2].into_array(),
285            &mut SESSION.create_execution_ctx(),
286        )
287        .unwrap();
288
289        assert_eq!(
290            array
291                .take(PrimitiveArray::from_option_iter([Option::<i32>::None]).into_array())
292                .unwrap()
293                .dtype(),
294            &DType::Primitive(I32, Nullability::Nullable)
295        );
296    }
297
298    #[test]
299    fn test_take_dict_conformance() {
300        let array = dict_encode(
301            &buffer![2, 0, 2, 0, 10].into_array(),
302            &mut SESSION.create_execution_ctx(),
303        )
304        .unwrap();
305        test_take_conformance(&array.into_array());
306
307        let array = dict_encode(
308            &PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)])
309                .into_array(),
310            &mut SESSION.create_execution_ctx(),
311        )
312        .unwrap();
313        test_take_conformance(&array.into_array());
314
315        let array = dict_encode(
316            &VarBinArray::from_iter(
317                [
318                    Some("hello"),
319                    None,
320                    Some("hello"),
321                    Some("good"),
322                    Some("good"),
323                ],
324                DType::Utf8(Nullability::Nullable),
325            )
326            .into_array(),
327            &mut SESSION.create_execution_ctx(),
328        )
329        .unwrap();
330        test_take_conformance(&array.into_array());
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use std::sync::LazyLock;
337
338    use rstest::rstest;
339    use vortex_buffer::buffer;
340    use vortex_session::VortexSession;
341
342    use crate::IntoArray;
343    use crate::VortexSessionExecute;
344    use crate::arrays::DictArray;
345    use crate::arrays::PrimitiveArray;
346    use crate::arrays::VarBinArray;
347    use crate::builders::dict::dict_encode;
348    use crate::compute::conformance::consistency::test_array_consistency;
349    use crate::dtype::DType;
350    use crate::dtype::Nullability;
351
352    static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
353
354    #[rstest]
355    // Primitive arrays
356    #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
357    #[case::dict_nullable_codes(DictArray::try_new(
358        buffer![0u32, 1, 2, 2, 0].into_array(),
359        PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
360    ).unwrap())]
361    #[case::dict_nullable_values(dict_encode(
362        &PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()
363    , &mut SESSION.create_execution_ctx()).unwrap())]
364    #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
365    // String arrays
366    #[case::dict_str(dict_encode(
367        &VarBinArray::from_iter(
368            ["hello", "world", "hello", "test", "world"].map(Some),
369            DType::Utf8(Nullability::NonNullable),
370        ).into_array()
371    , &mut SESSION.create_execution_ctx()).unwrap())]
372    #[case::dict_nullable_str(dict_encode(
373        &VarBinArray::from_iter(
374            [Some("hello"), None, Some("world"), Some("hello"), None],
375            DType::Utf8(Nullability::Nullable),
376        ).into_array()
377    , &mut SESSION.create_execution_ctx()).unwrap())]
378    // Edge cases
379    #[case::dict_single(dict_encode(&buffer![42i32].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
380    #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
381    #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array(), &mut SESSION.create_execution_ctx()).unwrap())]
382    fn test_dict_consistency(#[case] array: DictArray) {
383        test_array_consistency(&array.into_array(), &mut SESSION.create_execution_ctx());
384    }
385}