vortex_array/arrays/dict/
array.rs1use vortex_buffer::BitBuffer;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::ToCanonical;
17use crate::stats::ArrayStats;
18
19#[derive(Clone, prost::Message)]
20pub struct DictMetadata {
21 #[prost(uint32, tag = "1")]
22 pub(super) values_len: u32,
23 #[prost(enumeration = "PType", tag = "2")]
24 pub(super) codes_ptype: i32,
25 #[prost(optional, bool, tag = "3")]
27 pub(super) is_nullable_codes: Option<bool>,
28 #[prost(optional, bool, tag = "4")]
32 pub(super) all_values_referenced: Option<bool>,
33}
34
35#[derive(Debug, Clone)]
36pub struct DictArray {
37 pub(super) codes: ArrayRef,
38 pub(super) values: ArrayRef,
39 pub(super) stats_set: ArrayStats,
40 pub(super) dtype: DType,
41 pub(super) all_values_referenced: bool,
47}
48
49impl DictArray {
50 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
57 let dtype = values
58 .dtype()
59 .union_nullability(codes.dtype().nullability());
60 Self {
61 codes,
62 values,
63 stats_set: Default::default(),
64 dtype,
65 all_values_referenced: false,
66 }
67 }
68
69 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
79 self.all_values_referenced = all_values_referenced;
80
81 #[cfg(debug_assertions)]
82 {
83 use vortex_error::VortexExpect;
84 self.validate_all_values_referenced()
85 .vortex_expect("validation should succeed when all values are referenced")
86 }
87
88 self
89 }
90
91 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
96 Self::try_new(codes, values).vortex_expect("DictArray new")
97 }
98
99 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
111 if !codes.dtype().is_unsigned_int() {
112 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
113 }
114
115 Ok(unsafe { Self::new_unchecked(codes, values) })
116 }
117
118 pub fn into_parts(self) -> (ArrayRef, ArrayRef) {
119 (self.codes, self.values)
120 }
121
122 #[inline]
123 pub fn codes(&self) -> &ArrayRef {
124 &self.codes
125 }
126
127 #[inline]
128 pub fn values(&self) -> &ArrayRef {
129 &self.values
130 }
131
132 #[inline]
138 pub fn has_all_values_referenced(&self) -> bool {
139 self.all_values_referenced
140 }
141
142 pub fn validate_all_values_referenced(&self) -> VortexResult<()> {
149 if self.all_values_referenced {
150 let referenced_mask = self.compute_referenced_values_mask(true)?;
151 let all_referenced = referenced_mask.iter().all(|v| v);
152
153 vortex_ensure!(all_referenced, "value in dict not referenced");
154 }
155
156 Ok(())
157 }
158
159 pub fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
170 let codes_validity = self.codes().validity_mask();
171 let codes_primitive = self.codes().to_primitive();
172 let values_len = self.values().len();
173
174 let init_value = !referenced;
176 let referenced_value = referenced;
178
179 let mut values_vec = vec![init_value; values_len];
180 match codes_validity.bit_buffer() {
181 AllOr::All => {
182 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
183 #[allow(clippy::cast_possible_truncation)]
184 for &code in codes_primitive.as_slice::<P>().iter() {
185 values_vec[code as usize] = referenced_value;
186 }
187 });
188 }
189 AllOr::None => {}
190 AllOr::Some(buf) => {
191 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
192 let codes = codes_primitive.as_slice::<P>();
193
194 #[allow(clippy::cast_possible_truncation)]
195 buf.set_indices().for_each(|idx| {
196 values_vec[codes[idx] as usize] = referenced_value;
197 })
198 });
199 }
200 }
201
202 Ok(BitBuffer::collect_bool(values_len, |idx| values_vec[idx]))
203 }
204}
205
206#[cfg(test)]
207mod test {
208 #[allow(unused_imports)]
209 use itertools::Itertools;
210 use rand::Rng;
211 use rand::SeedableRng;
212 use rand::distr::Distribution;
213 use rand::distr::StandardUniform;
214 use rand::prelude::StdRng;
215 use vortex_buffer::BitBuffer;
216 use vortex_buffer::buffer;
217 use vortex_dtype::DType;
218 use vortex_dtype::NativePType;
219 use vortex_dtype::Nullability::NonNullable;
220 use vortex_dtype::PType;
221 use vortex_dtype::UnsignedPType;
222 use vortex_error::VortexExpect;
223 use vortex_error::vortex_panic;
224 use vortex_mask::AllOr;
225
226 use crate::Array;
227 use crate::ArrayRef;
228 use crate::IntoArray;
229 use crate::ToCanonical;
230 use crate::arrays::ChunkedArray;
231 use crate::arrays::PrimitiveArray;
232 use crate::arrays::dict::DictArray;
233 use crate::assert_arrays_eq;
234 use crate::builders::builder_with_capacity;
235 use crate::validity::Validity;
236
237 #[test]
238 fn nullable_codes_validity() {
239 let dict = DictArray::try_new(
240 PrimitiveArray::new(
241 buffer![0u32, 1, 2, 2, 1],
242 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
243 )
244 .into_array(),
245 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
246 )
247 .unwrap();
248 let mask = dict.validity_mask();
249 let AllOr::Some(indices) = mask.indices() else {
250 vortex_panic!("Expected indices from mask")
251 };
252 assert_eq!(indices, [0, 2, 4]);
253 }
254
255 #[test]
256 fn nullable_values_validity() {
257 let dict = DictArray::try_new(
258 buffer![0u32, 1, 2, 2, 1].into_array(),
259 PrimitiveArray::new(
260 buffer![3, 6, 9],
261 Validity::from(BitBuffer::from(vec![true, false, false])),
262 )
263 .into_array(),
264 )
265 .unwrap();
266 let mask = dict.validity_mask();
267 let AllOr::Some(indices) = mask.indices() else {
268 vortex_panic!("Expected indices from mask")
269 };
270 assert_eq!(indices, [0]);
271 }
272
273 #[test]
274 fn nullable_codes_and_values() {
275 let dict = DictArray::try_new(
276 PrimitiveArray::new(
277 buffer![0u32, 1, 2, 2, 1],
278 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
279 )
280 .into_array(),
281 PrimitiveArray::new(
282 buffer![3, 6, 9],
283 Validity::from(BitBuffer::from(vec![false, true, true])),
284 )
285 .into_array(),
286 )
287 .unwrap();
288 let mask = dict.validity_mask();
289 let AllOr::Some(indices) = mask.indices() else {
290 vortex_panic!("Expected indices from mask")
291 };
292 assert_eq!(indices, [2, 4]);
293 }
294
295 #[test]
296 fn nullable_codes_and_non_null_values() {
297 let dict = DictArray::try_new(
298 PrimitiveArray::new(
299 buffer![0u32, 1, 2, 2, 1],
300 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
301 )
302 .into_array(),
303 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
304 )
305 .unwrap();
306 let mask = dict.validity_mask();
307 let AllOr::Some(indices) = mask.indices() else {
308 vortex_panic!("Expected indices from mask")
309 };
310 assert_eq!(indices, [0, 2, 4]);
311 }
312
313 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
314 len: usize,
315 unique_values: usize,
316 chunk_count: usize,
317 ) -> ArrayRef
318 where
319 StandardUniform: Distribution<T>,
320 {
321 let mut rng = StdRng::seed_from_u64(0);
322
323 (0..chunk_count)
324 .map(|_| {
325 let values = (0..unique_values)
326 .map(|_| rng.random::<T>())
327 .collect::<PrimitiveArray>();
328 let codes = (0..len)
329 .map(|_| {
330 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
331 })
332 .collect::<PrimitiveArray>();
333
334 DictArray::try_new(codes.into_array(), values.into_array())
335 .vortex_expect("DictArray creation should succeed in arbitrary impl")
336 .into_array()
337 })
338 .collect::<ChunkedArray>()
339 .into_array()
340 }
341
342 #[test]
343 fn test_dict_array_from_primitive_chunks() {
344 let len = 2;
345 let chunk_count = 2;
346 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
347
348 let mut builder = builder_with_capacity(
349 &DType::Primitive(PType::U64, NonNullable),
350 len * chunk_count,
351 );
352 array.clone().append_to_builder(builder.as_mut());
353
354 let into_prim = array.to_primitive();
355 let prim_into = builder.finish_into_canonical().into_primitive();
356
357 assert_arrays_eq!(into_prim, prim_into);
358 }
359
360 #[cfg_attr(miri, ignore)]
361 #[test]
362 fn test_dict_metadata() {
363 use super::DictMetadata;
364 use crate::ProstMetadata;
365 use crate::test_harness::check_metadata;
366
367 check_metadata(
368 "dict.metadata",
369 ProstMetadata(DictMetadata {
370 codes_ptype: PType::U64 as i32,
371 values_len: u32::MAX,
372 is_nullable_codes: None,
373 all_values_referenced: None,
374 }),
375 );
376 }
377}