vortex_dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20    type Array = DictArray;
21    type Encoding = DictEncoding;
22
23    type ArrayVTable = Self;
24    type CanonicalVTable = Self;
25    type OperationsVTable = Self;
26    type ValidityVTable = Self;
27    type VisitorVTable = Self;
28    type ComputeVTable = NotSupported;
29    type EncodeVTable = Self;
30    type SerdeVTable = Self;
31
32    fn id(_encoding: &Self::Encoding) -> EncodingId {
33        EncodingId::new_ref("vortex.dict")
34    }
35
36    fn encoding(_array: &Self::Array) -> EncodingRef {
37        EncodingRef::new_ref(DictEncoding.as_ref())
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct DictArray {
43    codes: ArrayRef,
44    values: ArrayRef,
45    stats_set: ArrayStats,
46}
47
48#[derive(Clone, Debug)]
49pub struct DictEncoding;
50
51impl DictArray {
52    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
53        if !codes.dtype().is_unsigned_int() {
54            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
55        }
56
57        let dtype = values.dtype();
58        if dtype.is_nullable() {
59            // If the values are nullable, we force codes to be nullable as well.
60            codes = cast(&codes, &codes.dtype().as_nullable())?;
61        } else {
62            // If the values are non-nullable, we assert the codes are non-nullable as well.
63            if codes.dtype().is_nullable() {
64                vortex_bail!("Cannot have nullable codes for non-nullable dict array");
65            }
66        }
67        assert_eq!(
68            codes.dtype().nullability(),
69            values.dtype().nullability(),
70            "Mismatched nullability between codes and values"
71        );
72
73        Ok(Self {
74            codes,
75            values,
76            stats_set: Default::default(),
77        })
78    }
79
80    #[inline]
81    pub fn codes(&self) -> &ArrayRef {
82        &self.codes
83    }
84
85    #[inline]
86    pub fn values(&self) -> &ArrayRef {
87        &self.values
88    }
89}
90
91impl ArrayVTable<DictVTable> for DictVTable {
92    fn len(array: &DictArray) -> usize {
93        array.codes.len()
94    }
95
96    fn dtype(array: &DictArray) -> &DType {
97        array.values.dtype()
98    }
99
100    fn stats(array: &DictArray) -> StatsSetRef<'_> {
101        array.stats_set.to_ref(array.as_ref())
102    }
103}
104
105impl CanonicalVTable<DictVTable> for DictVTable {
106    fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
107        match array.dtype() {
108            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
109            // decompression to construct the views child array.
110            // For this case, it is *always* faster to decompress the values first and then create
111            // copies of the view pointers.
112            DType::Utf8(_) | DType::Binary(_) => {
113                let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
114                take(&canonical_values, array.codes())?.to_canonical()
115            }
116            _ => take(array.values(), array.codes())?.to_canonical(),
117        }
118    }
119}
120
121impl ValidityVTable<DictVTable> for DictVTable {
122    fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
123        let scalar = array.codes().scalar_at(index).map_err(|err| {
124            err.with_context(format!("Failed to get index {index} from DictArray codes"))
125        })?;
126
127        if scalar.is_null() {
128            return Ok(false);
129        };
130        let values_index: usize = scalar
131            .as_ref()
132            .try_into()
133            .vortex_expect("Failed to convert dictionary code to usize");
134        array.values().is_valid(values_index)
135    }
136
137    fn all_valid(array: &DictArray) -> VortexResult<bool> {
138        Ok(array.codes().all_valid()? && array.values().all_valid()?)
139    }
140
141    fn all_invalid(array: &DictArray) -> VortexResult<bool> {
142        Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
143    }
144
145    fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
146        let codes_validity = array.codes().validity_mask()?;
147        match codes_validity.boolean_buffer() {
148            AllOr::All => {
149                let primitive_codes = array.codes().to_primitive()?;
150                let values_mask = array.values().validity_mask()?;
151                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
152                    let codes_slice = primitive_codes.as_slice::<P>();
153                    BooleanBuffer::collect_bool(array.len(), |idx| {
154                        #[allow(clippy::cast_possible_truncation)]
155                        values_mask.value(codes_slice[idx] as usize)
156                    })
157                });
158                Ok(Mask::from_buffer(is_valid_buffer))
159            }
160            AllOr::None => Ok(Mask::AllFalse(array.len())),
161            AllOr::Some(validity_buff) => {
162                let primitive_codes = array.codes().to_primitive()?;
163                let values_mask = array.values().validity_mask()?;
164                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
165                    let codes_slice = primitive_codes.as_slice::<P>();
166                    #[allow(clippy::cast_possible_truncation)]
167                    BooleanBuffer::collect_bool(array.len(), |idx| {
168                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
169                    })
170                });
171                Ok(Mask::from_buffer(is_valid_buffer))
172            }
173        }
174    }
175}
176
177#[cfg(test)]
178mod test {
179    use arrow_buffer::BooleanBuffer;
180    use rand::distr::{Distribution, StandardUniform};
181    use rand::prelude::StdRng;
182    use rand::{Rng, SeedableRng};
183    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
184    use vortex_array::builders::builder_with_capacity;
185    use vortex_array::validity::Validity;
186    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
187    use vortex_buffer::buffer;
188    use vortex_dtype::Nullability::NonNullable;
189    use vortex_dtype::{DType, NativePType, PType};
190    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
191    use vortex_mask::AllOr;
192
193    use crate::DictArray;
194
195    #[test]
196    fn nullable_codes_validity() {
197        let dict = DictArray::try_new(
198            PrimitiveArray::new(
199                buffer![0u32, 1, 2, 2, 1],
200                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
201            )
202            .into_array(),
203            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
204        )
205        .unwrap();
206        let mask = dict.validity_mask().unwrap();
207        let AllOr::Some(indices) = mask.indices() else {
208            vortex_panic!("Expected indices from mask")
209        };
210        assert_eq!(indices, [0, 2, 4]);
211    }
212
213    #[test]
214    fn nullable_values_validity() {
215        let dict = DictArray::try_new(
216            buffer![0u32, 1, 2, 2, 1].into_array(),
217            PrimitiveArray::new(
218                buffer![3, 6, 9],
219                Validity::from(BooleanBuffer::from(vec![true, false, false])),
220            )
221            .into_array(),
222        )
223        .unwrap();
224        let mask = dict.validity_mask().unwrap();
225        let AllOr::Some(indices) = mask.indices() else {
226            vortex_panic!("Expected indices from mask")
227        };
228        assert_eq!(indices, [0]);
229    }
230
231    #[test]
232    fn nullable_codes_and_values() {
233        let dict = DictArray::try_new(
234            PrimitiveArray::new(
235                buffer![0u32, 1, 2, 2, 1],
236                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
237            )
238            .into_array(),
239            PrimitiveArray::new(
240                buffer![3, 6, 9],
241                Validity::from(BooleanBuffer::from(vec![false, true, true])),
242            )
243            .into_array(),
244        )
245        .unwrap();
246        let mask = dict.validity_mask().unwrap();
247        let AllOr::Some(indices) = mask.indices() else {
248            vortex_panic!("Expected indices from mask")
249        };
250        assert_eq!(indices, [2, 4]);
251    }
252
253    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
254        len: usize,
255        unique_values: usize,
256        chunk_count: usize,
257    ) -> ArrayRef
258    where
259        StandardUniform: Distribution<T>,
260    {
261        let mut rng = StdRng::seed_from_u64(0);
262
263        (0..chunk_count)
264            .map(|_| {
265                let values = (0..unique_values)
266                    .map(|_| rng.random::<T>())
267                    .collect::<PrimitiveArray>();
268                let codes = (0..len)
269                    .map(|_| {
270                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
271                    })
272                    .collect::<PrimitiveArray>();
273
274                DictArray::try_new(codes.into_array(), values.into_array())
275                    .vortex_unwrap()
276                    .into_array()
277            })
278            .collect::<ChunkedArray>()
279            .into_array()
280    }
281
282    #[test]
283    fn test_dict_array_from_primitive_chunks() {
284        let len = 2;
285        let chunk_count = 2;
286        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
287
288        let mut builder = builder_with_capacity(
289            &DType::Primitive(PType::U64, NonNullable),
290            len * chunk_count,
291        );
292        array
293            .clone()
294            .append_to_builder(builder.as_mut())
295            .vortex_unwrap();
296
297        let into_prim = array.to_primitive().unwrap();
298        let prim_into = builder.finish().to_primitive().unwrap();
299
300        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
301        assert_eq!(
302            into_prim.validity_mask().unwrap().boolean_buffer(),
303            prim_into.validity_mask().unwrap().boolean_buffer()
304        )
305    }
306}