Skip to main content

vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_error::VortexExpect;
6
7use crate::ArrayRef;
8use crate::Canonical;
9use crate::IntoArray as _;
10use crate::LEGACY_SESSION;
11use crate::VortexSessionExecute;
12use crate::arrays::PrimitiveArray;
13use crate::dtype::Nullability;
14
15/// Test conformance of the take compute function for an array.
16///
17/// This function tests various scenarios including:
18/// - Taking all elements
19/// - Taking no elements
20/// - Taking selective elements
21/// - Taking with out-of-bounds indices (should panic)
22/// - Taking with nullable indices
23/// - Edge cases like empty arrays
24pub fn test_take_conformance(array: &ArrayRef) {
25    let len = array.len();
26
27    if len > 0 {
28        test_take_all(array);
29        test_take_none(array);
30        test_take_selective(array);
31        test_take_first_and_last(array);
32        test_take_with_nullable_indices(array);
33        test_take_repeated_indices(array);
34    }
35
36    test_empty_indices(array);
37
38    // Additional edge cases for non-empty arrays
39    if len > 0 {
40        test_take_reverse(array);
41        test_take_single_middle(array);
42    }
43
44    if len > 3 {
45        test_take_random_unsorted(array);
46        test_take_contiguous_range(array);
47        test_take_mixed_repeated(array);
48    }
49
50    // Test for larger arrays
51    if len >= 1024 {
52        test_take_large_indices(array);
53    }
54}
55
56fn test_take_all(array: &ArrayRef) {
57    let len = array.len();
58    let indices = PrimitiveArray::from_iter(0..len as u64);
59    let result = array
60        .take(indices.into_array())
61        .vortex_expect("take should succeed in conformance test");
62
63    assert_eq!(result.len(), len);
64    assert_eq!(result.dtype(), array.dtype());
65
66    // Verify elements match
67    #[expect(deprecated)]
68    let array_canonical = array
69        .to_canonical()
70        .vortex_expect("to_canonical failed on array");
71    #[expect(deprecated)]
72    let result_canonical = result
73        .to_canonical()
74        .vortex_expect("to_canonical failed on result");
75    match (array_canonical, result_canonical) {
76        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
77            assert_eq!(
78                orig_prim.buffer_handle().to_host_sync(),
79                result_prim.buffer_handle().to_host_sync()
80            );
81        }
82        _ => {
83            // For non-primitive types, check scalar values
84            for i in 0..len {
85                assert_eq!(
86                    array
87                        .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
88                        .vortex_expect("scalar_at should succeed in conformance test"),
89                    result
90                        .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
91                        .vortex_expect("scalar_at should succeed in conformance test")
92                );
93            }
94        }
95    }
96}
97
98fn test_take_none(array: &ArrayRef) {
99    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
100    let result = array
101        .take(indices.into_array())
102        .vortex_expect("take should succeed in conformance test");
103
104    assert_eq!(result.len(), 0);
105    assert_eq!(result.dtype(), array.dtype());
106}
107
108#[expect(clippy::cast_possible_truncation)]
109fn test_take_selective(array: &ArrayRef) {
110    let len = array.len();
111
112    // Take every other element
113    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
114    let expected_len = indices.len();
115    let indices_array = PrimitiveArray::from_iter(indices.clone());
116
117    let result = array
118        .take(indices_array.into_array())
119        .vortex_expect("take should succeed in conformance test");
120    assert_eq!(result.len(), expected_len);
121
122    // Verify the taken elements
123    for (result_idx, &original_idx) in indices.iter().enumerate() {
124        assert_eq!(
125            array
126                .execute_scalar(
127                    original_idx as usize,
128                    &mut LEGACY_SESSION.create_execution_ctx()
129                )
130                .vortex_expect("scalar_at should succeed in conformance test"),
131            result
132                .execute_scalar(result_idx, &mut LEGACY_SESSION.create_execution_ctx())
133                .vortex_expect("scalar_at should succeed in conformance test")
134        );
135    }
136}
137
138fn test_take_first_and_last(array: &ArrayRef) {
139    let len = array.len();
140    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
141    let result = array
142        .take(indices.into_array())
143        .vortex_expect("take should succeed in conformance test");
144
145    assert_eq!(result.len(), 2);
146    assert_eq!(
147        array
148            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
149            .vortex_expect("scalar_at should succeed in conformance test"),
150        result
151            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
152            .vortex_expect("scalar_at should succeed in conformance test")
153    );
154    assert_eq!(
155        array
156            .execute_scalar(len - 1, &mut LEGACY_SESSION.create_execution_ctx())
157            .vortex_expect("scalar_at should succeed in conformance test"),
158        result
159            .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
160            .vortex_expect("scalar_at should succeed in conformance test")
161    );
162}
163
164#[expect(clippy::cast_possible_truncation)]
165fn test_take_with_nullable_indices(array: &ArrayRef) {
166    let len = array.len();
167
168    // Create indices with some null values
169    let indices_vec: Vec<Option<u64>> = if len >= 3 {
170        vec![Some(0), None, Some((len - 1) as u64)]
171    } else if len >= 2 {
172        vec![Some(0), None]
173    } else {
174        vec![None]
175    };
176
177    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
178    let result = array
179        .take(indices.into_array())
180        .vortex_expect("take should succeed in conformance test");
181
182    assert_eq!(result.len(), indices_vec.len());
183    assert_eq!(
184        result.dtype(),
185        &array.dtype().with_nullability(Nullability::Nullable)
186    );
187
188    // Verify values
189    for (i, idx_opt) in indices_vec.iter().enumerate() {
190        match idx_opt {
191            Some(idx) => {
192                let expected = array
193                    .execute_scalar(*idx as usize, &mut LEGACY_SESSION.create_execution_ctx())
194                    .vortex_expect("scalar_at should succeed in conformance test");
195                let actual = result
196                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
197                    .vortex_expect("scalar_at should succeed in conformance test");
198                assert_eq!(expected, actual);
199            }
200            None => {
201                assert!(
202                    result
203                        .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
204                        .vortex_expect("scalar_at should succeed in conformance test")
205                        .is_null()
206                );
207            }
208        }
209    }
210}
211
212fn test_take_repeated_indices(array: &ArrayRef) {
213    if array.is_empty() {
214        return;
215    }
216
217    // Take the first element multiple times
218    let indices = buffer![0u64, 0, 0].into_array();
219    let result = array
220        .take(indices)
221        .vortex_expect("take should succeed in conformance test");
222
223    assert_eq!(result.len(), 3);
224    let first_elem = array
225        .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
226        .vortex_expect("scalar_at should succeed in conformance test");
227    for i in 0..3 {
228        assert_eq!(
229            result
230                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
231                .vortex_expect("scalar_at should succeed in conformance test"),
232            first_elem
233        );
234    }
235}
236
237fn test_empty_indices(array: &ArrayRef) {
238    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
239    let result = array
240        .take(indices.into_array())
241        .vortex_expect("take should succeed in conformance test");
242
243    assert_eq!(result.len(), 0);
244    assert_eq!(result.dtype(), array.dtype());
245}
246
247fn test_take_reverse(array: &ArrayRef) {
248    let len = array.len();
249    // Take elements in reverse order
250    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
251    let result = array
252        .take(indices.into_array())
253        .vortex_expect("take should succeed in conformance test");
254
255    assert_eq!(result.len(), len);
256
257    // Verify elements are in reverse order
258    for i in 0..len {
259        assert_eq!(
260            array
261                .execute_scalar(len - 1 - i, &mut LEGACY_SESSION.create_execution_ctx())
262                .vortex_expect("scalar_at should succeed in conformance test"),
263            result
264                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
265                .vortex_expect("scalar_at should succeed in conformance test")
266        );
267    }
268}
269
270fn test_take_single_middle(array: &ArrayRef) {
271    let len = array.len();
272    let middle_idx = len / 2;
273
274    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
275    let result = array
276        .take(indices.into_array())
277        .vortex_expect("take should succeed in conformance test");
278
279    assert_eq!(result.len(), 1);
280    assert_eq!(
281        array
282            .execute_scalar(middle_idx, &mut LEGACY_SESSION.create_execution_ctx())
283            .vortex_expect("scalar_at should succeed in conformance test"),
284        result
285            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
286            .vortex_expect("scalar_at should succeed in conformance test")
287    );
288}
289
290#[expect(clippy::cast_possible_truncation)]
291fn test_take_random_unsorted(array: &ArrayRef) {
292    let len = array.len();
293
294    // Create a pseudo-random but deterministic pattern
295    let mut indices = Vec::new();
296    let mut idx = 1u64;
297    for _ in 0..len.min(10) {
298        indices.push((idx * 7 + 3) % len as u64);
299        idx = (idx * 3 + 1) % len as u64;
300    }
301
302    let indices_array = PrimitiveArray::from_iter(indices.clone());
303    let result = array
304        .take(indices_array.into_array())
305        .vortex_expect("take should succeed in conformance test");
306
307    assert_eq!(result.len(), indices.len());
308
309    // Verify elements match
310    for (i, &idx) in indices.iter().enumerate() {
311        assert_eq!(
312            array
313                .execute_scalar(idx as usize, &mut LEGACY_SESSION.create_execution_ctx())
314                .vortex_expect("scalar_at should succeed in conformance test"),
315            result
316                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
317                .vortex_expect("scalar_at should succeed in conformance test")
318        );
319    }
320}
321
322fn test_take_contiguous_range(array: &ArrayRef) {
323    let len = array.len();
324    let start = len / 4;
325    let end = len / 2;
326
327    // Take a contiguous range from the middle
328    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
329    let result = array
330        .take(indices.into_array())
331        .vortex_expect("take should succeed in conformance test");
332
333    assert_eq!(result.len(), end - start);
334
335    // Verify elements
336    for i in 0..(end - start) {
337        assert_eq!(
338            array
339                .execute_scalar(start + i, &mut LEGACY_SESSION.create_execution_ctx())
340                .vortex_expect("scalar_at should succeed in conformance test"),
341            result
342                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
343                .vortex_expect("scalar_at should succeed in conformance test")
344        );
345    }
346}
347
348#[expect(clippy::cast_possible_truncation)]
349fn test_take_mixed_repeated(array: &ArrayRef) {
350    let len = array.len();
351
352    // Create pattern with some repeated indices
353    let indices = vec![
354        0u64,
355        0,
356        1,
357        1,
358        len as u64 / 2,
359        len as u64 / 2,
360        len as u64 / 2,
361        (len - 1) as u64,
362    ];
363
364    let indices_array = PrimitiveArray::from_iter(indices.clone());
365    let result = array
366        .take(indices_array.into_array())
367        .vortex_expect("take should succeed in conformance test");
368
369    assert_eq!(result.len(), indices.len());
370
371    // Verify elements
372    for (i, &idx) in indices.iter().enumerate() {
373        assert_eq!(
374            array
375                .execute_scalar(idx as usize, &mut LEGACY_SESSION.create_execution_ctx())
376                .vortex_expect("scalar_at should succeed in conformance test"),
377            result
378                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
379                .vortex_expect("scalar_at should succeed in conformance test")
380        );
381    }
382}
383
384#[expect(clippy::cast_possible_truncation)]
385fn test_take_large_indices(array: &ArrayRef) {
386    // Test with a large number of indices to stress test performance
387    let len = array.len();
388    let num_indices = 10000.min(len * 3);
389
390    // Create many indices with a pattern
391    let indices: Vec<u64> = (0..num_indices)
392        .map(|i| ((i * 17 + 5) % len) as u64)
393        .collect();
394
395    let indices_array = PrimitiveArray::from_iter(indices.clone());
396    let result = array
397        .take(indices_array.into_array())
398        .vortex_expect("take should succeed in conformance test");
399
400    assert_eq!(result.len(), num_indices);
401
402    // Spot check a few elements
403    for i in (0..num_indices).step_by(1000) {
404        let expected_idx = indices[i] as usize;
405        assert_eq!(
406            array
407                .execute_scalar(expected_idx, &mut LEGACY_SESSION.create_execution_ctx())
408                .vortex_expect("scalar_at should succeed in conformance test"),
409            result
410                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
411                .vortex_expect("scalar_at should succeed in conformance test")
412        );
413    }
414}