vortex_array/arrays/dict/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use num_traits::AsPrimitive;
8use smallvec::smallvec;
9use vortex_buffer::BitBuffer;
10use vortex_error::VortexExpect;
11use vortex_error::VortexResult;
12use vortex_error::vortex_bail;
13use vortex_error::vortex_ensure;
14use vortex_mask::AllOr;
15
16use crate::ArrayRef;
17use crate::ArraySlots;
18use crate::LEGACY_SESSION;
19#[expect(deprecated)]
20use crate::ToCanonical as _;
21use crate::VortexSessionExecute;
22use crate::array::Array;
23use crate::array::ArrayParts;
24use crate::array::TypedArrayRef;
25use crate::array_slots;
26use crate::arrays::Dict;
27use crate::dtype::DType;
28use crate::dtype::PType;
29use crate::match_each_integer_ptype;
30
31#[derive(Clone, prost::Message)]
32pub struct DictMetadata {
33 #[prost(uint32, tag = "1")]
34 pub(super) values_len: u32,
35 #[prost(enumeration = "PType", tag = "2")]
36 pub(super) codes_ptype: i32,
37 #[prost(optional, bool, tag = "3")]
39 pub(super) is_nullable_codes: Option<bool>,
40 #[prost(optional, bool, tag = "4")]
44 pub(super) all_values_referenced: Option<bool>,
45}
46
47#[array_slots(Dict)]
48pub struct DictSlots {
49 pub codes: ArrayRef,
51 pub values: ArrayRef,
53}
54
55#[derive(Debug, Clone)]
56pub struct DictData {
57 pub(super) all_values_referenced: bool,
63}
64
65impl Display for DictData {
66 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67 write!(f, "all_values_referenced: {}", self.all_values_referenced)
68 }
69}
70
71impl DictData {
72 pub unsafe fn new_unchecked() -> Self {
79 Self {
80 all_values_referenced: false,
81 }
82 }
83
84 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
94 self.all_values_referenced = all_values_referenced;
95 self
96 }
97
98 pub fn new(codes_dtype: &DType) -> Self {
103 Self::try_new(codes_dtype).vortex_expect("DictArray new")
104 }
105
106 pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
118 if !codes_dtype.is_int() {
119 vortex_bail!(MismatchedTypes: "int", codes_dtype);
120 }
121
122 Ok(unsafe { Self::new_unchecked() })
123 }
124}
125
126pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
127 #[inline]
128 fn has_all_values_referenced(&self) -> bool {
129 self.all_values_referenced
130 }
131
132 fn validate_all_values_referenced(&self) -> VortexResult<()> {
133 if self.has_all_values_referenced() {
134 if !self.codes().is_host() {
135 return Ok(());
136 }
137
138 let referenced_mask = self.compute_referenced_values_mask(true)?;
139 let all_referenced = referenced_mask.true_count() == referenced_mask.len();
140
141 vortex_ensure!(all_referenced, "value in dict not referenced");
142 }
143
144 Ok(())
145 }
146
147 fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
148 let codes = self.codes();
149 let codes_validity = codes
150 .validity()?
151 .execute_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
152 #[expect(deprecated)]
153 let codes_primitive = self.codes().to_primitive();
154 let values_len = self.values().len();
155
156 let init_value = !referenced;
157 let referenced_value = referenced;
158
159 let mut values_vec = vec![init_value; values_len];
160 match codes_validity.bit_buffer() {
161 AllOr::All => {
162 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
163 for idx in codes_primitive.as_slice::<P>() {
164 let idxu: usize = idx.as_();
165 values_vec[idxu] = referenced_value;
166 }
167 });
168 }
169 AllOr::None => {}
170 AllOr::Some(mask) => {
171 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
172 let codes = codes_primitive.as_slice::<P>();
173 mask.set_indices().for_each(|idx| {
174 let idxu: usize = codes[idx].as_();
175 values_vec[idxu] = referenced_value;
176 });
177 });
178 }
179 }
180
181 Ok(BitBuffer::from(values_vec))
182 }
183}
184impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
185
186pub struct DictParts {
188 pub dtype: DType,
189 pub codes: ArrayRef,
190 pub values: ArrayRef,
191}
192
193pub trait DictOwnedExt {
194 fn into_parts(self) -> DictParts;
195}
196
197impl DictOwnedExt for Array<Dict> {
198 fn into_parts(self) -> DictParts {
199 match self.try_into_parts() {
200 Ok(array_parts) => {
201 let slots = DictSlots::from_slots(array_parts.slots);
202 DictParts {
203 dtype: array_parts.dtype,
204 codes: slots.codes,
205 values: slots.values,
206 }
207 }
208 Err(array) => {
209 let slots = DictSlotsView::from_slots(array.slots());
210 DictParts {
211 dtype: array.dtype().clone(),
212 codes: slots.codes.clone(),
213 values: slots.values.clone(),
214 }
215 }
216 }
217 }
218}
219
220impl Array<Dict> {
221 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
223 Self::try_new(codes, values).vortex_expect("DictArray new")
224 }
225
226 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
228 let dtype = values
229 .dtype()
230 .union_nullability(codes.dtype().nullability());
231 let len = codes.len();
232 let data = DictData::try_new(codes.dtype())?;
233 Array::try_from_parts(
234 ArrayParts::new(Dict, dtype, len, data)
235 .with_slots(smallvec![Some(codes), Some(values)]),
236 )
237 }
238
239 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
245 let dtype = values
246 .dtype()
247 .union_nullability(codes.dtype().nullability());
248 let len = codes.len();
249 let data = unsafe { DictData::new_unchecked() };
250 unsafe {
251 Array::from_parts_unchecked(
252 ArrayParts::new(Dict, dtype, len, data)
253 .with_slots(smallvec![Some(codes), Some(values)]),
254 )
255 }
256 }
257
258 pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
264 let dtype = self.dtype().clone();
265 let len = self.len();
266 let slots: ArraySlots = self.slots().iter().cloned().collect();
267 let data = unsafe {
268 self.into_data()
269 .set_all_values_referenced(all_values_referenced)
270 };
271 let array = unsafe {
272 Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
273 };
274
275 #[cfg(debug_assertions)]
276 if all_values_referenced {
277 array
278 .validate_all_values_referenced()
279 .vortex_expect("validation should succeed when all values are referenced");
280 }
281
282 array
283 }
284}
285
286#[cfg(test)]
287mod test {
288 use rand::RngExt;
289 use rand::SeedableRng;
290 use rand::distr::Distribution;
291 use rand::distr::StandardUniform;
292 use rand::prelude::StdRng;
293 use vortex_buffer::BitBuffer;
294 use vortex_buffer::buffer;
295 use vortex_error::VortexExpect;
296 use vortex_error::VortexResult;
297 use vortex_error::vortex_panic;
298 use vortex_mask::AllOr;
299
300 use crate::ArrayRef;
301 use crate::IntoArray;
302 #[expect(deprecated)]
303 use crate::ToCanonical as _;
304 use crate::VortexSessionExecute;
305 use crate::array_session;
306 use crate::arrays::ChunkedArray;
307 use crate::arrays::DictArray;
308 use crate::arrays::PrimitiveArray;
309 use crate::assert_arrays_eq;
310 use crate::builders::builder_with_capacity;
311 use crate::dtype::DType;
312 use crate::dtype::NativePType;
313 use crate::dtype::Nullability::NonNullable;
314 use crate::dtype::PType;
315 use crate::dtype::UnsignedPType;
316 use crate::validity::Validity;
317
318 #[test]
319 fn nullable_codes_validity() {
320 let dict = DictArray::try_new(
321 PrimitiveArray::new(
322 buffer![0u32, 1, 2, 2, 1],
323 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
324 )
325 .into_array(),
326 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
327 )
328 .unwrap();
329 let mask = dict
330 .as_ref()
331 .validity()
332 .unwrap()
333 .execute_mask(
334 dict.as_ref().len(),
335 &mut array_session().create_execution_ctx(),
336 )
337 .unwrap();
338 let AllOr::Some(indices) = mask.indices() else {
339 vortex_panic!("Expected indices from mask")
340 };
341 assert_eq!(indices, [0, 2, 4]);
342 }
343
344 #[test]
345 fn nullable_values_validity() {
346 let dict = DictArray::try_new(
347 buffer![0u32, 1, 2, 2, 1].into_array(),
348 PrimitiveArray::new(
349 buffer![3, 6, 9],
350 Validity::from(BitBuffer::from(vec![true, false, false])),
351 )
352 .into_array(),
353 )
354 .unwrap();
355 let mask = dict
356 .as_ref()
357 .validity()
358 .unwrap()
359 .execute_mask(
360 dict.as_ref().len(),
361 &mut array_session().create_execution_ctx(),
362 )
363 .unwrap();
364 let AllOr::Some(indices) = mask.indices() else {
365 vortex_panic!("Expected indices from mask")
366 };
367 assert_eq!(indices, [0]);
368 }
369
370 #[test]
371 fn nullable_codes_and_values() {
372 let dict = DictArray::try_new(
373 PrimitiveArray::new(
374 buffer![0u32, 1, 2, 2, 1],
375 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
376 )
377 .into_array(),
378 PrimitiveArray::new(
379 buffer![3, 6, 9],
380 Validity::from(BitBuffer::from(vec![false, true, true])),
381 )
382 .into_array(),
383 )
384 .unwrap();
385 let mask = dict
386 .as_ref()
387 .validity()
388 .unwrap()
389 .execute_mask(
390 dict.as_ref().len(),
391 &mut array_session().create_execution_ctx(),
392 )
393 .unwrap();
394 let AllOr::Some(indices) = mask.indices() else {
395 vortex_panic!("Expected indices from mask")
396 };
397 assert_eq!(indices, [2, 4]);
398 }
399
400 #[test]
401 fn nullable_codes_and_non_null_values() {
402 let dict = DictArray::try_new(
403 PrimitiveArray::new(
404 buffer![0u32, 1, 2, 2, 1],
405 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
406 )
407 .into_array(),
408 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
409 )
410 .unwrap();
411 let mask = dict
412 .as_ref()
413 .validity()
414 .unwrap()
415 .execute_mask(
416 dict.as_ref().len(),
417 &mut array_session().create_execution_ctx(),
418 )
419 .unwrap();
420 let AllOr::Some(indices) = mask.indices() else {
421 vortex_panic!("Expected indices from mask")
422 };
423 assert_eq!(indices, [0, 2, 4]);
424 }
425
426 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
427 len: usize,
428 unique_values: usize,
429 chunk_count: usize,
430 ) -> ArrayRef
431 where
432 StandardUniform: Distribution<T>,
433 {
434 let mut rng = StdRng::seed_from_u64(0);
435
436 (0..chunk_count)
437 .map(|_| {
438 let values = (0..unique_values)
439 .map(|_| rng.random::<T>())
440 .collect::<PrimitiveArray>();
441 let codes = (0..len)
442 .map(|_| {
443 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
444 })
445 .collect::<PrimitiveArray>();
446
447 DictArray::try_new(codes.into_array(), values.into_array())
448 .vortex_expect("DictArray creation should succeed in arbitrary impl")
449 .into_array()
450 })
451 .collect::<ChunkedArray>()
452 .into_array()
453 }
454
455 #[test]
456 fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
457 let mut ctx = array_session().create_execution_ctx();
458 let len = 2;
459 let chunk_count = 2;
460 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
461
462 let mut builder = builder_with_capacity(
463 &DType::Primitive(PType::U64, NonNullable),
464 len * chunk_count,
465 );
466 array.append_to_builder(
467 builder.as_mut(),
468 &mut array_session().create_execution_ctx(),
469 )?;
470
471 #[expect(deprecated)]
472 let into_prim = array.to_primitive();
473 let prim_into = builder.finish_into_canonical().into_primitive();
474
475 assert_arrays_eq!(into_prim, prim_into, &mut ctx);
476 Ok(())
477 }
478
479 #[cfg_attr(miri, ignore)]
480 #[test]
481 fn test_dict_metadata() {
482 use prost::Message;
483
484 use super::DictMetadata;
485 use crate::test_harness::check_metadata;
486
487 check_metadata(
488 "dict.metadata",
489 &DictMetadata {
490 codes_ptype: PType::U64 as i32,
491 values_len: u32::MAX,
492 is_nullable_codes: None,
493 all_values_referenced: None,
494 }
495 .encode_to_vec(),
496 );
497 }
498}