vortex_dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::take;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20    type Array = DictArray;
21    type Encoding = DictEncoding;
22
23    type ArrayVTable = Self;
24    type CanonicalVTable = Self;
25    type OperationsVTable = Self;
26    type ValidityVTable = Self;
27    type VisitorVTable = Self;
28    type ComputeVTable = NotSupported;
29    type EncodeVTable = Self;
30    type SerdeVTable = Self;
31    type PipelineVTable = NotSupported;
32
33    fn id(_encoding: &Self::Encoding) -> EncodingId {
34        EncodingId::new_ref("vortex.dict")
35    }
36
37    fn encoding(_array: &Self::Array) -> EncodingRef {
38        EncodingRef::new_ref(DictEncoding.as_ref())
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct DictArray {
44    codes: ArrayRef,
45    values: ArrayRef,
46    stats_set: ArrayStats,
47}
48
49#[derive(Clone, Debug)]
50pub struct DictEncoding;
51
52impl DictArray {
53    /// Build a new `DictArray` without validating the codes or values.
54    ///
55    /// # Safety
56    /// This should be called only when you can guarantee the invariants checked
57    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
58    /// you are filtering or slicing an existing valid `DictArray`.
59    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
60        Self {
61            codes,
62            values,
63            stats_set: Default::default(),
64        }
65    }
66
67    /// Build a new `DictArray` from its components, `codes` and `values`.
68    ///
69    /// This constructor will panic if `codes` or `values` do not pass validation for building
70    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
71    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
72        Self::try_new(codes, values).vortex_expect("DictArray new")
73    }
74
75    /// Build a new `DictArray` from its components, `codes` and `values`.
76    ///
77    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
78    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
79    ///
80    /// # Errors
81    ///
82    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
83    /// of the `values` array. Otherwise, this constructor returns an error.
84    ///
85    /// It is an error to provide a nullable `codes` with non-nullable `values`.
86    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
87        if !codes.dtype().is_unsigned_int() {
88            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
89        }
90
91        Ok(Self {
92            codes,
93            values,
94            stats_set: Default::default(),
95        })
96    }
97
98    #[inline]
99    pub fn codes(&self) -> &ArrayRef {
100        &self.codes
101    }
102
103    #[inline]
104    pub fn values(&self) -> &ArrayRef {
105        &self.values
106    }
107}
108
109impl ArrayVTable<DictVTable> for DictVTable {
110    fn len(array: &DictArray) -> usize {
111        array.codes.len()
112    }
113
114    fn dtype(array: &DictArray) -> &DType {
115        array.values.dtype()
116    }
117
118    fn stats(array: &DictArray) -> StatsSetRef<'_> {
119        array.stats_set.to_ref(array.as_ref())
120    }
121}
122
123impl CanonicalVTable<DictVTable> for DictVTable {
124    fn canonicalize(array: &DictArray) -> Canonical {
125        match array.dtype() {
126            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
127            // decompression to construct the views child array.
128            // For this case, it is *always* faster to decompress the values first and then create
129            // copies of the view pointers.
130            DType::Utf8(_) | DType::Binary(_) => {
131                let canonical_values: ArrayRef = array.values().to_canonical().into_array();
132                take(&canonical_values, array.codes())
133                    .vortex_expect("taking codes from dictionary values shouldn't fail")
134                    .to_canonical()
135            }
136            _ => take(array.values(), array.codes())
137                .vortex_expect("taking codes from dictionary values shouldn't fail")
138                .to_canonical(),
139        }
140    }
141}
142
143impl ValidityVTable<DictVTable> for DictVTable {
144    fn is_valid(array: &DictArray, index: usize) -> bool {
145        let scalar = array.codes().scalar_at(index);
146
147        if scalar.is_null() {
148            return false;
149        };
150        let values_index: usize = scalar
151            .as_ref()
152            .try_into()
153            .vortex_expect("Failed to convert dictionary code to usize");
154        array.values().is_valid(values_index)
155    }
156
157    fn all_valid(array: &DictArray) -> bool {
158        array.codes().all_valid() && array.values().all_valid()
159    }
160
161    fn all_invalid(array: &DictArray) -> bool {
162        array.codes().all_invalid() || array.values().all_invalid()
163    }
164
165    fn validity_mask(array: &DictArray) -> Mask {
166        let codes_validity = array.codes().validity_mask();
167        match codes_validity.boolean_buffer() {
168            AllOr::All => {
169                let primitive_codes = array.codes().to_primitive();
170                let values_mask = array.values().validity_mask();
171                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
172                    let codes_slice = primitive_codes.as_slice::<P>();
173                    BooleanBuffer::collect_bool(array.len(), |idx| {
174                        #[allow(clippy::cast_possible_truncation)]
175                        values_mask.value(codes_slice[idx] as usize)
176                    })
177                });
178                Mask::from_buffer(is_valid_buffer)
179            }
180            AllOr::None => Mask::AllFalse(array.len()),
181            AllOr::Some(validity_buff) => {
182                let primitive_codes = array.codes().to_primitive();
183                let values_mask = array.values().validity_mask();
184                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
185                    let codes_slice = primitive_codes.as_slice::<P>();
186                    #[allow(clippy::cast_possible_truncation)]
187                    BooleanBuffer::collect_bool(array.len(), |idx| {
188                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
189                    })
190                });
191                Mask::from_buffer(is_valid_buffer)
192            }
193        }
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use arrow_buffer::BooleanBuffer;
200    use rand::distr::{Distribution, StandardUniform};
201    use rand::prelude::StdRng;
202    use rand::{Rng, SeedableRng};
203    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
204    use vortex_array::builders::builder_with_capacity;
205    use vortex_array::validity::Validity;
206    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
207    use vortex_buffer::buffer;
208    use vortex_dtype::Nullability::NonNullable;
209    use vortex_dtype::{DType, NativePType, PType};
210    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
211    use vortex_mask::AllOr;
212
213    use crate::DictArray;
214
215    #[test]
216    fn nullable_codes_validity() {
217        let dict = DictArray::try_new(
218            PrimitiveArray::new(
219                buffer![0u32, 1, 2, 2, 1],
220                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
221            )
222            .into_array(),
223            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
224        )
225        .unwrap();
226        let mask = dict.validity_mask();
227        let AllOr::Some(indices) = mask.indices() else {
228            vortex_panic!("Expected indices from mask")
229        };
230        assert_eq!(indices, [0, 2, 4]);
231    }
232
233    #[test]
234    fn nullable_values_validity() {
235        let dict = DictArray::try_new(
236            buffer![0u32, 1, 2, 2, 1].into_array(),
237            PrimitiveArray::new(
238                buffer![3, 6, 9],
239                Validity::from(BooleanBuffer::from(vec![true, false, false])),
240            )
241            .into_array(),
242        )
243        .unwrap();
244        let mask = dict.validity_mask();
245        let AllOr::Some(indices) = mask.indices() else {
246            vortex_panic!("Expected indices from mask")
247        };
248        assert_eq!(indices, [0]);
249    }
250
251    #[test]
252    fn nullable_codes_and_values() {
253        let dict = DictArray::try_new(
254            PrimitiveArray::new(
255                buffer![0u32, 1, 2, 2, 1],
256                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
257            )
258            .into_array(),
259            PrimitiveArray::new(
260                buffer![3, 6, 9],
261                Validity::from(BooleanBuffer::from(vec![false, true, true])),
262            )
263            .into_array(),
264        )
265        .unwrap();
266        let mask = dict.validity_mask();
267        let AllOr::Some(indices) = mask.indices() else {
268            vortex_panic!("Expected indices from mask")
269        };
270        assert_eq!(indices, [2, 4]);
271    }
272
273    #[test]
274    fn nullable_codes_and_non_null_values() {
275        let dict = DictArray::try_new(
276            PrimitiveArray::new(
277                buffer![0u32, 1, 2, 2, 1],
278                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
279            )
280            .into_array(),
281            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
282        )
283        .unwrap();
284        let mask = dict.validity_mask();
285        let AllOr::Some(indices) = mask.indices() else {
286            vortex_panic!("Expected indices from mask")
287        };
288        assert_eq!(indices, [0, 2, 4]);
289    }
290
291    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
292        len: usize,
293        unique_values: usize,
294        chunk_count: usize,
295    ) -> ArrayRef
296    where
297        StandardUniform: Distribution<T>,
298    {
299        let mut rng = StdRng::seed_from_u64(0);
300
301        (0..chunk_count)
302            .map(|_| {
303                let values = (0..unique_values)
304                    .map(|_| rng.random::<T>())
305                    .collect::<PrimitiveArray>();
306                let codes = (0..len)
307                    .map(|_| {
308                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
309                    })
310                    .collect::<PrimitiveArray>();
311
312                DictArray::try_new(codes.into_array(), values.into_array())
313                    .vortex_unwrap()
314                    .into_array()
315            })
316            .collect::<ChunkedArray>()
317            .into_array()
318    }
319
320    #[test]
321    fn test_dict_array_from_primitive_chunks() {
322        let len = 2;
323        let chunk_count = 2;
324        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
325
326        let mut builder = builder_with_capacity(
327            &DType::Primitive(PType::U64, NonNullable),
328            len * chunk_count,
329        );
330        array.clone().append_to_builder(builder.as_mut());
331
332        let into_prim = array.to_primitive();
333        let prim_into = builder.finish_into_canonical().into_primitive();
334
335        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
336        assert_eq!(
337            into_prim.validity_mask().boolean_buffer(),
338            prim_into.validity_mask().boolean_buffer()
339        )
340    }
341}