vortex_dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20    type Array = DictArray;
21    type Encoding = DictEncoding;
22
23    type ArrayVTable = Self;
24    type CanonicalVTable = Self;
25    type OperationsVTable = Self;
26    type ValidityVTable = Self;
27    type VisitorVTable = Self;
28    type ComputeVTable = NotSupported;
29    type EncodeVTable = Self;
30    type SerdeVTable = Self;
31    type PipelineVTable = NotSupported;
32
33    fn id(_encoding: &Self::Encoding) -> EncodingId {
34        EncodingId::new_ref("vortex.dict")
35    }
36
37    fn encoding(_array: &Self::Array) -> EncodingRef {
38        EncodingRef::new_ref(DictEncoding.as_ref())
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct DictArray {
44    codes: ArrayRef,
45    values: ArrayRef,
46    stats_set: ArrayStats,
47}
48
49#[derive(Clone, Debug)]
50pub struct DictEncoding;
51
52impl DictArray {
53    /// Build a new `DictArray` without validating the codes or values.
54    ///
55    /// # Safety
56    /// This should be called only when you can guarantee the invariants checked
57    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
58    /// you are filtering or slicing an existing valid `DictArray`.
59    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
60        Self {
61            codes,
62            values,
63            stats_set: Default::default(),
64        }
65    }
66
67    /// Build a new `DictArray` from its components, `codes` and `values`.
68    ///
69    /// This constructor will panic if `codes` or `values` do not pass validation for building
70    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
71    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
72        Self::try_new(codes, values).vortex_expect("DictArray new")
73    }
74
75    /// Build a new `DictArray` from its components, `codes` and `values`.
76    ///
77    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
78    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
79    ///
80    /// # Errors
81    ///
82    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
83    /// of the `values` array. Otherwise, this constructor returns an error.
84    ///
85    /// It is an error to provide a nullable `codes` with non-nullable `values`.
86    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
87        if !codes.dtype().is_unsigned_int() {
88            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
89        }
90
91        let dtype = values.dtype();
92        if dtype.is_nullable() {
93            // If the values are nullable, we force codes to be nullable as well.
94            codes = cast(&codes, &codes.dtype().as_nullable())?;
95        } else {
96            // If the values are non-nullable, we assert the codes are non-nullable as well.
97            vortex_ensure!(
98                !codes.dtype().is_nullable(),
99                "Cannot have nullable codes for non-nullable dict array"
100            );
101        }
102
103        vortex_ensure!(
104            codes.dtype().nullability() == values.dtype().nullability(),
105            "Mismatched nullability between codes and values"
106        );
107
108        Ok(Self {
109            codes,
110            values,
111            stats_set: Default::default(),
112        })
113    }
114
115    #[inline]
116    pub fn codes(&self) -> &ArrayRef {
117        &self.codes
118    }
119
120    #[inline]
121    pub fn values(&self) -> &ArrayRef {
122        &self.values
123    }
124}
125
126impl ArrayVTable<DictVTable> for DictVTable {
127    fn len(array: &DictArray) -> usize {
128        array.codes.len()
129    }
130
131    fn dtype(array: &DictArray) -> &DType {
132        array.values.dtype()
133    }
134
135    fn stats(array: &DictArray) -> StatsSetRef<'_> {
136        array.stats_set.to_ref(array.as_ref())
137    }
138}
139
140impl CanonicalVTable<DictVTable> for DictVTable {
141    fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
142        match array.dtype() {
143            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
144            // decompression to construct the views child array.
145            // For this case, it is *always* faster to decompress the values first and then create
146            // copies of the view pointers.
147            DType::Utf8(_) | DType::Binary(_) => {
148                let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
149                take(&canonical_values, array.codes())?.to_canonical()
150            }
151            _ => take(array.values(), array.codes())?.to_canonical(),
152        }
153    }
154}
155
156impl ValidityVTable<DictVTable> for DictVTable {
157    fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
158        let scalar = array.codes().scalar_at(index);
159
160        if scalar.is_null() {
161            return Ok(false);
162        };
163        let values_index: usize = scalar
164            .as_ref()
165            .try_into()
166            .vortex_expect("Failed to convert dictionary code to usize");
167        array.values().is_valid(values_index)
168    }
169
170    fn all_valid(array: &DictArray) -> VortexResult<bool> {
171        Ok(array.codes().all_valid()? && array.values().all_valid()?)
172    }
173
174    fn all_invalid(array: &DictArray) -> VortexResult<bool> {
175        Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
176    }
177
178    fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
179        let codes_validity = array.codes().validity_mask()?;
180        match codes_validity.boolean_buffer() {
181            AllOr::All => {
182                let primitive_codes = array.codes().to_primitive()?;
183                let values_mask = array.values().validity_mask()?;
184                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
185                    let codes_slice = primitive_codes.as_slice::<P>();
186                    BooleanBuffer::collect_bool(array.len(), |idx| {
187                        #[allow(clippy::cast_possible_truncation)]
188                        values_mask.value(codes_slice[idx] as usize)
189                    })
190                });
191                Ok(Mask::from_buffer(is_valid_buffer))
192            }
193            AllOr::None => Ok(Mask::AllFalse(array.len())),
194            AllOr::Some(validity_buff) => {
195                let primitive_codes = array.codes().to_primitive()?;
196                let values_mask = array.values().validity_mask()?;
197                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
198                    let codes_slice = primitive_codes.as_slice::<P>();
199                    #[allow(clippy::cast_possible_truncation)]
200                    BooleanBuffer::collect_bool(array.len(), |idx| {
201                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
202                    })
203                });
204                Ok(Mask::from_buffer(is_valid_buffer))
205            }
206        }
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use arrow_buffer::BooleanBuffer;
213    use rand::distr::{Distribution, StandardUniform};
214    use rand::prelude::StdRng;
215    use rand::{Rng, SeedableRng};
216    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
217    use vortex_array::builders::builder_with_capacity;
218    use vortex_array::validity::Validity;
219    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
220    use vortex_buffer::buffer;
221    use vortex_dtype::Nullability::NonNullable;
222    use vortex_dtype::{DType, NativePType, PType};
223    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
224    use vortex_mask::AllOr;
225
226    use crate::DictArray;
227
228    #[test]
229    fn nullable_codes_validity() {
230        let dict = DictArray::try_new(
231            PrimitiveArray::new(
232                buffer![0u32, 1, 2, 2, 1],
233                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
234            )
235            .into_array(),
236            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
237        )
238        .unwrap();
239        let mask = dict.validity_mask().unwrap();
240        let AllOr::Some(indices) = mask.indices() else {
241            vortex_panic!("Expected indices from mask")
242        };
243        assert_eq!(indices, [0, 2, 4]);
244    }
245
246    #[test]
247    fn nullable_values_validity() {
248        let dict = DictArray::try_new(
249            buffer![0u32, 1, 2, 2, 1].into_array(),
250            PrimitiveArray::new(
251                buffer![3, 6, 9],
252                Validity::from(BooleanBuffer::from(vec![true, false, false])),
253            )
254            .into_array(),
255        )
256        .unwrap();
257        let mask = dict.validity_mask().unwrap();
258        let AllOr::Some(indices) = mask.indices() else {
259            vortex_panic!("Expected indices from mask")
260        };
261        assert_eq!(indices, [0]);
262    }
263
264    #[test]
265    fn nullable_codes_and_values() {
266        let dict = DictArray::try_new(
267            PrimitiveArray::new(
268                buffer![0u32, 1, 2, 2, 1],
269                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
270            )
271            .into_array(),
272            PrimitiveArray::new(
273                buffer![3, 6, 9],
274                Validity::from(BooleanBuffer::from(vec![false, true, true])),
275            )
276            .into_array(),
277        )
278        .unwrap();
279        let mask = dict.validity_mask().unwrap();
280        let AllOr::Some(indices) = mask.indices() else {
281            vortex_panic!("Expected indices from mask")
282        };
283        assert_eq!(indices, [2, 4]);
284    }
285
286    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
287        len: usize,
288        unique_values: usize,
289        chunk_count: usize,
290    ) -> ArrayRef
291    where
292        StandardUniform: Distribution<T>,
293    {
294        let mut rng = StdRng::seed_from_u64(0);
295
296        (0..chunk_count)
297            .map(|_| {
298                let values = (0..unique_values)
299                    .map(|_| rng.random::<T>())
300                    .collect::<PrimitiveArray>();
301                let codes = (0..len)
302                    .map(|_| {
303                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
304                    })
305                    .collect::<PrimitiveArray>();
306
307                DictArray::try_new(codes.into_array(), values.into_array())
308                    .vortex_unwrap()
309                    .into_array()
310            })
311            .collect::<ChunkedArray>()
312            .into_array()
313    }
314
315    #[test]
316    fn test_dict_array_from_primitive_chunks() {
317        let len = 2;
318        let chunk_count = 2;
319        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
320
321        let mut builder = builder_with_capacity(
322            &DType::Primitive(PType::U64, NonNullable),
323            len * chunk_count,
324        );
325        array
326            .clone()
327            .append_to_builder(builder.as_mut())
328            .vortex_unwrap();
329
330        let into_prim = array.to_primitive().unwrap();
331        let prim_into = builder.finish().to_primitive().unwrap();
332
333        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
334        assert_eq!(
335            into_prim.validity_mask().unwrap().boolean_buffer(),
336            prim_into.validity_mask().unwrap().boolean_buffer()
337        )
338    }
339}