vortex_array/arrays/dict/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::LEGACY_SESSION;
16#[expect(deprecated)]
17use crate::ToCanonical as _;
18use crate::VortexSessionExecute;
19use crate::array::Array;
20use crate::array::ArrayParts;
21use crate::array::TypedArrayRef;
22use crate::array_slots;
23use crate::arrays::Dict;
24use crate::dtype::DType;
25use crate::dtype::PType;
26use crate::match_each_integer_ptype;
27
28#[derive(Clone, prost::Message)]
29pub struct DictMetadata {
30 #[prost(uint32, tag = "1")]
31 pub(super) values_len: u32,
32 #[prost(enumeration = "PType", tag = "2")]
33 pub(super) codes_ptype: i32,
34 #[prost(optional, bool, tag = "3")]
36 pub(super) is_nullable_codes: Option<bool>,
37 #[prost(optional, bool, tag = "4")]
41 pub(super) all_values_referenced: Option<bool>,
42}
43
44#[array_slots(Dict)]
45pub struct DictSlots {
46 pub codes: ArrayRef,
48 pub values: ArrayRef,
50}
51
52#[derive(Debug, Clone)]
53pub struct DictData {
54 pub(super) all_values_referenced: bool,
60}
61
62impl Display for DictData {
63 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
64 write!(f, "all_values_referenced: {}", self.all_values_referenced)
65 }
66}
67
68impl DictData {
69 pub unsafe fn new_unchecked() -> Self {
76 Self {
77 all_values_referenced: false,
78 }
79 }
80
81 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
91 self.all_values_referenced = all_values_referenced;
92 self
93 }
94
95 pub fn new(codes_dtype: &DType) -> Self {
100 Self::try_new(codes_dtype).vortex_expect("DictArray new")
101 }
102
103 pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
115 if !codes_dtype.is_int() {
116 vortex_bail!(MismatchedTypes: "int", codes_dtype);
117 }
118
119 Ok(unsafe { Self::new_unchecked() })
120 }
121}
122
123pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
124 #[inline]
125 fn has_all_values_referenced(&self) -> bool {
126 self.all_values_referenced
127 }
128
129 fn validate_all_values_referenced(&self) -> VortexResult<()> {
130 if self.has_all_values_referenced() {
131 if !self.codes().is_host() {
132 return Ok(());
133 }
134
135 let referenced_mask = self.compute_referenced_values_mask(true)?;
136 let all_referenced = referenced_mask.iter().all(|v| v);
137
138 vortex_ensure!(all_referenced, "value in dict not referenced");
139 }
140
141 Ok(())
142 }
143
144 fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
145 let codes = self.codes();
146 let codes_validity = codes
147 .validity()?
148 .execute_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
149 #[expect(deprecated)]
150 let codes_primitive = self.codes().to_primitive();
151 let values_len = self.values().len();
152
153 let init_value = !referenced;
154 let referenced_value = referenced;
155
156 let mut values_vec = vec![init_value; values_len];
157 match codes_validity.bit_buffer() {
158 AllOr::All => {
159 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
160 #[allow(
161 clippy::cast_possible_truncation,
162 clippy::cast_sign_loss,
163 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
164 )]
165 for &idx in codes_primitive.as_slice::<P>() {
166 values_vec[idx as usize] = referenced_value;
167 }
168 });
169 }
170 AllOr::None => {}
171 AllOr::Some(mask) => {
172 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
173 let codes = codes_primitive.as_slice::<P>();
174
175 #[allow(
176 clippy::cast_possible_truncation,
177 clippy::cast_sign_loss,
178 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
179 )]
180 mask.set_indices().for_each(|idx| {
181 values_vec[codes[idx] as usize] = referenced_value;
182 });
183 });
184 }
185 }
186
187 Ok(BitBuffer::from(values_vec))
188 }
189}
190impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
191
192pub struct DictParts {
194 pub dtype: DType,
195 pub codes: ArrayRef,
196 pub values: ArrayRef,
197}
198
199pub trait DictOwnedExt {
200 fn into_parts(self) -> DictParts;
201}
202
203impl DictOwnedExt for Array<Dict> {
204 fn into_parts(self) -> DictParts {
205 match self.try_into_parts() {
206 Ok(array_parts) => {
207 let slots = DictSlots::from_slots(array_parts.slots);
208 DictParts {
209 dtype: array_parts.dtype,
210 codes: slots.codes,
211 values: slots.values,
212 }
213 }
214 Err(array) => {
215 let slots = DictSlotsView::from_slots(array.slots());
216 DictParts {
217 dtype: array.dtype().clone(),
218 codes: slots.codes.clone(),
219 values: slots.values.clone(),
220 }
221 }
222 }
223 }
224}
225
226impl Array<Dict> {
227 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
229 Self::try_new(codes, values).vortex_expect("DictArray new")
230 }
231
232 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
234 let dtype = values
235 .dtype()
236 .union_nullability(codes.dtype().nullability());
237 let len = codes.len();
238 let data = DictData::try_new(codes.dtype())?;
239 Array::try_from_parts(
240 ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
241 )
242 }
243
244 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
250 let dtype = values
251 .dtype()
252 .union_nullability(codes.dtype().nullability());
253 let len = codes.len();
254 let data = unsafe { DictData::new_unchecked() };
255 unsafe {
256 Array::from_parts_unchecked(
257 ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
258 )
259 }
260 }
261
262 pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
268 let dtype = self.dtype().clone();
269 let len = self.len();
270 let slots = self.slots().to_vec();
271 let data = unsafe {
272 self.into_data()
273 .set_all_values_referenced(all_values_referenced)
274 };
275 let array = unsafe {
276 Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
277 };
278
279 #[cfg(debug_assertions)]
280 if all_values_referenced {
281 array
282 .validate_all_values_referenced()
283 .vortex_expect("validation should succeed when all values are referenced");
284 }
285
286 array
287 }
288}
289
290#[cfg(test)]
291mod test {
292 use rand::RngExt;
293 use rand::SeedableRng;
294 use rand::distr::Distribution;
295 use rand::distr::StandardUniform;
296 use rand::prelude::StdRng;
297 use vortex_buffer::BitBuffer;
298 use vortex_buffer::buffer;
299 use vortex_error::VortexExpect;
300 use vortex_error::VortexResult;
301 use vortex_error::vortex_panic;
302 use vortex_mask::AllOr;
303
304 use crate::ArrayRef;
305 use crate::IntoArray;
306 use crate::LEGACY_SESSION;
307 #[expect(deprecated)]
308 use crate::ToCanonical as _;
309 use crate::VortexSessionExecute;
310 use crate::arrays::ChunkedArray;
311 use crate::arrays::DictArray;
312 use crate::arrays::PrimitiveArray;
313 use crate::assert_arrays_eq;
314 use crate::builders::builder_with_capacity;
315 use crate::dtype::DType;
316 use crate::dtype::NativePType;
317 use crate::dtype::Nullability::NonNullable;
318 use crate::dtype::PType;
319 use crate::dtype::UnsignedPType;
320 use crate::validity::Validity;
321
322 #[test]
323 fn nullable_codes_validity() {
324 let dict = DictArray::try_new(
325 PrimitiveArray::new(
326 buffer![0u32, 1, 2, 2, 1],
327 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
328 )
329 .into_array(),
330 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
331 )
332 .unwrap();
333 let mask = dict
334 .as_ref()
335 .validity()
336 .unwrap()
337 .execute_mask(
338 dict.as_ref().len(),
339 &mut LEGACY_SESSION.create_execution_ctx(),
340 )
341 .unwrap();
342 let AllOr::Some(indices) = mask.indices() else {
343 vortex_panic!("Expected indices from mask")
344 };
345 assert_eq!(indices, [0, 2, 4]);
346 }
347
348 #[test]
349 fn nullable_values_validity() {
350 let dict = DictArray::try_new(
351 buffer![0u32, 1, 2, 2, 1].into_array(),
352 PrimitiveArray::new(
353 buffer![3, 6, 9],
354 Validity::from(BitBuffer::from(vec![true, false, false])),
355 )
356 .into_array(),
357 )
358 .unwrap();
359 let mask = dict
360 .as_ref()
361 .validity()
362 .unwrap()
363 .execute_mask(
364 dict.as_ref().len(),
365 &mut LEGACY_SESSION.create_execution_ctx(),
366 )
367 .unwrap();
368 let AllOr::Some(indices) = mask.indices() else {
369 vortex_panic!("Expected indices from mask")
370 };
371 assert_eq!(indices, [0]);
372 }
373
374 #[test]
375 fn nullable_codes_and_values() {
376 let dict = DictArray::try_new(
377 PrimitiveArray::new(
378 buffer![0u32, 1, 2, 2, 1],
379 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
380 )
381 .into_array(),
382 PrimitiveArray::new(
383 buffer![3, 6, 9],
384 Validity::from(BitBuffer::from(vec![false, true, true])),
385 )
386 .into_array(),
387 )
388 .unwrap();
389 let mask = dict
390 .as_ref()
391 .validity()
392 .unwrap()
393 .execute_mask(
394 dict.as_ref().len(),
395 &mut LEGACY_SESSION.create_execution_ctx(),
396 )
397 .unwrap();
398 let AllOr::Some(indices) = mask.indices() else {
399 vortex_panic!("Expected indices from mask")
400 };
401 assert_eq!(indices, [2, 4]);
402 }
403
404 #[test]
405 fn nullable_codes_and_non_null_values() {
406 let dict = DictArray::try_new(
407 PrimitiveArray::new(
408 buffer![0u32, 1, 2, 2, 1],
409 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
410 )
411 .into_array(),
412 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
413 )
414 .unwrap();
415 let mask = dict
416 .as_ref()
417 .validity()
418 .unwrap()
419 .execute_mask(
420 dict.as_ref().len(),
421 &mut LEGACY_SESSION.create_execution_ctx(),
422 )
423 .unwrap();
424 let AllOr::Some(indices) = mask.indices() else {
425 vortex_panic!("Expected indices from mask")
426 };
427 assert_eq!(indices, [0, 2, 4]);
428 }
429
430 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
431 len: usize,
432 unique_values: usize,
433 chunk_count: usize,
434 ) -> ArrayRef
435 where
436 StandardUniform: Distribution<T>,
437 {
438 let mut rng = StdRng::seed_from_u64(0);
439
440 (0..chunk_count)
441 .map(|_| {
442 let values = (0..unique_values)
443 .map(|_| rng.random::<T>())
444 .collect::<PrimitiveArray>();
445 let codes = (0..len)
446 .map(|_| {
447 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
448 })
449 .collect::<PrimitiveArray>();
450
451 DictArray::try_new(codes.into_array(), values.into_array())
452 .vortex_expect("DictArray creation should succeed in arbitrary impl")
453 .into_array()
454 })
455 .collect::<ChunkedArray>()
456 .into_array()
457 }
458
459 #[test]
460 fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
461 let len = 2;
462 let chunk_count = 2;
463 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
464
465 let mut builder = builder_with_capacity(
466 &DType::Primitive(PType::U64, NonNullable),
467 len * chunk_count,
468 );
469 array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
470
471 #[expect(deprecated)]
472 let into_prim = array.to_primitive();
473 let prim_into = builder.finish_into_canonical().into_primitive();
474
475 assert_arrays_eq!(into_prim, prim_into);
476 Ok(())
477 }
478
479 #[cfg_attr(miri, ignore)]
480 #[test]
481 fn test_dict_metadata() {
482 use prost::Message;
483
484 use super::DictMetadata;
485 use crate::test_harness::check_metadata;
486
487 check_metadata(
488 "dict.metadata",
489 &DictMetadata {
490 codes_ptype: PType::U64 as i32,
491 values_len: u32::MAX,
492 is_nullable_codes: None,
493 all_values_referenced: None,
494 }
495 .encode_to_vec(),
496 );
497 }
498}