1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::builders::ArrayBuilder;
5use vortex_array::compute::{cast, scalar_at, take, take_into};
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::variants::PrimitiveArrayTrait;
8use vortex_array::vtable::VTableRef;
9use vortex_array::{
10 Array, ArrayCanonicalImpl, ArrayImpl, ArrayRef, ArrayStatisticsImpl, ArrayValidityImpl,
11 Canonical, Encoding, IntoArray, ProstMetadata, ToCanonical,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17use crate::serde::DictMetadata;
18
19#[derive(Debug, Clone)]
20pub struct DictArray {
21 codes: ArrayRef,
22 values: ArrayRef,
23 stats_set: ArrayStats,
24}
25
26#[derive(Debug)]
27pub struct DictEncoding;
28impl Encoding for DictEncoding {
29 type Array = DictArray;
30 type Metadata = ProstMetadata<DictMetadata>;
31}
32
33impl DictArray {
34 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
35 if !codes.dtype().is_unsigned_int() {
36 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
37 }
38
39 let dtype = values.dtype();
40 if dtype.is_nullable() {
41 codes = cast(&codes, &codes.dtype().as_nullable())?;
43 } else {
44 if codes.dtype().is_nullable() {
46 vortex_bail!("Cannot have nullable codes for non-nullable dict array");
47 }
48 }
49 assert_eq!(
50 codes.dtype().nullability(),
51 values.dtype().nullability(),
52 "Mismatched nullability between codes and values"
53 );
54
55 Ok(Self {
56 codes,
57 values,
58 stats_set: Default::default(),
59 })
60 }
61
62 #[inline]
63 pub fn codes(&self) -> &ArrayRef {
64 &self.codes
65 }
66
67 #[inline]
68 pub fn values(&self) -> &ArrayRef {
69 &self.values
70 }
71}
72
73impl ArrayImpl for DictArray {
74 type Encoding = DictEncoding;
75
76 fn _len(&self) -> usize {
77 self.codes.len()
78 }
79
80 fn _dtype(&self) -> &DType {
81 self.values.dtype()
82 }
83
84 fn _vtable(&self) -> VTableRef {
85 VTableRef::new_ref(&DictEncoding)
86 }
87
88 fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self> {
89 let codes = children[0].clone();
90 let values = children[1].clone();
91
92 Self::try_new(codes, values)
93 }
94}
95
96impl ArrayCanonicalImpl for DictArray {
97 fn _to_canonical(&self) -> VortexResult<Canonical> {
98 match self.dtype() {
99 DType::Utf8(_) | DType::Binary(_) => {
104 let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
105 take(&canonical_values, self.codes())?.to_canonical()
106 }
107 _ => take(self.values(), self.codes())?.to_canonical(),
108 }
109 }
110
111 fn _append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
112 match self.dtype() {
113 DType::Utf8(_) | DType::Binary(_) => {
119 let canonical_values: ArrayRef = self.values().to_canonical()?.into_array();
120 take_into(&canonical_values, self.codes(), builder)
121 }
122 _ => take_into(self.values(), self.codes(), builder),
124 }
125 }
126}
127
128impl ArrayValidityImpl for DictArray {
129 fn _is_valid(&self, index: usize) -> VortexResult<bool> {
130 let scalar = scalar_at(self.codes(), index).map_err(|err| {
131 err.with_context(format!(
132 "Failed to get index {} from DictArray codes",
133 index
134 ))
135 })?;
136
137 if scalar.is_null() {
138 return Ok(false);
139 };
140 let values_index: usize = scalar
141 .as_ref()
142 .try_into()
143 .vortex_expect("Failed to convert dictionary code to usize");
144 self.values().is_valid(values_index)
145 }
146
147 fn _all_valid(&self) -> VortexResult<bool> {
148 if !self.dtype().is_nullable() {
149 return Ok(true);
150 }
151
152 Ok(self.codes().all_valid()? && self.values().all_valid()?)
153 }
154
155 fn _all_invalid(&self) -> VortexResult<bool> {
156 if !self.dtype().is_nullable() {
157 return Ok(false);
158 }
159
160 Ok(self.codes().all_invalid()? || self.values().all_invalid()?)
161 }
162
163 fn _validity_mask(&self) -> VortexResult<Mask> {
164 let codes_validity = self.codes().validity_mask()?;
165 match codes_validity.boolean_buffer() {
166 AllOr::All => {
167 let primitive_codes = self.codes().to_primitive()?;
168 let values_mask = self.values().validity_mask()?;
169 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
170 let codes_slice = primitive_codes.as_slice::<$P>();
171 BooleanBuffer::collect_bool(self.len(), |idx| {
172 values_mask.value(codes_slice[idx] as usize)
173 })
174 });
175 Ok(Mask::from_buffer(is_valid_buffer))
176 }
177 AllOr::None => Ok(Mask::AllFalse(self.len())),
178 AllOr::Some(validity_buff) => {
179 let primitive_codes = self.codes().to_primitive()?;
180 let values_mask = self.values().validity_mask()?;
181 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |$P| {
182 let codes_slice = primitive_codes.as_slice::<$P>();
183 BooleanBuffer::collect_bool(self.len(), |idx| {
184 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
185 })
186 });
187 Ok(Mask::from_buffer(is_valid_buffer))
188 }
189 }
190 }
191}
192
193impl ArrayStatisticsImpl for DictArray {
194 fn _stats_ref(&self) -> StatsSetRef<'_> {
195 self.stats_set.to_ref(self)
196 }
197}
198
199#[cfg(test)]
200mod test {
201 use arrow_buffer::BooleanBuffer;
202 use rand::distr::{Distribution, StandardUniform};
203 use rand::prelude::StdRng;
204 use rand::{Rng, SeedableRng};
205 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
206 use vortex_array::builders::builder_with_capacity;
207 use vortex_array::validity::Validity;
208 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
209 use vortex_buffer::buffer;
210 use vortex_dtype::Nullability::NonNullable;
211 use vortex_dtype::{DType, NativePType, PType};
212 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
213 use vortex_mask::AllOr;
214
215 use crate::DictArray;
216
217 #[test]
218 fn nullable_codes_validity() {
219 let dict = DictArray::try_new(
220 PrimitiveArray::new(
221 buffer![0u32, 1, 2, 2, 1],
222 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
223 )
224 .into_array(),
225 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
226 )
227 .unwrap();
228 let mask = dict.validity_mask().unwrap();
229 let AllOr::Some(indices) = mask.indices() else {
230 vortex_panic!("Expected indices from mask")
231 };
232 assert_eq!(indices, [0, 2, 4]);
233 }
234
235 #[test]
236 fn nullable_values_validity() {
237 let dict = DictArray::try_new(
238 buffer![0u32, 1, 2, 2, 1].into_array(),
239 PrimitiveArray::new(
240 buffer![3, 6, 9],
241 Validity::from(BooleanBuffer::from(vec![true, false, false])),
242 )
243 .into_array(),
244 )
245 .unwrap();
246 let mask = dict.validity_mask().unwrap();
247 let AllOr::Some(indices) = mask.indices() else {
248 vortex_panic!("Expected indices from mask")
249 };
250 assert_eq!(indices, [0]);
251 }
252
253 #[test]
254 fn nullable_codes_and_values() {
255 let dict = DictArray::try_new(
256 PrimitiveArray::new(
257 buffer![0u32, 1, 2, 2, 1],
258 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
259 )
260 .into_array(),
261 PrimitiveArray::new(
262 buffer![3, 6, 9],
263 Validity::from(BooleanBuffer::from(vec![false, true, true])),
264 )
265 .into_array(),
266 )
267 .unwrap();
268 let mask = dict.validity_mask().unwrap();
269 let AllOr::Some(indices) = mask.indices() else {
270 vortex_panic!("Expected indices from mask")
271 };
272 assert_eq!(indices, [2, 4]);
273 }
274
275 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
276 len: usize,
277 unique_values: usize,
278 chunk_count: usize,
279 ) -> ArrayRef
280 where
281 StandardUniform: Distribution<T>,
282 {
283 let mut rng = StdRng::seed_from_u64(0);
284
285 (0..chunk_count)
286 .map(|_| {
287 let values = (0..unique_values)
288 .map(|_| rng.random::<T>())
289 .collect::<PrimitiveArray>();
290 let codes = (0..len)
291 .map(|_| {
292 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
293 })
294 .collect::<PrimitiveArray>();
295
296 DictArray::try_new(codes.into_array(), values.into_array())
297 .vortex_unwrap()
298 .into_array()
299 })
300 .collect::<ChunkedArray>()
301 .into_array()
302 }
303
304 #[test]
305 fn test_dict_array_from_primitive_chunks() {
306 let len = 2;
307 let chunk_count = 2;
308 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
309
310 let mut builder = builder_with_capacity(
311 &DType::Primitive(PType::U64, NonNullable),
312 len * chunk_count,
313 );
314 array
315 .clone()
316 .append_to_builder(builder.as_mut())
317 .vortex_unwrap();
318
319 let into_prim = array.to_primitive().unwrap();
320 let prim_into = builder.finish().to_primitive().unwrap();
321
322 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
323 assert_eq!(
324 into_prim.validity_mask().unwrap().boolean_buffer(),
325 prim_into.validity_mask().unwrap().boolean_buffer()
326 )
327 }
328}