vortex_dict/
array.rs

1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::builders::ArrayBuilder;
5use vortex_array::compute::{scalar_at, take, take_into, try_cast};
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::{EncodingVTable, VTableRef};
9use vortex_array::{
10    Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11    Canonical, Encoding, EncodingId, IntoArray, RkyvMetadata, ToCanonical,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17use crate::serde::DictMetadata;
18
19#[derive(Debug, Clone)]
20pub struct DictArray {
21    codes: ArrayRef,
22    values: ArrayRef,
23    stats_set: ArrayStats,
24}
25
26pub struct DictEncoding;
27impl Encoding for DictEncoding {
28    type Array = DictArray;
29    type Metadata = RkyvMetadata<DictMetadata>;
30}
31
32impl EncodingVTable for DictEncoding {
33    fn id(&self) -> EncodingId {
34        EncodingId::new_ref("vortex.dict")
35    }
36}
37
38impl DictArray {
39    pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
40        if !codes.dtype().is_unsigned_int() {
41            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
42        }
43
44        let dtype = values.dtype();
45        if dtype.is_nullable() {
46            // If the values are nullable, we force codes to be nullable as well.
47            codes = try_cast(&codes, &codes.dtype().as_nullable())?;
48        } else {
49            // If the values are non-nullable, we assert the codes are non-nullable as well.
50            if codes.dtype().is_nullable() {
51                vortex_bail!("Cannot have nullable codes for non-nullable dict array");
52            }
53        }
54        assert_eq!(
55            codes.dtype().nullability(),
56            values.dtype().nullability(),
57            "Mismatched nullability between codes and values"
58        );
59
60        Ok(Self {
61            codes,
62            values,
63            stats_set: Default::default(),
64        })
65    }
66
67    #[inline]
68    pub fn codes(&self) -> &ArrayRef {
69        &self.codes
70    }
71
72    #[inline]
73    pub fn values(&self) -> &ArrayRef {
74        &self.values
75    }
76}
77
78impl ArrayImpl for DictArray {
79    type Encoding = DictEncoding;
80
81    fn _len(&self) -> usize {
82        self.codes.len()
83    }
84
85    fn _dtype(&self) -> &DType {
86        self.values.dtype()
87    }
88
89    fn _vtable(&self) -> VTableRef {
90        VTableRef::new_ref(&DictEncoding)
91    }
92}
93
94impl ArrayCanonicalImpl for DictArray {
95    fn _to_canonical(&self) -> VortexResult<Canonical> {
96        match self.dtype() {
97            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
98            // decompression to construct the views child array.
99            // For this case, it is *always* faster to decompress the values first and then create
100            // copies of the view pointers.
101            DType::Utf8(_) | DType::Binary(_) => {
102                let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
103                take(&canonical_values, self.codes())?.to_canonical()
104            }
105            _ => take(self.values(), self.codes())?.to_canonical(),
106        }
107    }
108
109    fn _append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
110        match self.dtype() {
111            // NOTE: Utf8 and Binary will decompress into VarBinViewArray, which requires a full
112            // decompression to construct the views child array.
113            // For this case, it is *always* faster to decompress the values first and then create
114            // copies of the view pointers.
115            // TODO(joe): is the above still true?, investigate this.
116            DType::Utf8(_) | DType::Binary(_) => {
117                let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
118                take_into(&canonical_values, self.codes(), builder)
119            }
120            // Non-string case: take and then canonicalize
121            _ => take_into(self.values(), self.codes(), builder),
122        }
123    }
124}
125
126impl ArrayValidityImpl for DictArray {
127    fn _is_valid(&self, index: usize) -> VortexResult<bool> {
128        let scalar = scalar_at(self.codes(), index).map_err(|err| {
129            err.with_context(format!(
130                "Failed to get index {} from DictArray codes",
131                index
132            ))
133        })?;
134
135        if scalar.is_null() {
136            return Ok(false);
137        };
138        let values_index: usize = scalar
139            .as_ref()
140            .try_into()
141            .vortex_expect("Failed to convert dictionary code to usize");
142        self.values().is_valid(values_index)
143    }
144
145    fn _all_valid(&self) -> VortexResult<bool> {
146        if !self.dtype().is_nullable() {
147            return Ok(true);
148        }
149
150        Ok(self.codes().all_valid()? && self.values().all_valid()?)
151    }
152
153    fn _all_invalid(&self) -> VortexResult<bool> {
154        if !self.dtype().is_nullable() {
155            return Ok(false);
156        }
157
158        Ok(self.codes().all_invalid()? || self.values().all_invalid()?)
159    }
160
161    fn _validity_mask(&self) -> VortexResult<Mask> {
162        let codes_validity = self.codes().validity_mask()?;
163        match codes_validity.boolean_buffer() {
164            AllOr::All => {
165                let primitive_codes = self.codes().to_primitive()?;
166                let values_mask = self.values().validity_mask()?;
167                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
168                    let codes_slice = primitive_codes.as_slice::<$P>();
169                    BooleanBuffer::collect_bool(self.len(), |idx| {
170                       values_mask.value(codes_slice[idx] as usize)
171                    })
172                });
173                Ok(Mask::from_buffer(is_valid_buffer))
174            }
175            AllOr::None => Ok(Mask::AllFalse(self.len())),
176            AllOr::Some(validity_buff) => {
177                let primitive_codes = self.codes().to_primitive()?;
178                let values_mask = self.values().validity_mask()?;
179                let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
180                    let codes_slice = primitive_codes.as_slice::<$P>();
181                    BooleanBuffer::collect_bool(self.len(), |idx| {
182                       validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
183                    })
184                });
185                Ok(Mask::from_buffer(is_valid_buffer))
186            }
187        }
188    }
189}
190
191impl ArrayStatisticsImpl for DictArray {
192    fn _stats_ref(&self) -> StatsSetRef<'_> {
193        self.stats_set.to_ref(self)
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use arrow_buffer::BooleanBuffer;
200    use rand::distr::{Distribution, StandardUniform};
201    use rand::prelude::StdRng;
202    use rand::{Rng, SeedableRng};
203    use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
204    use vortex_array::builders::builder_with_capacity;
205    use vortex_array::validity::Validity;
206    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
207    use vortex_buffer::buffer;
208    use vortex_dtype::Nullability::NonNullable;
209    use vortex_dtype::{DType, NativePType, PType};
210    use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
211    use vortex_mask::AllOr;
212
213    use crate::DictArray;
214
215    #[test]
216    fn nullable_codes_validity() {
217        let dict = DictArray::try_new(
218            PrimitiveArray::new(
219                buffer![0u32, 1, 2, 2, 1],
220                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
221            )
222            .into_array(),
223            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
224        )
225        .unwrap();
226        let mask = dict.validity_mask().unwrap();
227        let AllOr::Some(indices) = mask.indices() else {
228            vortex_panic!("Expected indices from mask")
229        };
230        assert_eq!(indices, [0, 2, 4]);
231    }
232
233    #[test]
234    fn nullable_values_validity() {
235        let dict = DictArray::try_new(
236            buffer![0u32, 1, 2, 2, 1].into_array(),
237            PrimitiveArray::new(
238                buffer![3, 6, 9],
239                Validity::from(BooleanBuffer::from(vec![true, false, false])),
240            )
241            .into_array(),
242        )
243        .unwrap();
244        let mask = dict.validity_mask().unwrap();
245        let AllOr::Some(indices) = mask.indices() else {
246            vortex_panic!("Expected indices from mask")
247        };
248        assert_eq!(indices, [0]);
249    }
250
251    #[test]
252    fn nullable_codes_and_values() {
253        let dict = DictArray::try_new(
254            PrimitiveArray::new(
255                buffer![0u32, 1, 2, 2, 1],
256                Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
257            )
258            .into_array(),
259            PrimitiveArray::new(
260                buffer![3, 6, 9],
261                Validity::from(BooleanBuffer::from(vec![false, true, true])),
262            )
263            .into_array(),
264        )
265        .unwrap();
266        let mask = dict.validity_mask().unwrap();
267        let AllOr::Some(indices) = mask.indices() else {
268            vortex_panic!("Expected indices from mask")
269        };
270        assert_eq!(indices, [2, 4]);
271    }
272
273    fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
274        len: usize,
275        unique_values: usize,
276        chunk_count: usize,
277    ) -> ArrayRef
278    where
279        StandardUniform: Distribution<T>,
280    {
281        let mut rng = StdRng::seed_from_u64(0);
282
283        (0..chunk_count)
284            .map(|_| {
285                let values = (0..unique_values)
286                    .map(|_| rng.random::<T>())
287                    .collect::<PrimitiveArray>();
288                let codes = (0..len)
289                    .map(|_| {
290                        U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
291                    })
292                    .collect::<PrimitiveArray>();
293
294                DictArray::try_new(codes.into_array(), values.into_array())
295                    .vortex_unwrap()
296                    .into_array()
297            })
298            .collect::<ChunkedArray>()
299            .into_array()
300    }
301
302    #[test]
303    fn test_dict_array_from_primitive_chunks() {
304        let len = 2;
305        let chunk_count = 2;
306        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
307
308        let mut builder = builder_with_capacity(
309            &DType::Primitive(PType::U64, NonNullable),
310            len * chunk_count,
311        );
312        array
313            .clone()
314            .append_to_builder(builder.as_mut())
315            .vortex_unwrap();
316
317        let into_prim = array.to_primitive().unwrap();
318        let prim_into = builder.finish().to_primitive().unwrap();
319
320        assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
321        assert_eq!(
322            into_prim.validity_mask().unwrap().boolean_buffer(),
323            prim_into.validity_mask().unwrap().boolean_buffer()
324        )
325    }
326}