vortex_dict/
array.rs

1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::builders::ArrayBuilder;
5use vortex_array::compute::{scalar_at, take, take_into, try_cast};
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11    Canonical, Encoding, IntoArray, RkyvMetadata, ToCanonical,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17use crate::serde::DictMetadata;
18
19#[derive(Debug, Clone)]
20pub struct DictArray {
21    codes: ArrayRef,
22    values: ArrayRef,
23    stats_set: ArrayStats,
24}
25
26pub struct DictEncoding;
27impl Encoding for DictEncoding {
28    type Array = DictArray;
29    type Metadata = RkyvMetadata<DictMetadata>;
30}
31
32impl DictArray {
33    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
34        if !codes.dtype().is_unsigned_int() {
35            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
36        }
37
38        let dtype = values.dtype();
39        if dtype.is_nullable() {
40            // If the values are nullable, we force codes to be nullable as well.
41            codes = try_cast(&codes, &codes.dtype().as_nullable())?;
42        } else {
43            // If the values are non-nullable, we assert the codes are non-nullable as well.
44            if codes.dtype().is_nullable() {
45                vortex_bail!("Cannot have nullable codes for non-nullable dict array");
46            }
47        }
48        assert_eq!(
49            codes.dtype().nullability(),
50            values.dtype().nullability(),
51            "Mismatched nullability between codes and values"
52        );
53
54        Ok(Self {
55            codes,
56            values,
57            stats_set: Default::default(),
58        })
59    }
60
61    #[inline]
62    pub fn codes(&self) -> &ArrayRef {
63        &self.codes
64    }
65
66    #[inline]
67    pub fn values(&self) -> &ArrayRef {
68        &self.values
69    }
70}
71
72impl ArrayImpl for DictArray {
73    type Encoding = DictEncoding;
74
75    fn _len(&self) -> usize {
76        self.codes.len()
77    }
78
79    fn _dtype(&self) -> &DType {
80        self.values.dtype()
81    }
82
83    fn _vtable(&self) -> VTableRef {
84        VTableRef::new_ref(&DictEncoding)
85    }
86
87    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
88        let codes = children[0].clone();
89        let values = children[1].clone();
90
91        Self::try_new(codes, values)
92    }
93}
94
95impl ArrayCanonicalImpl for DictArray {
96    fn _to_canonical(&self) -> VortexResult<Canonical> {
97        match self.dtype() {
98            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
99            // decompression to construct the views child array.
100            // For this case, it is *always* faster to decompress the values first and then create
101            // copies of the view pointers.
102            DType::Utf8(_) | DType::Binary(_) => {
103                let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
104                take(&canonical_values, self.codes())?.to_canonical()
105            }
106            _ => take(self.values(), self.codes())?.to_canonical(),
107        }
108    }
109
110    fn _append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
111        match self.dtype() {
112            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
113            // decompression to construct the views child array.
114            // For this case, it is *always* faster to decompress the values first and then create
115            // copies of the view pointers.
116            // TODO(joe): is the above still true?, investigate this.
117            DType::Utf8(_) | DType::Binary(_) => {
118                let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
119                take_into(&canonical_values, self.codes(), builder)
120            }
121            // Non-string case: take and then canonicalize
122            _ => take_into(self.values(), self.codes(), builder),
123        }
124    }
125}
126
127impl ArrayValidityImpl for DictArray {
128    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
129        let scalar = scalar_at(self.codes(), index).map_err(|err| {
130            err.with_context(format!(
131                "Failed to get index {} from DictArray codes",
132                index
133            ))
134        })?;
135
136        if scalar.is_null() {
137            return Ok(false);
138        };
139        let values_index: usize = scalar
140            .as_ref()
141            .try_into()
142            .vortex_expect("Failed to convert dictionary code to usize");
143        self.values().is_valid(values_index)
144    }
145
146    fn _all_valid(&self) -> VortexResult<bool> {
147        if !self.dtype().is_nullable() {
148            return Ok(true);
149        }
150
151        Ok(self.codes().all_valid()? && self.values().all_valid()?)
152    }
153
154    fn _all_invalid(&self) -> VortexResult<bool> {
155        if !self.dtype().is_nullable() {
156            return Ok(false);
157        }
158
159        Ok(self.codes().all_invalid()? || self.values().all_invalid()?)
160    }
161
162    fn _validity_mask(&self) -> VortexResult<Mask> {
163        let codes_validity = self.codes().validity_mask()?;
164        match codes_validity.boolean_buffer() {
165            AllOr::All => {
166                let primitive_codes = self.codes().to_primitive()?;
167                let values_mask = self.values().validity_mask()?;
168                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
169                    let codes_slice = primitive_codes.as_slice::<$P>();
170                    BooleanBuffer::collect_bool(self.len(), |idx| {
171                       values_mask.value(codes_slice[idx] as usize)
172                    })
173                });
174                Ok(Mask::from_buffer(is_valid_buffer))
175            }
176            AllOr::None => Ok(Mask::AllFalse(self.len())),
177            AllOr::Some(validity_buff) => {
178                let primitive_codes = self.codes().to_primitive()?;
179                let values_mask = self.values().validity_mask()?;
180                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
181                    let codes_slice = primitive_codes.as_slice::<$P>();
182                    BooleanBuffer::collect_bool(self.len(), |idx| {
183                       validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
184                    })
185                });
186                Ok(Mask::from_buffer(is_valid_buffer))
187            }
188        }
189    }
190}
191
192impl ArrayStatisticsImpl for DictArray {
193    fn _stats_ref(&self) -> StatsSetRef<'_> {
194        self.stats_set.to_ref(self)
195    }
196}
197
198#[cfg(test)]
199mod test {
200    use arrow_buffer::BooleanBuffer;
201    use rand::distr::{Distribution, StandardUniform};
202    use rand::prelude::StdRng;
203    use rand::{Rng, SeedableRng};
204    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
205    use vortex_array::builders::builder_with_capacity;
206    use vortex_array::validity::Validity;
207    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
208    use vortex_buffer::buffer;
209    use vortex_dtype::Nullability::NonNullable;
210    use vortex_dtype::{DType, NativePType, PType};
211    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
212    use vortex_mask::AllOr;
213
214    use crate::DictArray;
215
216    #[test]
217    fn nullable_codes_validity() {
218        let dict = DictArray::try_new(
219            PrimitiveArray::new(
220                buffer![0u32, 1, 2, 2, 1],
221                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
222            )
223            .into_array(),
224            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
225        )
226        .unwrap();
227        let mask = dict.validity_mask().unwrap();
228        let AllOr::Some(indices) = mask.indices() else {
229            vortex_panic!("Expected indices from mask")
230        };
231        assert_eq!(indices, [0, 2, 4]);
232    }
233
234    #[test]
235    fn nullable_values_validity() {
236        let dict = DictArray::try_new(
237            buffer![0u32, 1, 2, 2, 1].into_array(),
238            PrimitiveArray::new(
239                buffer![3, 6, 9],
240                Validity::from(BooleanBuffer::from(vec![true, false, false])),
241            )
242            .into_array(),
243        )
244        .unwrap();
245        let mask = dict.validity_mask().unwrap();
246        let AllOr::Some(indices) = mask.indices() else {
247            vortex_panic!("Expected indices from mask")
248        };
249        assert_eq!(indices, [0]);
250    }
251
252    #[test]
253    fn nullable_codes_and_values() {
254        let dict = DictArray::try_new(
255            PrimitiveArray::new(
256                buffer![0u32, 1, 2, 2, 1],
257                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
258            )
259            .into_array(),
260            PrimitiveArray::new(
261                buffer![3, 6, 9],
262                Validity::from(BooleanBuffer::from(vec![false, true, true])),
263            )
264            .into_array(),
265        )
266        .unwrap();
267        let mask = dict.validity_mask().unwrap();
268        let AllOr::Some(indices) = mask.indices() else {
269            vortex_panic!("Expected indices from mask")
270        };
271        assert_eq!(indices, [2, 4]);
272    }
273
274    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
275        len: usize,
276        unique_values: usize,
277        chunk_count: usize,
278    ) -> ArrayRef
279    where
280        StandardUniform: Distribution<T>,
281    {
282        let mut rng = StdRng::seed_from_u64(0);
283
284        (0..chunk_count)
285            .map(|_| {
286                let values = (0..unique_values)
287                    .map(|_| rng.random::<T>())
288                    .collect::<PrimitiveArray>();
289                let codes = (0..len)
290                    .map(|_| {
291                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
292                    })
293                    .collect::<PrimitiveArray>();
294
295                DictArray::try_new(codes.into_array(), values.into_array())
296                    .vortex_unwrap()
297                    .into_array()
298            })
299            .collect::<ChunkedArray>()
300            .into_array()
301    }
302
303    #[test]
304    fn test_dict_array_from_primitive_chunks() {
305        let len = 2;
306        let chunk_count = 2;
307        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
308
309        let mut builder = builder_with_capacity(
310            &DType::Primitive(PType::U64, NonNullable),
311            len * chunk_count,
312        );
313        array
314            .clone()
315            .append_to_builder(builder.as_mut())
316            .vortex_unwrap();
317
318        let into_prim = array.to_primitive().unwrap();
319        let prim_into = builder.finish().to_primitive().unwrap();
320
321        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
322        assert_eq!(
323            into_prim.validity_mask().unwrap().boolean_buffer(),
324            prim_into.validity_mask().unwrap().boolean_buffer()
325        )
326    }
327}