1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::builders::ArrayBuilder;
5use vortex_array::compute::{scalar_at, take, take_into, try_cast};
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::{EncodingVTable, VTableRef};
9use vortex_array::{
10 Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11 Canonical, Encoding, EncodingId, IntoArray, RkyvMetadata, ToCanonical,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17use crate::serde::DictMetadata;
18
19#[derive(Debug, Clone)]
20pub struct DictArray {
21 codes: ArrayRef,
22 values: ArrayRef,
23 stats_set: ArrayStats,
24}
25
26pub struct DictEncoding;
27impl Encoding for DictEncoding {
28 type Array = DictArray;
29 type Metadata = RkyvMetadata<DictMetadata>;
30}
31
32impl EncodingVTable for DictEncoding {
33 fn id(&self) -> EncodingId {
34 EncodingId::new_ref("vortex.dict")
35 }
36}
37
38impl DictArray {
39 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
40 if !codes.dtype().is_unsigned_int() {
41 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
42 }
43
44 let dtype = values.dtype();
45 if dtype.is_nullable() {
46 codes = try_cast(&codes, &codes.dtype().as_nullable())?;
48 } else {
49 if codes.dtype().is_nullable() {
51 vortex_bail!("Cannot have nullable codes for non-nullable dict array");
52 }
53 }
54 assert_eq!(
55 codes.dtype().nullability(),
56 values.dtype().nullability(),
57 "Mismatched nullability between codes and values"
58 );
59
60 Ok(Self {
61 codes,
62 values,
63 stats_set: Default::default(),
64 })
65 }
66
67 #[inline]
68 pub fn codes(&self) -> &ArrayRef {
69 &self.codes
70 }
71
72 #[inline]
73 pub fn values(&self) -> &ArrayRef {
74 &self.values
75 }
76}
77
78impl ArrayImpl for DictArray {
79 type Encoding = DictEncoding;
80
81 fn _len(&self) -> usize {
82 self.codes.len()
83 }
84
85 fn _dtype(&self) -> &DType {
86 self.values.dtype()
87 }
88
89 fn _vtable(&self) -> VTableRef {
90 VTableRef::new_ref(&DictEncoding)
91 }
92}
93
94impl ArrayCanonicalImpl for DictArray {
95 fn _to_canonical(&self) -> VortexResult<Canonical> {
96 match self.dtype() {
97 DType::Utf8(_) | DType::Binary(_) => {
102 let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
103 take(&canonical_values, self.codes())?.to_canonical()
104 }
105 _ => take(self.values(), self.codes())?.to_canonical(),
106 }
107 }
108
109 fn _append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
110 match self.dtype() {
111 DType::Utf8(_) | DType::Binary(_) => {
117 let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
118 take_into(&canonical_values, self.codes(), builder)
119 }
120 _ => take_into(self.values(), self.codes(), builder),
122 }
123 }
124}
125
126impl ArrayValidityImpl for DictArray {
127 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
128 let scalar = scalar_at(self.codes(), index).map_err(|err| {
129 err.with_context(format!(
130 "Failed to get index {} from DictArray codes",
131 index
132 ))
133 })?;
134
135 if scalar.is_null() {
136 return Ok(false);
137 };
138 let values_index: usize = scalar
139 .as_ref()
140 .try_into()
141 .vortex_expect("Failed to convert dictionary code to usize");
142 self.values().is_valid(values_index)
143 }
144
145 fn _all_valid(&self) -> VortexResult<bool> {
146 if !self.dtype().is_nullable() {
147 return Ok(true);
148 }
149
150 Ok(self.codes().all_valid()? && self.values().all_valid()?)
151 }
152
153 fn _all_invalid(&self) -> VortexResult<bool> {
154 if !self.dtype().is_nullable() {
155 return Ok(false);
156 }
157
158 Ok(self.codes().all_invalid()? || self.values().all_invalid()?)
159 }
160
161 fn _validity_mask(&self) -> VortexResult<Mask> {
162 let codes_validity = self.codes().validity_mask()?;
163 match codes_validity.boolean_buffer() {
164 AllOr::All => {
165 let primitive_codes = self.codes().to_primitive()?;
166 let values_mask = self.values().validity_mask()?;
167 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
168 let codes_slice = primitive_codes.as_slice::<$P>();
169 BooleanBuffer::collect_bool(self.len(), |idx| {
170 values_mask.value(codes_slice[idx] as usize)
171 })
172 });
173 Ok(Mask::from_buffer(is_valid_buffer))
174 }
175 AllOr::None => Ok(Mask::AllFalse(self.len())),
176 AllOr::Some(validity_buff) => {
177 let primitive_codes = self.codes().to_primitive()?;
178 let values_mask = self.values().validity_mask()?;
179 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
180 let codes_slice = primitive_codes.as_slice::<$P>();
181 BooleanBuffer::collect_bool(self.len(), |idx| {
182 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
183 })
184 });
185 Ok(Mask::from_buffer(is_valid_buffer))
186 }
187 }
188 }
189}
190
191impl ArrayStatisticsImpl for DictArray {
192 fn _stats_ref(&self) -> StatsSetRef<'_> {
193 self.stats_set.to_ref(self)
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use arrow_buffer::BooleanBuffer;
200 use rand::distr::{Distribution, StandardUniform};
201 use rand::prelude::StdRng;
202 use rand::{Rng, SeedableRng};
203 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
204 use vortex_array::builders::builder_with_capacity;
205 use vortex_array::validity::Validity;
206 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
207 use vortex_buffer::buffer;
208 use vortex_dtype::Nullability::NonNullable;
209 use vortex_dtype::{DType, NativePType, PType};
210 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
211 use vortex_mask::AllOr;
212
213 use crate::DictArray;
214
215 #[test]
216 fn nullable_codes_validity() {
217 let dict = DictArray::try_new(
218 PrimitiveArray::new(
219 buffer![0u32, 1, 2, 2, 1],
220 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
221 )
222 .into_array(),
223 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
224 )
225 .unwrap();
226 let mask = dict.validity_mask().unwrap();
227 let AllOr::Some(indices) = mask.indices() else {
228 vortex_panic!("Expected indices from mask")
229 };
230 assert_eq!(indices, [0, 2, 4]);
231 }
232
233 #[test]
234 fn nullable_values_validity() {
235 let dict = DictArray::try_new(
236 buffer![0u32, 1, 2, 2, 1].into_array(),
237 PrimitiveArray::new(
238 buffer![3, 6, 9],
239 Validity::from(BooleanBuffer::from(vec![true, false, false])),
240 )
241 .into_array(),
242 )
243 .unwrap();
244 let mask = dict.validity_mask().unwrap();
245 let AllOr::Some(indices) = mask.indices() else {
246 vortex_panic!("Expected indices from mask")
247 };
248 assert_eq!(indices, [0]);
249 }
250
251 #[test]
252 fn nullable_codes_and_values() {
253 let dict = DictArray::try_new(
254 PrimitiveArray::new(
255 buffer![0u32, 1, 2, 2, 1],
256 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
257 )
258 .into_array(),
259 PrimitiveArray::new(
260 buffer![3, 6, 9],
261 Validity::from(BooleanBuffer::from(vec![false, true, true])),
262 )
263 .into_array(),
264 )
265 .unwrap();
266 let mask = dict.validity_mask().unwrap();
267 let AllOr::Some(indices) = mask.indices() else {
268 vortex_panic!("Expected indices from mask")
269 };
270 assert_eq!(indices, [2, 4]);
271 }
272
273 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
274 len: usize,
275 unique_values: usize,
276 chunk_count: usize,
277 ) -> ArrayRef
278 where
279 StandardUniform: Distribution<T>,
280 {
281 let mut rng = StdRng::seed_from_u64(0);
282
283 (0..chunk_count)
284 .map(|_| {
285 let values = (0..unique_values)
286 .map(|_| rng.random::<T>())
287 .collect::<PrimitiveArray>();
288 let codes = (0..len)
289 .map(|_| {
290 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
291 })
292 .collect::<PrimitiveArray>();
293
294 DictArray::try_new(codes.into_array(), values.into_array())
295 .vortex_unwrap()
296 .into_array()
297 })
298 .collect::<ChunkedArray>()
299 .into_array()
300 }
301
302 #[test]
303 fn test_dict_array_from_primitive_chunks() {
304 let len = 2;
305 let chunk_count = 2;
306 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
307
308 let mut builder = builder_with_capacity(
309 &DType::Primitive(PType::U64, NonNullable),
310 len * chunk_count,
311 );
312 array
313 .clone()
314 .append_to_builder(builder.as_mut())
315 .vortex_unwrap();
316
317 let into_prim = array.to_primitive().unwrap();
318 let prim_into = builder.finish().to_primitive().unwrap();
319
320 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
321 assert_eq!(
322 into_prim.validity_mask().unwrap().boolean_buffer(),
323 prim_into.validity_mask().unwrap().boolean_buffer()
324 )
325 }
326}