Skip to main content

vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use num_traits::AsPrimitive;
8use smallvec::smallvec;
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_mask::AllOr;
15
16use crate::ArrayRef;
17use crate::ArraySlots;
18use crate::LEGACY_SESSION;
19#[expect(deprecated)]
20use crate::ToCanonical as _;
21use crate::VortexSessionExecute;
22use crate::array::Array;
23use crate::array::ArrayParts;
24use crate::array::TypedArrayRef;
25use crate::array_slots;
26use crate::arrays::Dict;
27use crate::dtype::DType;
28use crate::dtype::PType;
29use crate::match_each_integer_ptype;
30
31#[derive(Clone, prost::Message)]
32pub struct DictMetadata {
33    #[prost(uint32, tag = "1")]
34    pub(super) values_len: u32,
35    #[prost(enumeration = "PType", tag = "2")]
36    pub(super) codes_ptype: i32,
37    // nullable codes are optional since they were added after stabilisation.
38    #[prost(optional, bool, tag = "3")]
39    pub(super) is_nullable_codes: Option<bool>,
40    // all_values_referenced is optional for backward compatibility.
41    // true = all dictionary values are definitely referenced by at least one code.
42    // false/None = unknown whether all values are referenced (conservative default).
43    #[prost(optional, bool, tag = "4")]
44    pub(super) all_values_referenced: Option<bool>,
45}
46
47#[array_slots(Dict)]
48pub struct DictSlots {
49    /// The codes array mapping each element to a dictionary entry.
50    pub codes: ArrayRef,
51    /// The dictionary values array containing the unique values.
52    pub values: ArrayRef,
53}
54
55#[derive(Debug, Clone)]
56pub struct DictData {
57    /// Indicates whether all dictionary values are definitely referenced by at least one code.
58    /// `true` = all values are referenced (computed during encoding).
59    /// `false` = unknown/might have unreferenced values.
60    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
61    /// incorrect behaviour.
62    pub(super) all_values_referenced: bool,
63}
64
65impl Display for DictData {
66    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67        write!(f, "all_values_referenced: {}", self.all_values_referenced)
68    }
69}
70
71impl DictData {
72    /// Build a new `DictArray` without validating the codes or values.
73    ///
74    /// # Safety
75    /// This should be called only when you can guarantee the invariants checked
76    /// by the safe `DictArray::try_new` constructor are valid, for example when
77    /// you are filtering or slicing an existing valid `DictArray`.
78    pub unsafe fn new_unchecked() -> Self {
79        Self {
80            all_values_referenced: false,
81        }
82    }
83
84    /// Set whether all dictionary values are definitely referenced.
85    ///
86    /// # Safety
87    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
88    /// values are actually referenced by at least one valid code. Setting this incorrectly can
89    /// lead to incorrect query results in operations like min/max.
90    ///
91    /// This is typically only set to `true` during dictionary encoding when we know for certain
92    /// that all values are referenced.
93    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
94        self.all_values_referenced = all_values_referenced;
95        self
96    }
97
98    /// Build a new `DictArray` from its components, `codes` and `values`.
99    ///
100    /// This constructor will panic if `codes` or `values` do not pass validation for building
101    /// a new `DictArray`. See `DictArray::try_new` for a description of the error conditions.
102    pub fn new(codes_dtype: &DType) -> Self {
103        Self::try_new(codes_dtype).vortex_expect("DictArray new")
104    }
105
106    /// Build a new `DictArray` from its components, `codes` and `values`.
107    ///
108    /// The codes must be integers, and may be nullable. Values can be any
109    /// type, and may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
110    ///
111    /// # Errors
112    ///
113    /// The `codes` **must** be integers, and the maximum code must be less than the length
114    /// of the `values` array. Otherwise, this constructor returns an error.
115    ///
116    /// It is an error to provide a nullable `codes` with non-nullable `values`.
117    pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
118        if !codes_dtype.is_int() {
119            vortex_bail!(MismatchedTypes: "int", codes_dtype);
120        }
121
122        Ok(unsafe { Self::new_unchecked() })
123    }
124}
125
126pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
127    #[inline]
128    fn has_all_values_referenced(&self) -> bool {
129        self.all_values_referenced
130    }
131
132    fn validate_all_values_referenced(&self) -> VortexResult<()> {
133        if self.has_all_values_referenced() {
134            if !self.codes().is_host() {
135                return Ok(());
136            }
137
138            let referenced_mask = self.compute_referenced_values_mask(true)?;
139            let all_referenced = referenced_mask.true_count() == referenced_mask.len();
140
141            vortex_ensure!(all_referenced, "value in dict not referenced");
142        }
143
144        Ok(())
145    }
146
147    fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
148        let codes = self.codes();
149        let codes_validity = codes
150            .validity()?
151            .execute_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
152        #[expect(deprecated)]
153        let codes_primitive = self.codes().to_primitive();
154        let values_len = self.values().len();
155
156        let init_value = !referenced;
157        let referenced_value = referenced;
158
159        let mut values_vec = vec![init_value; values_len];
160        match codes_validity.bit_buffer() {
161            AllOr::All => {
162                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
163                    for idx in codes_primitive.as_slice::<P>() {
164                        let idxu: usize = idx.as_();
165                        values_vec[idxu] = referenced_value;
166                    }
167                });
168            }
169            AllOr::None => {}
170            AllOr::Some(mask) => {
171                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
172                    let codes = codes_primitive.as_slice::<P>();
173                    mask.set_indices().for_each(|idx| {
174                        let idxu: usize = codes[idx].as_();
175                        values_vec[idxu] = referenced_value;
176                    });
177                });
178            }
179        }
180
181        Ok(BitBuffer::from(values_vec))
182    }
183}
184impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
185
186/// Concrete parts of a [`DictArray`](super::DictArray) after iterative execution.
187pub struct DictParts {
188    pub dtype: DType,
189    pub codes: ArrayRef,
190    pub values: ArrayRef,
191}
192
193pub trait DictOwnedExt {
194    fn into_parts(self) -> DictParts;
195}
196
197impl DictOwnedExt for Array<Dict> {
198    fn into_parts(self) -> DictParts {
199        match self.try_into_parts() {
200            Ok(array_parts) => {
201                let slots = DictSlots::from_slots(array_parts.slots);
202                DictParts {
203                    dtype: array_parts.dtype,
204                    codes: slots.codes,
205                    values: slots.values,
206                }
207            }
208            Err(array) => {
209                let slots = DictSlotsView::from_slots(array.slots());
210                DictParts {
211                    dtype: array.dtype().clone(),
212                    codes: slots.codes.clone(),
213                    values: slots.values.clone(),
214                }
215            }
216        }
217    }
218}
219
220impl Array<Dict> {
221    /// Build a new `DictArray` from its components, `codes` and `values`.
222    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
223        Self::try_new(codes, values).vortex_expect("DictArray new")
224    }
225
226    /// Build a new `DictArray` from its components, `codes` and `values`.
227    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
228        let dtype = values
229            .dtype()
230            .union_nullability(codes.dtype().nullability());
231        let len = codes.len();
232        let data = DictData::try_new(codes.dtype())?;
233        Array::try_from_parts(
234            ArrayParts::new(Dict, dtype, len, data)
235                .with_slots(smallvec![Some(codes), Some(values)]),
236        )
237    }
238
239    /// Build a new `DictArray` without validating the codes or values.
240    ///
241    /// # Safety
242    ///
243    /// See [`DictData::new_unchecked`].
244    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
245        let dtype = values
246            .dtype()
247            .union_nullability(codes.dtype().nullability());
248        let len = codes.len();
249        let data = unsafe { DictData::new_unchecked() };
250        unsafe {
251            Array::from_parts_unchecked(
252                ArrayParts::new(Dict, dtype, len, data)
253                    .with_slots(smallvec![Some(codes), Some(values)]),
254            )
255        }
256    }
257
258    /// Set whether all values in the dictionary are referenced by at least one code.
259    ///
260    /// # Safety
261    ///
262    /// See [`DictData::set_all_values_referenced`].
263    pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
264        let dtype = self.dtype().clone();
265        let len = self.len();
266        let slots: ArraySlots = self.slots().iter().cloned().collect();
267        let data = unsafe {
268            self.into_data()
269                .set_all_values_referenced(all_values_referenced)
270        };
271        let array = unsafe {
272            Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
273        };
274
275        #[cfg(debug_assertions)]
276        if all_values_referenced {
277            array
278                .validate_all_values_referenced()
279                .vortex_expect("validation should succeed when all values are referenced");
280        }
281
282        array
283    }
284}
285
286#[cfg(test)]
287mod test {
288    use rand::RngExt;
289    use rand::SeedableRng;
290    use rand::distr::Distribution;
291    use rand::distr::StandardUniform;
292    use rand::prelude::StdRng;
293    use vortex_buffer::BitBuffer;
294    use vortex_buffer::buffer;
295    use vortex_error::VortexExpect;
296    use vortex_error::VortexResult;
297    use vortex_error::vortex_panic;
298    use vortex_mask::AllOr;
299
300    use crate::ArrayRef;
301    use crate::IntoArray;
302    #[expect(deprecated)]
303    use crate::ToCanonical as _;
304    use crate::VortexSessionExecute;
305    use crate::array_session;
306    use crate::arrays::ChunkedArray;
307    use crate::arrays::DictArray;
308    use crate::arrays::PrimitiveArray;
309    use crate::assert_arrays_eq;
310    use crate::builders::builder_with_capacity;
311    use crate::dtype::DType;
312    use crate::dtype::NativePType;
313    use crate::dtype::Nullability::NonNullable;
314    use crate::dtype::PType;
315    use crate::dtype::UnsignedPType;
316    use crate::validity::Validity;
317
318    #[test]
319    fn nullable_codes_validity() {
320        let dict = DictArray::try_new(
321            PrimitiveArray::new(
322                buffer![0u32, 1, 2, 2, 1],
323                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
324            )
325            .into_array(),
326            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
327        )
328        .unwrap();
329        let mask = dict
330            .as_ref()
331            .validity()
332            .unwrap()
333            .execute_mask(
334                dict.as_ref().len(),
335                &mut array_session().create_execution_ctx(),
336            )
337            .unwrap();
338        let AllOr::Some(indices) = mask.indices() else {
339            vortex_panic!("Expected indices from mask")
340        };
341        assert_eq!(indices, [0, 2, 4]);
342    }
343
344    #[test]
345    fn nullable_values_validity() {
346        let dict = DictArray::try_new(
347            buffer![0u32, 1, 2, 2, 1].into_array(),
348            PrimitiveArray::new(
349                buffer![3, 6, 9],
350                Validity::from(BitBuffer::from(vec![true, false, false])),
351            )
352            .into_array(),
353        )
354        .unwrap();
355        let mask = dict
356            .as_ref()
357            .validity()
358            .unwrap()
359            .execute_mask(
360                dict.as_ref().len(),
361                &mut array_session().create_execution_ctx(),
362            )
363            .unwrap();
364        let AllOr::Some(indices) = mask.indices() else {
365            vortex_panic!("Expected indices from mask")
366        };
367        assert_eq!(indices, [0]);
368    }
369
370    #[test]
371    fn nullable_codes_and_values() {
372        let dict = DictArray::try_new(
373            PrimitiveArray::new(
374                buffer![0u32, 1, 2, 2, 1],
375                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
376            )
377            .into_array(),
378            PrimitiveArray::new(
379                buffer![3, 6, 9],
380                Validity::from(BitBuffer::from(vec![false, true, true])),
381            )
382            .into_array(),
383        )
384        .unwrap();
385        let mask = dict
386            .as_ref()
387            .validity()
388            .unwrap()
389            .execute_mask(
390                dict.as_ref().len(),
391                &mut array_session().create_execution_ctx(),
392            )
393            .unwrap();
394        let AllOr::Some(indices) = mask.indices() else {
395            vortex_panic!("Expected indices from mask")
396        };
397        assert_eq!(indices, [2, 4]);
398    }
399
400    #[test]
401    fn nullable_codes_and_non_null_values() {
402        let dict = DictArray::try_new(
403            PrimitiveArray::new(
404                buffer![0u32, 1, 2, 2, 1],
405                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
406            )
407            .into_array(),
408            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
409        )
410        .unwrap();
411        let mask = dict
412            .as_ref()
413            .validity()
414            .unwrap()
415            .execute_mask(
416                dict.as_ref().len(),
417                &mut array_session().create_execution_ctx(),
418            )
419            .unwrap();
420        let AllOr::Some(indices) = mask.indices() else {
421            vortex_panic!("Expected indices from mask")
422        };
423        assert_eq!(indices, [0, 2, 4]);
424    }
425
426    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
427        len: usize,
428        unique_values: usize,
429        chunk_count: usize,
430    ) -> ArrayRef
431    where
432        StandardUniform: Distribution<T>,
433    {
434        let mut rng = StdRng::seed_from_u64(0);
435
436        (0..chunk_count)
437            .map(|_| {
438                let values = (0..unique_values)
439                    .map(|_| rng.random::<T>())
440                    .collect::<PrimitiveArray>();
441                let codes = (0..len)
442                    .map(|_| {
443                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
444                    })
445                    .collect::<PrimitiveArray>();
446
447                DictArray::try_new(codes.into_array(), values.into_array())
448                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
449                    .into_array()
450            })
451            .collect::<ChunkedArray>()
452            .into_array()
453    }
454
455    #[test]
456    fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
457        let mut ctx = array_session().create_execution_ctx();
458        let len = 2;
459        let chunk_count = 2;
460        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
461
462        let mut builder = builder_with_capacity(
463            &DType::Primitive(PType::U64, NonNullable),
464            len * chunk_count,
465        );
466        array.append_to_builder(
467            builder.as_mut(),
468            &mut array_session().create_execution_ctx(),
469        )?;
470
471        #[expect(deprecated)]
472        let into_prim = array.to_primitive();
473        let prim_into = builder.finish_into_canonical().into_primitive();
474
475        assert_arrays_eq!(into_prim, prim_into, &mut ctx);
476        Ok(())
477    }
478
479    #[cfg_attr(miri, ignore)]
480    #[test]
481    fn test_dict_metadata() {
482        use prost::Message;
483
484        use super::DictMetadata;
485        use crate::test_harness::check_metadata;
486
487        check_metadata(
488            "dict.metadata",
489            &DictMetadata {
490                codes_ptype: PType::U64 as i32,
491                values_len: u32::MAX,
492                is_nullable_codes: None,
493                all_values_referenced: None,
494            }
495            .encode_to_vec(),
496        );
497    }
498}