Skip to main content

vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::LEGACY_SESSION;
16use crate::ToCanonical;
17use crate::VortexSessionExecute;
18use crate::array::Array;
19use crate::array::ArrayParts;
20use crate::array::TypedArrayRef;
21use crate::array_slots;
22use crate::arrays::Dict;
23use crate::dtype::DType;
24use crate::dtype::PType;
25use crate::match_each_integer_ptype;
26
27#[derive(Clone, prost::Message)]
28pub struct DictMetadata {
29    #[prost(uint32, tag = "1")]
30    pub(super) values_len: u32,
31    #[prost(enumeration = "PType", tag = "2")]
32    pub(super) codes_ptype: i32,
33    // nullable codes are optional since they were added after stabilisation.
34    #[prost(optional, bool, tag = "3")]
35    pub(super) is_nullable_codes: Option<bool>,
36    // all_values_referenced is optional for backward compatibility.
37    // true = all dictionary values are definitely referenced by at least one code.
38    // false/None = unknown whether all values are referenced (conservative default).
39    #[prost(optional, bool, tag = "4")]
40    pub(super) all_values_referenced: Option<bool>,
41}
42
43#[array_slots(Dict)]
44pub struct DictSlots {
45    /// The codes array mapping each element to a dictionary entry.
46    pub codes: ArrayRef,
47    /// The dictionary values array containing the unique values.
48    pub values: ArrayRef,
49}
50
51#[derive(Debug, Clone)]
52pub struct DictData {
53    /// Indicates whether all dictionary values are definitely referenced by at least one code.
54    /// `true` = all values are referenced (computed during encoding).
55    /// `false` = unknown/might have unreferenced values.
56    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
57    /// incorrect behaviour.
58    pub(super) all_values_referenced: bool,
59}
60
61impl Display for DictData {
62    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63        write!(f, "all_values_referenced: {}", self.all_values_referenced)
64    }
65}
66
67impl DictData {
68    /// Build a new `DictArray` without validating the codes or values.
69    ///
70    /// # Safety
71    /// This should be called only when you can guarantee the invariants checked
72    /// by the safe `DictArray::try_new` constructor are valid, for example when
73    /// you are filtering or slicing an existing valid `DictArray`.
74    pub unsafe fn new_unchecked() -> Self {
75        Self {
76            all_values_referenced: false,
77        }
78    }
79
80    /// Set whether all dictionary values are definitely referenced.
81    ///
82    /// # Safety
83    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
84    /// values are actually referenced by at least one valid code. Setting this incorrectly can
85    /// lead to incorrect query results in operations like min/max.
86    ///
87    /// This is typically only set to `true` during dictionary encoding when we know for certain
88    /// that all values are referenced.
89    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
90        self.all_values_referenced = all_values_referenced;
91        self
92    }
93
94    /// Build a new `DictArray` from its components, `codes` and `values`.
95    ///
96    /// This constructor will panic if `codes` or `values` do not pass validation for building
97    /// a new `DictArray`. See `DictArray::try_new` for a description of the error conditions.
98    pub fn new(codes_dtype: &DType) -> Self {
99        Self::try_new(codes_dtype).vortex_expect("DictArray new")
100    }
101
102    /// Build a new `DictArray` from its components, `codes` and `values`.
103    ///
104    /// The codes must be integers, and may be nullable. Values can be any
105    /// type, and may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
106    ///
107    /// # Errors
108    ///
109    /// The `codes` **must** be integers, and the maximum code must be less than the length
110    /// of the `values` array. Otherwise, this constructor returns an error.
111    ///
112    /// It is an error to provide a nullable `codes` with non-nullable `values`.
113    pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
114        if !codes_dtype.is_int() {
115            vortex_bail!(MismatchedTypes: "int", codes_dtype);
116        }
117
118        Ok(unsafe { Self::new_unchecked() })
119    }
120}
121
122pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
123    #[inline]
124    fn has_all_values_referenced(&self) -> bool {
125        self.all_values_referenced
126    }
127
128    fn validate_all_values_referenced(&self) -> VortexResult<()> {
129        if self.has_all_values_referenced() {
130            if !self.codes().is_host() {
131                return Ok(());
132            }
133
134            let referenced_mask = self.compute_referenced_values_mask(true)?;
135            let all_referenced = referenced_mask.iter().all(|v| v);
136
137            vortex_ensure!(all_referenced, "value in dict not referenced");
138        }
139
140        Ok(())
141    }
142
143    fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
144        let codes = self.codes();
145        let codes_validity = codes
146            .validity()?
147            .to_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
148        let codes_primitive = self.codes().to_primitive();
149        let values_len = self.values().len();
150
151        let init_value = !referenced;
152        let referenced_value = referenced;
153
154        let mut values_vec = vec![init_value; values_len];
155        match codes_validity.bit_buffer() {
156            AllOr::All => {
157                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
158                    #[allow(
159                        clippy::cast_possible_truncation,
160                        clippy::cast_sign_loss,
161                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
162                    )]
163                    for &idx in codes_primitive.as_slice::<P>() {
164                        values_vec[idx as usize] = referenced_value;
165                    }
166                });
167            }
168            AllOr::None => {}
169            AllOr::Some(mask) => {
170                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
171                    let codes = codes_primitive.as_slice::<P>();
172
173                    #[allow(
174                        clippy::cast_possible_truncation,
175                        clippy::cast_sign_loss,
176                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
177                    )]
178                    mask.set_indices().for_each(|idx| {
179                        values_vec[codes[idx] as usize] = referenced_value;
180                    });
181                });
182            }
183        }
184
185        Ok(BitBuffer::from(values_vec))
186    }
187}
188impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
189
190impl Array<Dict> {
191    /// Build a new `DictArray` from its components, `codes` and `values`.
192    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
193        Self::try_new(codes, values).vortex_expect("DictArray new")
194    }
195
196    /// Build a new `DictArray` from its components, `codes` and `values`.
197    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
198        let dtype = values
199            .dtype()
200            .union_nullability(codes.dtype().nullability());
201        let len = codes.len();
202        let data = DictData::try_new(codes.dtype())?;
203        Array::try_from_parts(
204            ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
205        )
206    }
207
208    /// Build a new `DictArray` without validating the codes or values.
209    ///
210    /// # Safety
211    ///
212    /// See [`DictData::new_unchecked`].
213    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
214        let dtype = values
215            .dtype()
216            .union_nullability(codes.dtype().nullability());
217        let len = codes.len();
218        let data = unsafe { DictData::new_unchecked() };
219        unsafe {
220            Array::from_parts_unchecked(
221                ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
222            )
223        }
224    }
225
226    /// Set whether all values in the dictionary are referenced by at least one code.
227    ///
228    /// # Safety
229    ///
230    /// See [`DictData::set_all_values_referenced`].
231    pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
232        let dtype = self.dtype().clone();
233        let len = self.len();
234        let slots = self.slots().to_vec();
235        let data = unsafe {
236            self.into_data()
237                .set_all_values_referenced(all_values_referenced)
238        };
239        let array = unsafe {
240            Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
241        };
242
243        #[cfg(debug_assertions)]
244        if all_values_referenced {
245            array
246                .validate_all_values_referenced()
247                .vortex_expect("validation should succeed when all values are referenced");
248        }
249
250        array
251    }
252}
253
254#[cfg(test)]
255mod test {
256    #[expect(unused_imports)]
257    use itertools::Itertools;
258    use rand::RngExt;
259    use rand::SeedableRng;
260    use rand::distr::Distribution;
261    use rand::distr::StandardUniform;
262    use rand::prelude::StdRng;
263    use vortex_buffer::BitBuffer;
264    use vortex_buffer::buffer;
265    use vortex_error::VortexExpect;
266    use vortex_error::VortexResult;
267    use vortex_error::vortex_panic;
268    use vortex_mask::AllOr;
269
270    use crate::ArrayRef;
271    use crate::IntoArray;
272    use crate::LEGACY_SESSION;
273    use crate::ToCanonical;
274    use crate::VortexSessionExecute;
275    use crate::arrays::ChunkedArray;
276    use crate::arrays::DictArray;
277    use crate::arrays::PrimitiveArray;
278    use crate::assert_arrays_eq;
279    use crate::builders::builder_with_capacity;
280    use crate::dtype::DType;
281    use crate::dtype::NativePType;
282    use crate::dtype::Nullability::NonNullable;
283    use crate::dtype::PType;
284    use crate::dtype::UnsignedPType;
285    use crate::validity::Validity;
286
287    #[test]
288    fn nullable_codes_validity() {
289        let dict = DictArray::try_new(
290            PrimitiveArray::new(
291                buffer![0u32, 1, 2, 2, 1],
292                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
293            )
294            .into_array(),
295            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
296        )
297        .unwrap();
298        let mask = dict
299            .as_ref()
300            .validity()
301            .unwrap()
302            .to_mask(
303                dict.as_ref().len(),
304                &mut LEGACY_SESSION.create_execution_ctx(),
305            )
306            .unwrap();
307        let AllOr::Some(indices) = mask.indices() else {
308            vortex_panic!("Expected indices from mask")
309        };
310        assert_eq!(indices, [0, 2, 4]);
311    }
312
313    #[test]
314    fn nullable_values_validity() {
315        let dict = DictArray::try_new(
316            buffer![0u32, 1, 2, 2, 1].into_array(),
317            PrimitiveArray::new(
318                buffer![3, 6, 9],
319                Validity::from(BitBuffer::from(vec![true, false, false])),
320            )
321            .into_array(),
322        )
323        .unwrap();
324        let mask = dict
325            .as_ref()
326            .validity()
327            .unwrap()
328            .to_mask(
329                dict.as_ref().len(),
330                &mut LEGACY_SESSION.create_execution_ctx(),
331            )
332            .unwrap();
333        let AllOr::Some(indices) = mask.indices() else {
334            vortex_panic!("Expected indices from mask")
335        };
336        assert_eq!(indices, [0]);
337    }
338
339    #[test]
340    fn nullable_codes_and_values() {
341        let dict = DictArray::try_new(
342            PrimitiveArray::new(
343                buffer![0u32, 1, 2, 2, 1],
344                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
345            )
346            .into_array(),
347            PrimitiveArray::new(
348                buffer![3, 6, 9],
349                Validity::from(BitBuffer::from(vec![false, true, true])),
350            )
351            .into_array(),
352        )
353        .unwrap();
354        let mask = dict
355            .as_ref()
356            .validity()
357            .unwrap()
358            .to_mask(
359                dict.as_ref().len(),
360                &mut LEGACY_SESSION.create_execution_ctx(),
361            )
362            .unwrap();
363        let AllOr::Some(indices) = mask.indices() else {
364            vortex_panic!("Expected indices from mask")
365        };
366        assert_eq!(indices, [2, 4]);
367    }
368
369    #[test]
370    fn nullable_codes_and_non_null_values() {
371        let dict = DictArray::try_new(
372            PrimitiveArray::new(
373                buffer![0u32, 1, 2, 2, 1],
374                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
375            )
376            .into_array(),
377            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
378        )
379        .unwrap();
380        let mask = dict
381            .as_ref()
382            .validity()
383            .unwrap()
384            .to_mask(
385                dict.as_ref().len(),
386                &mut LEGACY_SESSION.create_execution_ctx(),
387            )
388            .unwrap();
389        let AllOr::Some(indices) = mask.indices() else {
390            vortex_panic!("Expected indices from mask")
391        };
392        assert_eq!(indices, [0, 2, 4]);
393    }
394
395    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
396        len: usize,
397        unique_values: usize,
398        chunk_count: usize,
399    ) -> ArrayRef
400    where
401        StandardUniform: Distribution<T>,
402    {
403        let mut rng = StdRng::seed_from_u64(0);
404
405        (0..chunk_count)
406            .map(|_| {
407                let values = (0..unique_values)
408                    .map(|_| rng.random::<T>())
409                    .collect::<PrimitiveArray>();
410                let codes = (0..len)
411                    .map(|_| {
412                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
413                    })
414                    .collect::<PrimitiveArray>();
415
416                DictArray::try_new(codes.into_array(), values.into_array())
417                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
418                    .into_array()
419            })
420            .collect::<ChunkedArray>()
421            .into_array()
422    }
423
424    #[test]
425    fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
426        let len = 2;
427        let chunk_count = 2;
428        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
429
430        let mut builder = builder_with_capacity(
431            &DType::Primitive(PType::U64, NonNullable),
432            len * chunk_count,
433        );
434        array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
435
436        let into_prim = array.to_primitive();
437        let prim_into = builder.finish_into_canonical().into_primitive();
438
439        assert_arrays_eq!(into_prim, prim_into);
440        Ok(())
441    }
442
443    #[cfg_attr(miri, ignore)]
444    #[test]
445    fn test_dict_metadata() {
446        use prost::Message;
447
448        use super::DictMetadata;
449        use crate::test_harness::check_metadata;
450
451        check_metadata(
452            "dict.metadata",
453            &DictMetadata {
454                codes_ptype: PType::U64 as i32,
455                values_len: u32::MAX,
456                is_nullable_codes: None,
457                all_values_referenced: None,
458            }
459            .encode_to_vec(),
460        );
461    }
462}