vortex_array/arrays/dict/compute/
mod.rs1mod cast;
5mod compare;
6mod fill_null;
7mod is_constant;
8mod is_sorted;
9mod like;
10mod min_max;
11
12use vortex_error::VortexResult;
13use vortex_mask::Mask;
14
15use super::{DictArray, DictVTable};
16use crate::compute::{
17 FilterKernel, FilterKernelAdapter, TakeKernel, TakeKernelAdapter, filter, take,
18};
19use crate::{Array, ArrayRef, IntoArray, register_kernel};
20
21impl TakeKernel for DictVTable {
22 fn take(&self, array: &DictArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
23 let codes = take(array.codes(), indices)?;
24 Ok(unsafe { DictArray::new_unchecked(codes, array.values().clone()) }.into_array())
26 }
27}
28
29register_kernel!(TakeKernelAdapter(DictVTable).lift());
30
31impl FilterKernel for DictVTable {
32 fn filter(&self, array: &DictArray, mask: &Mask) -> VortexResult<ArrayRef> {
33 let codes = filter(array.codes(), mask)?;
34
35 unsafe { Ok(DictArray::new_unchecked(codes, array.values().clone()).into_array()) }
37 }
38}
39
40register_kernel!(FilterKernelAdapter(DictVTable).lift());
41
42#[cfg(test)]
43mod test {
44 #[allow(unused_imports)]
45 use itertools::Itertools;
46 use vortex_buffer::buffer;
47 use vortex_dtype::PType::I32;
48 use vortex_dtype::{DType, Nullability};
49
50 use crate::accessor::ArrayAccessor;
51 use crate::arrays::{ConstantArray, PrimitiveArray, VarBinArray, VarBinViewArray};
52 use crate::builders::dict::dict_encode;
53 use crate::compute::conformance::filter::test_filter_conformance;
54 use crate::compute::conformance::mask::test_mask_conformance;
55 use crate::compute::conformance::take::test_take_conformance;
56 use crate::compute::{Operator, compare, take};
57 use crate::{Array, ArrayRef, IntoArray, ToCanonical, assert_arrays_eq};
58
59 #[test]
60 fn canonicalise_nullable_primitive() {
61 let values: Vec<Option<i32>> = (0..65)
62 .map(|i| match i % 3 {
63 0 => Some(42),
64 1 => Some(-9),
65 2 => None,
66 _ => unreachable!(),
67 })
68 .collect();
69
70 let dict = dict_encode(PrimitiveArray::from_option_iter(values.clone()).as_ref()).unwrap();
71 let actual = dict.to_primitive();
72
73 let expected = PrimitiveArray::from_option_iter(values);
74
75 assert_arrays_eq!(actual, expected);
76 }
77
78 #[test]
79 fn canonicalise_non_nullable_primitive_32_unique_values() {
80 let unique_values: Vec<i32> = (0..32).collect();
81 let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 32]));
82
83 let dict = dict_encode(expected.as_ref()).unwrap();
84 let actual = dict.to_primitive();
85
86 assert_arrays_eq!(actual, expected);
87 }
88
89 #[test]
90 fn canonicalise_non_nullable_primitive_100_unique_values() {
91 let unique_values: Vec<i32> = (0..100).collect();
92 let expected = PrimitiveArray::from_iter((0..1000).map(|i| unique_values[i % 100]));
93
94 let dict = dict_encode(expected.as_ref()).unwrap();
95 let actual = dict.to_primitive();
96
97 assert_arrays_eq!(actual, expected);
98 }
99
100 #[test]
101 fn canonicalise_nullable_varbin() {
102 let reference = VarBinViewArray::from_iter(
103 vec![Some("a"), Some("b"), None, Some("a"), None, Some("b")],
104 DType::Utf8(Nullability::Nullable),
105 );
106 assert_eq!(reference.len(), 6);
107 let dict = dict_encode(reference.as_ref()).unwrap();
108 let flattened_dict = dict.to_varbinview();
109 assert_eq!(
110 flattened_dict.with_iterator(|iter| iter
111 .map(|slice| slice.map(|s| s.to_vec()))
112 .collect::<Vec<_>>()),
113 reference.with_iterator(|iter| iter
114 .map(|slice| slice.map(|s| s.to_vec()))
115 .collect::<Vec<_>>()),
116 );
117 }
118
119 fn sliced_dict_array() -> ArrayRef {
120 let reference = PrimitiveArray::from_option_iter([
121 Some(42),
122 Some(-9),
123 None,
124 Some(42),
125 Some(1),
126 Some(5),
127 ]);
128 let dict = dict_encode(reference.as_ref()).unwrap();
129 dict.slice(1..4)
130 }
131
132 #[test]
133 fn compare_sliced_dict() {
134 use crate::arrays::BoolArray;
135 let sliced = sliced_dict_array();
136 let compared = compare(&sliced, ConstantArray::new(42, 3).as_ref(), Operator::Eq).unwrap();
137
138 let expected = BoolArray::from_iter([Some(false), None, Some(true)]);
139 assert_arrays_eq!(compared, expected.to_array());
140 }
141
142 #[test]
143 fn test_mask_dict_array() {
144 let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
145 test_mask_conformance(array.as_ref());
146
147 let array = dict_encode(
148 PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
149 )
150 .unwrap();
151 test_mask_conformance(array.as_ref());
152
153 let array = dict_encode(
154 &VarBinArray::from_iter(
155 [
156 Some("hello"),
157 None,
158 Some("hello"),
159 Some("good"),
160 Some("good"),
161 ],
162 DType::Utf8(Nullability::Nullable),
163 )
164 .into_array(),
165 )
166 .unwrap();
167 test_mask_conformance(array.as_ref());
168 }
169
170 #[test]
171 fn test_filter_dict_array() {
172 let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
173 test_filter_conformance(array.as_ref());
174
175 let array = dict_encode(
176 PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
177 )
178 .unwrap();
179 test_filter_conformance(array.as_ref());
180
181 let array = dict_encode(
182 &VarBinArray::from_iter(
183 [
184 Some("hello"),
185 None,
186 Some("hello"),
187 Some("good"),
188 Some("good"),
189 ],
190 DType::Utf8(Nullability::Nullable),
191 )
192 .into_array(),
193 )
194 .unwrap();
195 test_filter_conformance(array.as_ref());
196 }
197
198 #[test]
199 fn test_take_dict() {
200 let array = dict_encode(buffer![1, 2].into_array().as_ref()).unwrap();
201
202 assert_eq!(
203 take(
204 array.as_ref(),
205 PrimitiveArray::from_option_iter([Option::<i32>::None]).as_ref()
206 )
207 .unwrap()
208 .dtype(),
209 &DType::Primitive(I32, Nullability::Nullable)
210 );
211 }
212
213 #[test]
214 fn test_take_dict_conformance() {
215 let array = dict_encode(&buffer![2, 0, 2, 0, 10].into_array()).unwrap();
216 test_take_conformance(array.as_ref());
217
218 let array = dict_encode(
219 PrimitiveArray::from_option_iter([Some(2), None, Some(2), Some(0), Some(10)]).as_ref(),
220 )
221 .unwrap();
222 test_take_conformance(array.as_ref());
223
224 let array = dict_encode(
225 &VarBinArray::from_iter(
226 [
227 Some("hello"),
228 None,
229 Some("hello"),
230 Some("good"),
231 Some("good"),
232 ],
233 DType::Utf8(Nullability::Nullable),
234 )
235 .into_array(),
236 )
237 .unwrap();
238 test_take_conformance(array.as_ref());
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use rstest::rstest;
245 use vortex_buffer::buffer;
246 use vortex_dtype::{DType, Nullability};
247
248 use crate::IntoArray;
249 use crate::arrays::dict::DictArray;
250 use crate::arrays::{PrimitiveArray, VarBinArray};
251 use crate::builders::dict::dict_encode;
252 use crate::compute::conformance::consistency::test_array_consistency;
253
254 #[rstest]
255 #[case::dict_i32(dict_encode(&buffer![1i32, 2, 3, 2, 1].into_array()).unwrap())]
257 #[case::dict_nullable_codes(DictArray::try_new(
258 buffer![0u32, 1, 2, 2, 0].into_array(),
259 PrimitiveArray::from_option_iter([Some(10), Some(20), None]).into_array(),
260 ).unwrap())]
261 #[case::dict_nullable_values(dict_encode(
262 PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).as_ref()
263 ).unwrap())]
264 #[case::dict_u64(dict_encode(&buffer![100u64, 200, 100, 300, 200].into_array()).unwrap())]
265 #[case::dict_str(dict_encode(
267 &VarBinArray::from_iter(
268 ["hello", "world", "hello", "test", "world"].map(Some),
269 DType::Utf8(Nullability::NonNullable),
270 ).into_array()
271 ).unwrap())]
272 #[case::dict_nullable_str(dict_encode(
273 &VarBinArray::from_iter(
274 [Some("hello"), None, Some("world"), Some("hello"), None],
275 DType::Utf8(Nullability::Nullable),
276 ).into_array()
277 ).unwrap())]
278 #[case::dict_single(dict_encode(&buffer![42i32].into_array()).unwrap())]
280 #[case::dict_all_same(dict_encode(&buffer![5i32, 5, 5, 5, 5].into_array()).unwrap())]
281 #[case::dict_large(dict_encode(&PrimitiveArray::from_iter((0..1000).map(|i| i % 10)).into_array()).unwrap())]
282 fn test_dict_consistency(#[case] array: DictArray) {
283 test_array_consistency(array.as_ref());
284 }
285}