vortex_array/arrays/dict/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use smallvec::smallvec;
8use vortex_buffer::BitBuffer;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_ensure;
13use vortex_mask::AllOr;
14
15use crate::ArrayRef;
16use crate::ArraySlots;
17use crate::LEGACY_SESSION;
18#[expect(deprecated)]
19use crate::ToCanonical as _;
20use crate::VortexSessionExecute;
21use crate::array::Array;
22use crate::array::ArrayParts;
23use crate::array::TypedArrayRef;
24use crate::array_slots;
25use crate::arrays::Dict;
26use crate::dtype::DType;
27use crate::dtype::PType;
28use crate::match_each_integer_ptype;
29
30#[derive(Clone, prost::Message)]
31pub struct DictMetadata {
32 #[prost(uint32, tag = "1")]
33 pub(super) values_len: u32,
34 #[prost(enumeration = "PType", tag = "2")]
35 pub(super) codes_ptype: i32,
36 #[prost(optional, bool, tag = "3")]
38 pub(super) is_nullable_codes: Option<bool>,
39 #[prost(optional, bool, tag = "4")]
43 pub(super) all_values_referenced: Option<bool>,
44}
45
46#[array_slots(Dict)]
47pub struct DictSlots {
48 pub codes: ArrayRef,
50 pub values: ArrayRef,
52}
53
54#[derive(Debug, Clone)]
55pub struct DictData {
56 pub(super) all_values_referenced: bool,
62}
63
64impl Display for DictData {
65 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
66 write!(f, "all_values_referenced: {}", self.all_values_referenced)
67 }
68}
69
70impl DictData {
71 pub unsafe fn new_unchecked() -> Self {
78 Self {
79 all_values_referenced: false,
80 }
81 }
82
83 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
93 self.all_values_referenced = all_values_referenced;
94 self
95 }
96
97 pub fn new(codes_dtype: &DType) -> Self {
102 Self::try_new(codes_dtype).vortex_expect("DictArray new")
103 }
104
105 pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
117 if !codes_dtype.is_int() {
118 vortex_bail!(MismatchedTypes: "int", codes_dtype);
119 }
120
121 Ok(unsafe { Self::new_unchecked() })
122 }
123}
124
125pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
126 #[inline]
127 fn has_all_values_referenced(&self) -> bool {
128 self.all_values_referenced
129 }
130
131 fn validate_all_values_referenced(&self) -> VortexResult<()> {
132 if self.has_all_values_referenced() {
133 if !self.codes().is_host() {
134 return Ok(());
135 }
136
137 let referenced_mask = self.compute_referenced_values_mask(true)?;
138 let all_referenced = referenced_mask.iter().all(|v| v);
139
140 vortex_ensure!(all_referenced, "value in dict not referenced");
141 }
142
143 Ok(())
144 }
145
146 fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
147 let codes = self.codes();
148 let codes_validity = codes
149 .validity()?
150 .execute_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
151 #[expect(deprecated)]
152 let codes_primitive = self.codes().to_primitive();
153 let values_len = self.values().len();
154
155 let init_value = !referenced;
156 let referenced_value = referenced;
157
158 let mut values_vec = vec![init_value; values_len];
159 match codes_validity.bit_buffer() {
160 AllOr::All => {
161 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
162 #[allow(
163 clippy::cast_possible_truncation,
164 clippy::cast_sign_loss,
165 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
166 )]
167 for &idx in codes_primitive.as_slice::<P>() {
168 values_vec[idx as usize] = referenced_value;
169 }
170 });
171 }
172 AllOr::None => {}
173 AllOr::Some(mask) => {
174 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
175 let codes = codes_primitive.as_slice::<P>();
176
177 #[allow(
178 clippy::cast_possible_truncation,
179 clippy::cast_sign_loss,
180 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
181 )]
182 mask.set_indices().for_each(|idx| {
183 values_vec[codes[idx] as usize] = referenced_value;
184 });
185 });
186 }
187 }
188
189 Ok(BitBuffer::from(values_vec))
190 }
191}
192impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
193
194pub struct DictParts {
196 pub dtype: DType,
197 pub codes: ArrayRef,
198 pub values: ArrayRef,
199}
200
201pub trait DictOwnedExt {
202 fn into_parts(self) -> DictParts;
203}
204
205impl DictOwnedExt for Array<Dict> {
206 fn into_parts(self) -> DictParts {
207 match self.try_into_parts() {
208 Ok(array_parts) => {
209 let slots = DictSlots::from_slots(array_parts.slots);
210 DictParts {
211 dtype: array_parts.dtype,
212 codes: slots.codes,
213 values: slots.values,
214 }
215 }
216 Err(array) => {
217 let slots = DictSlotsView::from_slots(array.slots());
218 DictParts {
219 dtype: array.dtype().clone(),
220 codes: slots.codes.clone(),
221 values: slots.values.clone(),
222 }
223 }
224 }
225 }
226}
227
228impl Array<Dict> {
229 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
231 Self::try_new(codes, values).vortex_expect("DictArray new")
232 }
233
234 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
236 let dtype = values
237 .dtype()
238 .union_nullability(codes.dtype().nullability());
239 let len = codes.len();
240 let data = DictData::try_new(codes.dtype())?;
241 Array::try_from_parts(
242 ArrayParts::new(Dict, dtype, len, data)
243 .with_slots(smallvec![Some(codes), Some(values)]),
244 )
245 }
246
247 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
253 let dtype = values
254 .dtype()
255 .union_nullability(codes.dtype().nullability());
256 let len = codes.len();
257 let data = unsafe { DictData::new_unchecked() };
258 unsafe {
259 Array::from_parts_unchecked(
260 ArrayParts::new(Dict, dtype, len, data)
261 .with_slots(smallvec![Some(codes), Some(values)]),
262 )
263 }
264 }
265
266 pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
272 let dtype = self.dtype().clone();
273 let len = self.len();
274 let slots: ArraySlots = self.slots().iter().cloned().collect();
275 let data = unsafe {
276 self.into_data()
277 .set_all_values_referenced(all_values_referenced)
278 };
279 let array = unsafe {
280 Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
281 };
282
283 #[cfg(debug_assertions)]
284 if all_values_referenced {
285 array
286 .validate_all_values_referenced()
287 .vortex_expect("validation should succeed when all values are referenced");
288 }
289
290 array
291 }
292}
293
294#[cfg(test)]
295mod test {
296 use rand::RngExt;
297 use rand::SeedableRng;
298 use rand::distr::Distribution;
299 use rand::distr::StandardUniform;
300 use rand::prelude::StdRng;
301 use vortex_buffer::BitBuffer;
302 use vortex_buffer::buffer;
303 use vortex_error::VortexExpect;
304 use vortex_error::VortexResult;
305 use vortex_error::vortex_panic;
306 use vortex_mask::AllOr;
307
308 use crate::ArrayRef;
309 use crate::IntoArray;
310 use crate::LEGACY_SESSION;
311 #[expect(deprecated)]
312 use crate::ToCanonical as _;
313 use crate::VortexSessionExecute;
314 use crate::arrays::ChunkedArray;
315 use crate::arrays::DictArray;
316 use crate::arrays::PrimitiveArray;
317 use crate::assert_arrays_eq;
318 use crate::builders::builder_with_capacity;
319 use crate::dtype::DType;
320 use crate::dtype::NativePType;
321 use crate::dtype::Nullability::NonNullable;
322 use crate::dtype::PType;
323 use crate::dtype::UnsignedPType;
324 use crate::validity::Validity;
325
326 #[test]
327 fn nullable_codes_validity() {
328 let dict = DictArray::try_new(
329 PrimitiveArray::new(
330 buffer![0u32, 1, 2, 2, 1],
331 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
332 )
333 .into_array(),
334 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
335 )
336 .unwrap();
337 let mask = dict
338 .as_ref()
339 .validity()
340 .unwrap()
341 .execute_mask(
342 dict.as_ref().len(),
343 &mut LEGACY_SESSION.create_execution_ctx(),
344 )
345 .unwrap();
346 let AllOr::Some(indices) = mask.indices() else {
347 vortex_panic!("Expected indices from mask")
348 };
349 assert_eq!(indices, [0, 2, 4]);
350 }
351
352 #[test]
353 fn nullable_values_validity() {
354 let dict = DictArray::try_new(
355 buffer![0u32, 1, 2, 2, 1].into_array(),
356 PrimitiveArray::new(
357 buffer![3, 6, 9],
358 Validity::from(BitBuffer::from(vec![true, false, false])),
359 )
360 .into_array(),
361 )
362 .unwrap();
363 let mask = dict
364 .as_ref()
365 .validity()
366 .unwrap()
367 .execute_mask(
368 dict.as_ref().len(),
369 &mut LEGACY_SESSION.create_execution_ctx(),
370 )
371 .unwrap();
372 let AllOr::Some(indices) = mask.indices() else {
373 vortex_panic!("Expected indices from mask")
374 };
375 assert_eq!(indices, [0]);
376 }
377
378 #[test]
379 fn nullable_codes_and_values() {
380 let dict = DictArray::try_new(
381 PrimitiveArray::new(
382 buffer![0u32, 1, 2, 2, 1],
383 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
384 )
385 .into_array(),
386 PrimitiveArray::new(
387 buffer![3, 6, 9],
388 Validity::from(BitBuffer::from(vec![false, true, true])),
389 )
390 .into_array(),
391 )
392 .unwrap();
393 let mask = dict
394 .as_ref()
395 .validity()
396 .unwrap()
397 .execute_mask(
398 dict.as_ref().len(),
399 &mut LEGACY_SESSION.create_execution_ctx(),
400 )
401 .unwrap();
402 let AllOr::Some(indices) = mask.indices() else {
403 vortex_panic!("Expected indices from mask")
404 };
405 assert_eq!(indices, [2, 4]);
406 }
407
408 #[test]
409 fn nullable_codes_and_non_null_values() {
410 let dict = DictArray::try_new(
411 PrimitiveArray::new(
412 buffer![0u32, 1, 2, 2, 1],
413 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
414 )
415 .into_array(),
416 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
417 )
418 .unwrap();
419 let mask = dict
420 .as_ref()
421 .validity()
422 .unwrap()
423 .execute_mask(
424 dict.as_ref().len(),
425 &mut LEGACY_SESSION.create_execution_ctx(),
426 )
427 .unwrap();
428 let AllOr::Some(indices) = mask.indices() else {
429 vortex_panic!("Expected indices from mask")
430 };
431 assert_eq!(indices, [0, 2, 4]);
432 }
433
434 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
435 len: usize,
436 unique_values: usize,
437 chunk_count: usize,
438 ) -> ArrayRef
439 where
440 StandardUniform: Distribution<T>,
441 {
442 let mut rng = StdRng::seed_from_u64(0);
443
444 (0..chunk_count)
445 .map(|_| {
446 let values = (0..unique_values)
447 .map(|_| rng.random::<T>())
448 .collect::<PrimitiveArray>();
449 let codes = (0..len)
450 .map(|_| {
451 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
452 })
453 .collect::<PrimitiveArray>();
454
455 DictArray::try_new(codes.into_array(), values.into_array())
456 .vortex_expect("DictArray creation should succeed in arbitrary impl")
457 .into_array()
458 })
459 .collect::<ChunkedArray>()
460 .into_array()
461 }
462
463 #[test]
464 fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
465 let len = 2;
466 let chunk_count = 2;
467 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
468
469 let mut builder = builder_with_capacity(
470 &DType::Primitive(PType::U64, NonNullable),
471 len * chunk_count,
472 );
473 array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
474
475 #[expect(deprecated)]
476 let into_prim = array.to_primitive();
477 let prim_into = builder.finish_into_canonical().into_primitive();
478
479 assert_arrays_eq!(into_prim, prim_into);
480 Ok(())
481 }
482
483 #[cfg_attr(miri, ignore)]
484 #[test]
485 fn test_dict_metadata() {
486 use prost::Message;
487
488 use super::DictMetadata;
489 use crate::test_harness::check_metadata;
490
491 check_metadata(
492 "dict.metadata",
493 &DictMetadata {
494 codes_ptype: PType::U64 as i32,
495 values_len: u32::MAX,
496 is_nullable_codes: None,
497 all_values_referenced: None,
498 }
499 .encode_to_vec(),
500 );
501 }
502}