vortex_dict/compute/
mod.rs

1mod binary_numeric;
2mod compare;
3mod fill_null;
4mod is_constant;
5mod is_sorted;
6mod like;
7mod min_max;
8
9use vortex_array::compute::{
10    FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, cast, filter, take,
11};
12use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
13use vortex_error::VortexResult;
14use vortex_mask::Mask;
15
16use crate::{DictArray, DictVTable};
17
18impl TakeKernel for DictVTable {
19    fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20        // TODO(joe): can we remove the cast and allow dict arrays to have nullable codes and values
21        let codes = take(array.codes(), indices)?;
22        let values_dtype = array
23            .values()
24            .dtype()
25            .union_nullability(codes.dtype().nullability());
26        DictArray::try_new(codes, cast(array.values(), &values_dtype)?).map(|a| a.into_array())
27    }
28}
29
30register_kernel!(TakeKernelAdapter(DictVTable).lift());
31
32impl FilterKernel for DictVTable {
33    fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
34        let codes = filter(array.codes(), mask)?;
35        DictArray::try_new(codes, array.values().clone()).map(|a| a.into_array())
36    }
37}
38
39register_kernel!(FilterKernelAdapter(DictVTable).lift());
40
41#[cfg(test)]
42mod test {
43    use vortex_array::accessor::ArrayAccessor;
44    use vortex_array::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
45    use vortex_array::compute::conformance::mask::test_mask;
46    use vortex_array::compute::{Operator, compare, take};
47    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
48    use vortex_dtype::PType::I32;
49    use vortex_dtype::{DType, Nullability};
50    use vortex_scalar::Scalar;
51
52    use crate::builders::dict_encode;
53
54    #[test]
55    fn canonicalise_nullable_primitive() {
56        let values: Vec<Option<i32>> = (0..65)
57            .map(|i| match i % 3 {
58                0 => Some(42),
59                1 => Some(-9),
60                2 => None,
61                _ => unreachable!(),
62            })
63            .collect();
64
65        let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
66        let actual = dict.to_primitive().unwrap();
67
68        let expected: Vec<i32> = (0..65)
69            .map(|i| match i % 3 {
70                // Compressor puts 0 as a code for invalid values which we end up using in take
71                // thus invalid values on decompression turn into whatever is at 0th position in dictionary
72                0 | 2 => 42,
73                1 => -9,
74                _ => unreachable!(),
75            })
76            .collect();
77
78        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
79
80        let expected_valid_count = values.iter().filter(|x| x.is_some()).count();
81        assert_eq!(
82            actual.validity_mask().unwrap().true_count(),
83            expected_valid_count
84        );
85    }
86
87    #[test]
88    fn canonicalise_non_nullable_primitive_32_unique_values() {
89        let unique_values: Vec<i32> = (0..32).collect();
90        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 32]).collect();
91
92        let dict =
93            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
94        let actual = dict.to_primitive().unwrap();
95
96        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
97    }
98
99    #[test]
100    fn canonicalise_non_nullable_primitive_100_unique_values() {
101        let unique_values: Vec<i32> = (0..100).collect();
102        let expected: Vec<i32> = (0..1000).map(|i| unique_values[i % 100]).collect();
103
104        let dict =
105            dict_encode(PrimitiveArray::from_iter(expected.iter().copied()).as_ref()).unwrap();
106        let actual = dict.to_primitive().unwrap();
107
108        assert_eq!(actual.as_slice::<i32>(), expected.as_slice());
109    }
110
111    #[test]
112    fn canonicalise_nullable_varbin() {
113        let reference = VarBinViewArray::from_iter(
114            vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
115            DType::Utf8(Nullability::Nullable),
116        );
117        assert_eq!(reference.len(), 6);
118        let dict = dict_encode(reference.as_ref()).unwrap();
119        let flattened_dict = dict.to_varbinview().unwrap();
120        assert_eq!(
121            flattened_dict
122                .with_iterator(|iter| iter
123                    .map(|slice| slice.map(|s| s.to_vec()))
124                    .collect::<Vec<_>>())
125                .unwrap(),
126            reference
127                .with_iterator(|iter| iter
128                    .map(|slice| slice.map(|s| s.to_vec()))
129                    .collect::<Vec<_>>())
130                .unwrap(),
131        );
132    }
133
134    fn sliced_dict_array() -> ArrayRef {
135        let reference = PrimitiveArray::from_option_iter([
136            Some(42),
137            Some(-9),
138            None,
139            Some(42),
140            Some(1),
141            Some(5),
142        ]);
143        let dict = dict_encode(reference.as_ref()).unwrap();
144        dict.slice(1, 4).unwrap()
145    }
146
147    #[test]
148    fn compare_sliced_dict() {
149        let sliced = sliced_dict_array();
150        let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
151
152        assert_eq!(
153            compared.scalar_at(0).unwrap(),
154            Scalar::bool(false, Nullability::Nullable)
155        );
156        assert_eq!(
157            compared.scalar_at(1).unwrap(),
158            Scalar::null(DType::Bool(Nullability::Nullable))
159        );
160        assert_eq!(
161            compared.scalar_at(2).unwrap(),
162            Scalar::bool(true, Nullability::Nullable)
163        );
164    }
165
166    #[test]
167    fn test_mask_dict_array() {
168        let array = dict_encode(&PrimitiveArray::from_iter([2, 0, 2, 0, 10]).into_array()).unwrap();
169        test_mask(array.as_ref());
170
171        let array = dict_encode(
172            PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
173        )
174        .unwrap();
175        test_mask(array.as_ref());
176
177        let array = dict_encode(
178            &VarBinArray::from_iter(
179                [
180                    Some("hello"),
181                    None,
182                    Some("hello"),
183                    Some("good"),
184                    Some("good"),
185                ],
186                DType::Utf8(Nullability::Nullable),
187            )
188            .into_array(),
189        )
190        .unwrap();
191        test_mask(array.as_ref());
192    }
193
194    #[test]
195    fn test_take_dict() {
196        let array = dict_encode(PrimitiveArray::from_iter([1, 2]).as_ref()).unwrap();
197
198        assert_eq!(
199            take(
200                array.as_ref(),
201                PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
202            )
203            .unwrap()
204            .dtype(),
205            &DType::Primitive(I32, Nullability::Nullable)
206        );
207    }
208}