vortex_array/arrays/dict/
array.rs1use std::fmt::Debug;
5use std::hash::Hash;
6
7use vortex_buffer::BitBuffer;
8use vortex_dtype::{DType, match_each_integer_ptype};
9use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
10use vortex_mask::{AllOr, Mask};
11
12use crate::stats::{ArrayStats, StatsSetRef};
13use crate::vtable::{ArrayVTable, NotSupported, VTable, ValidityVTable};
14use crate::{
15 Array, ArrayEq, ArrayHash, ArrayRef, EncodingId, EncodingRef, Precision, ToCanonical, vtable,
16};
17
18vtable!(Dict);
19
20impl VTable for DictVTable {
21 type Array = DictArray;
22 type Encoding = DictEncoding;
23
24 type ArrayVTable = Self;
25 type CanonicalVTable = Self;
26 type OperationsVTable = Self;
27 type ValidityVTable = Self;
28 type VisitorVTable = Self;
29 type ComputeVTable = NotSupported;
30 type EncodeVTable = Self;
31 type SerdeVTable = Self;
32 type OperatorVTable = NotSupported;
33
34 fn id(_encoding: &Self::Encoding) -> EncodingId {
35 EncodingId::new_ref("vortex.dict")
36 }
37
38 fn encoding(_array: &Self::Array) -> EncodingRef {
39 EncodingRef::new_ref(DictEncoding.as_ref())
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct DictArray {
45 codes: ArrayRef,
46 values: ArrayRef,
47 stats_set: ArrayStats,
48 dtype: DType,
49}
50
51#[derive(Clone, Debug)]
52pub struct DictEncoding;
53
54impl DictArray {
55 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
62 let dtype = values
63 .dtype()
64 .union_nullability(codes.dtype().nullability());
65 Self {
66 codes,
67 values,
68 stats_set: Default::default(),
69 dtype,
70 }
71 }
72
73 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
78 Self::try_new(codes, values).vortex_expect("DictArray new")
79 }
80
81 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
93 if !codes.dtype().is_unsigned_int() {
94 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
95 }
96
97 Ok(unsafe { Self::new_unchecked(codes, values) })
98 }
99
100 #[inline]
101 pub fn codes(&self) -> &ArrayRef {
102 &self.codes
103 }
104
105 #[inline]
106 pub fn values(&self) -> &ArrayRef {
107 &self.values
108 }
109}
110
111impl ArrayVTable<DictVTable> for DictVTable {
112 fn len(array: &DictArray) -> usize {
113 array.codes.len()
114 }
115
116 fn dtype(array: &DictArray) -> &DType {
117 &array.dtype
118 }
119
120 fn stats(array: &DictArray) -> StatsSetRef<'_> {
121 array.stats_set.to_ref(array.as_ref())
122 }
123
124 fn array_hash<H: std::hash::Hasher>(array: &DictArray, state: &mut H, precision: Precision) {
125 array.dtype.hash(state);
126 array.codes.array_hash(state, precision);
127 array.values.array_hash(state, precision);
128 }
129
130 fn array_eq(array: &DictArray, other: &DictArray, precision: Precision) -> bool {
131 array.dtype == other.dtype
132 && array.codes.array_eq(&other.codes, precision)
133 && array.values.array_eq(&other.values, precision)
134 }
135}
136
137impl ValidityVTable<DictVTable> for DictVTable {
138 fn is_valid(array: &DictArray, index: usize) -> bool {
139 let scalar = array.codes().scalar_at(index);
140
141 if scalar.is_null() {
142 return false;
143 };
144 let values_index: usize = scalar
145 .as_ref()
146 .try_into()
147 .vortex_expect("Failed to convert dictionary code to usize");
148 array.values().is_valid(values_index)
149 }
150
151 fn all_valid(array: &DictArray) -> bool {
152 array.codes().all_valid() && array.values().all_valid()
153 }
154
155 fn all_invalid(array: &DictArray) -> bool {
156 array.codes().all_invalid() || array.values().all_invalid()
157 }
158
159 fn validity_mask(array: &DictArray) -> Mask {
160 let codes_validity = array.codes().validity_mask();
161 match codes_validity.bit_buffer() {
162 AllOr::All => {
163 let primitive_codes = array.codes().to_primitive();
164 let values_mask = array.values().validity_mask();
165 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
166 let codes_slice = primitive_codes.as_slice::<P>();
167 BitBuffer::collect_bool(array.len(), |idx| {
168 #[allow(clippy::cast_possible_truncation)]
169 values_mask.value(codes_slice[idx] as usize)
170 })
171 });
172 Mask::from_buffer(is_valid_buffer)
173 }
174 AllOr::None => Mask::AllFalse(array.len()),
175 AllOr::Some(validity_buff) => {
176 let primitive_codes = array.codes().to_primitive();
177 let values_mask = array.values().validity_mask();
178 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
179 let codes_slice = primitive_codes.as_slice::<P>();
180 #[allow(clippy::cast_possible_truncation)]
181 BitBuffer::collect_bool(array.len(), |idx| {
182 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
183 })
184 });
185 Mask::from_buffer(is_valid_buffer)
186 }
187 }
188 }
189}
190
191#[cfg(test)]
192mod test {
193 #[allow(unused_imports)]
194 use itertools::Itertools;
195 use rand::distr::{Distribution, StandardUniform};
196 use rand::prelude::StdRng;
197 use rand::{Rng, SeedableRng};
198 use vortex_buffer::{BitBuffer, buffer};
199 use vortex_dtype::Nullability::NonNullable;
200 use vortex_dtype::{DType, NativePType, PType, UnsignedPType};
201 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
202 use vortex_mask::AllOr;
203
204 use crate::arrays::dict::DictArray;
205 use crate::arrays::{ChunkedArray, PrimitiveArray};
206 use crate::builders::builder_with_capacity;
207 use crate::validity::Validity;
208 use crate::{Array, ArrayRef, IntoArray, ToCanonical, assert_arrays_eq};
209
210 #[test]
211 fn nullable_codes_validity() {
212 let dict = DictArray::try_new(
213 PrimitiveArray::new(
214 buffer![0u32, 1, 2, 2, 1],
215 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
216 )
217 .into_array(),
218 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
219 )
220 .unwrap();
221 let mask = dict.validity_mask();
222 let AllOr::Some(indices) = mask.indices() else {
223 vortex_panic!("Expected indices from mask")
224 };
225 assert_eq!(indices, [0, 2, 4]);
226 }
227
228 #[test]
229 fn nullable_values_validity() {
230 let dict = DictArray::try_new(
231 buffer![0u32, 1, 2, 2, 1].into_array(),
232 PrimitiveArray::new(
233 buffer![3, 6, 9],
234 Validity::from(BitBuffer::from(vec![true, false, false])),
235 )
236 .into_array(),
237 )
238 .unwrap();
239 let mask = dict.validity_mask();
240 let AllOr::Some(indices) = mask.indices() else {
241 vortex_panic!("Expected indices from mask")
242 };
243 assert_eq!(indices, [0]);
244 }
245
246 #[test]
247 fn nullable_codes_and_values() {
248 let dict = DictArray::try_new(
249 PrimitiveArray::new(
250 buffer![0u32, 1, 2, 2, 1],
251 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
252 )
253 .into_array(),
254 PrimitiveArray::new(
255 buffer![3, 6, 9],
256 Validity::from(BitBuffer::from(vec![false, true, true])),
257 )
258 .into_array(),
259 )
260 .unwrap();
261 let mask = dict.validity_mask();
262 let AllOr::Some(indices) = mask.indices() else {
263 vortex_panic!("Expected indices from mask")
264 };
265 assert_eq!(indices, [2, 4]);
266 }
267
268 #[test]
269 fn nullable_codes_and_non_null_values() {
270 let dict = DictArray::try_new(
271 PrimitiveArray::new(
272 buffer![0u32, 1, 2, 2, 1],
273 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
274 )
275 .into_array(),
276 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
277 )
278 .unwrap();
279 let mask = dict.validity_mask();
280 let AllOr::Some(indices) = mask.indices() else {
281 vortex_panic!("Expected indices from mask")
282 };
283 assert_eq!(indices, [0, 2, 4]);
284 }
285
286 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
287 len: usize,
288 unique_values: usize,
289 chunk_count: usize,
290 ) -> ArrayRef
291 where
292 StandardUniform: Distribution<T>,
293 {
294 let mut rng = StdRng::seed_from_u64(0);
295
296 (0..chunk_count)
297 .map(|_| {
298 let values = (0..unique_values)
299 .map(|_| rng.random::<T>())
300 .collect::<PrimitiveArray>();
301 let codes = (0..len)
302 .map(|_| {
303 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
304 })
305 .collect::<PrimitiveArray>();
306
307 DictArray::try_new(codes.into_array(), values.into_array())
308 .vortex_unwrap()
309 .into_array()
310 })
311 .collect::<ChunkedArray>()
312 .into_array()
313 }
314
315 #[test]
316 fn test_dict_array_from_primitive_chunks() {
317 let len = 2;
318 let chunk_count = 2;
319 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
320
321 let mut builder = builder_with_capacity(
322 &DType::Primitive(PType::U64, NonNullable),
323 len * chunk_count,
324 );
325 array.clone().append_to_builder(builder.as_mut());
326
327 let into_prim = array.to_primitive();
328 let prim_into = builder.finish_into_canonical().into_primitive();
329
330 assert_arrays_eq!(into_prim, prim_into);
331 }
332}