vortex_array/compute/conformance/
take.rs1use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexExpect;
7
8use crate::Array;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12
13pub fn test_take_conformance(array: &dyn Array) {
23 let len = array.len();
24
25 if len > 0 {
26 test_take_all(array);
27 test_take_none(array);
28 test_take_selective(array);
29 test_take_first_and_last(array);
30 test_take_with_nullable_indices(array);
31 test_take_repeated_indices(array);
32 }
33
34 test_empty_indices(array);
35
36 if len > 0 {
38 test_take_reverse(array);
39 test_take_single_middle(array);
40 }
41
42 if len > 3 {
43 test_take_random_unsorted(array);
44 test_take_contiguous_range(array);
45 test_take_mixed_repeated(array);
46 }
47
48 if len >= 1024 {
50 test_take_large_indices(array);
51 }
52}
53
54fn test_take_all(array: &dyn Array) {
55 let len = array.len();
56 let indices = PrimitiveArray::from_iter(0..len as u64);
57 let result = array
58 .take(indices.to_array())
59 .vortex_expect("take should succeed in conformance test");
60
61 assert_eq!(result.len(), len);
62 assert_eq!(result.dtype(), array.dtype());
63
64 match (
66 array
67 .to_canonical()
68 .vortex_expect("to_canonical failed on array"),
69 result
70 .to_canonical()
71 .vortex_expect("to_canonical failed on result"),
72 ) {
73 (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
74 assert_eq!(
75 orig_prim.buffer_handle().to_host_sync(),
76 result_prim.buffer_handle().to_host_sync()
77 );
78 }
79 _ => {
80 for i in 0..len {
82 assert_eq!(
83 array
84 .scalar_at(i)
85 .vortex_expect("scalar_at should succeed in conformance test"),
86 result
87 .scalar_at(i)
88 .vortex_expect("scalar_at should succeed in conformance test")
89 );
90 }
91 }
92 }
93}
94
95fn test_take_none(array: &dyn Array) {
96 let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
97 let result = array
98 .take(indices.to_array())
99 .vortex_expect("take should succeed in conformance test");
100
101 assert_eq!(result.len(), 0);
102 assert_eq!(result.dtype(), array.dtype());
103}
104
105#[allow(clippy::cast_possible_truncation)]
106fn test_take_selective(array: &dyn Array) {
107 let len = array.len();
108
109 let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
111 let expected_len = indices.len();
112 let indices_array = PrimitiveArray::from_iter(indices.clone());
113
114 let result = array
115 .take(indices_array.to_array())
116 .vortex_expect("take should succeed in conformance test");
117 assert_eq!(result.len(), expected_len);
118
119 for (result_idx, &original_idx) in indices.iter().enumerate() {
121 assert_eq!(
122 array
123 .scalar_at(original_idx as usize)
124 .vortex_expect("scalar_at should succeed in conformance test"),
125 result
126 .scalar_at(result_idx)
127 .vortex_expect("scalar_at should succeed in conformance test")
128 );
129 }
130}
131
132fn test_take_first_and_last(array: &dyn Array) {
133 let len = array.len();
134 let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
135 let result = array
136 .take(indices.to_array())
137 .vortex_expect("take should succeed in conformance test");
138
139 assert_eq!(result.len(), 2);
140 assert_eq!(
141 array
142 .scalar_at(0)
143 .vortex_expect("scalar_at should succeed in conformance test"),
144 result
145 .scalar_at(0)
146 .vortex_expect("scalar_at should succeed in conformance test")
147 );
148 assert_eq!(
149 array
150 .scalar_at(len - 1)
151 .vortex_expect("scalar_at should succeed in conformance test"),
152 result
153 .scalar_at(1)
154 .vortex_expect("scalar_at should succeed in conformance test")
155 );
156}
157
158#[allow(clippy::cast_possible_truncation)]
159fn test_take_with_nullable_indices(array: &dyn Array) {
160 let len = array.len();
161
162 let indices_vec: Vec<Option<u64>> = if len >= 3 {
164 vec![Some(0), None, Some((len - 1) as u64)]
165 } else if len >= 2 {
166 vec![Some(0), None]
167 } else {
168 vec![None]
169 };
170
171 let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
172 let result = array
173 .take(indices.to_array())
174 .vortex_expect("take should succeed in conformance test");
175
176 assert_eq!(result.len(), indices_vec.len());
177 assert_eq!(
178 result.dtype(),
179 &array.dtype().with_nullability(Nullability::Nullable)
180 );
181
182 for (i, idx_opt) in indices_vec.iter().enumerate() {
184 match idx_opt {
185 Some(idx) => {
186 let expected = array
187 .scalar_at(*idx as usize)
188 .vortex_expect("scalar_at should succeed in conformance test");
189 let actual = result
190 .scalar_at(i)
191 .vortex_expect("scalar_at should succeed in conformance test");
192 assert_eq!(expected, actual);
193 }
194 None => {
195 assert!(
196 result
197 .scalar_at(i)
198 .vortex_expect("scalar_at should succeed in conformance test")
199 .is_null()
200 );
201 }
202 }
203 }
204}
205
206fn test_take_repeated_indices(array: &dyn Array) {
207 if array.is_empty() {
208 return;
209 }
210
211 let indices = buffer![0u64, 0, 0].into_array();
213 let result = array
214 .take(indices.to_array())
215 .vortex_expect("take should succeed in conformance test");
216
217 assert_eq!(result.len(), 3);
218 let first_elem = array
219 .scalar_at(0)
220 .vortex_expect("scalar_at should succeed in conformance test");
221 for i in 0..3 {
222 assert_eq!(
223 result
224 .scalar_at(i)
225 .vortex_expect("scalar_at should succeed in conformance test"),
226 first_elem
227 );
228 }
229}
230
231fn test_empty_indices(array: &dyn Array) {
232 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
233 let result = array
234 .take(indices.to_array())
235 .vortex_expect("take should succeed in conformance test");
236
237 assert_eq!(result.len(), 0);
238 assert_eq!(result.dtype(), array.dtype());
239}
240
241fn test_take_reverse(array: &dyn Array) {
242 let len = array.len();
243 let indices = PrimitiveArray::from_iter((0..len as u64).rev());
245 let result = array
246 .take(indices.to_array())
247 .vortex_expect("take should succeed in conformance test");
248
249 assert_eq!(result.len(), len);
250
251 for i in 0..len {
253 assert_eq!(
254 array
255 .scalar_at(len - 1 - i)
256 .vortex_expect("scalar_at should succeed in conformance test"),
257 result
258 .scalar_at(i)
259 .vortex_expect("scalar_at should succeed in conformance test")
260 );
261 }
262}
263
264fn test_take_single_middle(array: &dyn Array) {
265 let len = array.len();
266 let middle_idx = len / 2;
267
268 let indices = PrimitiveArray::from_iter([middle_idx as u64]);
269 let result = array
270 .take(indices.to_array())
271 .vortex_expect("take should succeed in conformance test");
272
273 assert_eq!(result.len(), 1);
274 assert_eq!(
275 array
276 .scalar_at(middle_idx)
277 .vortex_expect("scalar_at should succeed in conformance test"),
278 result
279 .scalar_at(0)
280 .vortex_expect("scalar_at should succeed in conformance test")
281 );
282}
283
284#[allow(clippy::cast_possible_truncation)]
285fn test_take_random_unsorted(array: &dyn Array) {
286 let len = array.len();
287
288 let mut indices = Vec::new();
290 let mut idx = 1u64;
291 for _ in 0..len.min(10) {
292 indices.push((idx * 7 + 3) % len as u64);
293 idx = (idx * 3 + 1) % len as u64;
294 }
295
296 let indices_array = PrimitiveArray::from_iter(indices.clone());
297 let result = array
298 .take(indices_array.to_array())
299 .vortex_expect("take should succeed in conformance test");
300
301 assert_eq!(result.len(), indices.len());
302
303 for (i, &idx) in indices.iter().enumerate() {
305 assert_eq!(
306 array
307 .scalar_at(idx as usize)
308 .vortex_expect("scalar_at should succeed in conformance test"),
309 result
310 .scalar_at(i)
311 .vortex_expect("scalar_at should succeed in conformance test")
312 );
313 }
314}
315
316fn test_take_contiguous_range(array: &dyn Array) {
317 let len = array.len();
318 let start = len / 4;
319 let end = len / 2;
320
321 let indices = PrimitiveArray::from_iter(start as u64..end as u64);
323 let result = array
324 .take(indices.to_array())
325 .vortex_expect("take should succeed in conformance test");
326
327 assert_eq!(result.len(), end - start);
328
329 for i in 0..(end - start) {
331 assert_eq!(
332 array
333 .scalar_at(start + i)
334 .vortex_expect("scalar_at should succeed in conformance test"),
335 result
336 .scalar_at(i)
337 .vortex_expect("scalar_at should succeed in conformance test")
338 );
339 }
340}
341
342#[allow(clippy::cast_possible_truncation)]
343fn test_take_mixed_repeated(array: &dyn Array) {
344 let len = array.len();
345
346 let indices = vec![
348 0u64,
349 0,
350 1,
351 1,
352 len as u64 / 2,
353 len as u64 / 2,
354 len as u64 / 2,
355 (len - 1) as u64,
356 ];
357
358 let indices_array = PrimitiveArray::from_iter(indices.clone());
359 let result = array
360 .take(indices_array.to_array())
361 .vortex_expect("take should succeed in conformance test");
362
363 assert_eq!(result.len(), indices.len());
364
365 for (i, &idx) in indices.iter().enumerate() {
367 assert_eq!(
368 array
369 .scalar_at(idx as usize)
370 .vortex_expect("scalar_at should succeed in conformance test"),
371 result
372 .scalar_at(i)
373 .vortex_expect("scalar_at should succeed in conformance test")
374 );
375 }
376}
377
378#[allow(clippy::cast_possible_truncation)]
379fn test_take_large_indices(array: &dyn Array) {
380 let len = array.len();
382 let num_indices = 10000.min(len * 3);
383
384 let indices: Vec<u64> = (0..num_indices)
386 .map(|i| ((i * 17 + 5) % len) as u64)
387 .collect();
388
389 let indices_array = PrimitiveArray::from_iter(indices.clone());
390 let result = array
391 .take(indices_array.to_array())
392 .vortex_expect("take should succeed in conformance test");
393
394 assert_eq!(result.len(), num_indices);
395
396 for i in (0..num_indices).step_by(1000) {
398 let expected_idx = indices[i] as usize;
399 assert_eq!(
400 array
401 .scalar_at(expected_idx)
402 .vortex_expect("scalar_at should succeed in conformance test"),
403 result
404 .scalar_at(i)
405 .vortex_expect("scalar_at should succeed in conformance test")
406 );
407 }
408}