1use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20 type Array = DictArray;
21 type Encoding = DictEncoding;
22
23 type ArrayVTable = Self;
24 type CanonicalVTable = Self;
25 type OperationsVTable = Self;
26 type ValidityVTable = Self;
27 type VisitorVTable = Self;
28 type ComputeVTable = NotSupported;
29 type EncodeVTable = Self;
30 type SerdeVTable = Self;
31 type PipelineVTable = NotSupported;
32
33 fn id(_encoding: &Self::Encoding) -> EncodingId {
34 EncodingId::new_ref("vortex.dict")
35 }
36
37 fn encoding(_array: &Self::Array) -> EncodingRef {
38 EncodingRef::new_ref(DictEncoding.as_ref())
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct DictArray {
44 codes: ArrayRef,
45 values: ArrayRef,
46 stats_set: ArrayStats,
47}
48
49#[derive(Clone, Debug)]
50pub struct DictEncoding;
51
52impl DictArray {
53 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
60 Self {
61 codes,
62 values,
63 stats_set: Default::default(),
64 }
65 }
66
67 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
72 Self::try_new(codes, values).vortex_expect("DictArray new")
73 }
74
75 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
87 if !codes.dtype().is_unsigned_int() {
88 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
89 }
90
91 let dtype = values.dtype();
92 if dtype.is_nullable() {
93 codes = cast(&codes, &codes.dtype().as_nullable())?;
95 } else {
96 vortex_ensure!(
98 !codes.dtype().is_nullable(),
99 "Cannot have nullable codes for non-nullable dict array"
100 );
101 }
102
103 vortex_ensure!(
104 codes.dtype().nullability() == values.dtype().nullability(),
105 "Mismatched nullability between codes and values"
106 );
107
108 Ok(Self {
109 codes,
110 values,
111 stats_set: Default::default(),
112 })
113 }
114
115 #[inline]
116 pub fn codes(&self) -> &ArrayRef {
117 &self.codes
118 }
119
120 #[inline]
121 pub fn values(&self) -> &ArrayRef {
122 &self.values
123 }
124}
125
126impl ArrayVTable<DictVTable> for DictVTable {
127 fn len(array: &DictArray) -> usize {
128 array.codes.len()
129 }
130
131 fn dtype(array: &DictArray) -> &DType {
132 array.values.dtype()
133 }
134
135 fn stats(array: &DictArray) -> StatsSetRef<'_> {
136 array.stats_set.to_ref(array.as_ref())
137 }
138}
139
140impl CanonicalVTable<DictVTable> for DictVTable {
141 fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
142 match array.dtype() {
143 DType::Utf8(_) | DType::Binary(_) => {
148 let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
149 take(&canonical_values, array.codes())?.to_canonical()
150 }
151 _ => take(array.values(), array.codes())?.to_canonical(),
152 }
153 }
154}
155
156impl ValidityVTable<DictVTable> for DictVTable {
157 fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
158 let scalar = array.codes().scalar_at(index);
159
160 if scalar.is_null() {
161 return Ok(false);
162 };
163 let values_index: usize = scalar
164 .as_ref()
165 .try_into()
166 .vortex_expect("Failed to convert dictionary code to usize");
167 array.values().is_valid(values_index)
168 }
169
170 fn all_valid(array: &DictArray) -> VortexResult<bool> {
171 Ok(array.codes().all_valid()? && array.values().all_valid()?)
172 }
173
174 fn all_invalid(array: &DictArray) -> VortexResult<bool> {
175 Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
176 }
177
178 fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
179 let codes_validity = array.codes().validity_mask()?;
180 match codes_validity.boolean_buffer() {
181 AllOr::All => {
182 let primitive_codes = array.codes().to_primitive()?;
183 let values_mask = array.values().validity_mask()?;
184 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
185 let codes_slice = primitive_codes.as_slice::<P>();
186 BooleanBuffer::collect_bool(array.len(), |idx| {
187 #[allow(clippy::cast_possible_truncation)]
188 values_mask.value(codes_slice[idx] as usize)
189 })
190 });
191 Ok(Mask::from_buffer(is_valid_buffer))
192 }
193 AllOr::None => Ok(Mask::AllFalse(array.len())),
194 AllOr::Some(validity_buff) => {
195 let primitive_codes = array.codes().to_primitive()?;
196 let values_mask = array.values().validity_mask()?;
197 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
198 let codes_slice = primitive_codes.as_slice::<P>();
199 #[allow(clippy::cast_possible_truncation)]
200 BooleanBuffer::collect_bool(array.len(), |idx| {
201 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
202 })
203 });
204 Ok(Mask::from_buffer(is_valid_buffer))
205 }
206 }
207 }
208}
209
210#[cfg(test)]
211mod test {
212 use arrow_buffer::BooleanBuffer;
213 use rand::distr::{Distribution, StandardUniform};
214 use rand::prelude::StdRng;
215 use rand::{Rng, SeedableRng};
216 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
217 use vortex_array::builders::builder_with_capacity;
218 use vortex_array::validity::Validity;
219 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
220 use vortex_buffer::buffer;
221 use vortex_dtype::Nullability::NonNullable;
222 use vortex_dtype::{DType, NativePType, PType};
223 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
224 use vortex_mask::AllOr;
225
226 use crate::DictArray;
227
228 #[test]
229 fn nullable_codes_validity() {
230 let dict = DictArray::try_new(
231 PrimitiveArray::new(
232 buffer![0u32, 1, 2, 2, 1],
233 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
234 )
235 .into_array(),
236 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
237 )
238 .unwrap();
239 let mask = dict.validity_mask().unwrap();
240 let AllOr::Some(indices) = mask.indices() else {
241 vortex_panic!("Expected indices from mask")
242 };
243 assert_eq!(indices, [0, 2, 4]);
244 }
245
246 #[test]
247 fn nullable_values_validity() {
248 let dict = DictArray::try_new(
249 buffer![0u32, 1, 2, 2, 1].into_array(),
250 PrimitiveArray::new(
251 buffer![3, 6, 9],
252 Validity::from(BooleanBuffer::from(vec![true, false, false])),
253 )
254 .into_array(),
255 )
256 .unwrap();
257 let mask = dict.validity_mask().unwrap();
258 let AllOr::Some(indices) = mask.indices() else {
259 vortex_panic!("Expected indices from mask")
260 };
261 assert_eq!(indices, [0]);
262 }
263
264 #[test]
265 fn nullable_codes_and_values() {
266 let dict = DictArray::try_new(
267 PrimitiveArray::new(
268 buffer![0u32, 1, 2, 2, 1],
269 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
270 )
271 .into_array(),
272 PrimitiveArray::new(
273 buffer![3, 6, 9],
274 Validity::from(BooleanBuffer::from(vec![false, true, true])),
275 )
276 .into_array(),
277 )
278 .unwrap();
279 let mask = dict.validity_mask().unwrap();
280 let AllOr::Some(indices) = mask.indices() else {
281 vortex_panic!("Expected indices from mask")
282 };
283 assert_eq!(indices, [2, 4]);
284 }
285
286 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
287 len: usize,
288 unique_values: usize,
289 chunk_count: usize,
290 ) -> ArrayRef
291 where
292 StandardUniform: Distribution<T>,
293 {
294 let mut rng = StdRng::seed_from_u64(0);
295
296 (0..chunk_count)
297 .map(|_| {
298 let values = (0..unique_values)
299 .map(|_| rng.random::<T>())
300 .collect::<PrimitiveArray>();
301 let codes = (0..len)
302 .map(|_| {
303 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
304 })
305 .collect::<PrimitiveArray>();
306
307 DictArray::try_new(codes.into_array(), values.into_array())
308 .vortex_unwrap()
309 .into_array()
310 })
311 .collect::<ChunkedArray>()
312 .into_array()
313 }
314
315 #[test]
316 fn test_dict_array_from_primitive_chunks() {
317 let len = 2;
318 let chunk_count = 2;
319 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
320
321 let mut builder = builder_with_capacity(
322 &DType::Primitive(PType::U64, NonNullable),
323 len * chunk_count,
324 );
325 array
326 .clone()
327 .append_to_builder(builder.as_mut())
328 .vortex_unwrap();
329
330 let into_prim = array.to_primitive().unwrap();
331 let prim_into = builder.finish().to_primitive().unwrap();
332
333 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
334 assert_eq!(
335 into_prim.validity_mask().unwrap().boolean_buffer(),
336 prim_into.validity_mask().unwrap().boolean_buffer()
337 )
338 }
339}