vortex_dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::stats::{ArrayStats, StatsSetRef};
8use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
9use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, ToCanonical, vtable};
10use vortex_dtype::{DType, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13
14vtable!(Dict);
15
16impl VTable for DictVTable {
17    type Array = DictArray;
18    type Encoding = DictEncoding;
19
20    type ArrayVTable = Self;
21    type CanonicalVTable = Self;
22    type OperationsVTable = Self;
23    type ValidityVTable = Self;
24    type VisitorVTable = Self;
25    type ComputeVTable = NotSupported;
26    type EncodeVTable = Self;
27    type SerdeVTable = Self;
28    type PipelineVTable = NotSupported;
29
30    fn id(_encoding: &Self::Encoding) -> EncodingId {
31        EncodingId::new_ref("vortex.dict")
32    }
33
34    fn encoding(_array: &Self::Array) -> EncodingRef {
35        EncodingRef::new_ref(DictEncoding.as_ref())
36    }
37}
38
39#[derive(Debug, Clone)]
40pub struct DictArray {
41    codes: ArrayRef,
42    values: ArrayRef,
43    stats_set: ArrayStats,
44    dtype: DType,
45}
46
47#[derive(Clone, Debug)]
48pub struct DictEncoding;
49
50impl DictArray {
51    /// Build a new `DictArray` without validating the codes or values.
52    ///
53    /// # Safety
54    /// This should be called only when you can guarantee the invariants checked
55    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
56    /// you are filtering or slicing an existing valid `DictArray`.
57    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
58        let dtype = values
59            .dtype()
60            .union_nullability(codes.dtype().nullability());
61        Self {
62            codes,
63            values,
64            stats_set: Default::default(),
65            dtype,
66        }
67    }
68
69    /// Build a new `DictArray` from its components, `codes` and `values`.
70    ///
71    /// This constructor will panic if `codes` or `values` do not pass validation for building
72    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
73    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
74        Self::try_new(codes, values).vortex_expect("DictArray new")
75    }
76
77    /// Build a new `DictArray` from its components, `codes` and `values`.
78    ///
79    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
80    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
81    ///
82    /// # Errors
83    ///
84    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
85    /// of the `values` array. Otherwise, this constructor returns an error.
86    ///
87    /// It is an error to provide a nullable `codes` with non-nullable `values`.
88    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
89        if !codes.dtype().is_unsigned_int() {
90            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
91        }
92
93        Ok(unsafe { Self::new_unchecked(codes, values) })
94    }
95
96    #[inline]
97    pub fn codes(&self) -> &ArrayRef {
98        &self.codes
99    }
100
101    #[inline]
102    pub fn values(&self) -> &ArrayRef {
103        &self.values
104    }
105}
106
107impl ArrayVTable<DictVTable> for DictVTable {
108    fn len(array: &DictArray) -> usize {
109        array.codes.len()
110    }
111
112    fn dtype(array: &DictArray) -> &DType {
113        &array.dtype
114    }
115
116    fn stats(array: &DictArray) -> StatsSetRef<'_> {
117        array.stats_set.to_ref(array.as_ref())
118    }
119}
120
121impl ValidityVTable<DictVTable> for DictVTable {
122    fn is_valid(array: &DictArray, index: usize) -> bool {
123        let scalar = array.codes().scalar_at(index);
124
125        if scalar.is_null() {
126            return false;
127        };
128        let values_index: usize = scalar
129            .as_ref()
130            .try_into()
131            .vortex_expect("Failed to convert dictionary code to usize");
132        array.values().is_valid(values_index)
133    }
134
135    fn all_valid(array: &DictArray) -> bool {
136        array.codes().all_valid() && array.values().all_valid()
137    }
138
139    fn all_invalid(array: &DictArray) -> bool {
140        array.codes().all_invalid() || array.values().all_invalid()
141    }
142
143    fn validity_mask(array: &DictArray) -> Mask {
144        let codes_validity = array.codes().validity_mask();
145        match codes_validity.boolean_buffer() {
146            AllOr::All => {
147                let primitive_codes = array.codes().to_primitive();
148                let values_mask = array.values().validity_mask();
149                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
150                    let codes_slice = primitive_codes.as_slice::<P>();
151                    BooleanBuffer::collect_bool(array.len(), |idx| {
152                        #[allow(clippy::cast_possible_truncation)]
153                        values_mask.value(codes_slice[idx] as usize)
154                    })
155                });
156                Mask::from_buffer(is_valid_buffer)
157            }
158            AllOr::None => Mask::AllFalse(array.len()),
159            AllOr::Some(validity_buff) => {
160                let primitive_codes = array.codes().to_primitive();
161                let values_mask = array.values().validity_mask();
162                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
163                    let codes_slice = primitive_codes.as_slice::<P>();
164                    #[allow(clippy::cast_possible_truncation)]
165                    BooleanBuffer::collect_bool(array.len(), |idx| {
166                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
167                    })
168                });
169                Mask::from_buffer(is_valid_buffer)
170            }
171        }
172    }
173}
174
175#[cfg(test)]
176mod test {
177    use arrow_buffer::BooleanBuffer;
178    use rand::distr::{Distribution, StandardUniform};
179    use rand::prelude::StdRng;
180    use rand::{Rng, SeedableRng};
181    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
182    use vortex_array::builders::builder_with_capacity;
183    use vortex_array::validity::Validity;
184    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
185    use vortex_buffer::buffer;
186    use vortex_dtype::Nullability::NonNullable;
187    use vortex_dtype::{DType, NativePType, PType, UnsignedPType};
188    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
189    use vortex_mask::AllOr;
190
191    use crate::DictArray;
192
193    #[test]
194    fn nullable_codes_validity() {
195        let dict = DictArray::try_new(
196            PrimitiveArray::new(
197                buffer![0u32, 1, 2, 2, 1],
198                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
199            )
200            .into_array(),
201            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
202        )
203        .unwrap();
204        let mask = dict.validity_mask();
205        let AllOr::Some(indices) = mask.indices() else {
206            vortex_panic!("Expected indices from mask")
207        };
208        assert_eq!(indices, [0, 2, 4]);
209    }
210
211    #[test]
212    fn nullable_values_validity() {
213        let dict = DictArray::try_new(
214            buffer![0u32, 1, 2, 2, 1].into_array(),
215            PrimitiveArray::new(
216                buffer![3, 6, 9],
217                Validity::from(BooleanBuffer::from(vec![true, false, false])),
218            )
219            .into_array(),
220        )
221        .unwrap();
222        let mask = dict.validity_mask();
223        let AllOr::Some(indices) = mask.indices() else {
224            vortex_panic!("Expected indices from mask")
225        };
226        assert_eq!(indices, [0]);
227    }
228
229    #[test]
230    fn nullable_codes_and_values() {
231        let dict = DictArray::try_new(
232            PrimitiveArray::new(
233                buffer![0u32, 1, 2, 2, 1],
234                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
235            )
236            .into_array(),
237            PrimitiveArray::new(
238                buffer![3, 6, 9],
239                Validity::from(BooleanBuffer::from(vec![false, true, true])),
240            )
241            .into_array(),
242        )
243        .unwrap();
244        let mask = dict.validity_mask();
245        let AllOr::Some(indices) = mask.indices() else {
246            vortex_panic!("Expected indices from mask")
247        };
248        assert_eq!(indices, [2, 4]);
249    }
250
251    #[test]
252    fn nullable_codes_and_non_null_values() {
253        let dict = DictArray::try_new(
254            PrimitiveArray::new(
255                buffer![0u32, 1, 2, 2, 1],
256                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
257            )
258            .into_array(),
259            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
260        )
261        .unwrap();
262        let mask = dict.validity_mask();
263        let AllOr::Some(indices) = mask.indices() else {
264            vortex_panic!("Expected indices from mask")
265        };
266        assert_eq!(indices, [0, 2, 4]);
267    }
268
269    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
270        len: usize,
271        unique_values: usize,
272        chunk_count: usize,
273    ) -> ArrayRef
274    where
275        StandardUniform: Distribution<T>,
276    {
277        let mut rng = StdRng::seed_from_u64(0);
278
279        (0..chunk_count)
280            .map(|_| {
281                let values = (0..unique_values)
282                    .map(|_| rng.random::<T>())
283                    .collect::<PrimitiveArray>();
284                let codes = (0..len)
285                    .map(|_| {
286                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
287                    })
288                    .collect::<PrimitiveArray>();
289
290                DictArray::try_new(codes.into_array(), values.into_array())
291                    .vortex_unwrap()
292                    .into_array()
293            })
294            .collect::<ChunkedArray>()
295            .into_array()
296    }
297
298    #[test]
299    fn test_dict_array_from_primitive_chunks() {
300        let len = 2;
301        let chunk_count = 2;
302        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
303
304        let mut builder = builder_with_capacity(
305            &DType::Primitive(PType::U64, NonNullable),
306            len * chunk_count,
307        );
308        array.clone().append_to_builder(builder.as_mut());
309
310        let into_prim = array.to_primitive();
311        let prim_into = builder.finish_into_canonical().into_primitive();
312
313        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
314        assert_eq!(
315            into_prim.validity_mask().boolean_buffer(),
316            prim_into.validity_mask().boolean_buffer()
317        )
318    }
319}