vortex_array/arrays/dict/
array.rs1use vortex_buffer::BitBuffer;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::ArrayRef;
15use crate::ToCanonical;
16use crate::stats::ArrayStats;
17
18#[derive(Clone, prost::Message)]
19pub struct DictMetadata {
20 #[prost(uint32, tag = "1")]
21 pub(super) values_len: u32,
22 #[prost(enumeration = "PType", tag = "2")]
23 pub(super) codes_ptype: i32,
24 #[prost(optional, bool, tag = "3")]
26 pub(super) is_nullable_codes: Option<bool>,
27 #[prost(optional, bool, tag = "4")]
31 pub(super) all_values_referenced: Option<bool>,
32}
33
34#[derive(Debug, Clone)]
35pub struct DictArray {
36 pub(super) codes: ArrayRef,
37 pub(super) values: ArrayRef,
38 pub(super) stats_set: ArrayStats,
39 pub(super) dtype: DType,
40 pub(super) all_values_referenced: bool,
46}
47
48pub struct DictArrayParts {
49 pub codes: ArrayRef,
50 pub values: ArrayRef,
51 pub dtype: DType,
52}
53
54impl DictArray {
55 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
62 let dtype = values
63 .dtype()
64 .union_nullability(codes.dtype().nullability());
65 Self {
66 codes,
67 values,
68 stats_set: Default::default(),
69 dtype,
70 all_values_referenced: false,
71 }
72 }
73
74 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
84 self.all_values_referenced = all_values_referenced;
85
86 #[cfg(debug_assertions)]
87 {
88 use vortex_error::VortexExpect;
89 self.validate_all_values_referenced()
90 .vortex_expect("validation should succeed when all values are referenced")
91 }
92
93 self
94 }
95
96 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
101 Self::try_new(codes, values).vortex_expect("DictArray new")
102 }
103
104 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
116 if !codes.dtype().is_int() {
117 vortex_bail!(MismatchedTypes: "int", codes.dtype());
118 }
119
120 Ok(unsafe { Self::new_unchecked(codes, values) })
121 }
122
123 pub fn into_parts(self) -> DictArrayParts {
124 DictArrayParts {
125 codes: self.codes,
126 values: self.values,
127 dtype: self.dtype,
128 }
129 }
130
131 #[inline]
132 pub fn codes(&self) -> &ArrayRef {
133 &self.codes
134 }
135
136 #[inline]
137 pub fn values(&self) -> &ArrayRef {
138 &self.values
139 }
140
141 #[inline]
147 pub fn has_all_values_referenced(&self) -> bool {
148 self.all_values_referenced
149 }
150
151 pub fn validate_all_values_referenced(&self) -> VortexResult<()> {
158 if self.all_values_referenced {
159 if !self.codes().is_host() {
161 return Ok(());
162 }
163
164 let referenced_mask = self.compute_referenced_values_mask(true)?;
165 let all_referenced = referenced_mask.iter().all(|v| v);
166
167 vortex_ensure!(all_referenced, "value in dict not referenced");
168 }
169
170 Ok(())
171 }
172
173 pub fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
184 let codes_validity = self.codes().validity_mask()?;
185 let codes_primitive = self.codes().to_primitive();
186 let values_len = self.values().len();
187
188 let init_value = !referenced;
190 let referenced_value = referenced;
192
193 let mut values_vec = vec![init_value; values_len];
194 match codes_validity.bit_buffer() {
195 AllOr::All => {
196 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
197 #[allow(
198 clippy::cast_possible_truncation,
199 clippy::cast_sign_loss,
200 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
201 )]
202 for &code in codes_primitive.as_slice::<P>().iter() {
203 values_vec[code as usize] = referenced_value;
204 }
205 });
206 }
207 AllOr::None => {}
208 AllOr::Some(buf) => {
209 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
210 let codes = codes_primitive.as_slice::<P>();
211
212 #[allow(
213 clippy::cast_possible_truncation,
214 clippy::cast_sign_loss,
215 reason = "codes are non-negative indices; a negative signed code would wrap to a large usize and panic on the bounds-checked array index"
216 )]
217 buf.set_indices().for_each(|idx| {
218 values_vec[codes[idx] as usize] = referenced_value;
219 })
220 });
221 }
222 }
223
224 Ok(BitBuffer::collect_bool(values_len, |idx| values_vec[idx]))
225 }
226}
227
228#[cfg(test)]
229mod test {
230 #[allow(unused_imports)]
231 use itertools::Itertools;
232 use rand::Rng;
233 use rand::SeedableRng;
234 use rand::distr::Distribution;
235 use rand::distr::StandardUniform;
236 use rand::prelude::StdRng;
237 use vortex_buffer::BitBuffer;
238 use vortex_buffer::buffer;
239 use vortex_dtype::DType;
240 use vortex_dtype::NativePType;
241 use vortex_dtype::Nullability::NonNullable;
242 use vortex_dtype::PType;
243 use vortex_dtype::UnsignedPType;
244 use vortex_error::VortexExpect;
245 use vortex_error::VortexResult;
246 use vortex_error::vortex_panic;
247 use vortex_mask::AllOr;
248
249 use crate::Array;
250 use crate::ArrayRef;
251 use crate::IntoArray;
252 use crate::LEGACY_SESSION;
253 use crate::ToCanonical;
254 use crate::VortexSessionExecute;
255 use crate::arrays::ChunkedArray;
256 use crate::arrays::PrimitiveArray;
257 use crate::arrays::dict::DictArray;
258 use crate::assert_arrays_eq;
259 use crate::builders::builder_with_capacity;
260 use crate::validity::Validity;
261
262 #[test]
263 fn nullable_codes_validity() {
264 let dict = DictArray::try_new(
265 PrimitiveArray::new(
266 buffer![0u32, 1, 2, 2, 1],
267 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
268 )
269 .into_array(),
270 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
271 )
272 .unwrap();
273 let mask = dict.validity_mask().unwrap();
274 let AllOr::Some(indices) = mask.indices() else {
275 vortex_panic!("Expected indices from mask")
276 };
277 assert_eq!(indices, [0, 2, 4]);
278 }
279
280 #[test]
281 fn nullable_values_validity() {
282 let dict = DictArray::try_new(
283 buffer![0u32, 1, 2, 2, 1].into_array(),
284 PrimitiveArray::new(
285 buffer![3, 6, 9],
286 Validity::from(BitBuffer::from(vec![true, false, false])),
287 )
288 .into_array(),
289 )
290 .unwrap();
291 let mask = dict.validity_mask().unwrap();
292 let AllOr::Some(indices) = mask.indices() else {
293 vortex_panic!("Expected indices from mask")
294 };
295 assert_eq!(indices, [0]);
296 }
297
298 #[test]
299 fn nullable_codes_and_values() {
300 let dict = DictArray::try_new(
301 PrimitiveArray::new(
302 buffer![0u32, 1, 2, 2, 1],
303 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
304 )
305 .into_array(),
306 PrimitiveArray::new(
307 buffer![3, 6, 9],
308 Validity::from(BitBuffer::from(vec![false, true, true])),
309 )
310 .into_array(),
311 )
312 .unwrap();
313 let mask = dict.validity_mask().unwrap();
314 let AllOr::Some(indices) = mask.indices() else {
315 vortex_panic!("Expected indices from mask")
316 };
317 assert_eq!(indices, [2, 4]);
318 }
319
320 #[test]
321 fn nullable_codes_and_non_null_values() {
322 let dict = DictArray::try_new(
323 PrimitiveArray::new(
324 buffer![0u32, 1, 2, 2, 1],
325 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
326 )
327 .into_array(),
328 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
329 )
330 .unwrap();
331 let mask = dict.validity_mask().unwrap();
332 let AllOr::Some(indices) = mask.indices() else {
333 vortex_panic!("Expected indices from mask")
334 };
335 assert_eq!(indices, [0, 2, 4]);
336 }
337
338 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
339 len: usize,
340 unique_values: usize,
341 chunk_count: usize,
342 ) -> ArrayRef
343 where
344 StandardUniform: Distribution<T>,
345 {
346 let mut rng = StdRng::seed_from_u64(0);
347
348 (0..chunk_count)
349 .map(|_| {
350 let values = (0..unique_values)
351 .map(|_| rng.random::<T>())
352 .collect::<PrimitiveArray>();
353 let codes = (0..len)
354 .map(|_| {
355 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
356 })
357 .collect::<PrimitiveArray>();
358
359 DictArray::try_new(codes.into_array(), values.into_array())
360 .vortex_expect("DictArray creation should succeed in arbitrary impl")
361 .into_array()
362 })
363 .collect::<ChunkedArray>()
364 .into_array()
365 }
366
367 #[test]
368 fn test_dict_array_from_primitive_chunks() -> VortexResult<()> {
369 let len = 2;
370 let chunk_count = 2;
371 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
372
373 let mut builder = builder_with_capacity(
374 &DType::Primitive(PType::U64, NonNullable),
375 len * chunk_count,
376 );
377 array
378 .clone()
379 .append_to_builder(builder.as_mut(), &mut LEGACY_SESSION.create_execution_ctx())?;
380
381 let into_prim = array.to_primitive();
382 let prim_into = builder.finish_into_canonical().into_primitive();
383
384 assert_arrays_eq!(into_prim, prim_into);
385 Ok(())
386 }
387
388 #[cfg_attr(miri, ignore)]
389 #[test]
390 fn test_dict_metadata() {
391 use super::DictMetadata;
392 use crate::ProstMetadata;
393 use crate::test_harness::check_metadata;
394
395 check_metadata(
396 "dict.metadata",
397 ProstMetadata(DictMetadata {
398 codes_ptype: PType::U64 as i32,
399 values_len: u32::MAX,
400 is_nullable_codes: None,
401 all_values_referenced: None,
402 }),
403 );
404 }
405}