vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::ToCanonical;
17use crate::stats::ArrayStats;
18
19#[derive(Clone, prost::Message)]
20pub struct DictMetadata {
21    #[prost(uint32, tag = "1")]
22    pub(super) values_len: u32,
23    #[prost(enumeration = "PType", tag = "2")]
24    pub(super) codes_ptype: i32,
25    // nullable codes are optional since they were added after stabilisation.
26    #[prost(optional, bool, tag = "3")]
27    pub(super) is_nullable_codes: Option<bool>,
28    // all_values_referenced is optional for backward compatibility.
29    // true = all dictionary values are definitely referenced by at least one code.
30    // false/None = unknown whether all values are referenced (conservative default).
31    #[prost(optional, bool, tag = "4")]
32    pub(super) all_values_referenced: Option<bool>,
33}
34
35#[derive(Debug, Clone)]
36pub struct DictArray {
37    pub(super) codes: ArrayRef,
38    pub(super) values: ArrayRef,
39    pub(super) stats_set: ArrayStats,
40    pub(super) dtype: DType,
41    /// Indicates whether all dictionary values are definitely referenced by at least one code.
42    /// `true` = all values are referenced (computed during encoding).
43    /// `false` = unknown/might have unreferenced values.
44    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
45    /// incorrect behaviour.
46    pub(super) all_values_referenced: bool,
47}
48
49impl DictArray {
50    /// Build a new `DictArray` without validating the codes or values.
51    ///
52    /// # Safety
53    /// This should be called only when you can guarantee the invariants checked
54    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
55    /// you are filtering or slicing an existing valid `DictArray`.
56    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
57        let dtype = values
58            .dtype()
59            .union_nullability(codes.dtype().nullability());
60        Self {
61            codes,
62            values,
63            stats_set: Default::default(),
64            dtype,
65            all_values_referenced: false,
66        }
67    }
68
69    /// Set whether all dictionary values are definitely referenced.
70    ///
71    /// # Safety
72    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
73    /// values are actually referenced by at least one valid code. Setting this incorrectly can
74    /// lead to incorrect query results in operations like min/max.
75    ///
76    /// This is typically only set to `true` during dictionary encoding when we know for certain
77    /// that all values are referenced.
78    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
79        self.all_values_referenced = all_values_referenced;
80
81        #[cfg(debug_assertions)]
82        {
83            use vortex_error::VortexExpect;
84            self.validate_all_values_referenced()
85                .vortex_expect("validation should succeed when all values are referenced")
86        }
87
88        self
89    }
90
91    /// Build a new `DictArray` from its components, `codes` and `values`.
92    ///
93    /// This constructor will panic if `codes` or `values` do not pass validation for building
94    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
95    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
96        Self::try_new(codes, values).vortex_expect("DictArray new")
97    }
98
99    /// Build a new `DictArray` from its components, `codes` and `values`.
100    ///
101    /// The codes must be unsigned integers, and may be nullable. Values can be any type, and
102    /// may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
103    ///
104    /// # Errors
105    ///
106    /// The `codes` **must** be unsigned integers, and the maximum code must be less than the length
107    /// of the `values` array. Otherwise, this constructor returns an error.
108    ///
109    /// It is an error to provide a nullable `codes` with non-nullable `values`.
110    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
111        if !codes.dtype().is_unsigned_int() {
112            vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
113        }
114
115        Ok(unsafe { Self::new_unchecked(codes, values) })
116    }
117
118    pub fn into_parts(self) -> (ArrayRef, ArrayRef) {
119        (self.codes, self.values)
120    }
121
122    #[inline]
123    pub fn codes(&self) -> &ArrayRef {
124        &self.codes
125    }
126
127    #[inline]
128    pub fn values(&self) -> &ArrayRef {
129        &self.values
130    }
131
132    /// Returns `true` if all dictionary values are definitely referenced by at least one code.
133    ///
134    /// When `true`, operations like min/max can safely operate on all values without needing to
135    /// compute which values are actually referenced. When `false`, it is unknown whether all
136    /// values are referenced (conservative default).
137    #[inline]
138    pub fn has_all_values_referenced(&self) -> bool {
139        self.all_values_referenced
140    }
141
142    /// Validates that the `all_values_referenced` flag matches reality.
143    ///
144    /// Returns `Ok(())` if the flag is consistent with the actual referenced values,
145    /// or an error describing the mismatch.
146    ///
147    /// This is primarily useful for testing and debugging.
148    pub fn validate_all_values_referenced(&self) -> VortexResult<()> {
149        if self.all_values_referenced {
150            let referenced_mask = self.compute_referenced_values_mask(true)?;
151            let all_referenced = referenced_mask.iter().all(|v| v);
152
153            vortex_ensure!(all_referenced, "value in dict not referenced");
154        }
155
156        Ok(())
157    }
158
159    /// Compute a mask indicating which values in the dictionary are referenced by at least one code.
160    ///
161    /// When `referenced = true`, returns a `BitBuffer` where set bits (true) correspond to
162    /// referenced values, and unset bits (false) correspond to unreferenced values.
163    ///
164    /// When `referenced = false` (default for unreferenced values), returns the inverse:
165    /// set bits (true) correspond to unreferenced values, and unset bits (false) correspond
166    /// to referenced values.
167    ///
168    /// This is useful for operations like min/max that need to ignore unreferenced values.
169    pub fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
170        let codes_validity = self.codes().validity_mask();
171        let codes_primitive = self.codes().to_primitive();
172        let values_len = self.values().len();
173
174        // Initialize with the starting value: false for referenced, true for unreferenced
175        let init_value = !referenced;
176        // Value to set when we find a referenced code: true for referenced, false for unreferenced
177        let referenced_value = referenced;
178
179        let mut values_vec = vec![init_value; values_len];
180        match codes_validity.bit_buffer() {
181            AllOr::All => {
182                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
183                    #[allow(clippy::cast_possible_truncation)]
184                    for &code in codes_primitive.as_slice::<P>().iter() {
185                        values_vec[code as usize] = referenced_value;
186                    }
187                });
188            }
189            AllOr::None => {}
190            AllOr::Some(buf) => {
191                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
192                    let codes = codes_primitive.as_slice::<P>();
193
194                    #[allow(clippy::cast_possible_truncation)]
195                    buf.set_indices().for_each(|idx| {
196                        values_vec[codes[idx] as usize] = referenced_value;
197                    })
198                });
199            }
200        }
201
202        Ok(BitBuffer::collect_bool(values_len, |idx| values_vec[idx]))
203    }
204}
205
206#[cfg(test)]
207mod test {
208    #[allow(unused_imports)]
209    use itertools::Itertools;
210    use rand::Rng;
211    use rand::SeedableRng;
212    use rand::distr::Distribution;
213    use rand::distr::StandardUniform;
214    use rand::prelude::StdRng;
215    use vortex_buffer::BitBuffer;
216    use vortex_buffer::buffer;
217    use vortex_dtype::DType;
218    use vortex_dtype::NativePType;
219    use vortex_dtype::Nullability::NonNullable;
220    use vortex_dtype::PType;
221    use vortex_dtype::UnsignedPType;
222    use vortex_error::VortexExpect;
223    use vortex_error::vortex_panic;
224    use vortex_mask::AllOr;
225
226    use crate::Array;
227    use crate::ArrayRef;
228    use crate::IntoArray;
229    use crate::ToCanonical;
230    use crate::arrays::ChunkedArray;
231    use crate::arrays::PrimitiveArray;
232    use crate::arrays::dict::DictArray;
233    use crate::assert_arrays_eq;
234    use crate::builders::builder_with_capacity;
235    use crate::validity::Validity;
236
237    #[test]
238    fn nullable_codes_validity() {
239        let dict = DictArray::try_new(
240            PrimitiveArray::new(
241                buffer![0u32, 1, 2, 2, 1],
242                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
243            )
244            .into_array(),
245            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
246        )
247        .unwrap();
248        let mask = dict.validity_mask();
249        let AllOr::Some(indices) = mask.indices() else {
250            vortex_panic!("Expected indices from mask")
251        };
252        assert_eq!(indices, [0, 2, 4]);
253    }
254
255    #[test]
256    fn nullable_values_validity() {
257        let dict = DictArray::try_new(
258            buffer![0u32, 1, 2, 2, 1].into_array(),
259            PrimitiveArray::new(
260                buffer![3, 6, 9],
261                Validity::from(BitBuffer::from(vec![true, false, false])),
262            )
263            .into_array(),
264        )
265        .unwrap();
266        let mask = dict.validity_mask();
267        let AllOr::Some(indices) = mask.indices() else {
268            vortex_panic!("Expected indices from mask")
269        };
270        assert_eq!(indices, [0]);
271    }
272
273    #[test]
274    fn nullable_codes_and_values() {
275        let dict = DictArray::try_new(
276            PrimitiveArray::new(
277                buffer![0u32, 1, 2, 2, 1],
278                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
279            )
280            .into_array(),
281            PrimitiveArray::new(
282                buffer![3, 6, 9],
283                Validity::from(BitBuffer::from(vec![false, true, true])),
284            )
285            .into_array(),
286        )
287        .unwrap();
288        let mask = dict.validity_mask();
289        let AllOr::Some(indices) = mask.indices() else {
290            vortex_panic!("Expected indices from mask")
291        };
292        assert_eq!(indices, [2, 4]);
293    }
294
295    #[test]
296    fn nullable_codes_and_non_null_values() {
297        let dict = DictArray::try_new(
298            PrimitiveArray::new(
299                buffer![0u32, 1, 2, 2, 1],
300                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
301            )
302            .into_array(),
303            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
304        )
305        .unwrap();
306        let mask = dict.validity_mask();
307        let AllOr::Some(indices) = mask.indices() else {
308            vortex_panic!("Expected indices from mask")
309        };
310        assert_eq!(indices, [0, 2, 4]);
311    }
312
313    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
314        len: usize,
315        unique_values: usize,
316        chunk_count: usize,
317    ) -> ArrayRef
318    where
319        StandardUniform: Distribution<T>,
320    {
321        let mut rng = StdRng::seed_from_u64(0);
322
323        (0..chunk_count)
324            .map(|_| {
325                let values = (0..unique_values)
326                    .map(|_| rng.random::<T>())
327                    .collect::<PrimitiveArray>();
328                let codes = (0..len)
329                    .map(|_| {
330                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
331                    })
332                    .collect::<PrimitiveArray>();
333
334                DictArray::try_new(codes.into_array(), values.into_array())
335                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
336                    .into_array()
337            })
338            .collect::<ChunkedArray>()
339            .into_array()
340    }
341
342    #[test]
343    fn test_dict_array_from_primitive_chunks() {
344        let len = 2;
345        let chunk_count = 2;
346        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
347
348        let mut builder = builder_with_capacity(
349            &DType::Primitive(PType::U64, NonNullable),
350            len * chunk_count,
351        );
352        array.clone().append_to_builder(builder.as_mut());
353
354        let into_prim = array.to_primitive();
355        let prim_into = builder.finish_into_canonical().into_primitive();
356
357        assert_arrays_eq!(into_prim, prim_into);
358    }
359
360    #[cfg_attr(miri, ignore)]
361    #[test]
362    fn test_dict_metadata() {
363        use super::DictMetadata;
364        use crate::ProstMetadata;
365        use crate::test_harness::check_metadata;
366
367        check_metadata(
368            "dict.metadata",
369            ProstMetadata(DictMetadata {
370                codes_ptype: PType::U64 as i32,
371                values_len: u32::MAX,
372                is_nullable_codes: None,
373                all_values_referenced: None,
374            }),
375        );
376    }
377}