vortex_dict/
array.rs

1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::compute::{cast, take};
5use vortex_array::stats::{ArrayStats, StatsSetRef};
6use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
7use vortex_array::{
8    Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
9};
10use vortex_dtype::{DType, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13
14vtable!(Dict);
15
16impl VTable for DictVTable {
17    type Array = DictArray;
18    type Encoding = DictEncoding;
19
20    type ArrayVTable = Self;
21    type CanonicalVTable = Self;
22    type OperationsVTable = Self;
23    type ValidityVTable = Self;
24    type VisitorVTable = Self;
25    type ComputeVTable = NotSupported;
26    type EncodeVTable = Self;
27    type SerdeVTable = Self;
28
29    fn id(_encoding: &Self::Encoding) -> EncodingId {
30        EncodingId::new_ref("vortex.dict")
31    }
32
33    fn encoding(_array: &Self::Array) -> EncodingRef {
34        EncodingRef::new_ref(DictEncoding.as_ref())
35    }
36}
37
38#[derive(Debug, Clone)]
39pub struct DictArray {
40    codes: ArrayRef,
41    values: ArrayRef,
42    stats_set: ArrayStats,
43}
44
45#[derive(Clone, Debug)]
46pub struct DictEncoding;
47
48impl DictArray {
49    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
50        if !codes.dtype().is_unsigned_int() {
51            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
52        }
53
54        let dtype = values.dtype();
55        if dtype.is_nullable() {
56            // If the values are nullable, we force codes to be nullable as well.
57            codes = cast(&codes, &codes.dtype().as_nullable())?;
58        } else {
59            // If the values are non-nullable, we assert the codes are non-nullable as well.
60            if codes.dtype().is_nullable() {
61                vortex_bail!("Cannot have nullable codes for non-nullable dict array");
62            }
63        }
64        assert_eq!(
65            codes.dtype().nullability(),
66            values.dtype().nullability(),
67            "Mismatched nullability between codes and values"
68        );
69
70        Ok(Self {
71            codes,
72            values,
73            stats_set: Default::default(),
74        })
75    }
76
77    #[inline]
78    pub fn codes(&self) -> &ArrayRef {
79        &self.codes
80    }
81
82    #[inline]
83    pub fn values(&self) -> &ArrayRef {
84        &self.values
85    }
86}
87
88impl ArrayVTable<DictVTable> for DictVTable {
89    fn len(array: &DictArray) -> usize {
90        array.codes.len()
91    }
92
93    fn dtype(array: &DictArray) -> &DType {
94        array.values.dtype()
95    }
96
97    fn stats(array: &DictArray) -> StatsSetRef<'_> {
98        array.stats_set.to_ref(array.as_ref())
99    }
100}
101
102impl CanonicalVTable<DictVTable> for DictVTable {
103    fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
104        match array.dtype() {
105            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
106            // decompression to construct the views child array.
107            // For this case, it is *always* faster to decompress the values first and then create
108            // copies of the view pointers.
109            DType::Utf8(_) | DType::Binary(_) => {
110                let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
111                take(&canonical_values, array.codes())?.to_canonical()
112            }
113            _ => take(array.values(), array.codes())?.to_canonical(),
114        }
115    }
116}
117
118impl ValidityVTable<DictVTable> for DictVTable {
119    fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
120        let scalar = array.codes().scalar_at(index).map_err(|err| {
121            err.with_context(format!("Failed to get index {index} from DictArray codes"))
122        })?;
123
124        if scalar.is_null() {
125            return Ok(false);
126        };
127        let values_index: usize = scalar
128            .as_ref()
129            .try_into()
130            .vortex_expect("Failed to convert dictionary code to usize");
131        array.values().is_valid(values_index)
132    }
133
134    fn all_valid(array: &DictArray) -> VortexResult<bool> {
135        Ok(array.codes().all_valid()? && array.values().all_valid()?)
136    }
137
138    fn all_invalid(array: &DictArray) -> VortexResult<bool> {
139        Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
140    }
141
142    fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
143        let codes_validity = array.codes().validity_mask()?;
144        match codes_validity.boolean_buffer() {
145            AllOr::All => {
146                let primitive_codes = array.codes().to_primitive()?;
147                let values_mask = array.values().validity_mask()?;
148                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
149                    let codes_slice = primitive_codes.as_slice::<P>();
150                    BooleanBuffer::collect_bool(array.len(), |idx| {
151                        #[allow(clippy::cast_possible_truncation)]
152                        values_mask.value(codes_slice[idx] as usize)
153                    })
154                });
155                Ok(Mask::from_buffer(is_valid_buffer))
156            }
157            AllOr::None => Ok(Mask::AllFalse(array.len())),
158            AllOr::Some(validity_buff) => {
159                let primitive_codes = array.codes().to_primitive()?;
160                let values_mask = array.values().validity_mask()?;
161                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
162                    let codes_slice = primitive_codes.as_slice::<P>();
163                    #[allow(clippy::cast_possible_truncation)]
164                    BooleanBuffer::collect_bool(array.len(), |idx| {
165                        validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
166                    })
167                });
168                Ok(Mask::from_buffer(is_valid_buffer))
169            }
170        }
171    }
172}
173
174#[cfg(test)]
175mod test {
176    use arrow_buffer::BooleanBuffer;
177    use rand::distr::{Distribution, StandardUniform};
178    use rand::prelude::StdRng;
179    use rand::{Rng, SeedableRng};
180    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
181    use vortex_array::builders::builder_with_capacity;
182    use vortex_array::validity::Validity;
183    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
184    use vortex_buffer::buffer;
185    use vortex_dtype::Nullability::NonNullable;
186    use vortex_dtype::{DType, NativePType, PType};
187    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
188    use vortex_mask::AllOr;
189
190    use crate::DictArray;
191
192    #[test]
193    fn nullable_codes_validity() {
194        let dict = DictArray::try_new(
195            PrimitiveArray::new(
196                buffer![0u32, 1, 2, 2, 1],
197                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
198            )
199            .into_array(),
200            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
201        )
202        .unwrap();
203        let mask = dict.validity_mask().unwrap();
204        let AllOr::Some(indices) = mask.indices() else {
205            vortex_panic!("Expected indices from mask")
206        };
207        assert_eq!(indices, [0, 2, 4]);
208    }
209
210    #[test]
211    fn nullable_values_validity() {
212        let dict = DictArray::try_new(
213            buffer![0u32, 1, 2, 2, 1].into_array(),
214            PrimitiveArray::new(
215                buffer![3, 6, 9],
216                Validity::from(BooleanBuffer::from(vec![true, false, false])),
217            )
218            .into_array(),
219        )
220        .unwrap();
221        let mask = dict.validity_mask().unwrap();
222        let AllOr::Some(indices) = mask.indices() else {
223            vortex_panic!("Expected indices from mask")
224        };
225        assert_eq!(indices, [0]);
226    }
227
228    #[test]
229    fn nullable_codes_and_values() {
230        let dict = DictArray::try_new(
231            PrimitiveArray::new(
232                buffer![0u32, 1, 2, 2, 1],
233                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
234            )
235            .into_array(),
236            PrimitiveArray::new(
237                buffer![3, 6, 9],
238                Validity::from(BooleanBuffer::from(vec![false, true, true])),
239            )
240            .into_array(),
241        )
242        .unwrap();
243        let mask = dict.validity_mask().unwrap();
244        let AllOr::Some(indices) = mask.indices() else {
245            vortex_panic!("Expected indices from mask")
246        };
247        assert_eq!(indices, [2, 4]);
248    }
249
250    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
251        len: usize,
252        unique_values: usize,
253        chunk_count: usize,
254    ) -> ArrayRef
255    where
256        StandardUniform: Distribution<T>,
257    {
258        let mut rng = StdRng::seed_from_u64(0);
259
260        (0..chunk_count)
261            .map(|_| {
262                let values = (0..unique_values)
263                    .map(|_| rng.random::<T>())
264                    .collect::<PrimitiveArray>();
265                let codes = (0..len)
266                    .map(|_| {
267                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
268                    })
269                    .collect::<PrimitiveArray>();
270
271                DictArray::try_new(codes.into_array(), values.into_array())
272                    .vortex_unwrap()
273                    .into_array()
274            })
275            .collect::<ChunkedArray>()
276            .into_array()
277    }
278
279    #[test]
280    fn test_dict_array_from_primitive_chunks() {
281        let len = 2;
282        let chunk_count = 2;
283        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
284
285        let mut builder = builder_with_capacity(
286            &DType::Primitive(PType::U64, NonNullable),
287            len * chunk_count,
288        );
289        array
290            .clone()
291            .append_to_builder(builder.as_mut())
292            .vortex_unwrap();
293
294        let into_prim = array.to_primitive().unwrap();
295        let prim_into = builder.finish().to_primitive().unwrap();
296
297        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
298        assert_eq!(
299            into_prim.validity_mask().unwrap().boolean_buffer(),
300            prim_into.validity_mask().unwrap().boolean_buffer()
301        )
302    }
303}