vortex_dict/builders/
mod.rs1use bytes::bytes_dict_builder;
5use primitive::primitive_dict_builder;
6use vortex_array::arrays::{PrimitiveVTable, VarBinVTable, VarBinViewVTable};
7use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
8use vortex_dtype::match_each_native_ptype;
9use vortex_error::{VortexResult, vortex_bail};
10
11use crate::DictArray;
12
13mod bytes;
14mod primitive;
15
16#[derive(Clone)]
17pub struct DictConstraints {
18 pub max_bytes: usize,
19 pub max_len: usize,
20}
21
22pub const UNCONSTRAINED: DictConstraints = DictConstraints {
23 max_bytes: usize::MAX,
24 max_len: usize::MAX,
25};
26
27pub trait DictEncoder: Send {
28 fn encode(&mut self, array: &dyn Array) -> VortexResult<ArrayRef>;
30
31 fn values(&mut self) -> VortexResult<ArrayRef>;
32}
33
34pub fn dict_encoder(
35 array: &dyn Array,
36 constraints: &DictConstraints,
37) -> VortexResult<Box<dyn DictEncoder>> {
38 let dict_builder: Box<dyn DictEncoder> = if let Some(pa) = array.as_opt::<PrimitiveVTable>() {
39 match_each_native_ptype!(pa.ptype(), |P| {
40 primitive_dict_builder::<P>(pa.dtype().nullability(), constraints)
41 })
42 } else if let Some(vbv) = array.as_opt::<VarBinViewVTable>() {
43 bytes_dict_builder(vbv.dtype().clone(), constraints)
44 } else if let Some(vb) = array.as_opt::<VarBinVTable>() {
45 bytes_dict_builder(vb.dtype().clone(), constraints)
46 } else {
47 vortex_bail!("Can only encode primitive or varbin/view arrays")
48 };
49 Ok(dict_builder)
50}
51
52pub fn dict_encode_with_constraints(
53 array: &dyn Array,
54 constraints: &DictConstraints,
55) -> VortexResult<DictArray> {
56 let mut encoder = dict_encoder(array, constraints)?;
57 let codes = encoder.encode(array)?.to_primitive()?.downcast()?;
58 unsafe {
60 Ok(DictArray::new_unchecked(
61 codes.into_array(),
62 encoder.values()?,
63 ))
64 }
65}
66
67pub fn dict_encode(array: &dyn Array) -> VortexResult<DictArray> {
68 let dict_array = dict_encode_with_constraints(array, &UNCONSTRAINED)?;
69 if dict_array.len() != array.len() {
70 vortex_bail!(
71 "must have encoded all {} elements, but only encoded {}",
72 array.len(),
73 dict_array.len(),
74 );
75 }
76 Ok(dict_array)
77}