1use std::fmt::Debug;
2
3use arrow_buffer::BooleanBuffer;
4use vortex_array::compute::{cast, take};
5use vortex_array::stats::{ArrayStats, StatsSetRef};
6use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
7use vortex_array::{
8 Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
9};
10use vortex_dtype::{DType, match_each_integer_ptype};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
12use vortex_mask::{AllOr, Mask};
13
14vtable!(Dict);
15
16impl VTable for DictVTable {
17 type Array = DictArray;
18 type Encoding = DictEncoding;
19
20 type ArrayVTable = Self;
21 type CanonicalVTable = Self;
22 type OperationsVTable = Self;
23 type ValidityVTable = Self;
24 type VisitorVTable = Self;
25 type ComputeVTable = NotSupported;
26 type EncodeVTable = Self;
27 type SerdeVTable = Self;
28
29 fn id(_encoding: &Self::Encoding) -> EncodingId {
30 EncodingId::new_ref("vortex.dict")
31 }
32
33 fn encoding(_array: &Self::Array) -> EncodingRef {
34 EncodingRef::new_ref(DictEncoding.as_ref())
35 }
36}
37
38#[derive(Debug, Clone)]
39pub struct DictArray {
40 codes: ArrayRef,
41 values: ArrayRef,
42 stats_set: ArrayStats,
43}
44
45#[derive(Clone, Debug)]
46pub struct DictEncoding;
47
48impl DictArray {
49 pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
50 if !codes.dtype().is_unsigned_int() {
51 vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
52 }
53
54 let dtype = values.dtype();
55 if dtype.is_nullable() {
56 codes = cast(&codes, &codes.dtype().as_nullable())?;
58 } else {
59 if codes.dtype().is_nullable() {
61 vortex_bail!("Cannot have nullable codes for non-nullable dict array");
62 }
63 }
64 assert_eq!(
65 codes.dtype().nullability(),
66 values.dtype().nullability(),
67 "Mismatched nullability between codes and values"
68 );
69
70 Ok(Self {
71 codes,
72 values,
73 stats_set: Default::default(),
74 })
75 }
76
77 #[inline]
78 pub fn codes(&self) -> &ArrayRef {
79 &self.codes
80 }
81
82 #[inline]
83 pub fn values(&self) -> &ArrayRef {
84 &self.values
85 }
86}
87
88impl ArrayVTable<DictVTable> for DictVTable {
89 fn len(array: &DictArray) -> usize {
90 array.codes.len()
91 }
92
93 fn dtype(array: &DictArray) -> &DType {
94 array.values.dtype()
95 }
96
97 fn stats(array: &DictArray) -> StatsSetRef<'_> {
98 array.stats_set.to_ref(array.as_ref())
99 }
100}
101
102impl CanonicalVTable<DictVTable> for DictVTable {
103 fn canonicalize(array: &DictArray) -> VortexResult<Canonical> {
104 match array.dtype() {
105 DType::Utf8(_) | DType::Binary(_) => {
110 let canonical_values: ArrayRef = array.values().to_canonical()?.into_array();
111 take(&canonical_values, array.codes())?.to_canonical()
112 }
113 _ => take(array.values(), array.codes())?.to_canonical(),
114 }
115 }
116}
117
118impl ValidityVTable<DictVTable> for DictVTable {
119 fn is_valid(array: &DictArray, index: usize) -> VortexResult<bool> {
120 let scalar = array.codes().scalar_at(index).map_err(|err| {
121 err.with_context(format!("Failed to get index {index} from DictArray codes"))
122 })?;
123
124 if scalar.is_null() {
125 return Ok(false);
126 };
127 let values_index: usize = scalar
128 .as_ref()
129 .try_into()
130 .vortex_expect("Failed to convert dictionary code to usize");
131 array.values().is_valid(values_index)
132 }
133
134 fn all_valid(array: &DictArray) -> VortexResult<bool> {
135 Ok(array.codes().all_valid()? && array.values().all_valid()?)
136 }
137
138 fn all_invalid(array: &DictArray) -> VortexResult<bool> {
139 Ok(array.codes().all_invalid()? || array.values().all_invalid()?)
140 }
141
142 fn validity_mask(array: &DictArray) -> VortexResult<Mask> {
143 let codes_validity = array.codes().validity_mask()?;
144 match codes_validity.boolean_buffer() {
145 AllOr::All => {
146 let primitive_codes = array.codes().to_primitive()?;
147 let values_mask = array.values().validity_mask()?;
148 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
149 let codes_slice = primitive_codes.as_slice::<P>();
150 BooleanBuffer::collect_bool(array.len(), |idx| {
151 #[allow(clippy::cast_possible_truncation)]
152 values_mask.value(codes_slice[idx] as usize)
153 })
154 });
155 Ok(Mask::from_buffer(is_valid_buffer))
156 }
157 AllOr::None => Ok(Mask::AllFalse(array.len())),
158 AllOr::Some(validity_buff) => {
159 let primitive_codes = array.codes().to_primitive()?;
160 let values_mask = array.values().validity_mask()?;
161 let is_valid_buffer = match_each_integer_ptype!(primitive_codes.ptype(), |P| {
162 let codes_slice = primitive_codes.as_slice::<P>();
163 #[allow(clippy::cast_possible_truncation)]
164 BooleanBuffer::collect_bool(array.len(), |idx| {
165 validity_buff.value(idx) && values_mask.value(codes_slice[idx] as usize)
166 })
167 });
168 Ok(Mask::from_buffer(is_valid_buffer))
169 }
170 }
171 }
172}
173
174#[cfg(test)]
175mod test {
176 use arrow_buffer::BooleanBuffer;
177 use rand::distr::{Distribution, StandardUniform};
178 use rand::prelude::StdRng;
179 use rand::{Rng, SeedableRng};
180 use vortex_array::arrays::{ChunkedArray, PrimitiveArray};
181 use vortex_array::builders::builder_with_capacity;
182 use vortex_array::validity::Validity;
183 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
184 use vortex_buffer::buffer;
185 use vortex_dtype::Nullability::NonNullable;
186 use vortex_dtype::{DType, NativePType, PType};
187 use vortex_error::{VortexExpect, VortexUnwrap, vortex_panic};
188 use vortex_mask::AllOr;
189
190 use crate::DictArray;
191
192 #[test]
193 fn nullable_codes_validity() {
194 let dict = DictArray::try_new(
195 PrimitiveArray::new(
196 buffer![0u32, 1, 2, 2, 1],
197 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
198 )
199 .into_array(),
200 PrimitiveArray::new(buffer![3, 6, 9], Validity::AllValid).into_array(),
201 )
202 .unwrap();
203 let mask = dict.validity_mask().unwrap();
204 let AllOr::Some(indices) = mask.indices() else {
205 vortex_panic!("Expected indices from mask")
206 };
207 assert_eq!(indices, [0, 2, 4]);
208 }
209
210 #[test]
211 fn nullable_values_validity() {
212 let dict = DictArray::try_new(
213 buffer![0u32, 1, 2, 2, 1].into_array(),
214 PrimitiveArray::new(
215 buffer![3, 6, 9],
216 Validity::from(BooleanBuffer::from(vec![true, false, false])),
217 )
218 .into_array(),
219 )
220 .unwrap();
221 let mask = dict.validity_mask().unwrap();
222 let AllOr::Some(indices) = mask.indices() else {
223 vortex_panic!("Expected indices from mask")
224 };
225 assert_eq!(indices, [0]);
226 }
227
228 #[test]
229 fn nullable_codes_and_values() {
230 let dict = DictArray::try_new(
231 PrimitiveArray::new(
232 buffer![0u32, 1, 2, 2, 1],
233 Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
234 )
235 .into_array(),
236 PrimitiveArray::new(
237 buffer![3, 6, 9],
238 Validity::from(BooleanBuffer::from(vec![false, true, true])),
239 )
240 .into_array(),
241 )
242 .unwrap();
243 let mask = dict.validity_mask().unwrap();
244 let AllOr::Some(indices) = mask.indices() else {
245 vortex_panic!("Expected indices from mask")
246 };
247 assert_eq!(indices, [2, 4]);
248 }
249
250 fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
251 len: usize,
252 unique_values: usize,
253 chunk_count: usize,
254 ) -> ArrayRef
255 where
256 StandardUniform: Distribution<T>,
257 {
258 let mut rng = StdRng::seed_from_u64(0);
259
260 (0..chunk_count)
261 .map(|_| {
262 let values = (0..unique_values)
263 .map(|_| rng.random::<T>())
264 .collect::<PrimitiveArray>();
265 let codes = (0..len)
266 .map(|_| {
267 U::from(rng.random_range(0..unique_values)).vortex_expect("valid value")
268 })
269 .collect::<PrimitiveArray>();
270
271 DictArray::try_new(codes.into_array(), values.into_array())
272 .vortex_unwrap()
273 .into_array()
274 })
275 .collect::<ChunkedArray>()
276 .into_array()
277 }
278
279 #[test]
280 fn test_dict_array_from_primitive_chunks() {
281 let len = 2;
282 let chunk_count = 2;
283 let array = make_dict_primitive_chunks::<u64, u64>(len, 2, chunk_count);
284
285 let mut builder = builder_with_capacity(
286 &DType::Primitive(PType::U64, NonNullable),
287 len * chunk_count,
288 );
289 array
290 .clone()
291 .append_to_builder(builder.as_mut())
292 .vortex_unwrap();
293
294 let into_prim = array.to_primitive().unwrap();
295 let prim_into = builder.finish().to_primitive().unwrap();
296
297 assert_eq!(into_prim.as_slice::<u64>(), prim_into.as_slice::<u64>());
298 assert_eq!(
299 into_prim.validity_mask().unwrap().boolean_buffer(),
300 prim_into.validity_mask().unwrap().boolean_buffer()
301 )
302 }
303}