1use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::stats::{ArrayStats, StatsSetRef};
8use vortex_array::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
9use vortex_array::{Array, ArrayRef, EncodingId, EncodingRef, ToCanonical, vtable};
10use vortex_dtype::{DType, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13
14vtable!(Dict);
15
16impl VTable for DictVTable {
17 type Array = DictArray;
18 type Encoding = DictEncoding;
19
20 type ArrayVTable = Self;
21 type CanonicalVTable = Self;
22 type OperationsVTable = Self;
23 type ValidityVTable = Self;
24 type VisitorVTable = Self;
25 type ComputeVTable = NotSupported;
26 type EncodeVTable = Self;
27 type SerdeVTable = Self;
28 type PipelineVTable = NotSupported;
29
30 fn id(_encoding: &Self::Encoding) -> EncodingId {
31 EncodingId::new_ref("vortex.dict")
32 }
33
34 fn encoding(_array: &Self::Array) -> EncodingRef {
35 EncodingRef::new_ref(DictEncoding.as_ref())
36 }
37}
38
39#[derive(Debug, Clone)]
40pub struct DictArray {
41 codes: ArrayRef,
42 values: ArrayRef,
43 stats_set: ArrayStats,
44 dtype: DType,
45}
46
47#[derive(Clone, Debug)]
48pub struct DictEncoding;
49
50impl DictArray {
51 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
58 let dtype = values
59 .dtype()
60 .union_nullability(codes.dtype().nullability());
61 Self {
62 codes,
63 values,
64 stats_set: Default::default(),
65 dtype,
66 }
67 }
68
69 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
74 Self::try_new(codes, values).vortex_expect("DictArray new")
75 }
76
77 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
89 if !codes.dtype().is_unsigned_int() {
90 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
91 }
92
93 Ok(unsafe { Self::new_unchecked(codes, values) })
94 }
95
96 #[inline]
97 pub fn codes(&self) -> &ArrayRef {
98 &self.codes
99 }
100
101 #[inline]
102 pub fn values(&self) -> &ArrayRef {
103 &self.values
104 }
105}
106
107impl ArrayVTable<DictVTable> for DictVTable {
108 fn len(array: &DictArray) -> usize {
109 array.codes.len()
110 }
111
112 fn dtype(array: &DictArray) -> &DType {
113 &array.dtype
114 }
115
116 fn stats(array: &DictArray) -> StatsSetRef<'_> {
117 array.stats_set.to_ref(array.as_ref())
118 }
119}
120
121impl ValidityVTable<DictVTable> for DictVTable {
122 fn is_valid(array: &DictArray, index: usize) -> bool {
123 let scalar = array.codes().scalar_at(index);
124
125 if scalar.is_null() {
126 return false;
127 };
128 let values_index: usize = scalar
129 .as_ref()
130 .try_into()
131 .vortex_expect("Failed to convert dictionary code to usize");
132 array.values().is_valid(values_index)
133 }
134
135 fn all_valid(array: &DictArray) -> bool {
136 array.codes().all_valid() && array.values().all_valid()
137 }
138
139 fn all_invalid(array: &DictArray) -> bool {
140 array.codes().all_invalid() || array.values().all_invalid()
141 }
142
143 fn validity_mask(array: &DictArray) -> Mask {
144 let codes_validity = array.codes().validity_mask();
145 match codes_validity.boolean_buffer() {
146 AllOr::All => {
147 let primitive_codes = array.codes().to_primitive();
148 let values_mask = array.values().validity_mask();
149 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
150 let codes_slice = primitive_codes.as_slice::<P>();
151 BooleanBuffer::collect_bool(array.len(), |idx| {
152 #[allow(clippy::cast_possible_truncation)]
153 values_mask.value(codes_slice[idx] as usize)
154 })
155 });
156 Mask::from_buffer(is_valid_buffer)
157 }
158 AllOr::None => Mask::AllFalse(array.len()),
159 AllOr::Some(validity_buff) => {
160 let primitive_codes = array.codes().to_primitive();
161 let values_mask = array.values().validity_mask();
162 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
163 let codes_slice = primitive_codes.as_slice::<P>();
164 #[allow(clippy::cast_possible_truncation)]
165 BooleanBuffer::collect_bool(array.len(), |idx| {
166 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
167 })
168 });
169 Mask::from_buffer(is_valid_buffer)
170 }
171 }
172 }
173}
174
175#[cfg(test)]
176mod test {
177 use arrow_buffer::BooleanBuffer;
178 use rand::distr::{Distribution, StandardUniform};
179 use rand::prelude::StdRng;
180 use rand::{Rng, SeedableRng};
181 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
182 use vortex_array::builders::builder_with_capacity;
183 use vortex_array::validity::Validity;
184 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
185 use vortex_buffer::buffer;
186 use vortex_dtype::Nullability::NonNullable;
187 use vortex_dtype::{DType, NativePType, PType, UnsignedPType};
188 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
189 use vortex_mask::AllOr;
190
191 use crate::DictArray;
192
193 #[test]
194 fn nullable_codes_validity() {
195 let dict = DictArray::try_new(
196 PrimitiveArray::new(
197 buffer![0u32, 1, 2, 2, 1],
198 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
199 )
200 .into_array(),
201 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
202 )
203 .unwrap();
204 let mask = dict.validity_mask();
205 let AllOr::Some(indices) = mask.indices() else {
206 vortex_panic!("Expected indices from mask")
207 };
208 assert_eq!(indices, [0, 2, 4]);
209 }
210
211 #[test]
212 fn nullable_values_validity() {
213 let dict = DictArray::try_new(
214 buffer![0u32, 1, 2, 2, 1].into_array(),
215 PrimitiveArray::new(
216 buffer![3, 6, 9],
217 Validity::from(BooleanBuffer::from(vec![true, false, false])),
218 )
219 .into_array(),
220 )
221 .unwrap();
222 let mask = dict.validity_mask();
223 let AllOr::Some(indices) = mask.indices() else {
224 vortex_panic!("Expected indices from mask")
225 };
226 assert_eq!(indices, [0]);
227 }
228
229 #[test]
230 fn nullable_codes_and_values() {
231 let dict = DictArray::try_new(
232 PrimitiveArray::new(
233 buffer![0u32, 1, 2, 2, 1],
234 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
235 )
236 .into_array(),
237 PrimitiveArray::new(
238 buffer![3, 6, 9],
239 Validity::from(BooleanBuffer::from(vec![false, true, true])),
240 )
241 .into_array(),
242 )
243 .unwrap();
244 let mask = dict.validity_mask();
245 let AllOr::Some(indices) = mask.indices() else {
246 vortex_panic!("Expected indices from mask")
247 };
248 assert_eq!(indices, [2, 4]);
249 }
250
251 #[test]
252 fn nullable_codes_and_non_null_values() {
253 let dict = DictArray::try_new(
254 PrimitiveArray::new(
255 buffer![0u32, 1, 2, 2, 1],
256 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
257 )
258 .into_array(),
259 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
260 )
261 .unwrap();
262 let mask = dict.validity_mask();
263 let AllOr::Some(indices) = mask.indices() else {
264 vortex_panic!("Expected indices from mask")
265 };
266 assert_eq!(indices, [0, 2, 4]);
267 }
268
269 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
270 len: usize,
271 unique_values: usize,
272 chunk_count: usize,
273 ) -> ArrayRef
274 where
275 StandardUniform: Distribution<T>,
276 {
277 let mut rng = StdRng::seed_from_u64(0);
278
279 (0..chunk_count)
280 .map(|_| {
281 let values = (0..unique_values)
282 .map(|_| rng.random::<T>())
283 .collect::<PrimitiveArray>();
284 let codes = (0..len)
285 .map(|_| {
286 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
287 })
288 .collect::<PrimitiveArray>();
289
290 DictArray::try_new(codes.into_array(), values.into_array())
291 .vortex_unwrap()
292 .into_array()
293 })
294 .collect::<ChunkedArray>()
295 .into_array()
296 }
297
298 #[test]
299 fn test_dict_array_from_primitive_chunks() {
300 let len = 2;
301 let chunk_count = 2;
302 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
303
304 let mut builder = builder_with_capacity(
305 &DType::Primitive(PType::U64, NonNullable),
306 len * chunk_count,
307 );
308 array.clone().append_to_builder(builder.as_mut());
309
310 let into_prim = array.to_primitive();
311 let prim_into = builder.finish_into_canonical().into_primitive();
312
313 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
314 assert_eq!(
315 into_prim.validity_mask().boolean_buffer(),
316 prim_into.validity_mask().boolean_buffer()
317 )
318 }
319}