vortex_array/compute/conformance/
take.rs1use vortex_buffer::buffer;
5use vortex_error::VortexExpect;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12use crate::dtype::Nullability;
13
14pub fn test_take_conformance(array: &ArrayRef) {
24 let len = array.len();
25
26 if len > 0 {
27 test_take_all(array);
28 test_take_none(array);
29 test_take_selective(array);
30 test_take_first_and_last(array);
31 test_take_with_nullable_indices(array);
32 test_take_repeated_indices(array);
33 }
34
35 test_empty_indices(array);
36
37 if len > 0 {
39 test_take_reverse(array);
40 test_take_single_middle(array);
41 }
42
43 if len > 3 {
44 test_take_random_unsorted(array);
45 test_take_contiguous_range(array);
46 test_take_mixed_repeated(array);
47 }
48
49 if len >= 1024 {
51 test_take_large_indices(array);
52 }
53}
54
55fn test_take_all(array: &ArrayRef) {
56 let len = array.len();
57 let indices = PrimitiveArray::from_iter(0..len as u64);
58 let result = array
59 .take(indices.to_array())
60 .vortex_expect("take should succeed in conformance test");
61
62 assert_eq!(result.len(), len);
63 assert_eq!(result.dtype(), array.dtype());
64
65 match (
67 array
68 .to_canonical()
69 .vortex_expect("to_canonical failed on array"),
70 result
71 .to_canonical()
72 .vortex_expect("to_canonical failed on result"),
73 ) {
74 (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
75 assert_eq!(
76 orig_prim.buffer_handle().to_host_sync(),
77 result_prim.buffer_handle().to_host_sync()
78 );
79 }
80 _ => {
81 for i in 0..len {
83 assert_eq!(
84 array
85 .scalar_at(i)
86 .vortex_expect("scalar_at should succeed in conformance test"),
87 result
88 .scalar_at(i)
89 .vortex_expect("scalar_at should succeed in conformance test")
90 );
91 }
92 }
93 }
94}
95
96fn test_take_none(array: &ArrayRef) {
97 let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
98 let result = array
99 .take(indices.to_array())
100 .vortex_expect("take should succeed in conformance test");
101
102 assert_eq!(result.len(), 0);
103 assert_eq!(result.dtype(), array.dtype());
104}
105
106#[allow(clippy::cast_possible_truncation)]
107fn test_take_selective(array: &ArrayRef) {
108 let len = array.len();
109
110 let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
112 let expected_len = indices.len();
113 let indices_array = PrimitiveArray::from_iter(indices.clone());
114
115 let result = array
116 .take(indices_array.to_array())
117 .vortex_expect("take should succeed in conformance test");
118 assert_eq!(result.len(), expected_len);
119
120 for (result_idx, &original_idx) in indices.iter().enumerate() {
122 assert_eq!(
123 array
124 .scalar_at(original_idx as usize)
125 .vortex_expect("scalar_at should succeed in conformance test"),
126 result
127 .scalar_at(result_idx)
128 .vortex_expect("scalar_at should succeed in conformance test")
129 );
130 }
131}
132
133fn test_take_first_and_last(array: &ArrayRef) {
134 let len = array.len();
135 let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
136 let result = array
137 .take(indices.to_array())
138 .vortex_expect("take should succeed in conformance test");
139
140 assert_eq!(result.len(), 2);
141 assert_eq!(
142 array
143 .scalar_at(0)
144 .vortex_expect("scalar_at should succeed in conformance test"),
145 result
146 .scalar_at(0)
147 .vortex_expect("scalar_at should succeed in conformance test")
148 );
149 assert_eq!(
150 array
151 .scalar_at(len - 1)
152 .vortex_expect("scalar_at should succeed in conformance test"),
153 result
154 .scalar_at(1)
155 .vortex_expect("scalar_at should succeed in conformance test")
156 );
157}
158
159#[allow(clippy::cast_possible_truncation)]
160fn test_take_with_nullable_indices(array: &ArrayRef) {
161 let len = array.len();
162
163 let indices_vec: Vec<Option<u64>> = if len >= 3 {
165 vec![Some(0), None, Some((len - 1) as u64)]
166 } else if len >= 2 {
167 vec![Some(0), None]
168 } else {
169 vec![None]
170 };
171
172 let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
173 let result = array
174 .take(indices.to_array())
175 .vortex_expect("take should succeed in conformance test");
176
177 assert_eq!(result.len(), indices_vec.len());
178 assert_eq!(
179 result.dtype(),
180 &array.dtype().with_nullability(Nullability::Nullable)
181 );
182
183 for (i, idx_opt) in indices_vec.iter().enumerate() {
185 match idx_opt {
186 Some(idx) => {
187 let expected = array
188 .scalar_at(*idx as usize)
189 .vortex_expect("scalar_at should succeed in conformance test");
190 let actual = result
191 .scalar_at(i)
192 .vortex_expect("scalar_at should succeed in conformance test");
193 assert_eq!(expected, actual);
194 }
195 None => {
196 assert!(
197 result
198 .scalar_at(i)
199 .vortex_expect("scalar_at should succeed in conformance test")
200 .is_null()
201 );
202 }
203 }
204 }
205}
206
207fn test_take_repeated_indices(array: &ArrayRef) {
208 if array.is_empty() {
209 return;
210 }
211
212 let indices = buffer![0u64, 0, 0].into_array();
214 let result = array
215 .take(indices.to_array())
216 .vortex_expect("take should succeed in conformance test");
217
218 assert_eq!(result.len(), 3);
219 let first_elem = array
220 .scalar_at(0)
221 .vortex_expect("scalar_at should succeed in conformance test");
222 for i in 0..3 {
223 assert_eq!(
224 result
225 .scalar_at(i)
226 .vortex_expect("scalar_at should succeed in conformance test"),
227 first_elem
228 );
229 }
230}
231
232fn test_empty_indices(array: &ArrayRef) {
233 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
234 let result = array
235 .take(indices.to_array())
236 .vortex_expect("take should succeed in conformance test");
237
238 assert_eq!(result.len(), 0);
239 assert_eq!(result.dtype(), array.dtype());
240}
241
242fn test_take_reverse(array: &ArrayRef) {
243 let len = array.len();
244 let indices = PrimitiveArray::from_iter((0..len as u64).rev());
246 let result = array
247 .take(indices.to_array())
248 .vortex_expect("take should succeed in conformance test");
249
250 assert_eq!(result.len(), len);
251
252 for i in 0..len {
254 assert_eq!(
255 array
256 .scalar_at(len - 1 - i)
257 .vortex_expect("scalar_at should succeed in conformance test"),
258 result
259 .scalar_at(i)
260 .vortex_expect("scalar_at should succeed in conformance test")
261 );
262 }
263}
264
265fn test_take_single_middle(array: &ArrayRef) {
266 let len = array.len();
267 let middle_idx = len / 2;
268
269 let indices = PrimitiveArray::from_iter([middle_idx as u64]);
270 let result = array
271 .take(indices.to_array())
272 .vortex_expect("take should succeed in conformance test");
273
274 assert_eq!(result.len(), 1);
275 assert_eq!(
276 array
277 .scalar_at(middle_idx)
278 .vortex_expect("scalar_at should succeed in conformance test"),
279 result
280 .scalar_at(0)
281 .vortex_expect("scalar_at should succeed in conformance test")
282 );
283}
284
285#[allow(clippy::cast_possible_truncation)]
286fn test_take_random_unsorted(array: &ArrayRef) {
287 let len = array.len();
288
289 let mut indices = Vec::new();
291 let mut idx = 1u64;
292 for _ in 0..len.min(10) {
293 indices.push((idx * 7 + 3) % len as u64);
294 idx = (idx * 3 + 1) % len as u64;
295 }
296
297 let indices_array = PrimitiveArray::from_iter(indices.clone());
298 let result = array
299 .take(indices_array.to_array())
300 .vortex_expect("take should succeed in conformance test");
301
302 assert_eq!(result.len(), indices.len());
303
304 for (i, &idx) in indices.iter().enumerate() {
306 assert_eq!(
307 array
308 .scalar_at(idx as usize)
309 .vortex_expect("scalar_at should succeed in conformance test"),
310 result
311 .scalar_at(i)
312 .vortex_expect("scalar_at should succeed in conformance test")
313 );
314 }
315}
316
317fn test_take_contiguous_range(array: &ArrayRef) {
318 let len = array.len();
319 let start = len / 4;
320 let end = len / 2;
321
322 let indices = PrimitiveArray::from_iter(start as u64..end as u64);
324 let result = array
325 .take(indices.to_array())
326 .vortex_expect("take should succeed in conformance test");
327
328 assert_eq!(result.len(), end - start);
329
330 for i in 0..(end - start) {
332 assert_eq!(
333 array
334 .scalar_at(start + i)
335 .vortex_expect("scalar_at should succeed in conformance test"),
336 result
337 .scalar_at(i)
338 .vortex_expect("scalar_at should succeed in conformance test")
339 );
340 }
341}
342
343#[allow(clippy::cast_possible_truncation)]
344fn test_take_mixed_repeated(array: &ArrayRef) {
345 let len = array.len();
346
347 let indices = vec![
349 0u64,
350 0,
351 1,
352 1,
353 len as u64 / 2,
354 len as u64 / 2,
355 len as u64 / 2,
356 (len - 1) as u64,
357 ];
358
359 let indices_array = PrimitiveArray::from_iter(indices.clone());
360 let result = array
361 .take(indices_array.to_array())
362 .vortex_expect("take should succeed in conformance test");
363
364 assert_eq!(result.len(), indices.len());
365
366 for (i, &idx) in indices.iter().enumerate() {
368 assert_eq!(
369 array
370 .scalar_at(idx as usize)
371 .vortex_expect("scalar_at should succeed in conformance test"),
372 result
373 .scalar_at(i)
374 .vortex_expect("scalar_at should succeed in conformance test")
375 );
376 }
377}
378
379#[allow(clippy::cast_possible_truncation)]
380fn test_take_large_indices(array: &ArrayRef) {
381 let len = array.len();
383 let num_indices = 10000.min(len * 3);
384
385 let indices: Vec<u64> = (0..num_indices)
387 .map(|i| ((i * 17 + 5) % len) as u64)
388 .collect();
389
390 let indices_array = PrimitiveArray::from_iter(indices.clone());
391 let result = array
392 .take(indices_array.to_array())
393 .vortex_expect("take should succeed in conformance test");
394
395 assert_eq!(result.len(), num_indices);
396
397 for i in (0..num_indices).step_by(1000) {
399 let expected_idx = indices[i] as usize;
400 assert_eq!(
401 array
402 .scalar_at(expected_idx)
403 .vortex_expect("scalar_at should succeed in conformance test"),
404 result
405 .scalar_at(i)
406 .vortex_expect("scalar_at should succeed in conformance test")
407 );
408 }
409}