vortex_array/arrays/dict/vtable/
mod.rs1use std::hash::Hash;
5use std::sync::Arc;
6
7use kernel::PARENT_KERNELS;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13use vortex_error::vortex_panic;
14use vortex_session::VortexSession;
15
16use super::DictArray;
17use super::DictArrayParts;
18use super::DictMetadata;
19use super::take_canonical;
20use crate::AnyCanonical;
21use crate::ArrayRef;
22use crate::Canonical;
23use crate::DeserializeMetadata;
24use crate::DynArray;
25use crate::IntoArray;
26use crate::Precision;
27use crate::ProstMetadata;
28use crate::SerializeMetadata;
29use crate::arrays::ConstantArray;
30use crate::arrays::Primitive;
31use crate::arrays::dict::compute::rules::PARENT_RULES;
32use crate::buffer::BufferHandle;
33use crate::dtype::DType;
34use crate::dtype::Nullability;
35use crate::dtype::PType;
36use crate::executor::ExecutionCtx;
37use crate::executor::ExecutionResult;
38use crate::hash::ArrayEq;
39use crate::hash::ArrayHash;
40use crate::require_child;
41use crate::scalar::Scalar;
42use crate::serde::ArrayChildren;
43use crate::stats::StatsSetRef;
44use crate::vtable;
45use crate::vtable::Array;
46use crate::vtable::ArrayId;
47use crate::vtable::VTable;
48mod kernel;
49mod operations;
50mod validity;
51
52vtable!(Dict);
53
54#[derive(Clone, Debug)]
55pub struct Dict;
56
57impl Dict {
58 pub const ID: ArrayId = ArrayId::new_ref("vortex.dict");
59}
60
61impl VTable for Dict {
62 type Array = DictArray;
63
64 type Metadata = ProstMetadata<DictMetadata>;
65 type OperationsVTable = Self;
66 type ValidityVTable = Self;
67
68 fn vtable(_array: &Self::Array) -> &Self {
69 &Dict
70 }
71
72 fn id(&self) -> ArrayId {
73 Self::ID
74 }
75
76 fn len(array: &DictArray) -> usize {
77 array.codes.len()
78 }
79
80 fn dtype(array: &DictArray) -> &DType {
81 &array.dtype
82 }
83
84 fn stats(array: &DictArray) -> StatsSetRef<'_> {
85 array.stats_set.to_ref(array.as_ref())
86 }
87
88 fn array_hash<H: std::hash::Hasher>(array: &DictArray, state: &mut H, precision: Precision) {
89 array.dtype.hash(state);
90 array.codes.array_hash(state, precision);
91 array.values.array_hash(state, precision);
92 }
93
94 fn array_eq(array: &DictArray, other: &DictArray, precision: Precision) -> bool {
95 array.dtype == other.dtype
96 && array.codes.array_eq(&other.codes, precision)
97 && array.values.array_eq(&other.values, precision)
98 }
99
100 fn nbuffers(_array: &DictArray) -> usize {
101 0
102 }
103
104 fn buffer(_array: &DictArray, idx: usize) -> BufferHandle {
105 vortex_panic!("DictArray buffer index {idx} out of bounds")
106 }
107
108 fn buffer_name(_array: &DictArray, _idx: usize) -> Option<String> {
109 None
110 }
111
112 fn nchildren(_array: &DictArray) -> usize {
113 2
114 }
115
116 fn child(array: &DictArray, idx: usize) -> ArrayRef {
117 match idx {
118 0 => array.codes().clone(),
119 1 => array.values().clone(),
120 _ => vortex_panic!("DictArray child index {idx} out of bounds"),
121 }
122 }
123
124 fn child_name(_array: &DictArray, idx: usize) -> String {
125 match idx {
126 0 => "codes".to_string(),
127 1 => "values".to_string(),
128 _ => vortex_panic!("DictArray child_name index {idx} out of bounds"),
129 }
130 }
131
132 fn metadata(array: &DictArray) -> VortexResult<Self::Metadata> {
133 Ok(ProstMetadata(DictMetadata {
134 codes_ptype: PType::try_from(array.codes().dtype())? as i32,
135 values_len: u32::try_from(array.values().len()).map_err(|_| {
136 vortex_err!(
137 "Dictionary values size {} overflowed u32",
138 array.values().len()
139 )
140 })?,
141 is_nullable_codes: Some(array.codes().dtype().is_nullable()),
142 all_values_referenced: Some(array.all_values_referenced),
143 }))
144 }
145
146 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
147 Ok(Some(metadata.serialize()))
148 }
149
150 fn deserialize(
151 bytes: &[u8],
152 _dtype: &DType,
153 _len: usize,
154 _buffers: &[BufferHandle],
155 _session: &VortexSession,
156 ) -> VortexResult<Self::Metadata> {
157 let metadata = <Self::Metadata as DeserializeMetadata>::deserialize(bytes)?;
158 Ok(ProstMetadata(metadata))
159 }
160
161 fn build(
162 dtype: &DType,
163 len: usize,
164 metadata: &Self::Metadata,
165 _buffers: &[BufferHandle],
166 children: &dyn ArrayChildren,
167 ) -> VortexResult<DictArray> {
168 if children.len() != 2 {
169 vortex_bail!(
170 "Expected 2 children for dict encoding, found {}",
171 children.len()
172 )
173 }
174 let codes_nullable = metadata
175 .is_nullable_codes
176 .map(Nullability::from)
177 .unwrap_or_else(|| dtype.nullability());
180 let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable);
181 let codes = children.get(0, &codes_dtype, len)?;
182 let values = children.get(1, dtype, metadata.values_len as usize)?;
183 let all_values_referenced = metadata.all_values_referenced.unwrap_or(false);
184
185 Ok(unsafe {
187 DictArray::new_unchecked(codes, values).set_all_values_referenced(all_values_referenced)
188 })
189 }
190
191 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
192 vortex_ensure!(
193 children.len() == 2,
194 "DictArray expects exactly 2 children (codes, values), got {}",
195 children.len()
196 );
197 let [codes, values]: [ArrayRef; 2] = children
198 .try_into()
199 .map_err(|_| vortex_err!("Failed to convert children to array"))?;
200 array.codes = codes;
201 array.values = values;
202 Ok(())
203 }
204
205 fn execute(array: Arc<Array<Self>>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
206 if array.is_empty() {
207 let result_dtype = array
208 .dtype()
209 .union_nullability(array.codes().dtype().nullability());
210 return Ok(ExecutionResult::done(Canonical::empty(&result_dtype)));
211 }
212
213 let array = require_child!(array, array.codes(), 0 => Primitive);
214
215 if array.codes().all_invalid()? {
219 return Ok(ExecutionResult::done(
220 ConstantArray::new(Scalar::null(array.dtype().as_nullable()), array.codes.len())
221 .into_array(),
222 ));
223 }
224
225 let array = require_child!(array, array.values(), 1 => AnyCanonical);
226
227 let DictArrayParts { codes, values, .. } =
228 Arc::unwrap_or_clone(array).into_inner().into_parts();
229
230 let codes = codes
231 .try_into::<Primitive>()
232 .ok()
233 .vortex_expect("must be primitive");
234 debug_assert!(values.is_canonical());
235 let values = values.to_canonical()?;
237
238 Ok(ExecutionResult::done(
239 take_canonical(values, &codes, ctx)?.into_array(),
240 ))
241 }
242
243 fn reduce_parent(
244 array: &Array<Self>,
245 parent: &ArrayRef,
246 child_idx: usize,
247 ) -> VortexResult<Option<ArrayRef>> {
248 PARENT_RULES.evaluate(array, parent, child_idx)
249 }
250
251 fn execute_parent(
252 array: &Array<Self>,
253 parent: &ArrayRef,
254 child_idx: usize,
255 ctx: &mut ExecutionCtx,
256 ) -> VortexResult<Option<ArrayRef>> {
257 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
258 }
259}