vortex_array/arrays/dict/
array.rs1use vortex_buffer::BitBuffer;
5use vortex_dtype::DType;
6use vortex_dtype::PType;
7use vortex_dtype::match_each_integer_ptype;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_mask::AllOr;
13
14use crate::Array;
15use crate::ArrayRef;
16use crate::ToCanonical;
17use crate::stats::ArrayStats;
18
19#[derive(Clone, prost::Message)]
20pub struct DictMetadata {
21 #[prost(uint32, tag = "1")]
22 pub(super) values_len: u32,
23 #[prost(enumeration = "PType", tag = "2")]
24 pub(super) codes_ptype: i32,
25 #[prost(optional, bool, tag = "3")]
27 pub(super) is_nullable_codes: Option<bool>,
28 #[prost(optional, bool, tag = "4")]
32 pub(super) all_values_referenced: Option<bool>,
33}
34
35#[derive(Debug, Clone)]
36pub struct DictArray {
37 pub(super) codes: ArrayRef,
38 pub(super) values: ArrayRef,
39 pub(super) stats_set: ArrayStats,
40 pub(super) dtype: DType,
41 pub(super) all_values_referenced: bool,
47}
48
49impl DictArray {
50 pub unsafe fn new_unchecked(codes: ArrayRef, values: ArrayRef) -> Self {
57 let dtype = values
58 .dtype()
59 .union_nullability(codes.dtype().nullability());
60 Self {
61 codes,
62 values,
63 stats_set: Default::default(),
64 dtype,
65 all_values_referenced: false,
66 }
67 }
68
69 pub unsafe fn set_all_values_referenced(mut self, all_values_referenced: bool) -> Self {
79 self.all_values_referenced = all_values_referenced;
80
81 #[cfg(debug_assertions)]
82 {
83 use vortex_error::VortexUnwrap;
84 self.validate_all_values_referenced().vortex_unwrap()
85 }
86
87 self
88 }
89
90 pub fn new(codes: ArrayRef, values: ArrayRef) -> Self {
95 Self::try_new(codes, values).vortex_expect("DictArray new")
96 }
97
98 pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
110 if !codes.dtype().is_unsigned_int() {
111 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
112 }
113
114 Ok(unsafe { Self::new_unchecked(codes, values) })
115 }
116
117 #[inline]
118 pub fn codes(&self) -> &ArrayRef {
119 &self.codes
120 }
121
122 #[inline]
123 pub fn values(&self) -> &ArrayRef {
124 &self.values
125 }
126
127 #[inline]
133 pub fn has_all_values_referenced(&self) -> bool {
134 self.all_values_referenced
135 }
136
137 pub fn validate_all_values_referenced(&self) -> VortexResult<()> {
144 if self.all_values_referenced {
145 let referenced_mask = self.compute_referenced_values_mask(true)?;
146 let all_referenced = referenced_mask.iter().all(|v| v);
147
148 vortex_ensure!(all_referenced, "value in dict not referenced");
149 }
150
151 Ok(())
152 }
153
154 pub fn compute_referenced_values_mask(&self, referenced: bool) -> VortexResult<BitBuffer> {
165 let codes_validity = self.codes().validity_mask();
166 let codes_primitive = self.codes().to_primitive();
167 let values_len = self.values().len();
168
169 let init_value = !referenced;
171 let referenced_value = referenced;
173
174 let mut values_vec = vec![init_value; values_len];
175 match codes_validity.bit_buffer() {
176 AllOr::All => {
177 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
178 #[allow(clippy::cast_possible_truncation)]
179 for &code in codes_primitive.as_slice::<P>().iter() {
180 values_vec[code as usize] = referenced_value;
181 }
182 });
183 }
184 AllOr::None => {}
185 AllOr::Some(buf) => {
186 match_each_integer_ptype!(codes_primitive.ptype(), |P| {
187 let codes = codes_primitive.as_slice::<P>();
188
189 #[allow(clippy::cast_possible_truncation)]
190 buf.set_indices().for_each(|idx| {
191 values_vec[codes[idx] as usize] = referenced_value;
192 })
193 });
194 }
195 }
196
197 Ok(BitBuffer::collect_bool(values_len, |idx| values_vec[idx]))
198 }
199}
200
201#[cfg(test)]
202mod test {
203 #[allow(unused_imports)]
204 use itertools::Itertools;
205 use rand::Rng;
206 use rand::SeedableRng;
207 use rand::distr::Distribution;
208 use rand::distr::StandardUniform;
209 use rand::prelude::StdRng;
210 use vortex_buffer::BitBuffer;
211 use vortex_buffer::buffer;
212 use vortex_dtype::DType;
213 use vortex_dtype::NativePType;
214 use vortex_dtype::Nullability::NonNullable;
215 use vortex_dtype::PType;
216 use vortex_dtype::UnsignedPType;
217 use vortex_error::VortexExpect;
218 use vortex_error::VortexUnwrap;
219 use vortex_error::vortex_panic;
220 use vortex_mask::AllOr;
221
222 use crate::Array;
223 use crate::ArrayRef;
224 use crate::IntoArray;
225 use crate::ToCanonical;
226 use crate::arrays::ChunkedArray;
227 use crate::arrays::PrimitiveArray;
228 use crate::arrays::dict::DictArray;
229 use crate::assert_arrays_eq;
230 use crate::builders::builder_with_capacity;
231 use crate::validity::Validity;
232
233 #[test]
234 fn nullable_codes_validity() {
235 let dict = DictArray::try_new(
236 PrimitiveArray::new(
237 buffer![0u32, 1, 2, 2, 1],
238 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
239 )
240 .into_array(),
241 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
242 )
243 .unwrap();
244 let mask = dict.validity_mask();
245 let AllOr::Some(indices) = mask.indices() else {
246 vortex_panic!("Expected indices from mask")
247 };
248 assert_eq!(indices, [0, 2, 4]);
249 }
250
251 #[test]
252 fn nullable_values_validity() {
253 let dict = DictArray::try_new(
254 buffer![0u32, 1, 2, 2, 1].into_array(),
255 PrimitiveArray::new(
256 buffer![3, 6, 9],
257 Validity::from(BitBuffer::from(vec![true, false, false])),
258 )
259 .into_array(),
260 )
261 .unwrap();
262 let mask = dict.validity_mask();
263 let AllOr::Some(indices) = mask.indices() else {
264 vortex_panic!("Expected indices from mask")
265 };
266 assert_eq!(indices, [0]);
267 }
268
269 #[test]
270 fn nullable_codes_and_values() {
271 let dict = DictArray::try_new(
272 PrimitiveArray::new(
273 buffer![0u32, 1, 2, 2, 1],
274 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
275 )
276 .into_array(),
277 PrimitiveArray::new(
278 buffer![3, 6, 9],
279 Validity::from(BitBuffer::from(vec![false, true, true])),
280 )
281 .into_array(),
282 )
283 .unwrap();
284 let mask = dict.validity_mask();
285 let AllOr::Some(indices) = mask.indices() else {
286 vortex_panic!("Expected indices from mask")
287 };
288 assert_eq!(indices, [2, 4]);
289 }
290
291 #[test]
292 fn nullable_codes_and_non_null_values() {
293 let dict = DictArray::try_new(
294 PrimitiveArray::new(
295 buffer![0u32, 1, 2, 2, 1],
296 Validity::from(BitBuffer::from(vec![true, false, true, false, true])),
297 )
298 .into_array(),
299 PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
300 )
301 .unwrap();
302 let mask = dict.validity_mask();
303 let AllOr::Some(indices) = mask.indices() else {
304 vortex_panic!("Expected indices from mask")
305 };
306 assert_eq!(indices, [0, 2, 4]);
307 }
308
309 fn make_dict_primitive_chunks<T: NativePType, Code: UnsignedPType>(
310 len: usize,
311 unique_values: usize,
312 chunk_count: usize,
313 ) -> ArrayRef
314 where
315 StandardUniform: Distribution<T>,
316 {
317 let mut rng = StdRng::seed_from_u64(0);
318
319 (0..chunk_count)
320 .map(|_| {
321 let values = (0..unique_values)
322 .map(|_| rng.random::<T>())
323 .collect::<PrimitiveArray>();
324 let codes = (0..len)
325 .map(|_| {
326 Code::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
327 })
328 .collect::<PrimitiveArray>();
329
330 DictArray::try_new(codes.into_array(), values.into_array())
331 .vortex_unwrap()
332 .into_array()
333 })
334 .collect::<ChunkedArray>()
335 .into_array()
336 }
337
338 #[test]
339 fn test_dict_array_from_primitive_chunks() {
340 let len = 2;
341 let chunk_count = 2;
342 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
343
344 let mut builder = builder_with_capacity(
345 &DType::Primitive(PType::U64, NonNullable),
346 len * chunk_count,
347 );
348 array.clone().append_to_builder(builder.as_mut());
349
350 let into_prim = array.to_primitive();
351 let prim_into = builder.finish_into_canonical().into_primitive();
352
353 assert_arrays_eq!(into_prim, prim_into);
354 }
355
356 #[cfg_attr(miri, ignore)]
357 #[test]
358 fn test_dict_metadata() {
359 use super::DictMetadata;
360 use crate::ProstMetadata;
361 use crate::test_harness::check_metadata;
362
363 check_metadata(
364 "dict.metadata",
365 ProstMetadata(DictMetadata {
366 codes_ptype: PType::U64 as i32,
367 values_len: u32::MAX,
368 is_nullable_codes: None,
369 all_values_referenced: None,
370 }),
371 );
372 }
373}