vortex_dict/
array.rs

1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::builders::ArrayBuilder;
5use vortex_array::compute::{cast, scalar_at, take, take_into};
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11    Canonical, Encoding, IntoArray, ProstMetadata, ToCanonical,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17use crate::serde::DictMetadata;
18
19#[derive(Debug, Clone)]
20pub struct DictArray {
21    codes: ArrayRef,
22    values: ArrayRef,
23    stats_set: ArrayStats,
24}
25
26#[derive(Debug)]
27pub struct DictEncoding;
28impl Encoding for DictEncoding {
29    type Array = DictArray;
30    type Metadata = ProstMetadata<DictMetadata>;
31}
32
33impl DictArray {
34    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
35        if !codes.dtype().is_unsigned_int() {
36            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
37        }
38
39        let dtype = values.dtype();
40        if dtype.is_nullable() {
41            // If the values are nullable, we force codes to be nullable as well.
42            codes = cast(&codes, &codes.dtype().as_nullable())?;
43        } else {
44            // If the values are non-nullable, we assert the codes are non-nullable as well.
45            if codes.dtype().is_nullable() {
46                vortex_bail!("Cannot have nullable codes for non-nullable dict array");
47            }
48        }
49        assert_eq!(
50            codes.dtype().nullability(),
51            values.dtype().nullability(),
52            "Mismatched nullability between codes and values"
53        );
54
55        Ok(Self {
56            codes,
57            values,
58            stats_set: Default::default(),
59        })
60    }
61
62    #[inline]
63    pub fn codes(&self) -> &ArrayRef {
64        &self.codes
65    }
66
67    #[inline]
68    pub fn values(&self) -> &ArrayRef {
69        &self.values
70    }
71}
72
73impl ArrayImpl for DictArray {
74    type Encoding = DictEncoding;
75
76    fn _len(&self) -> usize {
77        self.codes.len()
78    }
79
80    fn _dtype(&self) -> &DType {
81        self.values.dtype()
82    }
83
84    fn _vtable(&self) -> VTableRef {
85        VTableRef::new_ref(&DictEncoding)
86    }
87
88    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
89        let codes = children[0].clone();
90        let values = children[1].clone();
91
92        Self::try_new(codes, values)
93    }
94}
95
96impl ArrayCanonicalImpl for DictArray {
97    fn _to_canonical(&self) -> VortexResult<Canonical> {
98        match self.dtype() {
99            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
100            // decompression to construct the views child array.
101            // For this case, it is *always* faster to decompress the values first and then create
102            // copies of the view pointers.
103            DType::Utf8(_) | DType::Binary(_) => {
104                let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
105                take(&canonical_values, self.codes())?.to_canonical()
106            }
107            _ => take(self.values(), self.codes())?.to_canonical(),
108        }
109    }
110
111    fn _append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
112        match self.dtype() {
113            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
114            // decompression to construct the views child array.
115            // For this case, it is *always* faster to decompress the values first and then create
116            // copies of the view pointers.
117            // TODO(joe): is the above still true?, investigate this.
118            DType::Utf8(_) | DType::Binary(_) => {
119                let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
120                take_into(&canonical_values, self.codes(), builder)
121            }
122            // Non-string case: take and then canonicalize
123            _ => take_into(self.values(), self.codes(), builder),
124        }
125    }
126}
127
128impl ArrayValidityImpl for DictArray {
129    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
130        let scalar = scalar_at(self.codes(), index).map_err(|err| {
131            err.with_context(format!(
132                "Failed to get index {} from DictArray codes",
133                index
134            ))
135        })?;
136
137        if scalar.is_null() {
138            return Ok(false);
139        };
140        let values_index: usize = scalar
141            .as_ref()
142            .try_into()
143            .vortex_expect("Failed to convert dictionary code to usize");
144        self.values().is_valid(values_index)
145    }
146
147    fn _all_valid(&self) -> VortexResult<bool> {
148        if !self.dtype().is_nullable() {
149            return Ok(true);
150        }
151
152        Ok(self.codes().all_valid()? && self.values().all_valid()?)
153    }
154
155    fn _all_invalid(&self) -> VortexResult<bool> {
156        if !self.dtype().is_nullable() {
157            return Ok(false);
158        }
159
160        Ok(self.codes().all_invalid()? || self.values().all_invalid()?)
161    }
162
163    fn _validity_mask(&self) -> VortexResult<Mask> {
164        let codes_validity = self.codes().validity_mask()?;
165        match codes_validity.boolean_buffer() {
166            AllOr::All => {
167                let primitive_codes = self.codes().to_primitive()?;
168                let values_mask = self.values().validity_mask()?;
169                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
170                    let codes_slice = primitive_codes.as_slice::<$P>();
171                    BooleanBuffer::collect_bool(self.len(), |idx| {
172                       values_mask.value(codes_slice[idx] as usize)
173                    })
174                });
175                Ok(Mask::from_buffer(is_valid_buffer))
176            }
177            AllOr::None => Ok(Mask::AllFalse(self.len())),
178            AllOr::Some(validity_buff) => {
179                let primitive_codes = self.codes().to_primitive()?;
180                let values_mask = self.values().validity_mask()?;
181                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
182                    let codes_slice = primitive_codes.as_slice::<$P>();
183                    BooleanBuffer::collect_bool(self.len(), |idx| {
184                       validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
185                    })
186                });
187                Ok(Mask::from_buffer(is_valid_buffer))
188            }
189        }
190    }
191}
192
193impl ArrayStatisticsImpl for DictArray {
194    fn _stats_ref(&self) -> StatsSetRef<'_> {
195        self.stats_set.to_ref(self)
196    }
197}
198
199#[cfg(test)]
200mod test {
201    use arrow_buffer::BooleanBuffer;
202    use rand::distr::{Distribution, StandardUniform};
203    use rand::prelude::StdRng;
204    use rand::{Rng, SeedableRng};
205    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
206    use vortex_array::builders::builder_with_capacity;
207    use vortex_array::validity::Validity;
208    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
209    use vortex_buffer::buffer;
210    use vortex_dtype::Nullability::NonNullable;
211    use vortex_dtype::{DType, NativePType, PType};
212    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
213    use vortex_mask::AllOr;
214
215    use crate::DictArray;
216
217    #[test]
218    fn nullable_codes_validity() {
219        let dict = DictArray::try_new(
220            PrimitiveArray::new(
221                buffer![0u32, 1, 2, 2, 1],
222                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
223            )
224            .into_array(),
225            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
226        )
227        .unwrap();
228        let mask = dict.validity_mask().unwrap();
229        let AllOr::Some(indices) = mask.indices() else {
230            vortex_panic!("Expected indices from mask")
231        };
232        assert_eq!(indices, [0, 2, 4]);
233    }
234
235    #[test]
236    fn nullable_values_validity() {
237        let dict = DictArray::try_new(
238            buffer![0u32, 1, 2, 2, 1].into_array(),
239            PrimitiveArray::new(
240                buffer![3, 6, 9],
241                Validity::from(BooleanBuffer::from(vec![true, false, false])),
242            )
243            .into_array(),
244        )
245        .unwrap();
246        let mask = dict.validity_mask().unwrap();
247        let AllOr::Some(indices) = mask.indices() else {
248            vortex_panic!("Expected indices from mask")
249        };
250        assert_eq!(indices, [0]);
251    }
252
253    #[test]
254    fn nullable_codes_and_values() {
255        let dict = DictArray::try_new(
256            PrimitiveArray::new(
257                buffer![0u32, 1, 2, 2, 1],
258                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
259            )
260            .into_array(),
261            PrimitiveArray::new(
262                buffer![3, 6, 9],
263                Validity::from(BooleanBuffer::from(vec![false, true, true])),
264            )
265            .into_array(),
266        )
267        .unwrap();
268        let mask = dict.validity_mask().unwrap();
269        let AllOr::Some(indices) = mask.indices() else {
270            vortex_panic!("Expected indices from mask")
271        };
272        assert_eq!(indices, [2, 4]);
273    }
274
275    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
276        len: usize,
277        unique_values: usize,
278        chunk_count: usize,
279    ) -> ArrayRef
280    where
281        StandardUniform: Distribution<T>,
282    {
283        let mut rng = StdRng::seed_from_u64(0);
284
285        (0..chunk_count)
286            .map(|_| {
287                let values = (0..unique_values)
288                    .map(|_| rng.random::<T>())
289                    .collect::<PrimitiveArray>();
290                let codes = (0..len)
291                    .map(|_| {
292                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
293                    })
294                    .collect::<PrimitiveArray>();
295
296                DictArray::try_new(codes.into_array(), values.into_array())
297                    .vortex_unwrap()
298                    .into_array()
299            })
300            .collect::<ChunkedArray>()
301            .into_array()
302    }
303
304    #[test]
305    fn test_dict_array_from_primitive_chunks() {
306        let len = 2;
307        let chunk_count = 2;
308        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
309
310        let mut builder = builder_with_capacity(
311            &DType::Primitive(PType::U64, NonNullable),
312            len * chunk_count,
313        );
314        array
315            .clone()
316            .append_to_builder(builder.as_mut())
317            .vortex_unwrap();
318
319        let into_prim = array.to_primitive().unwrap();
320        let prim_into = builder.finish().to_primitive().unwrap();
321
322        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
323        assert_eq!(
324            into_prim.validity_mask().unwrap().boolean_buffer(),
325            prim_into.validity_mask().unwrap().boolean_buffer()
326        )
327    }
328}