vortex_array/arrays/dict/
array.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use vortex_buffer::BitBuffer;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::ToCanonical;
16use crate::array::Array;
17use crate::array::ArrayParts;
18use crate::array::TypedArrayRef;
19use crate::array_slots;
20use crate::arrays::Dict;
21use crate::dtype::DType;
22use crate::dtype::PType;
23use crate::match_each_integer_ptype;
24
25#[derive(Clone, prost::Message)]
26pub struct DictMetadata {
27 #[prost(uint32, tag = "1")]
28 pub(super) values_len: u32,
29 #[prost(enumeration = "PType", tag = "2")]
30 pub(super) codes_ptype: i32,
31 #[prost(optional, bool, tag = "3")]
33 pub(super) is_nullable_codes: Option<bool>,
34 #[prost(optional, bool, tag = "4")]
38 pub(super) all_values_referenced: Option<bool>,
39}
40
41#[array_slots(Dict)]
42pub struct DictSlots {
43 pub codes: ArrayRef,
45 pub values: ArrayRef,
47}
48
49#[derive(Debug, Clone)]
50pub struct DictData {
51 pub(super) all_values_referenced: bool,
57}
58
59impl Display for DictData {
60 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61 write!(f, "all_values_referenced: {}", self.all_values_referenced)
62 }
63}
64
65impl DictData {
66 pub unsafe fn new_unchecked() -> Self {
73 Self {
74 all_values_referenced: false,
75 }
76 }
77
78 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
88 self.all_values_referenced = all_values_referenced;
89 self
90 }
91
92 pub fn new(codes_dtype: &DType) -> Self {
97 Self::try_new(codes_dtype).vortex_expect("DictArray new")
98 }
99
100 pub(crate) fn try_new(codes_dtype: &DType) -> VortexResult<Self> {
112 if !codes_dtype.is_int() {
113 vortex_bail!(MismatchedTypes: "int", codes_dtype);
114 }
115
116 Ok(unsafe { Self::new_unchecked() })
117 }
118}
119
120pub trait DictArrayExt: TypedArrayRef<Dict> + DictArraySlotsExt {
121 #[inline]
122 fn has_all_values_referenced(&self) -> bool {
123 self.all_values_referenced
124 }
125
126 fn validate_all_values_referenced(&self) -> VortexResult<()> {
127 if self.has_all_values_referenced() {
128 if !self.codes().is_host() {
129 return Ok(());
130 }
131
132 let referenced_mask = self.compute_referenced_values_mask(true)?;
133 let all_referenced = referenced_mask.iter().all(|v| v);
134
135 vortex_ensure!(all_referenced, "value in dict not referenced");
136 }
137
138 Ok(())
139 }
140
141 #[allow(
142 clippy::cognitive_complexity,
143 reason = "branching depends on validity representation and code type"
144 )]
145 fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
146 let codes_validity = self.codes().validity_mask()?;
147 let codes_primitive = self.codes().to_primitive();
148 let values_len = self.values().len();
149
150 let init_value = !referenced;
151 let referenced_value = referenced;
152
153 let mut values_vec = vec![init_value; values_len];
154 match codes_validity.bit_buffer() {
155 AllOr::All => {
156 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
157 #[allow(
158 clippy::cast_possible_truncation,
159 clippy::cast_sign_loss,
160 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
161 )]
162 for &idx in codes_primitive.as_slice::<P>() {
163 values_vec[idx as usize] = referenced_value;
164 }
165 });
166 }
167 AllOr::None => {}
168 AllOr::Some(mask) => {
169 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
170 let codes = codes_primitive.as_slice::<P>();
171
172 #[allow(
173 clippy::cast_possible_truncation,
174 clippy::cast_sign_loss,
175 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
176 )]
177 mask.set_indices().for_each(|idx| {
178 values_vec[codes[idx] as usize] = referenced_value;
179 });
180 });
181 }
182 }
183
184 Ok(BitBuffer::from(values_vec))
185 }
186}
187impl<T: TypedArrayRef<Dict>> DictArrayExt for T {}
188
189impl Array<Dict> {
190 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
192 Self::try_new(codes, values).vortex_expect("DictArray new")
193 }
194
195 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
197 let dtype = values
198 .dtype()
199 .union_nullability(codes.dtype().nullability());
200 let len = codes.len();
201 let data = DictData::try_new(codes.dtype())?;
202 Array::try_from_parts(
203 ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
204 )
205 }
206
207 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
213 let dtype = values
214 .dtype()
215 .union_nullability(codes.dtype().nullability());
216 let len = codes.len();
217 let data = unsafe { DictData::new_unchecked() };
218 unsafe {
219 Array::from_parts_unchecked(
220 ArrayParts::new(Dict, dtype, len, data).with_slots(vec![Some(codes), Some(values)]),
221 )
222 }
223 }
224
225 pub unsafe fn set_all_values_referenced(self, all_values_referenced: bool) -> Self {
231 let dtype = self.dtype().clone();
232 let len = self.len();
233 let slots = self.slots().to_vec();
234 let data = unsafe {
235 self.into_data()
236 .set_all_values_referenced(all_values_referenced)
237 };
238 let array = unsafe {
239 Array::from_parts_unchecked(ArrayParts::new(Dict, dtype, len, data).with_slots(slots))
240 };
241
242 #[cfg(debug_assertions)]
243 if all_values_referenced {
244 array
245 .validate_all_values_referenced()
246 .vortex_expect("validation should succeed when all values are referenced");
247 }
248
249 array
250 }
251}
252
253#[cfg(test)]
254mod test {
255 #[allow(unused_imports)]
256 use itertools::Itertools;
257 use rand::RngExt;
258 use rand::SeedableRng;
259 use rand::distr::Distribution;
260 use rand::distr::StandardUniform;
261 use rand::prelude::StdRng;
262 use vortex_buffer::BitBuffer;
263 use vortex_buffer::buffer;
264 use vortex_error::VortexExpect;
265 use vortex_error::VortexResult;
266 use vortex_error::vortex_panic;
267 use vortex_mask::AllOr;
268
269 use crate::ArrayRef;
270 use crate::IntoArray;
271 use crate::LEGACY_SESSION;
272 use crate::ToCanonical;
273 use crate::VortexSessionExecute;
274 use crate::arrays::ChunkedArray;
275 use crate::arrays::DictArray;
276 use crate::arrays::PrimitiveArray;
277 use crate::assert_arrays_eq;
278 use crate::builders::builder_with_capacity;
279 use crate::dtype::DType;
280 use crate::dtype::NativePType;
281 use crate::dtype::Nullability::NonNullable;
282 use crate::dtype::PType;
283 use crate::dtype::UnsignedPType;
284 use crate::validity::Validity;
285
286 #[test]
287 fn nullable_codes_validity() {
288 let dict = DictArray::try_new(
289 PrimitiveArray::new(
290 buffer![0u32, 1, 2, 2, 1],
291 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
292 )
293 .into_array(),
294 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
295 )
296 .unwrap();
297 let mask = dict.validity_mask().unwrap();
298 let AllOr::Some(indices) = mask.indices() else {
299 vortex_panic!("Expected indices from mask")
300 };
301 assert_eq!(indices, [0, 2, 4]);
302 }
303
304 #[test]
305 fn nullable_values_validity() {
306 let dict = DictArray::try_new(
307 buffer![0u32, 1, 2, 2, 1].into_array(),
308 PrimitiveArray::new(
309 buffer![3, 6, 9],
310 Validity::from(BitBuffer::from(vec![true, false, false])),
311 )
312 .into_array(),
313 )
314 .unwrap();
315 let mask = dict.validity_mask().unwrap();
316 let AllOr::Some(indices) = mask.indices() else {
317 vortex_panic!("Expected indices from mask")
318 };
319 assert_eq!(indices, [0]);
320 }
321
322 #[test]
323 fn nullable_codes_and_values() {
324 let dict = DictArray::try_new(
325 PrimitiveArray::new(
326 buffer![0u32, 1, 2, 2, 1],
327 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
328 )
329 .into_array(),
330 PrimitiveArray::new(
331 buffer![3, 6, 9],
332 Validity::from(BitBuffer::from(vec![false, true, true])),
333 )
334 .into_array(),
335 )
336 .unwrap();
337 let mask = dict.validity_mask().unwrap();
338 let AllOr::Some(indices) = mask.indices() else {
339 vortex_panic!("Expected indices from mask")
340 };
341 assert_eq!(indices, [2, 4]);
342 }
343
344 #[test]
345 fn nullable_codes_and_non_null_values() {
346 let dict = DictArray::try_new(
347 PrimitiveArray::new(
348 buffer![0u32, 1, 2, 2, 1],
349 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
350 )
351 .into_array(),
352 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
353 )
354 .unwrap();
355 let mask = dict.validity_mask().unwrap();
356 let AllOr::Some(indices) = mask.indices() else {
357 vortex_panic!("Expected indices from mask")
358 };
359 assert_eq!(indices, [0, 2, 4]);
360 }
361
362 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
363 len: usize,
364 unique_values: usize,
365 chunk_count: usize,
366 ) -> ArrayRef
367 where
368 StandardUniform: Distribution<T>,
369 {
370 let mut rng = StdRng::seed_from_u64(0);
371
372 (0..chunk_count)
373 .map(|_| {
374 let values = (0..unique_values)
375 .map(|_| rng.random::<T>())
376 .collect::<PrimitiveArray>();
377 let codes = (0..len)
378 .map(|_| {
379 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
380 })
381 .collect::<PrimitiveArray>();
382
383 DictArray::try_new(codes.into_array(), values.into_array())
384 .vortex_expect("DictArray creation should succeed in arbitrary impl")
385 .into_array()
386 })
387 .collect::<ChunkedArray>()
388 .into_array()
389 }
390
391 #[test]
392 fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
393 let len = 2;
394 let chunk_count = 2;
395 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
396
397 let mut builder = builder_with_capacity(
398 &DType::Primitive(PType::U64, NonNullable),
399 len * chunk_count,
400 );
401 array.append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
402
403 let into_prim = array.to_primitive();
404 let prim_into = builder.finish_into_canonical().into_primitive();
405
406 assert_arrays_eq!(into_prim, prim_into);
407 Ok(())
408 }
409
410 #[cfg_attr(miri, ignore)]
411 #[test]
412 fn test_dict_metadata() {
413 use prost::Message;
414
415 use super::DictMetadata;
416 use crate::test_harness::check_metadata;
417
418 check_metadata(
419 "dict.metadata",
420 &DictMetadata {
421 codes_ptype: PType::U64 as i32,
422 values_len: u32::MAX,
423 is_nullable_codes: None,
424 all_values_referenced: None,
425 }
426 .encode_to_vec(),
427 );
428 }
429}