1use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20 type Array = DictArray;
21 type Encoding = DictEncoding;
22
23 type ArrayVTable = Self;
24 type CanonicalVTable = Self;
25 type OperationsVTable = Self;
26 type ValidityVTable = Self;
27 type VisitorVTable = Self;
28 type ComputeVTable = NotSupported;
29 type EncodeVTable = Self;
30 type SerdeVTable = Self;
31
32 fn id(_encoding: &Self::Encoding) -> EncodingId {
33 EncodingId::new_ref("vortex.dict")
34 }
35
36 fn encoding(_array: &Self::Array) -> EncodingRef {
37 EncodingRef::new_ref(DictEncoding.as_ref())
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct DictArray {
43 codes: ArrayRef,
44 values: ArrayRef,
45 stats_set: ArrayStats,
46}
47
48#[derive(Clone, Debug)]
49pub struct DictEncoding;
50
51impl DictArray {
52 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
59 Self {
60 codes,
61 values,
62 stats_set: Default::default(),
63 }
64 }
65
66 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
71 Self::try_new(codes, values).vortex_expect("DictArray new")
72 }
73
74 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
86 if !codes.dtype().is_unsigned_int() {
87 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
88 }
89
90 let dtype = values.dtype();
91 if dtype.is_nullable() {
92 codes = cast(&codes, &codes.dtype().as_nullable())?;
94 } else {
95 vortex_ensure!(
97 !codes.dtype().is_nullable(),
98 "Cannot have nullable codes for non-nullable dict array"
99 );
100 }
101
102 vortex_ensure!(
103 codes.dtype().nullability() == values.dtype().nullability(),
104 "Mismatched nullability between codes and values"
105 );
106
107 Ok(Self {
108 codes,
109 values,
110 stats_set: Default::default(),
111 })
112 }
113
114 #[inline]
115 pub fn codes(&self) -> &ArrayRef {
116 &self.codes
117 }
118
119 #[inline]
120 pub fn values(&self) -> &ArrayRef {
121 &self.values
122 }
123}
124
125impl ArrayVTable<DictVTable> for DictVTable {
126 fn len(array: &DictArray) -> usize {
127 array.codes.len()
128 }
129
130 fn dtype(array: &DictArray) -> &DType {
131 array.values.dtype()
132 }
133
134 fn stats(array: &DictArray) -> StatsSetRef<'_> {
135 array.stats_set.to_ref(array.as_ref())
136 }
137}
138
139impl CanonicalVTable<DictVTable> for DictVTable {
140 fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
141 match array.dtype() {
142 DType::Utf8(_) | DType::Binary(_) => {
147 let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
148 take(&canonical_values, array.codes())?.to_canonical()
149 }
150 _ => take(array.values(), array.codes())?.to_canonical(),
151 }
152 }
153}
154
155impl ValidityVTable<DictVTable> for DictVTable {
156 fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
157 let scalar = array.codes().scalar_at(index);
158
159 if scalar.is_null() {
160 return Ok(false);
161 };
162 let values_index: usize = scalar
163 .as_ref()
164 .try_into()
165 .vortex_expect("Failed to convert dictionary code to usize");
166 array.values().is_valid(values_index)
167 }
168
169 fn all_valid(array: &DictArray) -> VortexResult<bool> {
170 Ok(array.codes().all_valid()? && array.values().all_valid()?)
171 }
172
173 fn all_invalid(array: &DictArray) -> VortexResult<bool> {
174 Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
175 }
176
177 fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
178 let codes_validity = array.codes().validity_mask()?;
179 match codes_validity.boolean_buffer() {
180 AllOr::All => {
181 let primitive_codes = array.codes().to_primitive()?;
182 let values_mask = array.values().validity_mask()?;
183 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
184 let codes_slice = primitive_codes.as_slice::<P>();
185 BooleanBuffer::collect_bool(array.len(), |idx| {
186 #[allow(clippy::cast_possible_truncation)]
187 values_mask.value(codes_slice[idx] as usize)
188 })
189 });
190 Ok(Mask::from_buffer(is_valid_buffer))
191 }
192 AllOr::None => Ok(Mask::AllFalse(array.len())),
193 AllOr::Some(validity_buff) => {
194 let primitive_codes = array.codes().to_primitive()?;
195 let values_mask = array.values().validity_mask()?;
196 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
197 let codes_slice = primitive_codes.as_slice::<P>();
198 #[allow(clippy::cast_possible_truncation)]
199 BooleanBuffer::collect_bool(array.len(), |idx| {
200 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
201 })
202 });
203 Ok(Mask::from_buffer(is_valid_buffer))
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod test {
211 use arrow_buffer::BooleanBuffer;
212 use rand::distr::{Distribution, StandardUniform};
213 use rand::prelude::StdRng;
214 use rand::{Rng, SeedableRng};
215 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
216 use vortex_array::builders::builder_with_capacity;
217 use vortex_array::validity::Validity;
218 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
219 use vortex_buffer::buffer;
220 use vortex_dtype::Nullability::NonNullable;
221 use vortex_dtype::{DType, NativePType, PType};
222 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
223 use vortex_mask::AllOr;
224
225 use crate::DictArray;
226
227 #[test]
228 fn nullable_codes_validity() {
229 let dict = DictArray::try_new(
230 PrimitiveArray::new(
231 buffer![0u32, 1, 2, 2, 1],
232 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
233 )
234 .into_array(),
235 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
236 )
237 .unwrap();
238 let mask = dict.validity_mask().unwrap();
239 let AllOr::Some(indices) = mask.indices() else {
240 vortex_panic!("Expected indices from mask")
241 };
242 assert_eq!(indices, [0, 2, 4]);
243 }
244
245 #[test]
246 fn nullable_values_validity() {
247 let dict = DictArray::try_new(
248 buffer![0u32, 1, 2, 2, 1].into_array(),
249 PrimitiveArray::new(
250 buffer![3, 6, 9],
251 Validity::from(BooleanBuffer::from(vec![true, false, false])),
252 )
253 .into_array(),
254 )
255 .unwrap();
256 let mask = dict.validity_mask().unwrap();
257 let AllOr::Some(indices) = mask.indices() else {
258 vortex_panic!("Expected indices from mask")
259 };
260 assert_eq!(indices, [0]);
261 }
262
263 #[test]
264 fn nullable_codes_and_values() {
265 let dict = DictArray::try_new(
266 PrimitiveArray::new(
267 buffer![0u32, 1, 2, 2, 1],
268 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
269 )
270 .into_array(),
271 PrimitiveArray::new(
272 buffer![3, 6, 9],
273 Validity::from(BooleanBuffer::from(vec![false, true, true])),
274 )
275 .into_array(),
276 )
277 .unwrap();
278 let mask = dict.validity_mask().unwrap();
279 let AllOr::Some(indices) = mask.indices() else {
280 vortex_panic!("Expected indices from mask")
281 };
282 assert_eq!(indices, [2, 4]);
283 }
284
285 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
286 len: usize,
287 unique_values: usize,
288 chunk_count: usize,
289 ) -> ArrayRef
290 where
291 StandardUniform: Distribution<T>,
292 {
293 let mut rng = StdRng::seed_from_u64(0);
294
295 (0..chunk_count)
296 .map(|_| {
297 let values = (0..unique_values)
298 .map(|_| rng.random::<T>())
299 .collect::<PrimitiveArray>();
300 let codes = (0..len)
301 .map(|_| {
302 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
303 })
304 .collect::<PrimitiveArray>();
305
306 DictArray::try_new(codes.into_array(), values.into_array())
307 .vortex_unwrap()
308 .into_array()
309 })
310 .collect::<ChunkedArray>()
311 .into_array()
312 }
313
314 #[test]
315 fn test_dict_array_from_primitive_chunks() {
316 let len = 2;
317 let chunk_count = 2;
318 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
319
320 let mut builder = builder_with_capacity(
321 &DType::Primitive(PType::U64, NonNullable),
322 len * chunk_count,
323 );
324 array
325 .clone()
326 .append_to_builder(builder.as_mut())
327 .vortex_unwrap();
328
329 let into_prim = array.to_primitive().unwrap();
330 let prim_into = builder.finish().to_primitive().unwrap();
331
332 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
333 assert_eq!(
334 into_prim.validity_mask().unwrap().boolean_buffer(),
335 prim_into.validity_mask().unwrap().boolean_buffer()
336 )
337 }
338}