1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::builders::ArrayBuilder;
5use vortex_array::compute::{scalar_at, take, take_into, try_cast};
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10 Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11 Canonical, Encoding, IntoArray, RkyvMetadata, ToCanonical,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17use crate::serde::DictMetadata;
18
19#[derive(Debug, Clone)]
20pub struct DictArray {
21 codes: ArrayRef,
22 values: ArrayRef,
23 stats_set: ArrayStats,
24}
25
26pub struct DictEncoding;
27impl Encoding for DictEncoding {
28 type Array = DictArray;
29 type Metadata = RkyvMetadata<DictMetadata>;
30}
31
32impl DictArray {
33 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
34 if !codes.dtype().is_unsigned_int() {
35 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
36 }
37
38 let dtype = values.dtype();
39 if dtype.is_nullable() {
40 codes = try_cast(&codes, &codes.dtype().as_nullable())?;
42 } else {
43 if codes.dtype().is_nullable() {
45 vortex_bail!("Cannot have nullable codes for non-nullable dict array");
46 }
47 }
48 assert_eq!(
49 codes.dtype().nullability(),
50 values.dtype().nullability(),
51 "Mismatched nullability between codes and values"
52 );
53
54 Ok(Self {
55 codes,
56 values,
57 stats_set: Default::default(),
58 })
59 }
60
61 #[inline]
62 pub fn codes(&self) -> &ArrayRef {
63 &self.codes
64 }
65
66 #[inline]
67 pub fn values(&self) -> &ArrayRef {
68 &self.values
69 }
70}
71
72impl ArrayImpl for DictArray {
73 type Encoding = DictEncoding;
74
75 fn _len(&self) -> usize {
76 self.codes.len()
77 }
78
79 fn _dtype(&self) -> &DType {
80 self.values.dtype()
81 }
82
83 fn _vtable(&self) -> VTableRef {
84 VTableRef::new_ref(&DictEncoding)
85 }
86
87 fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
88 let codes = children[0].clone();
89 let values = children[1].clone();
90
91 Self::try_new(codes, values)
92 }
93}
94
95impl ArrayCanonicalImpl for DictArray {
96 fn _to_canonical(&self) -> VortexResult<Canonical> {
97 match self.dtype() {
98 DType::Utf8(_) | DType::Binary(_) => {
103 let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
104 take(&canonical_values, self.codes())?.to_canonical()
105 }
106 _ => take(self.values(), self.codes())?.to_canonical(),
107 }
108 }
109
110 fn _append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
111 match self.dtype() {
112 DType::Utf8(_) | DType::Binary(_) => {
118 let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
119 take_into(&canonical_values, self.codes(), builder)
120 }
121 _ => take_into(self.values(), self.codes(), builder),
123 }
124 }
125}
126
127impl ArrayValidityImpl for DictArray {
128 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
129 let scalar = scalar_at(self.codes(), index).map_err(|err| {
130 err.with_context(format!(
131 "Failed to get index {} from DictArray codes",
132 index
133 ))
134 })?;
135
136 if scalar.is_null() {
137 return Ok(false);
138 };
139 let values_index: usize = scalar
140 .as_ref()
141 .try_into()
142 .vortex_expect("Failed to convert dictionary code to usize");
143 self.values().is_valid(values_index)
144 }
145
146 fn _all_valid(&self) -> VortexResult<bool> {
147 if !self.dtype().is_nullable() {
148 return Ok(true);
149 }
150
151 Ok(self.codes().all_valid()? && self.values().all_valid()?)
152 }
153
154 fn _all_invalid(&self) -> VortexResult<bool> {
155 if !self.dtype().is_nullable() {
156 return Ok(false);
157 }
158
159 Ok(self.codes().all_invalid()? || self.values().all_invalid()?)
160 }
161
162 fn _validity_mask(&self) -> VortexResult<Mask> {
163 let codes_validity = self.codes().validity_mask()?;
164 match codes_validity.boolean_buffer() {
165 AllOr::All => {
166 let primitive_codes = self.codes().to_primitive()?;
167 let values_mask = self.values().validity_mask()?;
168 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
169 let codes_slice = primitive_codes.as_slice::<$P>();
170 BooleanBuffer::collect_bool(self.len(), |idx| {
171 values_mask.value(codes_slice[idx] as usize)
172 })
173 });
174 Ok(Mask::from_buffer(is_valid_buffer))
175 }
176 AllOr::None => Ok(Mask::AllFalse(self.len())),
177 AllOr::Some(validity_buff) => {
178 let primitive_codes = self.codes().to_primitive()?;
179 let values_mask = self.values().validity_mask()?;
180 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
181 let codes_slice = primitive_codes.as_slice::<$P>();
182 BooleanBuffer::collect_bool(self.len(), |idx| {
183 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
184 })
185 });
186 Ok(Mask::from_buffer(is_valid_buffer))
187 }
188 }
189 }
190}
191
192impl ArrayStatisticsImpl for DictArray {
193 fn _stats_ref(&self) -> StatsSetRef<'_> {
194 self.stats_set.to_ref(self)
195 }
196}
197
198#[cfg(test)]
199mod test {
200 use arrow_buffer::BooleanBuffer;
201 use rand::distr::{Distribution, StandardUniform};
202 use rand::prelude::StdRng;
203 use rand::{Rng, SeedableRng};
204 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
205 use vortex_array::builders::builder_with_capacity;
206 use vortex_array::validity::Validity;
207 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
208 use vortex_buffer::buffer;
209 use vortex_dtype::Nullability::NonNullable;
210 use vortex_dtype::{DType, NativePType, PType};
211 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
212 use vortex_mask::AllOr;
213
214 use crate::DictArray;
215
216 #[test]
217 fn nullable_codes_validity() {
218 let dict = DictArray::try_new(
219 PrimitiveArray::new(
220 buffer![0u32, 1, 2, 2, 1],
221 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
222 )
223 .into_array(),
224 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
225 )
226 .unwrap();
227 let mask = dict.validity_mask().unwrap();
228 let AllOr::Some(indices) = mask.indices() else {
229 vortex_panic!("Expected indices from mask")
230 };
231 assert_eq!(indices, [0, 2, 4]);
232 }
233
234 #[test]
235 fn nullable_values_validity() {
236 let dict = DictArray::try_new(
237 buffer![0u32, 1, 2, 2, 1].into_array(),
238 PrimitiveArray::new(
239 buffer![3, 6, 9],
240 Validity::from(BooleanBuffer::from(vec![true, false, false])),
241 )
242 .into_array(),
243 )
244 .unwrap();
245 let mask = dict.validity_mask().unwrap();
246 let AllOr::Some(indices) = mask.indices() else {
247 vortex_panic!("Expected indices from mask")
248 };
249 assert_eq!(indices, [0]);
250 }
251
252 #[test]
253 fn nullable_codes_and_values() {
254 let dict = DictArray::try_new(
255 PrimitiveArray::new(
256 buffer![0u32, 1, 2, 2, 1],
257 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
258 )
259 .into_array(),
260 PrimitiveArray::new(
261 buffer![3, 6, 9],
262 Validity::from(BooleanBuffer::from(vec![false, true, true])),
263 )
264 .into_array(),
265 )
266 .unwrap();
267 let mask = dict.validity_mask().unwrap();
268 let AllOr::Some(indices) = mask.indices() else {
269 vortex_panic!("Expected indices from mask")
270 };
271 assert_eq!(indices, [2, 4]);
272 }
273
274 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
275 len: usize,
276 unique_values: usize,
277 chunk_count: usize,
278 ) -> ArrayRef
279 where
280 StandardUniform: Distribution<T>,
281 {
282 let mut rng = StdRng::seed_from_u64(0);
283
284 (0..chunk_count)
285 .map(|_| {
286 let values = (0..unique_values)
287 .map(|_| rng.random::<T>())
288 .collect::<PrimitiveArray>();
289 let codes = (0..len)
290 .map(|_| {
291 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
292 })
293 .collect::<PrimitiveArray>();
294
295 DictArray::try_new(codes.into_array(), values.into_array())
296 .vortex_unwrap()
297 .into_array()
298 })
299 .collect::<ChunkedArray>()
300 .into_array()
301 }
302
303 #[test]
304 fn test_dict_array_from_primitive_chunks() {
305 let len = 2;
306 let chunk_count = 2;
307 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
308
309 let mut builder = builder_with_capacity(
310 &DType::Primitive(PType::U64, NonNullable),
311 len * chunk_count,
312 );
313 array
314 .clone()
315 .append_to_builder(builder.as_mut())
316 .vortex_unwrap();
317
318 let into_prim = array.to_primitive().unwrap();
319 let prim_into = builder.finish().to_primitive().unwrap();
320
321 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
322 assert_eq!(
323 into_prim.validity_mask().unwrap().boolean_buffer(),
324 prim_into.validity_mask().unwrap().boolean_buffer()
325 )
326 }
327}