Skip to main content

vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::ToCanonical;
16use crate::array::Array;
17use crate::array::ArrayParts;
18use crate::array::TypedArrayRef;
19use crate::array_slots;
20use crate::arrays::Dict;
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::match_each_integer_ptype;
24
25#[derive(Clone, prost::Message)]
26pub struct DictMetadata {
27    #[prost(uint32, tag = "1")]
28    pub(super) values_len: u32,
29    #[prost(enumeration = "PType", tag = "2")]
30    pub(super) codes_ptype: i32,
31    // nullable codes are optional since they were added after stabilisation.
32    #[prost(optional, bool, tag = "3")]
33    pub(super) is_nullable_codes: Option<bool>,
34    // all_values_referenced is optional for backward compatibility.
35    // true = all dictionary values are definitely referenced by at least one code.
36    // false/None = unknown whether all values are referenced (conservative default).
37    #[prost(optional, bool, tag = "4")]
38    pub(super) all_values_referenced: Option<bool>,
39}
40
41#[array_slots(Dict)]
42pub struct DictSlots {
43    /// The codes array mapping each element to a dictionary entry.
44    pub codes: ArrayRef,
45    /// The dictionary values array containing the unique values.
46    pub values: ArrayRef,
47}
48
49#[derive(Debug, Clone)]
50pub struct DictData {
51    /// Indicates whether all dictionary values are definitely referenced by at least one code.
52    /// `true` = all values are referenced (computed during encoding).
53    /// `false` = unknown/might have unreferenced values.
54    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
55    /// incorrect behaviour.
56    pub(super) all_values_referenced: bool,
57}
58
59impl Display for DictData {
60    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61        write!(f, "all_values_referenced: {}", self.all_values_referenced)
62    }
63}
64
65impl DictData {
66    /// Build a new `DictArray` without validating the codes or values.
67    ///
68    /// # Safety
69    /// This should be called only when you can guarantee the invariants checked
70    /// by the safe `DictArray::try_new` constructor are valid, for example when
71    /// you are filtering or slicing an existing valid `DictArray`.
72    pub unsafe fn new_unchecked() -> Self {
73        Self {
74            all_values_referenced: false,
75        }
76    }
77
78    /// Set whether all dictionary values are definitely referenced.
79    ///
80    /// # Safety
81    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
82    /// values are actually referenced by at least one valid code. Setting this incorrectly can
83    /// lead to incorrect query results in operations like min/max.
84    ///
85    /// This is typically only set to `true` during dictionary encoding when we know for certain
86    /// that all values are referenced.
87    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
88        self.all_values_referenced = all_values_referenced;
89        self
90    }
91
92    /// Build a new `DictArray` from its components, `codes` and `values`.
93    ///
94    /// This constructor will panic if `codes` or `values` do not pass validation for building
95    /// a new `DictArray`. See `DictArray::try_new` for a description of the error conditions.
96    pub fn new(codes_dtype: &DType) -> Self {
97        Self::try_new(codes_dtype).vortex_expect("DictArray new")
98    }
99
100    /// Build a new `DictArray` from its components, `codes` and `values`.
101    ///
102    /// The codes must be integers, and may be nullable. Values can be any
103    /// type, and may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
104    ///
105    /// # Errors
106    ///
107    /// The `codes` **must** be integers, and the maximum code must be less than the length
108    /// of the `values` array. Otherwise, this constructor returns an error.
109    ///
110    /// It is an error to provide a nullable `codes` with non-nullable `values`.
111    pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
112        if !codes_dtype.is_int() {
113            vortex_bail!(MismatchedTypes: "int", codes_dtype);
114        }
115
116        Ok(unsafe { Self::new_unchecked() })
117    }
118}
119
120pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
121    #[inline]
122    fn has_all_values_referenced(&self) -> bool {
123        self.all_values_referenced
124    }
125
126    fn validate_all_values_referenced(&self) -> VortexResult<()> {
127        if self.has_all_values_referenced() {
128            if !self.codes().is_host() {
129                return Ok(());
130            }
131
132            let referenced_mask = self.compute_referenced_values_mask(true)?;
133            let all_referenced = referenced_mask.iter().all(|v| v);
134
135            vortex_ensure!(all_referenced, "value in dict not referenced");
136        }
137
138        Ok(())
139    }
140
141    #[allow(
142        clippy::cognitive_complexity,
143        reason = "branching depends on validity representation and code type"
144    )]
145    fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
146        let codes_validity = self.codes().validity_mask()?;
147        let codes_primitive = self.codes().to_primitive();
148        let values_len = self.values().len();
149
150        let init_value = !referenced;
151        let referenced_value = referenced;
152
153        let mut values_vec = vec![init_value; values_len];
154        match codes_validity.bit_buffer() {
155            AllOr::All => {
156                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
157                    #[allow(
158                        clippy::cast_possible_truncation,
159                        clippy::cast_sign_loss,
160                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
161                    )]
162                    for &idx in codes_primitive.as_slice::<P>() {
163                        values_vec[idx as usize] = referenced_value;
164                    }
165                });
166            }
167            AllOr::None => {}
168            AllOr::Some(mask) => {
169                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
170                    let codes = codes_primitive.as_slice::<P>();
171
172                    #[allow(
173                        clippy::cast_possible_truncation,
174                        clippy::cast_sign_loss,
175                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
176                    )]
177                    mask.set_indices().for_each(|idx| {
178                        values_vec[codes[idx] as usize] = referenced_value;
179                    });
180                });
181            }
182        }
183
184        Ok(BitBuffer::from(values_vec))
185    }
186}
187impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
188
189impl Array<Dict> {
190    /// Build a new `DictArray` from its components, `codes` and `values`.
191    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
192        Self::try_new(codes, values).vortex_expect("DictArray new")
193    }
194
195    /// Build a new `DictArray` from its components, `codes` and `values`.
196    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
197        let dtype = values
198            .dtype()
199            .union_nullability(codes.dtype().nullability());
200        let len = codes.len();
201        let data = DictData::try_new(codes.dtype())?;
202        Array::try_from_parts(
203            ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
204        )
205    }
206
207    /// Build a new `DictArray` without validating the codes or values.
208    ///
209    /// # Safety
210    ///
211    /// See [`DictData::new_unchecked`].
212    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
213        let dtype = values
214            .dtype()
215            .union_nullability(codes.dtype().nullability());
216        let len = codes.len();
217        let data = unsafe { DictData::new_unchecked() };
218        unsafe {
219            Array::from_parts_unchecked(
220                ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
221            )
222        }
223    }
224
225    /// Set whether all values in the dictionary are referenced by at least one code.
226    ///
227    /// # Safety
228    ///
229    /// See [`DictData::set_all_values_referenced`].
230    pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
231        let dtype = self.dtype().clone();
232        let len = self.len();
233        let slots = self.slots().to_vec();
234        let data = unsafe {
235            self.into_data()
236                .set_all_values_referenced(all_values_referenced)
237        };
238        let array = unsafe {
239            Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
240        };
241
242        #[cfg(debug_assertions)]
243        if all_values_referenced {
244            array
245                .validate_all_values_referenced()
246                .vortex_expect("validation should succeed when all values are referenced");
247        }
248
249        array
250    }
251}
252
253#[cfg(test)]
254mod test {
255    #[allow(unused_imports)]
256    use itertools::Itertools;
257    use rand::RngExt;
258    use rand::SeedableRng;
259    use rand::distr::Distribution;
260    use rand::distr::StandardUniform;
261    use rand::prelude::StdRng;
262    use vortex_buffer::BitBuffer;
263    use vortex_buffer::buffer;
264    use vortex_error::VortexExpect;
265    use vortex_error::VortexResult;
266    use vortex_error::vortex_panic;
267    use vortex_mask::AllOr;
268
269    use crate::ArrayRef;
270    use crate::IntoArray;
271    use crate::LEGACY_SESSION;
272    use crate::ToCanonical;
273    use crate::VortexSessionExecute;
274    use crate::arrays::ChunkedArray;
275    use crate::arrays::DictArray;
276    use crate::arrays::PrimitiveArray;
277    use crate::assert_arrays_eq;
278    use crate::builders::builder_with_capacity;
279    use crate::dtype::DType;
280    use crate::dtype::NativePType;
281    use crate::dtype::Nullability::NonNullable;
282    use crate::dtype::PType;
283    use crate::dtype::UnsignedPType;
284    use crate::validity::Validity;
285
286    #[test]
287    fn nullable_codes_validity() {
288        let dict = DictArray::try_new(
289            PrimitiveArray::new(
290                buffer![0u32, 1, 2, 2, 1],
291                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
292            )
293            .into_array(),
294            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
295        )
296        .unwrap();
297        let mask = dict.validity_mask().unwrap();
298        let AllOr::Some(indices) = mask.indices() else {
299            vortex_panic!("Expected indices from mask")
300        };
301        assert_eq!(indices, [0, 2, 4]);
302    }
303
304    #[test]
305    fn nullable_values_validity() {
306        let dict = DictArray::try_new(
307            buffer![0u32, 1, 2, 2, 1].into_array(),
308            PrimitiveArray::new(
309                buffer![3, 6, 9],
310                Validity::from(BitBuffer::from(vec![true, false, false])),
311            )
312            .into_array(),
313        )
314        .unwrap();
315        let mask = dict.validity_mask().unwrap();
316        let AllOr::Some(indices) = mask.indices() else {
317            vortex_panic!("Expected indices from mask")
318        };
319        assert_eq!(indices, [0]);
320    }
321
322    #[test]
323    fn nullable_codes_and_values() {
324        let dict = DictArray::try_new(
325            PrimitiveArray::new(
326                buffer![0u32, 1, 2, 2, 1],
327                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
328            )
329            .into_array(),
330            PrimitiveArray::new(
331                buffer![3, 6, 9],
332                Validity::from(BitBuffer::from(vec![false, true, true])),
333            )
334            .into_array(),
335        )
336        .unwrap();
337        let mask = dict.validity_mask().unwrap();
338        let AllOr::Some(indices) = mask.indices() else {
339            vortex_panic!("Expected indices from mask")
340        };
341        assert_eq!(indices, [2, 4]);
342    }
343
344    #[test]
345    fn nullable_codes_and_non_null_values() {
346        let dict = DictArray::try_new(
347            PrimitiveArray::new(
348                buffer![0u32, 1, 2, 2, 1],
349                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
350            )
351            .into_array(),
352            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
353        )
354        .unwrap();
355        let mask = dict.validity_mask().unwrap();
356        let AllOr::Some(indices) = mask.indices() else {
357            vortex_panic!("Expected indices from mask")
358        };
359        assert_eq!(indices, [0, 2, 4]);
360    }
361
362    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
363        len: usize,
364        unique_values: usize,
365        chunk_count: usize,
366    ) -> ArrayRef
367    where
368        StandardUniform: Distribution<T>,
369    {
370        let mut rng = StdRng::seed_from_u64(0);
371
372        (0..chunk_count)
373            .map(|_| {
374                let values = (0..unique_values)
375                    .map(|_| rng.random::<T>())
376                    .collect::<PrimitiveArray>();
377                let codes = (0..len)
378                    .map(|_| {
379                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
380                    })
381                    .collect::<PrimitiveArray>();
382
383                DictArray::try_new(codes.into_array(), values.into_array())
384                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
385                    .into_array()
386            })
387            .collect::<ChunkedArray>()
388            .into_array()
389    }
390
391    #[test]
392    fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
393        let len = 2;
394        let chunk_count = 2;
395        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
396
397        let mut builder = builder_with_capacity(
398            &DType::Primitive(PType::U64, NonNullable),
399            len * chunk_count,
400        );
401        array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
402
403        let into_prim = array.to_primitive();
404        let prim_into = builder.finish_into_canonical().into_primitive();
405
406        assert_arrays_eq!(into_prim, prim_into);
407        Ok(())
408    }
409
410    #[cfg_attr(miri, ignore)]
411    #[test]
412    fn test_dict_metadata() {
413        use prost::Message;
414
415        use super::DictMetadata;
416        use crate::test_harness::check_metadata;
417
418        check_metadata(
419            "dict.metadata",
420            &DictMetadata {
421                codes_ptype: PType::U64 as i32,
422                values_len: u32::MAX,
423                is_nullable_codes: None,
424                all_values_referenced: None,
425            }
426            .encode_to_vec(),
427        );
428    }
429}