vortex_array/arrays/dict/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::LEGACY_SESSION;
16use crate::ToCanonical;
17use crate::VortexSessionExecute;
18use crate::array::Array;
19use crate::array::ArrayParts;
20use crate::array::TypedArrayRef;
21use crate::array_slots;
22use crate::arrays::Dict;
23use crate::dtype::DType;
24use crate::dtype::PType;
25use crate::match_each_integer_ptype;
26
27#[derive(Clone, prost::Message)]
28pub struct DictMetadata {
29 #[prost(uint32, tag = "1")]
30 pub(super) values_len: u32,
31 #[prost(enumeration = "PType", tag = "2")]
32 pub(super) codes_ptype: i32,
33 #[prost(optional, bool, tag = "3")]
35 pub(super) is_nullable_codes: Option<bool>,
36 #[prost(optional, bool, tag = "4")]
40 pub(super) all_values_referenced: Option<bool>,
41}
42
43#[array_slots(Dict)]
44pub struct DictSlots {
45 pub codes: ArrayRef,
47 pub values: ArrayRef,
49}
50
51#[derive(Debug, Clone)]
52pub struct DictData {
53 pub(super) all_values_referenced: bool,
59}
60
61impl Display for DictData {
62 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
63 write!(f, "all_values_referenced: {}", self.all_values_referenced)
64 }
65}
66
67impl DictData {
68 pub unsafe fn new_unchecked() -> Self {
75 Self {
76 all_values_referenced: false,
77 }
78 }
79
80 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
90 self.all_values_referenced = all_values_referenced;
91 self
92 }
93
94 pub fn new(codes_dtype: &DType) -> Self {
99 Self::try_new(codes_dtype).vortex_expect("DictArray new")
100 }
101
102 pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
114 if !codes_dtype.is_int() {
115 vortex_bail!(MismatchedTypes: "int", codes_dtype);
116 }
117
118 Ok(unsafe { Self::new_unchecked() })
119 }
120}
121
122pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
123 #[inline]
124 fn has_all_values_referenced(&self) -> bool {
125 self.all_values_referenced
126 }
127
128 fn validate_all_values_referenced(&self) -> VortexResult<()> {
129 if self.has_all_values_referenced() {
130 if !self.codes().is_host() {
131 return Ok(());
132 }
133
134 let referenced_mask = self.compute_referenced_values_mask(true)?;
135 let all_referenced = referenced_mask.iter().all(|v| v);
136
137 vortex_ensure!(all_referenced, "value in dict not referenced");
138 }
139
140 Ok(())
141 }
142
143 fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
144 let codes = self.codes();
145 let codes_validity = codes
146 .validity()?
147 .to_mask(codes.len(), &mut LEGACY_SESSION.create_execution_ctx())?;
148 let codes_primitive = self.codes().to_primitive();
149 let values_len = self.values().len();
150
151 let init_value = !referenced;
152 let referenced_value = referenced;
153
154 let mut values_vec = vec![init_value; values_len];
155 match codes_validity.bit_buffer() {
156 AllOr::All => {
157 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
158 #[allow(
159 clippy::cast_possible_truncation,
160 clippy::cast_sign_loss,
161 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
162 )]
163 for &idx in codes_primitive.as_slice::<P>() {
164 values_vec[idx as usize] = referenced_value;
165 }
166 });
167 }
168 AllOr::None => {}
169 AllOr::Some(mask) => {
170 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
171 let codes = codes_primitive.as_slice::<P>();
172
173 #[allow(
174 clippy::cast_possible_truncation,
175 clippy::cast_sign_loss,
176 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
177 )]
178 mask.set_indices().for_each(|idx| {
179 values_vec[codes[idx] as usize] = referenced_value;
180 });
181 });
182 }
183 }
184
185 Ok(BitBuffer::from(values_vec))
186 }
187}
188impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
189
190impl Array<Dict> {
191 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
193 Self::try_new(codes, values).vortex_expect("DictArray new")
194 }
195
196 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
198 let dtype = values
199 .dtype()
200 .union_nullability(codes.dtype().nullability());
201 let len = codes.len();
202 let data = DictData::try_new(codes.dtype())?;
203 Array::try_from_parts(
204 ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
205 )
206 }
207
208 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
214 let dtype = values
215 .dtype()
216 .union_nullability(codes.dtype().nullability());
217 let len = codes.len();
218 let data = unsafe { DictData::new_unchecked() };
219 unsafe {
220 Array::from_parts_unchecked(
221 ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
222 )
223 }
224 }
225
226 pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
232 let dtype = self.dtype().clone();
233 let len = self.len();
234 let slots = self.slots().to_vec();
235 let data = unsafe {
236 self.into_data()
237 .set_all_values_referenced(all_values_referenced)
238 };
239 let array = unsafe {
240 Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
241 };
242
243 #[cfg(debug_assertions)]
244 if all_values_referenced {
245 array
246 .validate_all_values_referenced()
247 .vortex_expect("validation should succeed when all values are referenced");
248 }
249
250 array
251 }
252}
253
254#[cfg(test)]
255mod test {
256 #[expect(unused_imports)]
257 use itertools::Itertools;
258 use rand::RngExt;
259 use rand::SeedableRng;
260 use rand::distr::Distribution;
261 use rand::distr::StandardUniform;
262 use rand::prelude::StdRng;
263 use vortex_buffer::BitBuffer;
264 use vortex_buffer::buffer;
265 use vortex_error::VortexExpect;
266 use vortex_error::VortexResult;
267 use vortex_error::vortex_panic;
268 use vortex_mask::AllOr;
269
270 use crate::ArrayRef;
271 use crate::IntoArray;
272 use crate::LEGACY_SESSION;
273 use crate::ToCanonical;
274 use crate::VortexSessionExecute;
275 use crate::arrays::ChunkedArray;
276 use crate::arrays::DictArray;
277 use crate::arrays::PrimitiveArray;
278 use crate::assert_arrays_eq;
279 use crate::builders::builder_with_capacity;
280 use crate::dtype::DType;
281 use crate::dtype::NativePType;
282 use crate::dtype::Nullability::NonNullable;
283 use crate::dtype::PType;
284 use crate::dtype::UnsignedPType;
285 use crate::validity::Validity;
286
287 #[test]
288 fn nullable_codes_validity() {
289 let dict = DictArray::try_new(
290 PrimitiveArray::new(
291 buffer![0u32, 1, 2, 2, 1],
292 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
293 )
294 .into_array(),
295 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
296 )
297 .unwrap();
298 let mask = dict
299 .as_ref()
300 .validity()
301 .unwrap()
302 .to_mask(
303 dict.as_ref().len(),
304 &mut LEGACY_SESSION.create_execution_ctx(),
305 )
306 .unwrap();
307 let AllOr::Some(indices) = mask.indices() else {
308 vortex_panic!("Expected indices from mask")
309 };
310 assert_eq!(indices, [0, 2, 4]);
311 }
312
313 #[test]
314 fn nullable_values_validity() {
315 let dict = DictArray::try_new(
316 buffer![0u32, 1, 2, 2, 1].into_array(),
317 PrimitiveArray::new(
318 buffer![3, 6, 9],
319 Validity::from(BitBuffer::from(vec![true, false, false])),
320 )
321 .into_array(),
322 )
323 .unwrap();
324 let mask = dict
325 .as_ref()
326 .validity()
327 .unwrap()
328 .to_mask(
329 dict.as_ref().len(),
330 &mut LEGACY_SESSION.create_execution_ctx(),
331 )
332 .unwrap();
333 let AllOr::Some(indices) = mask.indices() else {
334 vortex_panic!("Expected indices from mask")
335 };
336 assert_eq!(indices, [0]);
337 }
338
339 #[test]
340 fn nullable_codes_and_values() {
341 let dict = DictArray::try_new(
342 PrimitiveArray::new(
343 buffer![0u32, 1, 2, 2, 1],
344 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
345 )
346 .into_array(),
347 PrimitiveArray::new(
348 buffer![3, 6, 9],
349 Validity::from(BitBuffer::from(vec![false, true, true])),
350 )
351 .into_array(),
352 )
353 .unwrap();
354 let mask = dict
355 .as_ref()
356 .validity()
357 .unwrap()
358 .to_mask(
359 dict.as_ref().len(),
360 &mut LEGACY_SESSION.create_execution_ctx(),
361 )
362 .unwrap();
363 let AllOr::Some(indices) = mask.indices() else {
364 vortex_panic!("Expected indices from mask")
365 };
366 assert_eq!(indices, [2, 4]);
367 }
368
369 #[test]
370 fn nullable_codes_and_non_null_values() {
371 let dict = DictArray::try_new(
372 PrimitiveArray::new(
373 buffer![0u32, 1, 2, 2, 1],
374 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
375 )
376 .into_array(),
377 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
378 )
379 .unwrap();
380 let mask = dict
381 .as_ref()
382 .validity()
383 .unwrap()
384 .to_mask(
385 dict.as_ref().len(),
386 &mut LEGACY_SESSION.create_execution_ctx(),
387 )
388 .unwrap();
389 let AllOr::Some(indices) = mask.indices() else {
390 vortex_panic!("Expected indices from mask")
391 };
392 assert_eq!(indices, [0, 2, 4]);
393 }
394
395 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
396 len: usize,
397 unique_values: usize,
398 chunk_count: usize,
399 ) -> ArrayRef
400 where
401 StandardUniform: Distribution<T>,
402 {
403 let mut rng = StdRng::seed_from_u64(0);
404
405 (0..chunk_count)
406 .map(|_| {
407 let values = (0..unique_values)
408 .map(|_| rng.random::<T>())
409 .collect::<PrimitiveArray>();
410 let codes = (0..len)
411 .map(|_| {
412 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
413 })
414 .collect::<PrimitiveArray>();
415
416 DictArray::try_new(codes.into_array(), values.into_array())
417 .vortex_expect("DictArray creation should succeed in arbitrary impl")
418 .into_array()
419 })
420 .collect::<ChunkedArray>()
421 .into_array()
422 }
423
424 #[test]
425 fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
426 let len = 2;
427 let chunk_count = 2;
428 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
429
430 let mut builder = builder_with_capacity(
431 &DType::Primitive(PType::U64, NonNullable),
432 len * chunk_count,
433 );
434 array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
435
436 let into_prim = array.to_primitive();
437 let prim_into = builder.finish_into_canonical().into_primitive();
438
439 assert_arrays_eq!(into_prim, prim_into);
440 Ok(())
441 }
442
443 #[cfg_attr(miri, ignore)]
444 #[test]
445 fn test_dict_metadata() {
446 use prost::Message;
447
448 use super::DictMetadata;
449 use crate::test_harness::check_metadata;
450
451 check_metadata(
452 "dict.metadata",
453 &DictMetadata {
454 codes_ptype: PType::U64 as i32,
455 values_len: u32::MAX,
456 is_nullable_codes: None,
457 all_values_referenced: None,
458 }
459 .encode_to_vec(),
460 );
461 }
462}