vortex_array/builders/dict/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use bytes::bytes_dict_builder;
5use primitive::primitive_dict_builder;
6use vortex_dtype::match_each_native_ptype;
7use vortex_error::{VortexResult, vortex_bail, vortex_panic};
8
9use crate::arrays::{DictArray, PrimitiveVTable, VarBinVTable, VarBinViewVTable};
10use crate::{Array, ArrayRef, IntoArray, ToCanonical};
11
12mod bytes;
13mod primitive;
14
15#[derive(Clone)]
16pub struct DictConstraints {
17    pub max_bytes: usize,
18    pub max_len: usize,
19}
20
21pub const UNCONSTRAINED: DictConstraints = DictConstraints {
22    max_bytes: usize::MAX,
23    max_len: usize::MAX,
24};
25
26pub trait DictEncoder: Send {
27    /// Assign dictionary codes to the given input array.
28    fn encode(&mut self, array: &dyn Array) -> ArrayRef;
29
30    /// Clear the encoder state to make it ready for a new round of decoding.
31    fn reset(&mut self) -> ArrayRef;
32}
33
34pub fn dict_encoder(array: &dyn Array, constraints: &DictConstraints) -> Box<dyn DictEncoder> {
35    let dict_builder: Box<dyn DictEncoder> = if let Some(pa) = array.as_opt::<PrimitiveVTable>() {
36        match_each_native_ptype!(pa.ptype(), |P| {
37            primitive_dict_builder::<P>(pa.dtype().nullability(), constraints)
38        })
39    } else if let Some(vbv) = array.as_opt::<VarBinViewVTable>() {
40        bytes_dict_builder(vbv.dtype().clone(), constraints)
41    } else if let Some(vb) = array.as_opt::<VarBinVTable>() {
42        bytes_dict_builder(vb.dtype().clone(), constraints)
43    } else {
44        vortex_panic!("Can only encode primitive or varbin/view arrays")
45    };
46    dict_builder
47}
48
49pub fn dict_encode_with_constraints(
50    array: &dyn Array,
51    constraints: &DictConstraints,
52) -> VortexResult<DictArray> {
53    let mut encoder = dict_encoder(array, constraints);
54    let codes = encoder.encode(array).to_primitive().narrow()?;
55    // SAFETY: The encoding process will produce a value set of codes and values
56    unsafe {
57        Ok(DictArray::new_unchecked(
58            codes.into_array(),
59            encoder.reset(),
60        ))
61    }
62}
63
64pub fn dict_encode(array: &dyn Array) -> VortexResult<DictArray> {
65    let dict_array = dict_encode_with_constraints(array, &UNCONSTRAINED)?;
66    if dict_array.len() != array.len() {
67        vortex_bail!(
68            "must have encoded all {} elements, but only encoded {}",
69            array.len(),
70            dict_array.len(),
71        );
72    }
73    Ok(dict_array)
74}