Skip to main content

vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::LEGACY_SESSION;
16#[expect(deprecated)]
17use crate::ToCanonical as _;
18use crate::VortexSessionExecute;
19use crate::array::Array;
20use crate::array::ArrayParts;
21use crate::array::TypedArrayRef;
22use crate::array_slots;
23use crate::arrays::Dict;
24use crate::dtype::DType;
25use crate::dtype::PType;
26use crate::match_each_integer_ptype;
27
28#[derive(Clone, prost::Message)]
29pub struct DictMetadata {
30    #[prost(uint32, tag = "1")]
31    pub(super) values_len: u32,
32    #[prost(enumeration = "PType", tag = "2")]
33    pub(super) codes_ptype: i32,
34    // nullable codes are optional since they were added after stabilisation.
35    #[prost(optional, bool, tag = "3")]
36    pub(super) is_nullable_codes: Option<bool>,
37    // all_values_referenced is optional for backward compatibility.
38    // true = all dictionary values are definitely referenced by at least one code.
39    // false/None = unknown whether all values are referenced (conservative default).
40    #[prost(optional, bool, tag = "4")]
41    pub(super) all_values_referenced: Option<bool>,
42}
43
44#[array_slots(Dict)]
45pub struct DictSlots {
46    /// The codes array mapping each element to a dictionary entry.
47    pub codes: ArrayRef,
48    /// The dictionary values array containing the unique values.
49    pub values: ArrayRef,
50}
51
52#[derive(Debug, Clone)]
53pub struct DictData {
54    /// Indicates whether all dictionary values are definitely referenced by at least one code.
55    /// `true` = all values are referenced (computed during encoding).
56    /// `false` = unknown/might have unreferenced values.
57    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
58    /// incorrect behaviour.
59    pub(super) all_values_referenced: bool,
60}
61
62impl Display for DictData {
63    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
64        write!(f, "all_values_referenced: {}", self.all_values_referenced)
65    }
66}
67
68impl DictData {
69    /// Build a new `DictArray` without validating the codes or values.
70    ///
71    /// # Safety
72    /// This should be called only when you can guarantee the invariants checked
73    /// by the safe `DictArray::try_new` constructor are valid, for example when
74    /// you are filtering or slicing an existing valid `DictArray`.
75    pub unsafe fn new_unchecked() -> Self {
76        Self {
77            all_values_referenced: false,
78        }
79    }
80
81    /// Set whether all dictionary values are definitely referenced.
82    ///
83    /// # Safety
84    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
85    /// values are actually referenced by at least one valid code. Setting this incorrectly can
86    /// lead to incorrect query results in operations like min/max.
87    ///
88    /// This is typically only set to `true` during dictionary encoding when we know for certain
89    /// that all values are referenced.
90    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
91        self.all_values_referenced = all_values_referenced;
92        self
93    }
94
95    /// Build a new `DictArray` from its components, `codes` and `values`.
96    ///
97    /// This constructor will panic if `codes` or `values` do not pass validation for building
98    /// a new `DictArray`. See `DictArray::try_new` for a description of the error conditions.
99    pub fn new(codes_dtype: &DType) -> Self {
100        Self::try_new(codes_dtype).vortex_expect("DictArray new")
101    }
102
103    /// Build a new `DictArray` from its components, `codes` and `values`.
104    ///
105    /// The codes must be integers, and may be nullable. Values can be any
106    /// type, and may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
107    ///
108    /// # Errors
109    ///
110    /// The `codes` **must** be integers, and the maximum code must be less than the length
111    /// of the `values` array. Otherwise, this constructor returns an error.
112    ///
113    /// It is an error to provide a nullable `codes` with non-nullable `values`.
114    pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
115        if !codes_dtype.is_int() {
116            vortex_bail!(MismatchedTypes: "int", codes_dtype);
117        }
118
119        Ok(unsafe { Self::new_unchecked() })
120    }
121}
122
123pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
124    #[inline]
125    fn has_all_values_referenced(&self) -> bool {
126        self.all_values_referenced
127    }
128
129    fn validate_all_values_referenced(&self) -> VortexResult<()> {
130        if self.has_all_values_referenced() {
131            if !self.codes().is_host() {
132                return Ok(());
133            }
134
135            let referenced_mask = self.compute_referenced_values_mask(true)?;
136            let all_referenced = referenced_mask.iter().all(|v| v);
137
138            vortex_ensure!(all_referenced, "value in dict not referenced");
139        }
140
141        Ok(())
142    }
143
144    fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
145        let codes = self.codes();
146        let codes_validity = codes
147            .validity()?
148            .execute_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
149        #[expect(deprecated)]
150        let codes_primitive = self.codes().to_primitive();
151        let values_len = self.values().len();
152
153        let init_value = !referenced;
154        let referenced_value = referenced;
155
156        let mut values_vec = vec![init_value; values_len];
157        match codes_validity.bit_buffer() {
158            AllOr::All => {
159                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
160                    #[allow(
161                        clippy::cast_possible_truncation,
162                        clippy::cast_sign_loss,
163                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
164                    )]
165                    for &idx in codes_primitive.as_slice::<P>() {
166                        values_vec[idx as usize] = referenced_value;
167                    }
168                });
169            }
170            AllOr::None => {}
171            AllOr::Some(mask) => {
172                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
173                    let codes = codes_primitive.as_slice::<P>();
174
175                    #[allow(
176                        clippy::cast_possible_truncation,
177                        clippy::cast_sign_loss,
178                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
179                    )]
180                    mask.set_indices().for_each(|idx| {
181                        values_vec[codes[idx] as usize] = referenced_value;
182                    });
183                });
184            }
185        }
186
187        Ok(BitBuffer::from(values_vec))
188    }
189}
190impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
191
192/// Concrete parts of a [`DictArray`](super::DictArray) after iterative execution.
193pub struct DictParts {
194    pub dtype: DType,
195    pub codes: ArrayRef,
196    pub values: ArrayRef,
197}
198
199pub trait DictOwnedExt {
200    fn into_parts(self) -> DictParts;
201}
202
203impl DictOwnedExt for Array<Dict> {
204    fn into_parts(self) -> DictParts {
205        match self.try_into_parts() {
206            Ok(array_parts) => {
207                let slots = DictSlots::from_slots(array_parts.slots);
208                DictParts {
209                    dtype: array_parts.dtype,
210                    codes: slots.codes,
211                    values: slots.values,
212                }
213            }
214            Err(array) => {
215                let slots = DictSlotsView::from_slots(array.slots());
216                DictParts {
217                    dtype: array.dtype().clone(),
218                    codes: slots.codes.clone(),
219                    values: slots.values.clone(),
220                }
221            }
222        }
223    }
224}
225
226impl Array<Dict> {
227    /// Build a new `DictArray` from its components, `codes` and `values`.
228    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
229        Self::try_new(codes, values).vortex_expect("DictArray new")
230    }
231
232    /// Build a new `DictArray` from its components, `codes` and `values`.
233    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
234        let dtype = values
235            .dtype()
236            .union_nullability(codes.dtype().nullability());
237        let len = codes.len();
238        let data = DictData::try_new(codes.dtype())?;
239        Array::try_from_parts(
240            ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
241        )
242    }
243
244    /// Build a new `DictArray` without validating the codes or values.
245    ///
246    /// # Safety
247    ///
248    /// See [`DictData::new_unchecked`].
249    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
250        let dtype = values
251            .dtype()
252            .union_nullability(codes.dtype().nullability());
253        let len = codes.len();
254        let data = unsafe { DictData::new_unchecked() };
255        unsafe {
256            Array::from_parts_unchecked(
257                ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
258            )
259        }
260    }
261
262    /// Set whether all values in the dictionary are referenced by at least one code.
263    ///
264    /// # Safety
265    ///
266    /// See [`DictData::set_all_values_referenced`].
267    pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
268        let dtype = self.dtype().clone();
269        let len = self.len();
270        let slots = self.slots().to_vec();
271        let data = unsafe {
272            self.into_data()
273                .set_all_values_referenced(all_values_referenced)
274        };
275        let array = unsafe {
276            Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
277        };
278
279        #[cfg(debug_assertions)]
280        if all_values_referenced {
281            array
282                .validate_all_values_referenced()
283                .vortex_expect("validation should succeed when all values are referenced");
284        }
285
286        array
287    }
288}
289
290#[cfg(test)]
291mod test {
292    use rand::RngExt;
293    use rand::SeedableRng;
294    use rand::distr::Distribution;
295    use rand::distr::StandardUniform;
296    use rand::prelude::StdRng;
297    use vortex_buffer::BitBuffer;
298    use vortex_buffer::buffer;
299    use vortex_error::VortexExpect;
300    use vortex_error::VortexResult;
301    use vortex_error::vortex_panic;
302    use vortex_mask::AllOr;
303
304    use crate::ArrayRef;
305    use crate::IntoArray;
306    use crate::LEGACY_SESSION;
307    #[expect(deprecated)]
308    use crate::ToCanonical as _;
309    use crate::VortexSessionExecute;
310    use crate::arrays::ChunkedArray;
311    use crate::arrays::DictArray;
312    use crate::arrays::PrimitiveArray;
313    use crate::assert_arrays_eq;
314    use crate::builders::builder_with_capacity;
315    use crate::dtype::DType;
316    use crate::dtype::NativePType;
317    use crate::dtype::Nullability::NonNullable;
318    use crate::dtype::PType;
319    use crate::dtype::UnsignedPType;
320    use crate::validity::Validity;
321
322    #[test]
323    fn nullable_codes_validity() {
324        let dict = DictArray::try_new(
325            PrimitiveArray::new(
326                buffer![0u32, 1, 2, 2, 1],
327                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
328            )
329            .into_array(),
330            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
331        )
332        .unwrap();
333        let mask = dict
334            .as_ref()
335            .validity()
336            .unwrap()
337            .execute_mask(
338                dict.as_ref().len(),
339                &mut LEGACY_SESSION.create_execution_ctx(),
340            )
341            .unwrap();
342        let AllOr::Some(indices) = mask.indices() else {
343            vortex_panic!("Expected indices from mask")
344        };
345        assert_eq!(indices, [0, 2, 4]);
346    }
347
348    #[test]
349    fn nullable_values_validity() {
350        let dict = DictArray::try_new(
351            buffer![0u32, 1, 2, 2, 1].into_array(),
352            PrimitiveArray::new(
353                buffer![3, 6, 9],
354                Validity::from(BitBuffer::from(vec![true, false, false])),
355            )
356            .into_array(),
357        )
358        .unwrap();
359        let mask = dict
360            .as_ref()
361            .validity()
362            .unwrap()
363            .execute_mask(
364                dict.as_ref().len(),
365                &mut LEGACY_SESSION.create_execution_ctx(),
366            )
367            .unwrap();
368        let AllOr::Some(indices) = mask.indices() else {
369            vortex_panic!("Expected indices from mask")
370        };
371        assert_eq!(indices, [0]);
372    }
373
374    #[test]
375    fn nullable_codes_and_values() {
376        let dict = DictArray::try_new(
377            PrimitiveArray::new(
378                buffer![0u32, 1, 2, 2, 1],
379                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
380            )
381            .into_array(),
382            PrimitiveArray::new(
383                buffer![3, 6, 9],
384                Validity::from(BitBuffer::from(vec![false, true, true])),
385            )
386            .into_array(),
387        )
388        .unwrap();
389        let mask = dict
390            .as_ref()
391            .validity()
392            .unwrap()
393            .execute_mask(
394                dict.as_ref().len(),
395                &mut LEGACY_SESSION.create_execution_ctx(),
396            )
397            .unwrap();
398        let AllOr::Some(indices) = mask.indices() else {
399            vortex_panic!("Expected indices from mask")
400        };
401        assert_eq!(indices, [2, 4]);
402    }
403
404    #[test]
405    fn nullable_codes_and_non_null_values() {
406        let dict = DictArray::try_new(
407            PrimitiveArray::new(
408                buffer![0u32, 1, 2, 2, 1],
409                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
410            )
411            .into_array(),
412            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
413        )
414        .unwrap();
415        let mask = dict
416            .as_ref()
417            .validity()
418            .unwrap()
419            .execute_mask(
420                dict.as_ref().len(),
421                &mut LEGACY_SESSION.create_execution_ctx(),
422            )
423            .unwrap();
424        let AllOr::Some(indices) = mask.indices() else {
425            vortex_panic!("Expected indices from mask")
426        };
427        assert_eq!(indices, [0, 2, 4]);
428    }
429
430    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
431        len: usize,
432        unique_values: usize,
433        chunk_count: usize,
434    ) -> ArrayRef
435    where
436        StandardUniform: Distribution<T>,
437    {
438        let mut rng = StdRng::seed_from_u64(0);
439
440        (0..chunk_count)
441            .map(|_| {
442                let values = (0..unique_values)
443                    .map(|_| rng.random::<T>())
444                    .collect::<PrimitiveArray>();
445                let codes = (0..len)
446                    .map(|_| {
447                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
448                    })
449                    .collect::<PrimitiveArray>();
450
451                DictArray::try_new(codes.into_array(), values.into_array())
452                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
453                    .into_array()
454            })
455            .collect::<ChunkedArray>()
456            .into_array()
457    }
458
459    #[test]
460    fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
461        let len = 2;
462        let chunk_count = 2;
463        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
464
465        let mut builder = builder_with_capacity(
466            &DType::Primitive(PType::U64, NonNullable),
467            len * chunk_count,
468        );
469        array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
470
471        #[expect(deprecated)]
472        let into_prim = array.to_primitive();
473        let prim_into = builder.finish_into_canonical().into_primitive();
474
475        assert_arrays_eq!(into_prim, prim_into);
476        Ok(())
477    }
478
479    #[cfg_attr(miri, ignore)]
480    #[test]
481    fn test_dict_metadata() {
482        use prost::Message;
483
484        use super::DictMetadata;
485        use crate::test_harness::check_metadata;
486
487        check_metadata(
488            "dict.metadata",
489            &DictMetadata {
490                codes_ptype: PType::U64 as i32,
491                values_len: u32::MAX,
492                is_nullable_codes: None,
493                all_values_referenced: None,
494            }
495            .encode_to_vec(),
496        );
497    }
498}