Skip to main content

vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_error::VortexExpect;
6
7use crate::ArrayRef;
8use crate::Canonical;
9use crate::IntoArray as _;
10use crate::LEGACY_SESSION;
11use crate::VortexSessionExecute;
12use crate::arrays::PrimitiveArray;
13use crate::dtype::Nullability;
14
15/// Test conformance of the take compute function for an array.
16///
17/// This function tests various scenarios including:
18/// - Taking all elements
19/// - Taking no elements
20/// - Taking selective elements
21/// - Taking with out-of-bounds indices (should panic)
22/// - Taking with nullable indices
23/// - Edge cases like empty arrays
24pub fn test_take_conformance(array: &ArrayRef) {
25    let len = array.len();
26
27    if len > 0 {
28        test_take_all(array);
29        test_take_none(array);
30        test_take_selective(array);
31        test_take_first_and_last(array);
32        test_take_with_nullable_indices(array);
33        test_take_repeated_indices(array);
34    }
35
36    test_empty_indices(array);
37
38    // Additional edge cases for non-empty arrays
39    if len > 0 {
40        test_take_reverse(array);
41        test_take_single_middle(array);
42    }
43
44    if len > 3 {
45        test_take_random_unsorted(array);
46        test_take_contiguous_range(array);
47        test_take_mixed_repeated(array);
48    }
49
50    // Test for larger arrays
51    if len >= 1024 {
52        test_take_large_indices(array);
53    }
54}
55
56fn test_take_all(array: &ArrayRef) {
57    let len = array.len();
58    let indices = PrimitiveArray::from_iter(0..len as u64);
59    let result = array
60        .take(indices.into_array())
61        .vortex_expect("take should succeed in conformance test");
62
63    assert_eq!(result.len(), len);
64    assert_eq!(result.dtype(), array.dtype());
65
66    // Verify elements match
67    match (
68        array
69            .to_canonical()
70            .vortex_expect("to_canonical failed on array"),
71        result
72            .to_canonical()
73            .vortex_expect("to_canonical failed on result"),
74    ) {
75        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
76            assert_eq!(
77                orig_prim.buffer_handle().to_host_sync(),
78                result_prim.buffer_handle().to_host_sync()
79            );
80        }
81        _ => {
82            // For non-primitive types, check scalar values
83            for i in 0..len {
84                assert_eq!(
85                    array
86                        .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
87                        .vortex_expect("scalar_at should succeed in conformance test"),
88                    result
89                        .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
90                        .vortex_expect("scalar_at should succeed in conformance test")
91                );
92            }
93        }
94    }
95}
96
97fn test_take_none(array: &ArrayRef) {
98    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
99    let result = array
100        .take(indices.into_array())
101        .vortex_expect("take should succeed in conformance test");
102
103    assert_eq!(result.len(), 0);
104    assert_eq!(result.dtype(), array.dtype());
105}
106
107#[expect(clippy::cast_possible_truncation)]
108fn test_take_selective(array: &ArrayRef) {
109    let len = array.len();
110
111    // Take every other element
112    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
113    let expected_len = indices.len();
114    let indices_array = PrimitiveArray::from_iter(indices.clone());
115
116    let result = array
117        .take(indices_array.into_array())
118        .vortex_expect("take should succeed in conformance test");
119    assert_eq!(result.len(), expected_len);
120
121    // Verify the taken elements
122    for (result_idx, &original_idx) in indices.iter().enumerate() {
123        assert_eq!(
124            array
125                .execute_scalar(
126                    original_idx as usize,
127                    &mut LEGACY_SESSION.create_execution_ctx()
128                )
129                .vortex_expect("scalar_at should succeed in conformance test"),
130            result
131                .execute_scalar(result_idx, &mut LEGACY_SESSION.create_execution_ctx())
132                .vortex_expect("scalar_at should succeed in conformance test")
133        );
134    }
135}
136
137fn test_take_first_and_last(array: &ArrayRef) {
138    let len = array.len();
139    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
140    let result = array
141        .take(indices.into_array())
142        .vortex_expect("take should succeed in conformance test");
143
144    assert_eq!(result.len(), 2);
145    assert_eq!(
146        array
147            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
148            .vortex_expect("scalar_at should succeed in conformance test"),
149        result
150            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
151            .vortex_expect("scalar_at should succeed in conformance test")
152    );
153    assert_eq!(
154        array
155            .execute_scalar(len - 1, &mut LEGACY_SESSION.create_execution_ctx())
156            .vortex_expect("scalar_at should succeed in conformance test"),
157        result
158            .execute_scalar(1, &mut LEGACY_SESSION.create_execution_ctx())
159            .vortex_expect("scalar_at should succeed in conformance test")
160    );
161}
162
163#[expect(clippy::cast_possible_truncation)]
164fn test_take_with_nullable_indices(array: &ArrayRef) {
165    let len = array.len();
166
167    // Create indices with some null values
168    let indices_vec: Vec<Option<u64>> = if len >= 3 {
169        vec![Some(0), None, Some((len - 1) as u64)]
170    } else if len >= 2 {
171        vec![Some(0), None]
172    } else {
173        vec![None]
174    };
175
176    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
177    let result = array
178        .take(indices.into_array())
179        .vortex_expect("take should succeed in conformance test");
180
181    assert_eq!(result.len(), indices_vec.len());
182    assert_eq!(
183        result.dtype(),
184        &array.dtype().with_nullability(Nullability::Nullable)
185    );
186
187    // Verify values
188    for (i, idx_opt) in indices_vec.iter().enumerate() {
189        match idx_opt {
190            Some(idx) => {
191                let expected = array
192                    .execute_scalar(*idx as usize, &mut LEGACY_SESSION.create_execution_ctx())
193                    .vortex_expect("scalar_at should succeed in conformance test");
194                let actual = result
195                    .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
196                    .vortex_expect("scalar_at should succeed in conformance test");
197                assert_eq!(expected, actual);
198            }
199            None => {
200                assert!(
201                    result
202                        .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
203                        .vortex_expect("scalar_at should succeed in conformance test")
204                        .is_null()
205                );
206            }
207        }
208    }
209}
210
211fn test_take_repeated_indices(array: &ArrayRef) {
212    if array.is_empty() {
213        return;
214    }
215
216    // Take the first element multiple times
217    let indices = buffer![0u64, 0, 0].into_array();
218    let result = array
219        .take(indices)
220        .vortex_expect("take should succeed in conformance test");
221
222    assert_eq!(result.len(), 3);
223    let first_elem = array
224        .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
225        .vortex_expect("scalar_at should succeed in conformance test");
226    for i in 0..3 {
227        assert_eq!(
228            result
229                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
230                .vortex_expect("scalar_at should succeed in conformance test"),
231            first_elem
232        );
233    }
234}
235
236fn test_empty_indices(array: &ArrayRef) {
237    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
238    let result = array
239        .take(indices.into_array())
240        .vortex_expect("take should succeed in conformance test");
241
242    assert_eq!(result.len(), 0);
243    assert_eq!(result.dtype(), array.dtype());
244}
245
246fn test_take_reverse(array: &ArrayRef) {
247    let len = array.len();
248    // Take elements in reverse order
249    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
250    let result = array
251        .take(indices.into_array())
252        .vortex_expect("take should succeed in conformance test");
253
254    assert_eq!(result.len(), len);
255
256    // Verify elements are in reverse order
257    for i in 0..len {
258        assert_eq!(
259            array
260                .execute_scalar(len - 1 - i, &mut LEGACY_SESSION.create_execution_ctx())
261                .vortex_expect("scalar_at should succeed in conformance test"),
262            result
263                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
264                .vortex_expect("scalar_at should succeed in conformance test")
265        );
266    }
267}
268
269fn test_take_single_middle(array: &ArrayRef) {
270    let len = array.len();
271    let middle_idx = len / 2;
272
273    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
274    let result = array
275        .take(indices.into_array())
276        .vortex_expect("take should succeed in conformance test");
277
278    assert_eq!(result.len(), 1);
279    assert_eq!(
280        array
281            .execute_scalar(middle_idx, &mut LEGACY_SESSION.create_execution_ctx())
282            .vortex_expect("scalar_at should succeed in conformance test"),
283        result
284            .execute_scalar(0, &mut LEGACY_SESSION.create_execution_ctx())
285            .vortex_expect("scalar_at should succeed in conformance test")
286    );
287}
288
289#[expect(clippy::cast_possible_truncation)]
290fn test_take_random_unsorted(array: &ArrayRef) {
291    let len = array.len();
292
293    // Create a pseudo-random but deterministic pattern
294    let mut indices = Vec::new();
295    let mut idx = 1u64;
296    for _ in 0..len.min(10) {
297        indices.push((idx * 7 + 3) % len as u64);
298        idx = (idx * 3 + 1) % len as u64;
299    }
300
301    let indices_array = PrimitiveArray::from_iter(indices.clone());
302    let result = array
303        .take(indices_array.into_array())
304        .vortex_expect("take should succeed in conformance test");
305
306    assert_eq!(result.len(), indices.len());
307
308    // Verify elements match
309    for (i, &idx) in indices.iter().enumerate() {
310        assert_eq!(
311            array
312                .execute_scalar(idx as usize, &mut LEGACY_SESSION.create_execution_ctx())
313                .vortex_expect("scalar_at should succeed in conformance test"),
314            result
315                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
316                .vortex_expect("scalar_at should succeed in conformance test")
317        );
318    }
319}
320
321fn test_take_contiguous_range(array: &ArrayRef) {
322    let len = array.len();
323    let start = len / 4;
324    let end = len / 2;
325
326    // Take a contiguous range from the middle
327    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
328    let result = array
329        .take(indices.into_array())
330        .vortex_expect("take should succeed in conformance test");
331
332    assert_eq!(result.len(), end - start);
333
334    // Verify elements
335    for i in 0..(end - start) {
336        assert_eq!(
337            array
338                .execute_scalar(start + i, &mut LEGACY_SESSION.create_execution_ctx())
339                .vortex_expect("scalar_at should succeed in conformance test"),
340            result
341                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
342                .vortex_expect("scalar_at should succeed in conformance test")
343        );
344    }
345}
346
347#[expect(clippy::cast_possible_truncation)]
348fn test_take_mixed_repeated(array: &ArrayRef) {
349    let len = array.len();
350
351    // Create pattern with some repeated indices
352    let indices = vec![
353        0u64,
354        0,
355        1,
356        1,
357        len as u64 / 2,
358        len as u64 / 2,
359        len as u64 / 2,
360        (len - 1) as u64,
361    ];
362
363    let indices_array = PrimitiveArray::from_iter(indices.clone());
364    let result = array
365        .take(indices_array.into_array())
366        .vortex_expect("take should succeed in conformance test");
367
368    assert_eq!(result.len(), indices.len());
369
370    // Verify elements
371    for (i, &idx) in indices.iter().enumerate() {
372        assert_eq!(
373            array
374                .execute_scalar(idx as usize, &mut LEGACY_SESSION.create_execution_ctx())
375                .vortex_expect("scalar_at should succeed in conformance test"),
376            result
377                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
378                .vortex_expect("scalar_at should succeed in conformance test")
379        );
380    }
381}
382
383#[expect(clippy::cast_possible_truncation)]
384fn test_take_large_indices(array: &ArrayRef) {
385    // Test with a large number of indices to stress test performance
386    let len = array.len();
387    let num_indices = 10000.min(len * 3);
388
389    // Create many indices with a pattern
390    let indices: Vec<u64> = (0..num_indices)
391        .map(|i| ((i * 17 + 5) % len) as u64)
392        .collect();
393
394    let indices_array = PrimitiveArray::from_iter(indices.clone());
395    let result = array
396        .take(indices_array.into_array())
397        .vortex_expect("take should succeed in conformance test");
398
399    assert_eq!(result.len(), num_indices);
400
401    // Spot check a few elements
402    for i in (0..num_indices).step_by(1000) {
403        let expected_idx = indices[i] as usize;
404        assert_eq!(
405            array
406                .execute_scalar(expected_idx, &mut LEGACY_SESSION.create_execution_ctx())
407                .vortex_expect("scalar_at should succeed in conformance test"),
408            result
409                .execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())
410                .vortex_expect("scalar_at should succeed in conformance test")
411        );
412    }
413}