1use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20 type Array = DictArray;
21 type Encoding = DictEncoding;
22
23 type ArrayVTable = Self;
24 type CanonicalVTable = Self;
25 type OperationsVTable = Self;
26 type ValidityVTable = Self;
27 type VisitorVTable = Self;
28 type ComputeVTable = NotSupported;
29 type EncodeVTable = Self;
30 type SerdeVTable = Self;
31 type PipelineVTable = NotSupported;
32
33 fn id(_encoding: &Self::Encoding) -> EncodingId {
34 EncodingId::new_ref("vortex.dict")
35 }
36
37 fn encoding(_array: &Self::Array) -> EncodingRef {
38 EncodingRef::new_ref(DictEncoding.as_ref())
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct DictArray {
44 codes: ArrayRef,
45 values: ArrayRef,
46 stats_set: ArrayStats,
47}
48
49#[derive(Clone, Debug)]
50pub struct DictEncoding;
51
52impl DictArray {
53 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
60 Self {
61 codes,
62 values,
63 stats_set: Default::default(),
64 }
65 }
66
67 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
72 Self::try_new(codes, values).vortex_expect("DictArray new")
73 }
74
75 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
87 if !codes.dtype().is_unsigned_int() {
88 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
89 }
90
91 let dtype = values.dtype();
92 if dtype.is_nullable() {
93 codes = cast(&codes, &codes.dtype().as_nullable())?;
95 } else {
96 vortex_ensure!(
98 !codes.dtype().is_nullable(),
99 "Cannot have nullable codes for non-nullable dict array"
100 );
101 }
102
103 vortex_ensure!(
104 codes.dtype().nullability() == values.dtype().nullability(),
105 "Mismatched nullability between codes and values"
106 );
107
108 Ok(Self {
109 codes,
110 values,
111 stats_set: Default::default(),
112 })
113 }
114
115 #[inline]
116 pub fn codes(&self) -> &ArrayRef {
117 &self.codes
118 }
119
120 #[inline]
121 pub fn values(&self) -> &ArrayRef {
122 &self.values
123 }
124}
125
126impl ArrayVTable<DictVTable> for DictVTable {
127 fn len(array: &DictArray) -> usize {
128 array.codes.len()
129 }
130
131 fn dtype(array: &DictArray) -> &DType {
132 array.values.dtype()
133 }
134
135 fn stats(array: &DictArray) -> StatsSetRef<'_> {
136 array.stats_set.to_ref(array.as_ref())
137 }
138}
139
140impl CanonicalVTable<DictVTable> for DictVTable {
141 fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
142 match array.dtype() {
143 DType::Utf8(_) | DType::Binary(_) => {
148 let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
149 take(&canonical_values, array.codes())?.to_canonical()
150 }
151 _ => take(array.values(), array.codes())?.to_canonical(),
152 }
153 }
154}
155
156impl ValidityVTable<DictVTable> for DictVTable {
157 fn is_valid(array: &DictArray, index: usize) -> bool {
158 let scalar = array.codes().scalar_at(index);
159
160 if scalar.is_null() {
161 return false;
162 };
163 let values_index: usize = scalar
164 .as_ref()
165 .try_into()
166 .vortex_expect("Failed to convert dictionary code to usize");
167 array.values().is_valid(values_index)
168 }
169
170 fn all_valid(array: &DictArray) -> bool {
171 array.codes().all_valid() && array.values().all_valid()
172 }
173
174 fn all_invalid(array: &DictArray) -> bool {
175 array.codes().all_invalid() || array.values().all_invalid()
176 }
177
178 fn validity_mask(array: &DictArray) -> Mask {
179 let codes_validity = array.codes().validity_mask();
180 match codes_validity.boolean_buffer() {
181 AllOr::All => {
182 let primitive_codes = array
183 .codes()
184 .to_primitive()
185 .vortex_expect("dict codes must be primitive");
186 let values_mask = array.values().validity_mask();
187 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
188 let codes_slice = primitive_codes.as_slice::<P>();
189 BooleanBuffer::collect_bool(array.len(), |idx| {
190 #[allow(clippy::cast_possible_truncation)]
191 values_mask.value(codes_slice[idx] as usize)
192 })
193 });
194 Mask::from_buffer(is_valid_buffer)
195 }
196 AllOr::None => Mask::AllFalse(array.len()),
197 AllOr::Some(validity_buff) => {
198 let primitive_codes = array
199 .codes()
200 .to_primitive()
201 .vortex_expect("dict codes must be primitive");
202 let values_mask = array.values().validity_mask();
203 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
204 let codes_slice = primitive_codes.as_slice::<P>();
205 #[allow(clippy::cast_possible_truncation)]
206 BooleanBuffer::collect_bool(array.len(), |idx| {
207 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
208 })
209 });
210 Mask::from_buffer(is_valid_buffer)
211 }
212 }
213 }
214}
215
216#[cfg(test)]
217mod test {
218 use arrow_buffer::BooleanBuffer;
219 use rand::distr::{Distribution, StandardUniform};
220 use rand::prelude::StdRng;
221 use rand::{Rng, SeedableRng};
222 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
223 use vortex_array::builders::builder_with_capacity;
224 use vortex_array::validity::Validity;
225 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
226 use vortex_buffer::buffer;
227 use vortex_dtype::Nullability::NonNullable;
228 use vortex_dtype::{DType, NativePType, PType};
229 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
230 use vortex_mask::AllOr;
231
232 use crate::DictArray;
233
234 #[test]
235 fn nullable_codes_validity() {
236 let dict = DictArray::try_new(
237 PrimitiveArray::new(
238 buffer![0u32, 1, 2, 2, 1],
239 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
240 )
241 .into_array(),
242 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
243 )
244 .unwrap();
245 let mask = dict.validity_mask();
246 let AllOr::Some(indices) = mask.indices() else {
247 vortex_panic!("Expected indices from mask")
248 };
249 assert_eq!(indices, [0, 2, 4]);
250 }
251
252 #[test]
253 fn nullable_values_validity() {
254 let dict = DictArray::try_new(
255 buffer![0u32, 1, 2, 2, 1].into_array(),
256 PrimitiveArray::new(
257 buffer![3, 6, 9],
258 Validity::from(BooleanBuffer::from(vec![true, false, false])),
259 )
260 .into_array(),
261 )
262 .unwrap();
263 let mask = dict.validity_mask();
264 let AllOr::Some(indices) = mask.indices() else {
265 vortex_panic!("Expected indices from mask")
266 };
267 assert_eq!(indices, [0]);
268 }
269
270 #[test]
271 fn nullable_codes_and_values() {
272 let dict = DictArray::try_new(
273 PrimitiveArray::new(
274 buffer![0u32, 1, 2, 2, 1],
275 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
276 )
277 .into_array(),
278 PrimitiveArray::new(
279 buffer![3, 6, 9],
280 Validity::from(BooleanBuffer::from(vec![false, true, true])),
281 )
282 .into_array(),
283 )
284 .unwrap();
285 let mask = dict.validity_mask();
286 let AllOr::Some(indices) = mask.indices() else {
287 vortex_panic!("Expected indices from mask")
288 };
289 assert_eq!(indices, [2, 4]);
290 }
291
292 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
293 len: usize,
294 unique_values: usize,
295 chunk_count: usize,
296 ) -> ArrayRef
297 where
298 StandardUniform: Distribution<T>,
299 {
300 let mut rng = StdRng::seed_from_u64(0);
301
302 (0..chunk_count)
303 .map(|_| {
304 let values = (0..unique_values)
305 .map(|_| rng.random::<T>())
306 .collect::<PrimitiveArray>();
307 let codes = (0..len)
308 .map(|_| {
309 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
310 })
311 .collect::<PrimitiveArray>();
312
313 DictArray::try_new(codes.into_array(), values.into_array())
314 .vortex_unwrap()
315 .into_array()
316 })
317 .collect::<ChunkedArray>()
318 .into_array()
319 }
320
321 #[test]
322 fn test_dict_array_from_primitive_chunks() {
323 let len = 2;
324 let chunk_count = 2;
325 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
326
327 let mut builder = builder_with_capacity(
328 &DType::Primitive(PType::U64, NonNullable),
329 len * chunk_count,
330 );
331 array
332 .clone()
333 .append_to_builder(builder.as_mut())
334 .vortex_unwrap();
335
336 let into_prim = array.to_primitive().unwrap();
337 let prim_into = builder.finish().to_primitive().unwrap();
338
339 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
340 assert_eq!(
341 into_prim.validity_mask().boolean_buffer(),
342 prim_into.validity_mask().boolean_buffer()
343 )
344 }
345}