vortex_dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20    type Array = DictArray;
21    type Encoding = DictEncoding;
22
23    type ArrayVTable = Self;
24    type CanonicalVTable = Self;
25    type OperationsVTable = Self;
26    type ValidityVTable = Self;
27    type VisitorVTable = Self;
28    type ComputeVTable = NotSupported;
29    type EncodeVTable = Self;
30    type SerdeVTable = Self;
31    type PipelineVTable = NotSupported;
32
33    fn id(_encoding: &Self::Encoding) -> EncodingId {
34        EncodingId::new_ref("vortex.dict")
35    }
36
37    fn encoding(_array: &Self::Array) -> EncodingRef {
38        EncodingRef::new_ref(DictEncoding.as_ref())
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct DictArray {
44    codes: ArrayRef,
45    values: ArrayRef,
46    stats_set: ArrayStats,
47}
48
49#[derive(Clone, Debug)]
50pub struct DictEncoding;
51
52impl DictArray {
53    /// Build a new `DictArray` without validating the codes or values.
54    ///
55    /// # Safety
56    /// This should be called only when you can guarantee the invariants checked
57    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
58    /// you are filtering or slicing an existing valid `DictArray`.
59    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
60        Self {
61            codes,
62            values,
63            stats_set: Default::default(),
64        }
65    }
66
67    /// Build a new `DictArray` from its components, `codes` and `values`.
68    ///
69    /// This constructor will panic if `codes` or `values` do not pass validation for building
70    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
71    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
72        Self::try_new(codes, values).vortex_expect("DictArray new")
73    }
74
75    /// Build a new `DictArray` from its components, `codes` and `values`.
76    ///
77    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
78    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
79    ///
80    /// # Errors
81    ///
82    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
83    /// of the `values` array. Otherwise, this constructor returns an error.
84    ///
85    /// It is an error to provide a nullable `codes` with non-nullable `values`.
86    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
87        if !codes.dtype().is_unsigned_int() {
88            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
89        }
90
91        let dtype = values.dtype();
92        if dtype.is_nullable() {
93            // If the values are nullable, we force codes to be nullable as well.
94            codes = cast(&codes, &codes.dtype().as_nullable())?;
95        } else {
96            // If the values are non-nullable, we assert the codes are non-nullable as well.
97            vortex_ensure!(
98                !codes.dtype().is_nullable(),
99                "Cannot have nullable codes for non-nullable dict array"
100            );
101        }
102
103        vortex_ensure!(
104            codes.dtype().nullability() == values.dtype().nullability(),
105            "Mismatched nullability between codes and values"
106        );
107
108        Ok(Self {
109            codes,
110            values,
111            stats_set: Default::default(),
112        })
113    }
114
115    #[inline]
116    pub fn codes(&self) -> &ArrayRef {
117        &self.codes
118    }
119
120    #[inline]
121    pub fn values(&self) -> &ArrayRef {
122        &self.values
123    }
124}
125
126impl ArrayVTable<DictVTable> for DictVTable {
127    fn len(array: &DictArray) -> usize {
128        array.codes.len()
129    }
130
131    fn dtype(array: &DictArray) -> &DType {
132        array.values.dtype()
133    }
134
135    fn stats(array: &DictArray) -> StatsSetRef<'_> {
136        array.stats_set.to_ref(array.as_ref())
137    }
138}
139
140impl CanonicalVTable<DictVTable> for DictVTable {
141    fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
142        match array.dtype() {
143            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
144            // decompression to construct the views child array.
145            // For this case, it is *always* faster to decompress the values first and then create
146            // copies of the view pointers.
147            DType::Utf8(_) | DType::Binary(_) => {
148                let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
149                take(&canonical_values, array.codes())?.to_canonical()
150            }
151            _ => take(array.values(), array.codes())?.to_canonical(),
152        }
153    }
154}
155
156impl ValidityVTable<DictVTable> for DictVTable {
157    fn is_valid(array: &DictArray, index: usize) -> bool {
158        let scalar = array.codes().scalar_at(index);
159
160        if scalar.is_null() {
161            return false;
162        };
163        let values_index: usize = scalar
164            .as_ref()
165            .try_into()
166            .vortex_expect("Failed to convert dictionary code to usize");
167        array.values().is_valid(values_index)
168    }
169
170    fn all_valid(array: &DictArray) -> bool {
171        array.codes().all_valid() && array.values().all_valid()
172    }
173
174    fn all_invalid(array: &DictArray) -> bool {
175        array.codes().all_invalid() || array.values().all_invalid()
176    }
177
178    fn validity_mask(array: &DictArray) -> Mask {
179        let codes_validity = array.codes().validity_mask();
180        match codes_validity.boolean_buffer() {
181            AllOr::All => {
182                let primitive_codes = array
183                    .codes()
184                    .to_primitive()
185                    .vortex_expect("dict codes must be primitive");
186                let values_mask = array.values().validity_mask();
187                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
188                    let codes_slice = primitive_codes.as_slice::<P>();
189                    BooleanBuffer::collect_bool(array.len(), |idx| {
190                        #[allow(clippy::cast_possible_truncation)]
191                        values_mask.value(codes_slice[idx] as usize)
192                    })
193                });
194                Mask::from_buffer(is_valid_buffer)
195            }
196            AllOr::None => Mask::AllFalse(array.len()),
197            AllOr::Some(validity_buff) => {
198                let primitive_codes = array
199                    .codes()
200                    .to_primitive()
201                    .vortex_expect("dict codes must be primitive");
202                let values_mask = array.values().validity_mask();
203                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
204                    let codes_slice = primitive_codes.as_slice::<P>();
205                    #[allow(clippy::cast_possible_truncation)]
206                    BooleanBuffer::collect_bool(array.len(), |idx| {
207                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
208                    })
209                });
210                Mask::from_buffer(is_valid_buffer)
211            }
212        }
213    }
214}
215
216#[cfg(test)]
217mod test {
218    use arrow_buffer::BooleanBuffer;
219    use rand::distr::{Distribution, StandardUniform};
220    use rand::prelude::StdRng;
221    use rand::{Rng, SeedableRng};
222    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
223    use vortex_array::builders::builder_with_capacity;
224    use vortex_array::validity::Validity;
225    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
226    use vortex_buffer::buffer;
227    use vortex_dtype::Nullability::NonNullable;
228    use vortex_dtype::{DType, NativePType, PType};
229    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
230    use vortex_mask::AllOr;
231
232    use crate::DictArray;
233
234    #[test]
235    fn nullable_codes_validity() {
236        let dict = DictArray::try_new(
237            PrimitiveArray::new(
238                buffer![0u32, 1, 2, 2, 1],
239                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
240            )
241            .into_array(),
242            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
243        )
244        .unwrap();
245        let mask = dict.validity_mask();
246        let AllOr::Some(indices) = mask.indices() else {
247            vortex_panic!("Expected indices from mask")
248        };
249        assert_eq!(indices, [0, 2, 4]);
250    }
251
252    #[test]
253    fn nullable_values_validity() {
254        let dict = DictArray::try_new(
255            buffer![0u32, 1, 2, 2, 1].into_array(),
256            PrimitiveArray::new(
257                buffer![3, 6, 9],
258                Validity::from(BooleanBuffer::from(vec![true, false, false])),
259            )
260            .into_array(),
261        )
262        .unwrap();
263        let mask = dict.validity_mask();
264        let AllOr::Some(indices) = mask.indices() else {
265            vortex_panic!("Expected indices from mask")
266        };
267        assert_eq!(indices, [0]);
268    }
269
270    #[test]
271    fn nullable_codes_and_values() {
272        let dict = DictArray::try_new(
273            PrimitiveArray::new(
274                buffer![0u32, 1, 2, 2, 1],
275                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
276            )
277            .into_array(),
278            PrimitiveArray::new(
279                buffer![3, 6, 9],
280                Validity::from(BooleanBuffer::from(vec![false, true, true])),
281            )
282            .into_array(),
283        )
284        .unwrap();
285        let mask = dict.validity_mask();
286        let AllOr::Some(indices) = mask.indices() else {
287            vortex_panic!("Expected indices from mask")
288        };
289        assert_eq!(indices, [2, 4]);
290    }
291
292    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
293        len: usize,
294        unique_values: usize,
295        chunk_count: usize,
296    ) -> ArrayRef
297    where
298        StandardUniform: Distribution<T>,
299    {
300        let mut rng = StdRng::seed_from_u64(0);
301
302        (0..chunk_count)
303            .map(|_| {
304                let values = (0..unique_values)
305                    .map(|_| rng.random::<T>())
306                    .collect::<PrimitiveArray>();
307                let codes = (0..len)
308                    .map(|_| {
309                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
310                    })
311                    .collect::<PrimitiveArray>();
312
313                DictArray::try_new(codes.into_array(), values.into_array())
314                    .vortex_unwrap()
315                    .into_array()
316            })
317            .collect::<ChunkedArray>()
318            .into_array()
319    }
320
321    #[test]
322    fn test_dict_array_from_primitive_chunks() {
323        let len = 2;
324        let chunk_count = 2;
325        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
326
327        let mut builder = builder_with_capacity(
328            &DType::Primitive(PType::U64, NonNullable),
329            len * chunk_count,
330        );
331        array
332            .clone()
333            .append_to_builder(builder.as_mut())
334            .vortex_unwrap();
335
336        let into_prim = array.to_primitive().unwrap();
337        let prim_into = builder.finish().to_primitive().unwrap();
338
339        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
340        assert_eq!(
341            into_prim.validity_mask().boolean_buffer(),
342            prim_into.validity_mask().boolean_buffer()
343        )
344    }
345}