Skip to main content

vortex_array/arrays/dict/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::BitBuffer;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::ToCanonical;
16use crate::stats::ArrayStats;
17
18#[derive(Clone, prost::Message)]
19pub struct DictMetadata {
20    #[prost(uint32, tag = "1")]
21    pub(super) values_len: u32,
22    #[prost(enumeration = "PType", tag = "2")]
23    pub(super) codes_ptype: i32,
24    // nullable codes are optional since they were added after stabilisation.
25    #[prost(optional, bool, tag = "3")]
26    pub(super) is_nullable_codes: Option<bool>,
27    // all_values_referenced is optional for backward compatibility.
28    // true = all dictionary values are definitely referenced by at least one code.
29    // false/None = unknown whether all values are referenced (conservative default).
30    #[prost(optional, bool, tag = "4")]
31    pub(super) all_values_referenced: Option<bool>,
32}
33
34#[derive(Debug, Clone)]
35pub struct DictArray {
36    pub(super) codes: ArrayRef,
37    pub(super) values: ArrayRef,
38    pub(super) stats_set: ArrayStats,
39    pub(super) dtype: DType,
40    /// Indicates whether all dictionary values are definitely referenced by at least one code.
41    /// `true` = all values are referenced (computed during encoding).
42    /// `false` = unknown/might have unreferenced values.
43    /// In case this is incorrect never use this to enable memory unsafe behaviour just semantically
44    /// incorrect behaviour.
45    pub(super) all_values_referenced: bool,
46}
47
48pub struct DictArrayParts {
49    pub codes: ArrayRef,
50    pub values: ArrayRef,
51    pub dtype: DType,
52}
53
54impl DictArray {
55    /// Build a new `DictArray` without validating the codes or values.
56    ///
57    /// # Safety
58    /// This should be called only when you can guarantee the invariants checked
59    /// by the safe [`DictArray::try_new`] constructor are valid, for example when
60    /// you are filtering or slicing an existing valid `DictArray`.
61    pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
62        let dtype = values
63            .dtype()
64            .union_nullability(codes.dtype().nullability());
65        Self {
66            codes,
67            values,
68            stats_set: Default::default(),
69            dtype,
70            all_values_referenced: false,
71        }
72    }
73
74    /// Set whether all dictionary values are definitely referenced.
75    ///
76    /// # Safety
77    /// The caller must ensure that when setting `all_values_referenced = true`, ALL dictionary
78    /// values are actually referenced by at least one valid code. Setting this incorrectly can
79    /// lead to incorrect query results in operations like min/max.
80    ///
81    /// This is typically only set to `true` during dictionary encoding when we know for certain
82    /// that all values are referenced.
83    pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
84        self.all_values_referenced = all_values_referenced;
85
86        #[cfg(debug_assertions)]
87        {
88            use vortex_error::VortexExpect;
89            self.validate_all_values_referenced()
90                .vortex_expect("validation should succeed when all values are referenced")
91        }
92
93        self
94    }
95
96    /// Build a new `DictArray` from its components, `codes` and `values`.
97    ///
98    /// This constructor will panic if `codes` or `values` do not pass validation for building
99    /// a new `DictArray`. See [`DictArray::try_new`] for a description of the error conditions.
100    pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
101        Self::try_new(codes, values).vortex_expect("DictArray new")
102    }
103
104    /// Build a new `DictArray` from its components, `codes` and `values`.
105    ///
106    /// The codes must be integers, and may be nullable. Values can be any
107    /// type, and may also be nullable. This mirrors the nullability of the Arrow `DictionaryArray`.
108    ///
109    /// # Errors
110    ///
111    /// The `codes` **must** be integers, and the maximum code must be less than the length
112    /// of the `values` array. Otherwise, this constructor returns an error.
113    ///
114    /// It is an error to provide a nullable `codes` with non-nullable `values`.
115    pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
116        if !codes.dtype().is_int() {
117            vortex_bail!(MismatchedTypes: "int", codes.dtype());
118        }
119
120        Ok(unsafe { Self::new_unchecked(codes, values) })
121    }
122
123    pub fn into_parts(self) -> DictArrayParts {
124        DictArrayParts {
125            codes: self.codes,
126            values: self.values,
127            dtype: self.dtype,
128        }
129    }
130
131    #[inline]
132    pub fn codes(&self) -> &ArrayRef {
133        &self.codes
134    }
135
136    #[inline]
137    pub fn values(&self) -> &ArrayRef {
138        &self.values
139    }
140
141    /// Returns `true` if all dictionary values are definitely referenced by at least one code.
142    ///
143    /// When `true`, operations like min/max can safely operate on all values without needing to
144    /// compute which values are actually referenced. When `false`, it is unknown whether all
145    /// values are referenced (conservative default).
146    #[inline]
147    pub fn has_all_values_referenced(&self) -> bool {
148        self.all_values_referenced
149    }
150
151    /// Validates that the `all_values_referenced` flag matches reality.
152    ///
153    /// Returns `Ok(())` if the flag is consistent with the actual referenced values,
154    /// or an error describing the mismatch.
155    ///
156    /// This is primarily useful for testing and debugging.
157    pub fn validate_all_values_referenced(&self) -> VortexResult<()> {
158        if self.all_values_referenced {
159            // Skip host-only validation when codes are not host-resident.
160            if !self.codes().is_host() {
161                return Ok(());
162            }
163
164            let referenced_mask = self.compute_referenced_values_mask(true)?;
165            let all_referenced = referenced_mask.iter().all(|v| v);
166
167            vortex_ensure!(all_referenced, "value in dict not referenced");
168        }
169
170        Ok(())
171    }
172
173    /// Compute a mask indicating which values in the dictionary are referenced by at least one code.
174    ///
175    /// When `referenced = true`, returns a `BitBuffer` where set bits (true) correspond to
176    /// referenced values, and unset bits (false) correspond to unreferenced values.
177    ///
178    /// When `referenced = false` (default for unreferenced values), returns the inverse:
179    /// set bits (true) correspond to unreferenced values, and unset bits (false) correspond
180    /// to referenced values.
181    ///
182    /// This is useful for operations like min/max that need to ignore unreferenced values.
183    pub fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
184        let codes_validity = self.codes().validity_mask()?;
185        let codes_primitive = self.codes().to_primitive();
186        let values_len = self.values().len();
187
188        // Initialize with the starting value: false for referenced, true for unreferenced
189        let init_value = !referenced;
190        // Value to set when we find a referenced code: true for referenced, false for unreferenced
191        let referenced_value = referenced;
192
193        let mut values_vec = vec![init_value; values_len];
194        match codes_validity.bit_buffer() {
195            AllOr::All => {
196                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
197                    #[allow(
198                        clippy::cast_possible_truncation,
199                        clippy::cast_sign_loss,
200                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
201                    )]
202                    for &code in codes_primitive.as_slice::<P>().iter() {
203                        values_vec[code as usize] = referenced_value;
204                    }
205                });
206            }
207            AllOr::None => {}
208            AllOr::Some(buf) => {
209                match_each_integer_ptype!(codes_primitive.ptype(), |P| {
210                    let codes = codes_primitive.as_slice::<P>();
211
212                    #[allow(
213                        clippy::cast_possible_truncation,
214                        clippy::cast_sign_loss,
215                        reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
216                    )]
217                    buf.set_indices().for_each(|idx| {
218                        values_vec[codes[idx] as usize] = referenced_value;
219                    })
220                });
221            }
222        }
223
224        Ok(BitBuffer::collect_bool(values_len, |idx| values_vec[idx]))
225    }
226}
227
228#[cfg(test)]
229mod test {
230    #[allow(unused_imports)]
231    use itertools::Itertools;
232    use rand::Rng;
233    use rand::SeedableRng;
234    use rand::distr::Distribution;
235    use rand::distr::StandardUniform;
236    use rand::prelude::StdRng;
237    use vortex_buffer::BitBuffer;
238    use vortex_buffer::buffer;
239    use vortex_dtype::DType;
240    use vortex_dtype::NativePType;
241    use vortex_dtype::Nullability::NonNullable;
242    use vortex_dtype::PType;
243    use vortex_dtype::UnsignedPType;
244    use vortex_error::VortexExpect;
245    use vortex_error::VortexResult;
246    use vortex_error::vortex_panic;
247    use vortex_mask::AllOr;
248
249    use crate::Array;
250    use crate::ArrayRef;
251    use crate::IntoArray;
252    use crate::LEGACY_SESSION;
253    use crate::ToCanonical;
254    use crate::VortexSessionExecute;
255    use crate::arrays::ChunkedArray;
256    use crate::arrays::PrimitiveArray;
257    use crate::arrays::dict::DictArray;
258    use crate::assert_arrays_eq;
259    use crate::builders::builder_with_capacity;
260    use crate::validity::Validity;
261
262    #[test]
263    fn nullable_codes_validity() {
264        let dict = DictArray::try_new(
265            PrimitiveArray::new(
266                buffer![0u32, 1, 2, 2, 1],
267                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
268            )
269            .into_array(),
270            PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
271        )
272        .unwrap();
273        let mask = dict.validity_mask().unwrap();
274        let AllOr::Some(indices) = mask.indices() else {
275            vortex_panic!("Expected indices from mask")
276        };
277        assert_eq!(indices, [0, 2, 4]);
278    }
279
280    #[test]
281    fn nullable_values_validity() {
282        let dict = DictArray::try_new(
283            buffer![0u32, 1, 2, 2, 1].into_array(),
284            PrimitiveArray::new(
285                buffer![3, 6, 9],
286                Validity::from(BitBuffer::from(vec![true, false, false])),
287            )
288            .into_array(),
289        )
290        .unwrap();
291        let mask = dict.validity_mask().unwrap();
292        let AllOr::Some(indices) = mask.indices() else {
293            vortex_panic!("Expected indices from mask")
294        };
295        assert_eq!(indices, [0]);
296    }
297
298    #[test]
299    fn nullable_codes_and_values() {
300        let dict = DictArray::try_new(
301            PrimitiveArray::new(
302                buffer![0u32, 1, 2, 2, 1],
303                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
304            )
305            .into_array(),
306            PrimitiveArray::new(
307                buffer![3, 6, 9],
308                Validity::from(BitBuffer::from(vec![false, true, true])),
309            )
310            .into_array(),
311        )
312        .unwrap();
313        let mask = dict.validity_mask().unwrap();
314        let AllOr::Some(indices) = mask.indices() else {
315            vortex_panic!("Expected indices from mask")
316        };
317        assert_eq!(indices, [2, 4]);
318    }
319
320    #[test]
321    fn nullable_codes_and_non_null_values() {
322        let dict = DictArray::try_new(
323            PrimitiveArray::new(
324                buffer![0u32, 1, 2, 2, 1],
325                Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
326            )
327            .into_array(),
328            PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
329        )
330        .unwrap();
331        let mask = dict.validity_mask().unwrap();
332        let AllOr::Some(indices) = mask.indices() else {
333            vortex_panic!("Expected indices from mask")
334        };
335        assert_eq!(indices, [0, 2, 4]);
336    }
337
338    fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
339        len: usize,
340        unique_values: usize,
341        chunk_count: usize,
342    ) -> ArrayRef
343    where
344        StandardUniform: Distribution<T>,
345    {
346        let mut rng = StdRng::seed_from_u64(0);
347
348        (0..chunk_count)
349            .map(|_| {
350                let values = (0..unique_values)
351                    .map(|_| rng.random::<T>())
352                    .collect::<PrimitiveArray>();
353                let codes = (0..len)
354                    .map(|_| {
355                        Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
356                    })
357                    .collect::<PrimitiveArray>();
358
359                DictArray::try_new(codes.into_array(), values.into_array())
360                    .vortex_expect("DictArray creation should succeed in arbitrary impl")
361                    .into_array()
362            })
363            .collect::<ChunkedArray>()
364            .into_array()
365    }
366
367    #[test]
368    fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
369        let len = 2;
370        let chunk_count = 2;
371        let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
372
373        let mut builder = builder_with_capacity(
374            &DType::Primitive(PType::U64, NonNullable),
375            len * chunk_count,
376        );
377        array
378            .clone()
379            .append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
380
381        let into_prim = array.to_primitive();
382        let prim_into = builder.finish_into_canonical().into_primitive();
383
384        assert_arrays_eq!(into_prim, prim_into);
385        Ok(())
386    }
387
388    #[cfg_attr(miri, ignore)]
389    #[test]
390    fn test_dict_metadata() {
391        use super::DictMetadata;
392        use crate::ProstMetadata;
393        use crate::test_harness::check_metadata;
394
395        check_metadata(
396            "dict.metadata",
397            ProstMetadata(DictMetadata {
398                codes_ptype: PType::U64 as i32,
399                values_len: u32::MAX,
400                is_nullable_codes: None,
401                all_values_referenced: None,
402            }),
403        );
404    }
405}