Skip to main content

vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use smallvec::smallvec;
8use vortex_buffer::BitBuffer;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_mask::AllOr;
14
15use crate::ArrayRef;
16use crate::ArraySlots;
17use crate::LEGACY_SESSION;
18#[expect(deprecated)]
19use crate::ToCanonical as _;
20use crate::VortexSessionExecute;
21use crate::array::Array;
22use crate::array::ArrayParts;
23use crate::array::TypedArrayRef;
24use crate::array_slots;
25use crate::arrays::Dict;
26use crate::dtype::DType;
27use crate::dtype::PType;
28use crate::match_each_integer_ptype;
29
30#[derive(Clone, prost::Message)]
31pub struct DictMetadata {
32    #[prost(uint32, tag = "1")]
33    pub(super) values_len: u32,
34    #[prost(enumeration = "PType", tag = "2")]
35    pub(super) codes_ptype: i32,
36    // nullable codes are optional since they were added after stabilisation.
37    #[prost(optional, bool, tag = "3")]
38    pub(super) is_nullable_codes: Option<bool>,
39    // all_values_referenced is optional for backward compatibility.
40    // true = all dictionary values are definitely referenced by at least one code.
41    // false/None = unknown whether all values are referenced (conservative default).
42    #[prost(optional, bool, tag = "4")]
43    pub(super) all_values_referenced: Option<bool>,
44}
45
46#[array_slots(Dict)]
47pub struct DictSlots {
48    /// The codes array mapping each element to a dictionary entry.
49    pub codes: ArrayRef,
50    /// The dictionary values array containing the unique values.
51    pub values: ArrayRef,
52}
53
54#[derive(Debug, Clone)]
55pub struct DictData {
56    /// Indicates whether all dictionary values are definitely referenced by at least one code.
57    /// `true` = all values are referenced (computed during encoding).
58    /// `false` = unknown/might have unreferenced values.
59    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
60    /// incorrect behaviour.
61    pub(super) all_values_referenced: bool,
62}
63
64impl Display for DictData {
65    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
66        write!(f, "all_values_referenced: {}", self.all_values_referenced)
67    }
68}
69
70impl DictData {
71    /// Build a new `DictArray` without validating the codes or values.
72    ///
73    /// # Safety
74    /// This should be called only when you can guarantee the invariants checked
75    /// by the safe `DictArray::try_new` constructor are valid, for example when
76    /// you are filtering or slicing an existing valid `DictArray`.
77    pub unsafe fn new_unchecked() -> Self {
78        Self {
79            all_values_referenced: false,
80        }
81    }
82
83    /// Set whether all dictionary values are definitely referenced.
84    ///
85    /// # Safety
86    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
87    /// values are actually referenced by at least one valid code. Setting this incorrectly can
88    /// lead to incorrect query results in operations like min/max.
89    ///
90    /// This is typically only set to `true` during dictionary encoding when we know for certain
91    /// that all values are referenced.
92    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
93        self.all_values_referenced = all_values_referenced;
94        self
95    }
96
97    /// Build a new `DictArray` from its components, `codes` and `values`.
98    ///
99    /// This constructor will panic if `codes` or `values` do not pass validation for building
100    /// a new `DictArray`. See `DictArray::try_new` for a description of the error conditions.
101    pub fn new(codes_dtype: &DType) -> Self {
102        Self::try_new(codes_dtype).vortex_expect("DictArray new")
103    }
104
105    /// Build a new `DictArray` from its components, `codes` and `values`.
106    ///
107    /// The codes must be integers, and may be nullable. Values can be any
108    /// type, and may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
109    ///
110    /// # Errors
111    ///
112    /// The `codes` **must** be integers, and the maximum code must be less than the length
113    /// of the `values` array. Otherwise, this constructor returns an error.
114    ///
115    /// It is an error to provide a nullable `codes` with non-nullable `values`.
116    pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
117        if !codes_dtype.is_int() {
118            vortex_bail!(MismatchedTypes: "int", codes_dtype);
119        }
120
121        Ok(unsafe { Self::new_unchecked() })
122    }
123}
124
125pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
126    #[inline]
127    fn has_all_values_referenced(&self) -> bool {
128        self.all_values_referenced
129    }
130
131    fn validate_all_values_referenced(&self) -> VortexResult<()> {
132        if self.has_all_values_referenced() {
133            if !self.codes().is_host() {
134                return Ok(());
135            }
136
137            let referenced_mask = self.compute_referenced_values_mask(true)?;
138            let all_referenced = referenced_mask.iter().all(|v| v);
139
140            vortex_ensure!(all_referenced, "value in dict not referenced");
141        }
142
143        Ok(())
144    }
145
146    fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
147        let codes = self.codes();
148        let codes_validity = codes
149            .validity()?
150            .execute_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
151        #[expect(deprecated)]
152        let codes_primitive = self.codes().to_primitive();
153        let values_len = self.values().len();
154
155        let init_value = !referenced;
156        let referenced_value = referenced;
157
158        let mut values_vec = vec![init_value; values_len];
159        match codes_validity.bit_buffer() {
160            AllOr::All => {
161                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
162                    #[allow(
163                        clippy::cast_possible_truncation,
164                        clippy::cast_sign_loss,
165                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
166                    )]
167                    for &idx in codes_primitive.as_slice::<P>() {
168                        values_vec[idx as usize] = referenced_value;
169                    }
170                });
171            }
172            AllOr::None => {}
173            AllOr::Some(mask) => {
174                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
175                    let codes = codes_primitive.as_slice::<P>();
176
177                    #[allow(
178                        clippy::cast_possible_truncation,
179                        clippy::cast_sign_loss,
180                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
181                    )]
182                    mask.set_indices().for_each(|idx| {
183                        values_vec[codes[idx] as usize] = referenced_value;
184                    });
185                });
186            }
187        }
188
189        Ok(BitBuffer::from(values_vec))
190    }
191}
192impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
193
194/// Concrete parts of a [`DictArray`](super::DictArray) after iterative execution.
195pub struct DictParts {
196    pub dtype: DType,
197    pub codes: ArrayRef,
198    pub values: ArrayRef,
199}
200
201pub trait DictOwnedExt {
202    fn into_parts(self) -> DictParts;
203}
204
205impl DictOwnedExt for Array<Dict> {
206    fn into_parts(self) -> DictParts {
207        match self.try_into_parts() {
208            Ok(array_parts) => {
209                let slots = DictSlots::from_slots(array_parts.slots);
210                DictParts {
211                    dtype: array_parts.dtype,
212                    codes: slots.codes,
213                    values: slots.values,
214                }
215            }
216            Err(array) => {
217                let slots = DictSlotsView::from_slots(array.slots());
218                DictParts {
219                    dtype: array.dtype().clone(),
220                    codes: slots.codes.clone(),
221                    values: slots.values.clone(),
222                }
223            }
224        }
225    }
226}
227
228impl Array<Dict> {
229    /// Build a new `DictArray` from its components, `codes` and `values`.
230    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
231        Self::try_new(codes, values).vortex_expect("DictArray new")
232    }
233
234    /// Build a new `DictArray` from its components, `codes` and `values`.
235    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
236        let dtype = values
237            .dtype()
238            .union_nullability(codes.dtype().nullability());
239        let len = codes.len();
240        let data = DictData::try_new(codes.dtype())?;
241        Array::try_from_parts(
242            ArrayParts::new(Dict, dtype, len, data)
243                .with_slots(smallvec![Some(codes), Some(values)]),
244        )
245    }
246
247    /// Build a new `DictArray` without validating the codes or values.
248    ///
249    /// # Safety
250    ///
251    /// See [`DictData::new_unchecked`].
252    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
253        let dtype = values
254            .dtype()
255            .union_nullability(codes.dtype().nullability());
256        let len = codes.len();
257        let data = unsafe { DictData::new_unchecked() };
258        unsafe {
259            Array::from_parts_unchecked(
260                ArrayParts::new(Dict, dtype, len, data)
261                    .with_slots(smallvec![Some(codes), Some(values)]),
262            )
263        }
264    }
265
266    /// Set whether all values in the dictionary are referenced by at least one code.
267    ///
268    /// # Safety
269    ///
270    /// See [`DictData::set_all_values_referenced`].
271    pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
272        let dtype = self.dtype().clone();
273        let len = self.len();
274        let slots: ArraySlots = self.slots().iter().cloned().collect();
275        let data = unsafe {
276            self.into_data()
277                .set_all_values_referenced(all_values_referenced)
278        };
279        let array = unsafe {
280            Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
281        };
282
283        #[cfg(debug_assertions)]
284        if all_values_referenced {
285            array
286                .validate_all_values_referenced()
287                .vortex_expect("validation should succeed when all values are referenced");
288        }
289
290        array
291    }
292}
293
294#[cfg(test)]
295mod test {
296    use rand::RngExt;
297    use rand::SeedableRng;
298    use rand::distr::Distribution;
299    use rand::distr::StandardUniform;
300    use rand::prelude::StdRng;
301    use vortex_buffer::BitBuffer;
302    use vortex_buffer::buffer;
303    use vortex_error::VortexExpect;
304    use vortex_error::VortexResult;
305    use vortex_error::vortex_panic;
306    use vortex_mask::AllOr;
307
308    use crate::ArrayRef;
309    use crate::IntoArray;
310    use crate::LEGACY_SESSION;
311    #[expect(deprecated)]
312    use crate::ToCanonical as _;
313    use crate::VortexSessionExecute;
314    use crate::arrays::ChunkedArray;
315    use crate::arrays::DictArray;
316    use crate::arrays::PrimitiveArray;
317    use crate::assert_arrays_eq;
318    use crate::builders::builder_with_capacity;
319    use crate::dtype::DType;
320    use crate::dtype::NativePType;
321    use crate::dtype::Nullability::NonNullable;
322    use crate::dtype::PType;
323    use crate::dtype::UnsignedPType;
324    use crate::validity::Validity;
325
326    #[test]
327    fn nullable_codes_validity() {
328        let dict = DictArray::try_new(
329            PrimitiveArray::new(
330                buffer![0u32, 1, 2, 2, 1],
331                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
332            )
333            .into_array(),
334            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
335        )
336        .unwrap();
337        let mask = dict
338            .as_ref()
339            .validity()
340            .unwrap()
341            .execute_mask(
342                dict.as_ref().len(),
343                &mut LEGACY_SESSION.create_execution_ctx(),
344            )
345            .unwrap();
346        let AllOr::Some(indices) = mask.indices() else {
347            vortex_panic!("Expected indices from mask")
348        };
349        assert_eq!(indices, [0, 2, 4]);
350    }
351
352    #[test]
353    fn nullable_values_validity() {
354        let dict = DictArray::try_new(
355            buffer![0u32, 1, 2, 2, 1].into_array(),
356            PrimitiveArray::new(
357                buffer![3, 6, 9],
358                Validity::from(BitBuffer::from(vec![true, false, false])),
359            )
360            .into_array(),
361        )
362        .unwrap();
363        let mask = dict
364            .as_ref()
365            .validity()
366            .unwrap()
367            .execute_mask(
368                dict.as_ref().len(),
369                &mut LEGACY_SESSION.create_execution_ctx(),
370            )
371            .unwrap();
372        let AllOr::Some(indices) = mask.indices() else {
373            vortex_panic!("Expected indices from mask")
374        };
375        assert_eq!(indices, [0]);
376    }
377
378    #[test]
379    fn nullable_codes_and_values() {
380        let dict = DictArray::try_new(
381            PrimitiveArray::new(
382                buffer![0u32, 1, 2, 2, 1],
383                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
384            )
385            .into_array(),
386            PrimitiveArray::new(
387                buffer![3, 6, 9],
388                Validity::from(BitBuffer::from(vec![false, true, true])),
389            )
390            .into_array(),
391        )
392        .unwrap();
393        let mask = dict
394            .as_ref()
395            .validity()
396            .unwrap()
397            .execute_mask(
398                dict.as_ref().len(),
399                &mut LEGACY_SESSION.create_execution_ctx(),
400            )
401            .unwrap();
402        let AllOr::Some(indices) = mask.indices() else {
403            vortex_panic!("Expected indices from mask")
404        };
405        assert_eq!(indices, [2, 4]);
406    }
407
408    #[test]
409    fn nullable_codes_and_non_null_values() {
410        let dict = DictArray::try_new(
411            PrimitiveArray::new(
412                buffer![0u32, 1, 2, 2, 1],
413                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
414            )
415            .into_array(),
416            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
417        )
418        .unwrap();
419        let mask = dict
420            .as_ref()
421            .validity()
422            .unwrap()
423            .execute_mask(
424                dict.as_ref().len(),
425                &mut LEGACY_SESSION.create_execution_ctx(),
426            )
427            .unwrap();
428        let AllOr::Some(indices) = mask.indices() else {
429            vortex_panic!("Expected indices from mask")
430        };
431        assert_eq!(indices, [0, 2, 4]);
432    }
433
434    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
435        len: usize,
436        unique_values: usize,
437        chunk_count: usize,
438    ) -> ArrayRef
439    where
440        StandardUniform: Distribution<T>,
441    {
442        let mut rng = StdRng::seed_from_u64(0);
443
444        (0..chunk_count)
445            .map(|_| {
446                let values = (0..unique_values)
447                    .map(|_| rng.random::<T>())
448                    .collect::<PrimitiveArray>();
449                let codes = (0..len)
450                    .map(|_| {
451                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
452                    })
453                    .collect::<PrimitiveArray>();
454
455                DictArray::try_new(codes.into_array(), values.into_array())
456                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
457                    .into_array()
458            })
459            .collect::<ChunkedArray>()
460            .into_array()
461    }
462
463    #[test]
464    fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
465        let len = 2;
466        let chunk_count = 2;
467        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
468
469        let mut builder = builder_with_capacity(
470            &DType::Primitive(PType::U64, NonNullable),
471            len * chunk_count,
472        );
473        array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
474
475        #[expect(deprecated)]
476        let into_prim = array.to_primitive();
477        let prim_into = builder.finish_into_canonical().into_primitive();
478
479        assert_arrays_eq!(into_prim, prim_into);
480        Ok(())
481    }
482
483    #[cfg_attr(miri, ignore)]
484    #[test]
485    fn test_dict_metadata() {
486        use prost::Message;
487
488        use super::DictMetadata;
489        use crate::test_harness::check_metadata;
490
491        check_metadata(
492            "dict.metadata",
493            &DictMetadata {
494                codes_ptype: PType::U64 as i32,
495                values_len: u32::MAX,
496                is_nullable_codes: None,
497                all_values_referenced: None,
498            }
499            .encode_to_vec(),
500        );
501    }
502}