1use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::take;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20 type Array = DictArray;
21 type Encoding = DictEncoding;
22
23 type ArrayVTable = Self;
24 type CanonicalVTable = Self;
25 type OperationsVTable = Self;
26 type ValidityVTable = Self;
27 type VisitorVTable = Self;
28 type ComputeVTable = NotSupported;
29 type EncodeVTable = Self;
30 type SerdeVTable = Self;
31 type PipelineVTable = NotSupported;
32
33 fn id(_encoding: &Self::Encoding) -> EncodingId {
34 EncodingId::new_ref("vortex.dict")
35 }
36
37 fn encoding(_array: &Self::Array) -> EncodingRef {
38 EncodingRef::new_ref(DictEncoding.as_ref())
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct DictArray {
44 codes: ArrayRef,
45 values: ArrayRef,
46 stats_set: ArrayStats,
47}
48
49#[derive(Clone, Debug)]
50pub struct DictEncoding;
51
52impl DictArray {
53 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
60 Self {
61 codes,
62 values,
63 stats_set: Default::default(),
64 }
65 }
66
67 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
72 Self::try_new(codes, values).vortex_expect("DictArray new")
73 }
74
75 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
87 if !codes.dtype().is_unsigned_int() {
88 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
89 }
90
91 Ok(Self {
92 codes,
93 values,
94 stats_set: Default::default(),
95 })
96 }
97
98 #[inline]
99 pub fn codes(&self) -> &ArrayRef {
100 &self.codes
101 }
102
103 #[inline]
104 pub fn values(&self) -> &ArrayRef {
105 &self.values
106 }
107}
108
109impl ArrayVTable<DictVTable> for DictVTable {
110 fn len(array: &DictArray) -> usize {
111 array.codes.len()
112 }
113
114 fn dtype(array: &DictArray) -> &DType {
115 array.values.dtype()
116 }
117
118 fn stats(array: &DictArray) -> StatsSetRef<'_> {
119 array.stats_set.to_ref(array.as_ref())
120 }
121}
122
123impl CanonicalVTable<DictVTable> for DictVTable {
124 fn canonicalize(array: &DictArray) -> Canonical {
125 match array.dtype() {
126 DType::Utf8(_) | DType::Binary(_) => {
131 let canonical_values: ArrayRef = array.values().to_canonical().into_array();
132 take(&canonical_values, array.codes())
133 .vortex_expect("taking codes from dictionary values shouldn't fail")
134 .to_canonical()
135 }
136 _ => take(array.values(), array.codes())
137 .vortex_expect("taking codes from dictionary values shouldn't fail")
138 .to_canonical(),
139 }
140 }
141}
142
143impl ValidityVTable<DictVTable> for DictVTable {
144 fn is_valid(array: &DictArray, index: usize) -> bool {
145 let scalar = array.codes().scalar_at(index);
146
147 if scalar.is_null() {
148 return false;
149 };
150 let values_index: usize = scalar
151 .as_ref()
152 .try_into()
153 .vortex_expect("Failed to convert dictionary code to usize");
154 array.values().is_valid(values_index)
155 }
156
157 fn all_valid(array: &DictArray) -> bool {
158 array.codes().all_valid() && array.values().all_valid()
159 }
160
161 fn all_invalid(array: &DictArray) -> bool {
162 array.codes().all_invalid() || array.values().all_invalid()
163 }
164
165 fn validity_mask(array: &DictArray) -> Mask {
166 let codes_validity = array.codes().validity_mask();
167 match codes_validity.boolean_buffer() {
168 AllOr::All => {
169 let primitive_codes = array.codes().to_primitive();
170 let values_mask = array.values().validity_mask();
171 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
172 let codes_slice = primitive_codes.as_slice::<P>();
173 BooleanBuffer::collect_bool(array.len(), |idx| {
174 #[allow(clippy::cast_possible_truncation)]
175 values_mask.value(codes_slice[idx] as usize)
176 })
177 });
178 Mask::from_buffer(is_valid_buffer)
179 }
180 AllOr::None => Mask::AllFalse(array.len()),
181 AllOr::Some(validity_buff) => {
182 let primitive_codes = array.codes().to_primitive();
183 let values_mask = array.values().validity_mask();
184 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
185 let codes_slice = primitive_codes.as_slice::<P>();
186 #[allow(clippy::cast_possible_truncation)]
187 BooleanBuffer::collect_bool(array.len(), |idx| {
188 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
189 })
190 });
191 Mask::from_buffer(is_valid_buffer)
192 }
193 }
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use arrow_buffer::BooleanBuffer;
200 use rand::distr::{Distribution, StandardUniform};
201 use rand::prelude::StdRng;
202 use rand::{Rng, SeedableRng};
203 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
204 use vortex_array::builders::builder_with_capacity;
205 use vortex_array::validity::Validity;
206 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
207 use vortex_buffer::buffer;
208 use vortex_dtype::Nullability::NonNullable;
209 use vortex_dtype::{DType, NativePType, PType};
210 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
211 use vortex_mask::AllOr;
212
213 use crate::DictArray;
214
215 #[test]
216 fn nullable_codes_validity() {
217 let dict = DictArray::try_new(
218 PrimitiveArray::new(
219 buffer![0u32, 1, 2, 2, 1],
220 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
221 )
222 .into_array(),
223 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
224 )
225 .unwrap();
226 let mask = dict.validity_mask();
227 let AllOr::Some(indices) = mask.indices() else {
228 vortex_panic!("Expected indices from mask")
229 };
230 assert_eq!(indices, [0, 2, 4]);
231 }
232
233 #[test]
234 fn nullable_values_validity() {
235 let dict = DictArray::try_new(
236 buffer![0u32, 1, 2, 2, 1].into_array(),
237 PrimitiveArray::new(
238 buffer![3, 6, 9],
239 Validity::from(BooleanBuffer::from(vec![true, false, false])),
240 )
241 .into_array(),
242 )
243 .unwrap();
244 let mask = dict.validity_mask();
245 let AllOr::Some(indices) = mask.indices() else {
246 vortex_panic!("Expected indices from mask")
247 };
248 assert_eq!(indices, [0]);
249 }
250
251 #[test]
252 fn nullable_codes_and_values() {
253 let dict = DictArray::try_new(
254 PrimitiveArray::new(
255 buffer![0u32, 1, 2, 2, 1],
256 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
257 )
258 .into_array(),
259 PrimitiveArray::new(
260 buffer![3, 6, 9],
261 Validity::from(BooleanBuffer::from(vec![false, true, true])),
262 )
263 .into_array(),
264 )
265 .unwrap();
266 let mask = dict.validity_mask();
267 let AllOr::Some(indices) = mask.indices() else {
268 vortex_panic!("Expected indices from mask")
269 };
270 assert_eq!(indices, [2, 4]);
271 }
272
273 #[test]
274 fn nullable_codes_and_non_null_values() {
275 let dict = DictArray::try_new(
276 PrimitiveArray::new(
277 buffer![0u32, 1, 2, 2, 1],
278 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
279 )
280 .into_array(),
281 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
282 )
283 .unwrap();
284 let mask = dict.validity_mask();
285 let AllOr::Some(indices) = mask.indices() else {
286 vortex_panic!("Expected indices from mask")
287 };
288 assert_eq!(indices, [0, 2, 4]);
289 }
290
291 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
292 len: usize,
293 unique_values: usize,
294 chunk_count: usize,
295 ) -> ArrayRef
296 where
297 StandardUniform: Distribution<T>,
298 {
299 let mut rng = StdRng::seed_from_u64(0);
300
301 (0..chunk_count)
302 .map(|_| {
303 let values = (0..unique_values)
304 .map(|_| rng.random::<T>())
305 .collect::<PrimitiveArray>();
306 let codes = (0..len)
307 .map(|_| {
308 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
309 })
310 .collect::<PrimitiveArray>();
311
312 DictArray::try_new(codes.into_array(), values.into_array())
313 .vortex_unwrap()
314 .into_array()
315 })
316 .collect::<ChunkedArray>()
317 .into_array()
318 }
319
320 #[test]
321 fn test_dict_array_from_primitive_chunks() {
322 let len = 2;
323 let chunk_count = 2;
324 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
325
326 let mut builder = builder_with_capacity(
327 &DType::Primitive(PType::U64, NonNullable),
328 len * chunk_count,
329 );
330 array.clone().append_to_builder(builder.as_mut());
331
332 let into_prim = array.to_primitive();
333 let prim_into = builder.finish_into_canonical().into_primitive();
334
335 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
336 assert_eq!(
337 into_prim.validity_mask().boolean_buffer(),
338 prim_into.validity_mask().boolean_buffer()
339 )
340 }
341}