vortex_dict/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod binary_numeric;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod min_max;
11
12use vortex_array::compute::{
13    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
14};
15use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
16use vortex_error::VortexResult;
17use vortex_mask::Mask;
18
19use crate::{DictArray, DictVTable};
20
21impl TakeKernel for DictVTable {
22    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23        // TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
24        let codes = take(array.codes(), indices)?;
25        let values_dtype = array
26            .values()
27            .dtype()
28            .union_nullability(codes.dtype().nullability());
29        DictArray::try_new(codes, cast(array.values(), &values_dtype)?).map(|a| a.into_array())
30    }
31}
32
33register_kernel!(TakeKernelAdapter(DictVTable).lift());
34
35impl FilterKernel for DictVTable {
36    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
37        let codes = filter(array.codes(), mask)?;
38        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
39    }
40}
41
42register_kernel!(FilterKernelAdapter(DictVTable).lift());
43
44#[cfg(test)]
45mod test {
46    use vortex_array::accessor::ArrayAccessor;
47    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
48    use vortex_array::compute::conformance::mask::test_mask;
49    use vortex_array::compute::{Operator, compare, take};
50    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
51    use vortex_dtype::PType::I32;
52    use vortex_dtype::{DType, Nullability};
53    use vortex_scalar::Scalar;
54
55    use crate::builders::dict_encode;
56
57    #[test]
58    fn canonicalise_nullable_primitive() {
59        let values: Vec<Option<i32>> = (0..65)
60            .map(|i| match i % 3 {
61                0 => Some(42),
62                1 => Some(-9),
63                2 => None,
64                _ => unreachable!(),
65            })
66            .collect();
67
68        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
69        let actual = dict.to_primitive().unwrap();
70
71        let expected: Vec<i32> = (0..65)
72            .map(|i| match i % 3 {
73                // Compressor puts 0 as a code for invalid values which we end up using in take
74                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
75                0 | 2 => 42,
76                1 => -9,
77                _ => unreachable!(),
78            })
79            .collect();
80
81        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
82
83        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
84        assert_eq!(
85            actual.validity_mask().unwrap().true_count(),
86            expected_valid_count
87        );
88    }
89
90    #[test]
91    fn canonicalise_non_nullable_primitive_32_unique_values() {
92        let unique_values: Vec<i32> = (0..32).collect();
93        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
94
95        let dict =
96            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
97        let actual = dict.to_primitive().unwrap();
98
99        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
100    }
101
102    #[test]
103    fn canonicalise_non_nullable_primitive_100_unique_values() {
104        let unique_values: Vec<i32> = (0..100).collect();
105        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
106
107        let dict =
108            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
109        let actual = dict.to_primitive().unwrap();
110
111        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
112    }
113
114    #[test]
115    fn canonicalise_nullable_varbin() {
116        let reference = VarBinViewArray::from_iter(
117            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
118            DType::Utf8(Nullability::Nullable),
119        );
120        assert_eq!(reference.len(), 6);
121        let dict = dict_encode(reference.as_ref()).unwrap();
122        let flattened_dict = dict.to_varbinview().unwrap();
123        assert_eq!(
124            flattened_dict
125                .with_iterator(|iter| iter
126                    .map(|slice| slice.map(|s| s.to_vec()))
127                    .collect::<Vec<_>>())
128                .unwrap(),
129            reference
130                .with_iterator(|iter| iter
131                    .map(|slice| slice.map(|s| s.to_vec()))
132                    .collect::<Vec<_>>())
133                .unwrap(),
134        );
135    }
136
137    fn sliced_dict_array() -> ArrayRef {
138        let reference = PrimitiveArray::from_option_iter([
139            Some(42),
140            Some(-9),
141            None,
142            Some(42),
143            Some(1),
144            Some(5),
145        ]);
146        let dict = dict_encode(reference.as_ref()).unwrap();
147        dict.slice(1, 4).unwrap()
148    }
149
150    #[test]
151    fn compare_sliced_dict() {
152        let sliced = sliced_dict_array();
153        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
154
155        assert_eq!(
156            compared.scalar_at(0).unwrap(),
157            Scalar::bool(false, Nullability::Nullable)
158        );
159        assert_eq!(
160            compared.scalar_at(1).unwrap(),
161            Scalar::null(DType::Bool(Nullability::Nullable))
162        );
163        assert_eq!(
164            compared.scalar_at(2).unwrap(),
165            Scalar::bool(true, Nullability::Nullable)
166        );
167    }
168
169    #[test]
170    fn test_mask_dict_array() {
171        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
172        test_mask(array.as_ref());
173
174        let array = dict_encode(
175            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
176        )
177        .unwrap();
178        test_mask(array.as_ref());
179
180        let array = dict_encode(
181            &VarBinArray::from_iter(
182                [
183                    Some("hello"),
184                    None,
185                    Some("hello"),
186                    Some("good"),
187                    Some("good"),
188                ],
189                DType::Utf8(Nullability::Nullable),
190            )
191            .into_array(),
192        )
193        .unwrap();
194        test_mask(array.as_ref());
195    }
196
197    #[test]
198    fn test_take_dict() {
199        let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
200
201        assert_eq!(
202            take(
203                array.as_ref(),
204                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
205            )
206            .unwrap()
207            .dtype(),
208            &DType::Primitive(I32, Nullability::Nullable)
209        );
210    }
211}