1use std::fmt::Debug;
5
6use arrow_buffer::BooleanBuffer;
7use vortex_array::compute::{cast, take};
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
10use vortex_array::{
11 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
12};
13use vortex_dtype::{DType, match_each_integer_ptype};
14use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
15use vortex_mask::{AllOr, Mask};
16
17vtable!(Dict);
18
19impl VTable for DictVTable {
20 type Array = DictArray;
21 type Encoding = DictEncoding;
22
23 type ArrayVTable = Self;
24 type CanonicalVTable = Self;
25 type OperationsVTable = Self;
26 type ValidityVTable = Self;
27 type VisitorVTable = Self;
28 type ComputeVTable = NotSupported;
29 type EncodeVTable = Self;
30 type SerdeVTable = Self;
31
32 fn id(_encoding: &Self::Encoding) -> EncodingId {
33 EncodingId::new_ref("vortex.dict")
34 }
35
36 fn encoding(_array: &Self::Array) -> EncodingRef {
37 EncodingRef::new_ref(DictEncoding.as_ref())
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct DictArray {
43 codes: ArrayRef,
44 values: ArrayRef,
45 stats_set: ArrayStats,
46}
47
48#[derive(Clone, Debug)]
49pub struct DictEncoding;
50
51impl DictArray {
52 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
53 if !codes.dtype().is_unsigned_int() {
54 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
55 }
56
57 let dtype = values.dtype();
58 if dtype.is_nullable() {
59 codes = cast(&codes, &codes.dtype().as_nullable())?;
61 } else {
62 if codes.dtype().is_nullable() {
64 vortex_bail!("Cannot have nullable codes for non-nullable dict array");
65 }
66 }
67 assert_eq!(
68 codes.dtype().nullability(),
69 values.dtype().nullability(),
70 "Mismatched nullability between codes and values"
71 );
72
73 Ok(Self {
74 codes,
75 values,
76 stats_set: Default::default(),
77 })
78 }
79
80 #[inline]
81 pub fn codes(&self) -> &ArrayRef {
82 &self.codes
83 }
84
85 #[inline]
86 pub fn values(&self) -> &ArrayRef {
87 &self.values
88 }
89}
90
91impl ArrayVTable<DictVTable> for DictVTable {
92 fn len(array: &DictArray) -> usize {
93 array.codes.len()
94 }
95
96 fn dtype(array: &DictArray) -> &DType {
97 array.values.dtype()
98 }
99
100 fn stats(array: &DictArray) -> StatsSetRef<'_> {
101 array.stats_set.to_ref(array.as_ref())
102 }
103}
104
105impl CanonicalVTable<DictVTable> for DictVTable {
106 fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
107 match array.dtype() {
108 DType::Utf8(_) | DType::Binary(_) => {
113 let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
114 take(&canonical_values, array.codes())?.to_canonical()
115 }
116 _ => take(array.values(), array.codes())?.to_canonical(),
117 }
118 }
119}
120
121impl ValidityVTable<DictVTable> for DictVTable {
122 fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
123 let scalar = array.codes().scalar_at(index).map_err(|err| {
124 err.with_context(format!("Failed to get index {index} from DictArray codes"))
125 })?;
126
127 if scalar.is_null() {
128 return Ok(false);
129 };
130 let values_index: usize = scalar
131 .as_ref()
132 .try_into()
133 .vortex_expect("Failed to convert dictionary code to usize");
134 array.values().is_valid(values_index)
135 }
136
137 fn all_valid(array: &DictArray) -> VortexResult<bool> {
138 Ok(array.codes().all_valid()? && array.values().all_valid()?)
139 }
140
141 fn all_invalid(array: &DictArray) -> VortexResult<bool> {
142 Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
143 }
144
145 fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
146 let codes_validity = array.codes().validity_mask()?;
147 match codes_validity.boolean_buffer() {
148 AllOr::All => {
149 let primitive_codes = array.codes().to_primitive()?;
150 let values_mask = array.values().validity_mask()?;
151 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
152 let codes_slice = primitive_codes.as_slice::<P>();
153 BooleanBuffer::collect_bool(array.len(), |idx| {
154 #[allow(clippy::cast_possible_truncation)]
155 values_mask.value(codes_slice[idx] as usize)
156 })
157 });
158 Ok(Mask::from_buffer(is_valid_buffer))
159 }
160 AllOr::None => Ok(Mask::AllFalse(array.len())),
161 AllOr::Some(validity_buff) => {
162 let primitive_codes = array.codes().to_primitive()?;
163 let values_mask = array.values().validity_mask()?;
164 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
165 let codes_slice = primitive_codes.as_slice::<P>();
166 #[allow(clippy::cast_possible_truncation)]
167 BooleanBuffer::collect_bool(array.len(), |idx| {
168 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
169 })
170 });
171 Ok(Mask::from_buffer(is_valid_buffer))
172 }
173 }
174 }
175}
176
177#[cfg(test)]
178mod test {
179 use arrow_buffer::BooleanBuffer;
180 use rand::distr::{Distribution, StandardUniform};
181 use rand::prelude::StdRng;
182 use rand::{Rng, SeedableRng};
183 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
184 use vortex_array::builders::builder_with_capacity;
185 use vortex_array::validity::Validity;
186 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
187 use vortex_buffer::buffer;
188 use vortex_dtype::Nullability::NonNullable;
189 use vortex_dtype::{DType, NativePType, PType};
190 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
191 use vortex_mask::AllOr;
192
193 use crate::DictArray;
194
195 #[test]
196 fn nullable_codes_validity() {
197 let dict = DictArray::try_new(
198 PrimitiveArray::new(
199 buffer![0u32, 1, 2, 2, 1],
200 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
201 )
202 .into_array(),
203 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
204 )
205 .unwrap();
206 let mask = dict.validity_mask().unwrap();
207 let AllOr::Some(indices) = mask.indices() else {
208 vortex_panic!("Expected indices from mask")
209 };
210 assert_eq!(indices, [0, 2, 4]);
211 }
212
213 #[test]
214 fn nullable_values_validity() {
215 let dict = DictArray::try_new(
216 buffer![0u32, 1, 2, 2, 1].into_array(),
217 PrimitiveArray::new(
218 buffer![3, 6, 9],
219 Validity::from(BooleanBuffer::from(vec![true, false, false])),
220 )
221 .into_array(),
222 )
223 .unwrap();
224 let mask = dict.validity_mask().unwrap();
225 let AllOr::Some(indices) = mask.indices() else {
226 vortex_panic!("Expected indices from mask")
227 };
228 assert_eq!(indices, [0]);
229 }
230
231 #[test]
232 fn nullable_codes_and_values() {
233 let dict = DictArray::try_new(
234 PrimitiveArray::new(
235 buffer![0u32, 1, 2, 2, 1],
236 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
237 )
238 .into_array(),
239 PrimitiveArray::new(
240 buffer![3, 6, 9],
241 Validity::from(BooleanBuffer::from(vec![false, true, true])),
242 )
243 .into_array(),
244 )
245 .unwrap();
246 let mask = dict.validity_mask().unwrap();
247 let AllOr::Some(indices) = mask.indices() else {
248 vortex_panic!("Expected indices from mask")
249 };
250 assert_eq!(indices, [2, 4]);
251 }
252
253 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
254 len: usize,
255 unique_values: usize,
256 chunk_count: usize,
257 ) -> ArrayRef
258 where
259 StandardUniform: Distribution<T>,
260 {
261 let mut rng = StdRng::seed_from_u64(0);
262
263 (0..chunk_count)
264 .map(|_| {
265 let values = (0..unique_values)
266 .map(|_| rng.random::<T>())
267 .collect::<PrimitiveArray>();
268 let codes = (0..len)
269 .map(|_| {
270 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
271 })
272 .collect::<PrimitiveArray>();
273
274 DictArray::try_new(codes.into_array(), values.into_array())
275 .vortex_unwrap()
276 .into_array()
277 })
278 .collect::<ChunkedArray>()
279 .into_array()
280 }
281
282 #[test]
283 fn test_dict_array_from_primitive_chunks() {
284 let len = 2;
285 let chunk_count = 2;
286 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
287
288 let mut builder = builder_with_capacity(
289 &DType::Primitive(PType::U64, NonNullable),
290 len * chunk_count,
291 );
292 array
293 .clone()
294 .append_to_builder(builder.as_mut())
295 .vortex_unwrap();
296
297 let into_prim = array.to_primitive().unwrap();
298 let prim_into = builder.finish().to_primitive().unwrap();
299
300 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
301 assert_eq!(
302 into_prim.validity_mask().unwrap().boolean_buffer(),
303 prim_into.validity_mask().unwrap().boolean_buffer()
304 )
305 }
306}