vortex_dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20    type Array = DictArray;
21    type Encoding = DictEncoding;
22
23    type ArrayVTable = Self;
24    type CanonicalVTable = Self;
25    type OperationsVTable = Self;
26    type ValidityVTable = Self;
27    type VisitorVTable = Self;
28    type ComputeVTable = NotSupported;
29    type EncodeVTable = Self;
30    type SerdeVTable = Self;
31
32    fn id(_encoding: &Self::Encoding) -> EncodingId {
33        EncodingId::new_ref("vortex.dict")
34    }
35
36    fn encoding(_array: &Self::Array) -> EncodingRef {
37        EncodingRef::new_ref(DictEncoding.as_ref())
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct DictArray {
43    codes: ArrayRef,
44    values: ArrayRef,
45    stats_set: ArrayStats,
46}
47
48#[derive(Clone, Debug)]
49pub struct DictEncoding;
50
51impl DictArray {
52    /// Build a new `DictArray` without validating the codes or values.
53    ///
54    /// # Safety
55    /// This should be called only when you can guarantee the invariants checked
56    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
57    /// you are filtering or slicing an existing valid `DictArray`.
58    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
59        Self {
60            codes,
61            values,
62            stats_set: Default::default(),
63        }
64    }
65
66    /// Build a new `DictArray` from its components, `codes` and `values`.
67    ///
68    /// This constructor will panic if `codes` or `values` do not pass validation for building
69    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
70    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
71        Self::try_new(codes, values).vortex_expect("DictArray new")
72    }
73
74    /// Build a new `DictArray` from its components, `codes` and `values`.
75    ///
76    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
77    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
78    ///
79    /// # Errors
80    ///
81    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
82    /// of the `values` array. Otherwise, this constructor returns an error.
83    ///
84    /// It is an error to provide a nullable `codes` with non-nullable `values`.
85    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
86        if !codes.dtype().is_unsigned_int() {
87            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
88        }
89
90        let dtype = values.dtype();
91        if dtype.is_nullable() {
92            // If the values are nullable, we force codes to be nullable as well.
93            codes = cast(&codes, &codes.dtype().as_nullable())?;
94        } else {
95            // If the values are non-nullable, we assert the codes are non-nullable as well.
96            vortex_ensure!(
97                !codes.dtype().is_nullable(),
98                "Cannot have nullable codes for non-nullable dict array"
99            );
100        }
101
102        vortex_ensure!(
103            codes.dtype().nullability() == values.dtype().nullability(),
104            "Mismatched nullability between codes and values"
105        );
106
107        Ok(Self {
108            codes,
109            values,
110            stats_set: Default::default(),
111        })
112    }
113
114    #[inline]
115    pub fn codes(&self) -> &ArrayRef {
116        &self.codes
117    }
118
119    #[inline]
120    pub fn values(&self) -> &ArrayRef {
121        &self.values
122    }
123}
124
125impl ArrayVTable<DictVTable> for DictVTable {
126    fn len(array: &DictArray) -> usize {
127        array.codes.len()
128    }
129
130    fn dtype(array: &DictArray) -> &DType {
131        array.values.dtype()
132    }
133
134    fn stats(array: &DictArray) -> StatsSetRef<'_> {
135        array.stats_set.to_ref(array.as_ref())
136    }
137}
138
139impl CanonicalVTable<DictVTable> for DictVTable {
140    fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
141        match array.dtype() {
142            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
143            // decompression to construct the views child array.
144            // For this case, it is *always* faster to decompress the values first and then create
145            // copies of the view pointers.
146            DType::Utf8(_) | DType::Binary(_) => {
147                let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
148                take(&canonical_values, array.codes())?.to_canonical()
149            }
150            _ => take(array.values(), array.codes())?.to_canonical(),
151        }
152    }
153}
154
155impl ValidityVTable<DictVTable> for DictVTable {
156    fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
157        let scalar = array.codes().scalar_at(index);
158
159        if scalar.is_null() {
160            return Ok(false);
161        };
162        let values_index: usize = scalar
163            .as_ref()
164            .try_into()
165            .vortex_expect("Failed to convert dictionary code to usize");
166        array.values().is_valid(values_index)
167    }
168
169    fn all_valid(array: &DictArray) -> VortexResult<bool> {
170        Ok(array.codes().all_valid()? && array.values().all_valid()?)
171    }
172
173    fn all_invalid(array: &DictArray) -> VortexResult<bool> {
174        Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
175    }
176
177    fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
178        let codes_validity = array.codes().validity_mask()?;
179        match codes_validity.boolean_buffer() {
180            AllOr::All => {
181                let primitive_codes = array.codes().to_primitive()?;
182                let values_mask = array.values().validity_mask()?;
183                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
184                    let codes_slice = primitive_codes.as_slice::<P>();
185                    BooleanBuffer::collect_bool(array.len(), |idx| {
186                        #[allow(clippy::cast_possible_truncation)]
187                        values_mask.value(codes_slice[idx] as usize)
188                    })
189                });
190                Ok(Mask::from_buffer(is_valid_buffer))
191            }
192            AllOr::None => Ok(Mask::AllFalse(array.len())),
193            AllOr::Some(validity_buff) => {
194                let primitive_codes = array.codes().to_primitive()?;
195                let values_mask = array.values().validity_mask()?;
196                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
197                    let codes_slice = primitive_codes.as_slice::<P>();
198                    #[allow(clippy::cast_possible_truncation)]
199                    BooleanBuffer::collect_bool(array.len(), |idx| {
200                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
201                    })
202                });
203                Ok(Mask::from_buffer(is_valid_buffer))
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod test {
211    use arrow_buffer::BooleanBuffer;
212    use rand::distr::{Distribution, StandardUniform};
213    use rand::prelude::StdRng;
214    use rand::{Rng, SeedableRng};
215    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
216    use vortex_array::builders::builder_with_capacity;
217    use vortex_array::validity::Validity;
218    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
219    use vortex_buffer::buffer;
220    use vortex_dtype::Nullability::NonNullable;
221    use vortex_dtype::{DType, NativePType, PType};
222    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
223    use vortex_mask::AllOr;
224
225    use crate::DictArray;
226
227    #[test]
228    fn nullable_codes_validity() {
229        let dict = DictArray::try_new(
230            PrimitiveArray::new(
231                buffer![0u32, 1, 2, 2, 1],
232                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
233            )
234            .into_array(),
235            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
236        )
237        .unwrap();
238        let mask = dict.validity_mask().unwrap();
239        let AllOr::Some(indices) = mask.indices() else {
240            vortex_panic!("Expected indices from mask")
241        };
242        assert_eq!(indices, [0, 2, 4]);
243    }
244
245    #[test]
246    fn nullable_values_validity() {
247        let dict = DictArray::try_new(
248            buffer![0u32, 1, 2, 2, 1].into_array(),
249            PrimitiveArray::new(
250                buffer![3, 6, 9],
251                Validity::from(BooleanBuffer::from(vec![true, false, false])),
252            )
253            .into_array(),
254        )
255        .unwrap();
256        let mask = dict.validity_mask().unwrap();
257        let AllOr::Some(indices) = mask.indices() else {
258            vortex_panic!("Expected indices from mask")
259        };
260        assert_eq!(indices, [0]);
261    }
262
263    #[test]
264    fn nullable_codes_and_values() {
265        let dict = DictArray::try_new(
266            PrimitiveArray::new(
267                buffer![0u32, 1, 2, 2, 1],
268                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
269            )
270            .into_array(),
271            PrimitiveArray::new(
272                buffer![3, 6, 9],
273                Validity::from(BooleanBuffer::from(vec![false, true, true])),
274            )
275            .into_array(),
276        )
277        .unwrap();
278        let mask = dict.validity_mask().unwrap();
279        let AllOr::Some(indices) = mask.indices() else {
280            vortex_panic!("Expected indices from mask")
281        };
282        assert_eq!(indices, [2, 4]);
283    }
284
285    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
286        len: usize,
287        unique_values: usize,
288        chunk_count: usize,
289    ) -> ArrayRef
290    where
291        StandardUniform: Distribution<T>,
292    {
293        let mut rng = StdRng::seed_from_u64(0);
294
295        (0..chunk_count)
296            .map(|_| {
297                let values = (0..unique_values)
298                    .map(|_| rng.random::<T>())
299                    .collect::<PrimitiveArray>();
300                let codes = (0..len)
301                    .map(|_| {
302                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
303                    })
304                    .collect::<PrimitiveArray>();
305
306                DictArray::try_new(codes.into_array(), values.into_array())
307                    .vortex_unwrap()
308                    .into_array()
309            })
310            .collect::<ChunkedArray>()
311            .into_array()
312    }
313
314    #[test]
315    fn test_dict_array_from_primitive_chunks() {
316        let len = 2;
317        let chunk_count = 2;
318        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
319
320        let mut builder = builder_with_capacity(
321            &DType::Primitive(PType::U64, NonNullable),
322            len * chunk_count,
323        );
324        array
325            .clone()
326            .append_to_builder(builder.as_mut())
327            .vortex_unwrap();
328
329        let into_prim = array.to_primitive().unwrap();
330        let prim_into = builder.finish().to_primitive().unwrap();
331
332        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
333        assert_eq!(
334            into_prim.validity_mask().unwrap().boolean_buffer(),
335            prim_into.validity_mask().unwrap().boolean_buffer()
336        )
337    }
338}