vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_buffer::BitBuffer;
8use vortex_dtype::{DType, match_each_integer_ptype};
9use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
10use vortex_mask::{AllOr, Mask};
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
14use crate::{
15    Array, ArrayEq, ArrayHash, ArrayRef, EncodingId, EncodingRef, Precision, ToCanonical, vtable,
16};
17
18vtable!(Dict);
19
20impl VTable for DictVTable {
21    type Array = DictArray;
22    type Encoding = DictEncoding;
23
24    type ArrayVTable = Self;
25    type CanonicalVTable = Self;
26    type OperationsVTable = Self;
27    type ValidityVTable = Self;
28    type VisitorVTable = Self;
29    type ComputeVTable = NotSupported;
30    type EncodeVTable = Self;
31    type SerdeVTable = Self;
32    type OperatorVTable = NotSupported;
33
34    fn id(_encoding: &Self::Encoding) -> EncodingId {
35        EncodingId::new_ref("vortex.dict")
36    }
37
38    fn encoding(_array: &Self::Array) -> EncodingRef {
39        EncodingRef::new_ref(DictEncoding.as_ref())
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct DictArray {
45    codes: ArrayRef,
46    values: ArrayRef,
47    stats_set: ArrayStats,
48    dtype: DType,
49}
50
51#[derive(Clone, Debug)]
52pub struct DictEncoding;
53
54impl DictArray {
55    /// Build a new `DictArray` without validating the codes or values.
56    ///
57    /// # Safety
58    /// This should be called only when you can guarantee the invariants checked
59    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
60    /// you are filtering or slicing an existing valid `DictArray`.
61    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
62        let dtype = values
63            .dtype()
64            .union_nullability(codes.dtype().nullability());
65        Self {
66            codes,
67            values,
68            stats_set: Default::default(),
69            dtype,
70        }
71    }
72
73    /// Build a new `DictArray` from its components, `codes` and `values`.
74    ///
75    /// This constructor will panic if `codes` or `values` do not pass validation for building
76    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
77    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
78        Self::try_new(codes, values).vortex_expect("DictArray new")
79    }
80
81    /// Build a new `DictArray` from its components, `codes` and `values`.
82    ///
83    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
84    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
85    ///
86    /// # Errors
87    ///
88    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
89    /// of the `values` array. Otherwise, this constructor returns an error.
90    ///
91    /// It is an error to provide a nullable `codes` with non-nullable `values`.
92    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
93        if !codes.dtype().is_unsigned_int() {
94            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
95        }
96
97        Ok(unsafe { Self::new_unchecked(codes, values) })
98    }
99
100    #[inline]
101    pub fn codes(&self) -> &ArrayRef {
102        &self.codes
103    }
104
105    #[inline]
106    pub fn values(&self) -> &ArrayRef {
107        &self.values
108    }
109}
110
111impl ArrayVTable<DictVTable> for DictVTable {
112    fn len(array: &DictArray) -> usize {
113        array.codes.len()
114    }
115
116    fn dtype(array: &DictArray) -> &DType {
117        &array.dtype
118    }
119
120    fn stats(array: &DictArray) -> StatsSetRef<'_> {
121        array.stats_set.to_ref(array.as_ref())
122    }
123
124    fn array_hash<H: std::hash::Hasher>(array: &DictArray, state: &mut H, precision: Precision) {
125        array.dtype.hash(state);
126        array.codes.array_hash(state, precision);
127        array.values.array_hash(state, precision);
128    }
129
130    fn array_eq(array: &DictArray, other: &DictArray, precision: Precision) -> bool {
131        array.dtype == other.dtype
132            && array.codes.array_eq(&other.codes, precision)
133            && array.values.array_eq(&other.values, precision)
134    }
135}
136
137impl ValidityVTable<DictVTable> for DictVTable {
138    fn is_valid(array: &DictArray, index: usize) -> bool {
139        let scalar = array.codes().scalar_at(index);
140
141        if scalar.is_null() {
142            return false;
143        };
144        let values_index: usize = scalar
145            .as_ref()
146            .try_into()
147            .vortex_expect("Failed to convert dictionary code to usize");
148        array.values().is_valid(values_index)
149    }
150
151    fn all_valid(array: &DictArray) -> bool {
152        array.codes().all_valid() && array.values().all_valid()
153    }
154
155    fn all_invalid(array: &DictArray) -> bool {
156        array.codes().all_invalid() || array.values().all_invalid()
157    }
158
159    fn validity_mask(array: &DictArray) -> Mask {
160        let codes_validity = array.codes().validity_mask();
161        match codes_validity.bit_buffer() {
162            AllOr::All => {
163                let primitive_codes = array.codes().to_primitive();
164                let values_mask = array.values().validity_mask();
165                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
166                    let codes_slice = primitive_codes.as_slice::<P>();
167                    BitBuffer::collect_bool(array.len(), |idx| {
168                        #[allow(clippy::cast_possible_truncation)]
169                        values_mask.value(codes_slice[idx] as usize)
170                    })
171                });
172                Mask::from_buffer(is_valid_buffer)
173            }
174            AllOr::None => Mask::AllFalse(array.len()),
175            AllOr::Some(validity_buff) => {
176                let primitive_codes = array.codes().to_primitive();
177                let values_mask = array.values().validity_mask();
178                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
179                    let codes_slice = primitive_codes.as_slice::<P>();
180                    #[allow(clippy::cast_possible_truncation)]
181                    BitBuffer::collect_bool(array.len(), |idx| {
182                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
183                    })
184                });
185                Mask::from_buffer(is_valid_buffer)
186            }
187        }
188    }
189}
190
191#[cfg(test)]
192mod test {
193    #[allow(unused_imports)]
194    use itertools::Itertools;
195    use rand::distr::{Distribution, StandardUniform};
196    use rand::prelude::StdRng;
197    use rand::{Rng, SeedableRng};
198    use vortex_buffer::{BitBuffer, buffer};
199    use vortex_dtype::Nullability::NonNullable;
200    use vortex_dtype::{DType, NativePType, PType, UnsignedPType};
201    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
202    use vortex_mask::AllOr;
203
204    use crate::arrays::dict::DictArray;
205    use crate::arrays::{ChunkedArray, PrimitiveArray};
206    use crate::builders::builder_with_capacity;
207    use crate::validity::Validity;
208    use crate::{Array, ArrayRef, IntoArray, ToCanonical, assert_arrays_eq};
209
210    #[test]
211    fn nullable_codes_validity() {
212        let dict = DictArray::try_new(
213            PrimitiveArray::new(
214                buffer![0u32, 1, 2, 2, 1],
215                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
216            )
217            .into_array(),
218            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
219        )
220        .unwrap();
221        let mask = dict.validity_mask();
222        let AllOr::Some(indices) = mask.indices() else {
223            vortex_panic!("Expected indices from mask")
224        };
225        assert_eq!(indices, [0, 2, 4]);
226    }
227
228    #[test]
229    fn nullable_values_validity() {
230        let dict = DictArray::try_new(
231            buffer![0u32, 1, 2, 2, 1].into_array(),
232            PrimitiveArray::new(
233                buffer![3, 6, 9],
234                Validity::from(BitBuffer::from(vec![true, false, false])),
235            )
236            .into_array(),
237        )
238        .unwrap();
239        let mask = dict.validity_mask();
240        let AllOr::Some(indices) = mask.indices() else {
241            vortex_panic!("Expected indices from mask")
242        };
243        assert_eq!(indices, [0]);
244    }
245
246    #[test]
247    fn nullable_codes_and_values() {
248        let dict = DictArray::try_new(
249            PrimitiveArray::new(
250                buffer![0u32, 1, 2, 2, 1],
251                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
252            )
253            .into_array(),
254            PrimitiveArray::new(
255                buffer![3, 6, 9],
256                Validity::from(BitBuffer::from(vec![false, true, true])),
257            )
258            .into_array(),
259        )
260        .unwrap();
261        let mask = dict.validity_mask();
262        let AllOr::Some(indices) = mask.indices() else {
263            vortex_panic!("Expected indices from mask")
264        };
265        assert_eq!(indices, [2, 4]);
266    }
267
268    #[test]
269    fn nullable_codes_and_non_null_values() {
270        let dict = DictArray::try_new(
271            PrimitiveArray::new(
272                buffer![0u32, 1, 2, 2, 1],
273                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
274            )
275            .into_array(),
276            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
277        )
278        .unwrap();
279        let mask = dict.validity_mask();
280        let AllOr::Some(indices) = mask.indices() else {
281            vortex_panic!("Expected indices from mask")
282        };
283        assert_eq!(indices, [0, 2, 4]);
284    }
285
286    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
287        len: usize,
288        unique_values: usize,
289        chunk_count: usize,
290    ) -> ArrayRef
291    where
292        StandardUniform: Distribution<T>,
293    {
294        let mut rng = StdRng::seed_from_u64(0);
295
296        (0..chunk_count)
297            .map(|_| {
298                let values = (0..unique_values)
299                    .map(|_| rng.random::<T>())
300                    .collect::<PrimitiveArray>();
301                let codes = (0..len)
302                    .map(|_| {
303                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
304                    })
305                    .collect::<PrimitiveArray>();
306
307                DictArray::try_new(codes.into_array(), values.into_array())
308                    .vortex_unwrap()
309                    .into_array()
310            })
311            .collect::<ChunkedArray>()
312            .into_array()
313    }
314
315    #[test]
316    fn test_dict_array_from_primitive_chunks() {
317        let len = 2;
318        let chunk_count = 2;
319        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
320
321        let mut builder = builder_with_capacity(
322            &DType::Primitive(PType::U64, NonNullable),
323            len * chunk_count,
324        );
325        array.clone().append_to_builder(builder.as_mut());
326
327        let into_prim = array.to_primitive();
328        let prim_into = builder.finish_into_canonical().into_primitive();
329
330        assert_arrays_eq!(into_prim, prim_into);
331    }
332}