Skip to main content

vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_error::VortexExpect;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12use crate::dtype::Nullability;
13
14/// Test conformance of the take compute function for an array.
15///
16/// This function tests various scenarios including:
17/// - Taking all elements
18/// - Taking no elements
19/// - Taking selective elements
20/// - Taking with out-of-bounds indices (should panic)
21/// - Taking with nullable indices
22/// - Edge cases like empty arrays
23pub fn test_take_conformance(array: &ArrayRef) {
24    let len = array.len();
25
26    if len > 0 {
27        test_take_all(array);
28        test_take_none(array);
29        test_take_selective(array);
30        test_take_first_and_last(array);
31        test_take_with_nullable_indices(array);
32        test_take_repeated_indices(array);
33    }
34
35    test_empty_indices(array);
36
37    // Additional edge cases for non-empty arrays
38    if len > 0 {
39        test_take_reverse(array);
40        test_take_single_middle(array);
41    }
42
43    if len > 3 {
44        test_take_random_unsorted(array);
45        test_take_contiguous_range(array);
46        test_take_mixed_repeated(array);
47    }
48
49    // Test for larger arrays
50    if len >= 1024 {
51        test_take_large_indices(array);
52    }
53}
54
55fn test_take_all(array: &ArrayRef) {
56    let len = array.len();
57    let indices = PrimitiveArray::from_iter(0..len as u64);
58    let result = array
59        .take(indices.to_array())
60        .vortex_expect("take should succeed in conformance test");
61
62    assert_eq!(result.len(), len);
63    assert_eq!(result.dtype(), array.dtype());
64
65    // Verify elements match
66    match (
67        array
68            .to_canonical()
69            .vortex_expect("to_canonical failed on array"),
70        result
71            .to_canonical()
72            .vortex_expect("to_canonical failed on result"),
73    ) {
74        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
75            assert_eq!(
76                orig_prim.buffer_handle().to_host_sync(),
77                result_prim.buffer_handle().to_host_sync()
78            );
79        }
80        _ => {
81            // For non-primitive types, check scalar values
82            for i in 0..len {
83                assert_eq!(
84                    array
85                        .scalar_at(i)
86                        .vortex_expect("scalar_at should succeed in conformance test"),
87                    result
88                        .scalar_at(i)
89                        .vortex_expect("scalar_at should succeed in conformance test")
90                );
91            }
92        }
93    }
94}
95
96fn test_take_none(array: &ArrayRef) {
97    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
98    let result = array
99        .take(indices.to_array())
100        .vortex_expect("take should succeed in conformance test");
101
102    assert_eq!(result.len(), 0);
103    assert_eq!(result.dtype(), array.dtype());
104}
105
106#[allow(clippy::cast_possible_truncation)]
107fn test_take_selective(array: &ArrayRef) {
108    let len = array.len();
109
110    // Take every other element
111    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
112    let expected_len = indices.len();
113    let indices_array = PrimitiveArray::from_iter(indices.clone());
114
115    let result = array
116        .take(indices_array.to_array())
117        .vortex_expect("take should succeed in conformance test");
118    assert_eq!(result.len(), expected_len);
119
120    // Verify the taken elements
121    for (result_idx, &original_idx) in indices.iter().enumerate() {
122        assert_eq!(
123            array
124                .scalar_at(original_idx as usize)
125                .vortex_expect("scalar_at should succeed in conformance test"),
126            result
127                .scalar_at(result_idx)
128                .vortex_expect("scalar_at should succeed in conformance test")
129        );
130    }
131}
132
133fn test_take_first_and_last(array: &ArrayRef) {
134    let len = array.len();
135    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
136    let result = array
137        .take(indices.to_array())
138        .vortex_expect("take should succeed in conformance test");
139
140    assert_eq!(result.len(), 2);
141    assert_eq!(
142        array
143            .scalar_at(0)
144            .vortex_expect("scalar_at should succeed in conformance test"),
145        result
146            .scalar_at(0)
147            .vortex_expect("scalar_at should succeed in conformance test")
148    );
149    assert_eq!(
150        array
151            .scalar_at(len - 1)
152            .vortex_expect("scalar_at should succeed in conformance test"),
153        result
154            .scalar_at(1)
155            .vortex_expect("scalar_at should succeed in conformance test")
156    );
157}
158
159#[allow(clippy::cast_possible_truncation)]
160fn test_take_with_nullable_indices(array: &ArrayRef) {
161    let len = array.len();
162
163    // Create indices with some null values
164    let indices_vec: Vec<Option<u64>> = if len >= 3 {
165        vec![Some(0), None, Some((len - 1) as u64)]
166    } else if len >= 2 {
167        vec![Some(0), None]
168    } else {
169        vec![None]
170    };
171
172    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
173    let result = array
174        .take(indices.to_array())
175        .vortex_expect("take should succeed in conformance test");
176
177    assert_eq!(result.len(), indices_vec.len());
178    assert_eq!(
179        result.dtype(),
180        &array.dtype().with_nullability(Nullability::Nullable)
181    );
182
183    // Verify values
184    for (i, idx_opt) in indices_vec.iter().enumerate() {
185        match idx_opt {
186            Some(idx) => {
187                let expected = array
188                    .scalar_at(*idx as usize)
189                    .vortex_expect("scalar_at should succeed in conformance test");
190                let actual = result
191                    .scalar_at(i)
192                    .vortex_expect("scalar_at should succeed in conformance test");
193                assert_eq!(expected, actual);
194            }
195            None => {
196                assert!(
197                    result
198                        .scalar_at(i)
199                        .vortex_expect("scalar_at should succeed in conformance test")
200                        .is_null()
201                );
202            }
203        }
204    }
205}
206
207fn test_take_repeated_indices(array: &ArrayRef) {
208    if array.is_empty() {
209        return;
210    }
211
212    // Take the first element multiple times
213    let indices = buffer![0u64, 0, 0].into_array();
214    let result = array
215        .take(indices.to_array())
216        .vortex_expect("take should succeed in conformance test");
217
218    assert_eq!(result.len(), 3);
219    let first_elem = array
220        .scalar_at(0)
221        .vortex_expect("scalar_at should succeed in conformance test");
222    for i in 0..3 {
223        assert_eq!(
224            result
225                .scalar_at(i)
226                .vortex_expect("scalar_at should succeed in conformance test"),
227            first_elem
228        );
229    }
230}
231
232fn test_empty_indices(array: &ArrayRef) {
233    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
234    let result = array
235        .take(indices.to_array())
236        .vortex_expect("take should succeed in conformance test");
237
238    assert_eq!(result.len(), 0);
239    assert_eq!(result.dtype(), array.dtype());
240}
241
242fn test_take_reverse(array: &ArrayRef) {
243    let len = array.len();
244    // Take elements in reverse order
245    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
246    let result = array
247        .take(indices.to_array())
248        .vortex_expect("take should succeed in conformance test");
249
250    assert_eq!(result.len(), len);
251
252    // Verify elements are in reverse order
253    for i in 0..len {
254        assert_eq!(
255            array
256                .scalar_at(len - 1 - i)
257                .vortex_expect("scalar_at should succeed in conformance test"),
258            result
259                .scalar_at(i)
260                .vortex_expect("scalar_at should succeed in conformance test")
261        );
262    }
263}
264
265fn test_take_single_middle(array: &ArrayRef) {
266    let len = array.len();
267    let middle_idx = len / 2;
268
269    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
270    let result = array
271        .take(indices.to_array())
272        .vortex_expect("take should succeed in conformance test");
273
274    assert_eq!(result.len(), 1);
275    assert_eq!(
276        array
277            .scalar_at(middle_idx)
278            .vortex_expect("scalar_at should succeed in conformance test"),
279        result
280            .scalar_at(0)
281            .vortex_expect("scalar_at should succeed in conformance test")
282    );
283}
284
285#[allow(clippy::cast_possible_truncation)]
286fn test_take_random_unsorted(array: &ArrayRef) {
287    let len = array.len();
288
289    // Create a pseudo-random but deterministic pattern
290    let mut indices = Vec::new();
291    let mut idx = 1u64;
292    for _ in 0..len.min(10) {
293        indices.push((idx * 7 + 3) % len as u64);
294        idx = (idx * 3 + 1) % len as u64;
295    }
296
297    let indices_array = PrimitiveArray::from_iter(indices.clone());
298    let result = array
299        .take(indices_array.to_array())
300        .vortex_expect("take should succeed in conformance test");
301
302    assert_eq!(result.len(), indices.len());
303
304    // Verify elements match
305    for (i, &idx) in indices.iter().enumerate() {
306        assert_eq!(
307            array
308                .scalar_at(idx as usize)
309                .vortex_expect("scalar_at should succeed in conformance test"),
310            result
311                .scalar_at(i)
312                .vortex_expect("scalar_at should succeed in conformance test")
313        );
314    }
315}
316
317fn test_take_contiguous_range(array: &ArrayRef) {
318    let len = array.len();
319    let start = len / 4;
320    let end = len / 2;
321
322    // Take a contiguous range from the middle
323    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
324    let result = array
325        .take(indices.to_array())
326        .vortex_expect("take should succeed in conformance test");
327
328    assert_eq!(result.len(), end - start);
329
330    // Verify elements
331    for i in 0..(end - start) {
332        assert_eq!(
333            array
334                .scalar_at(start + i)
335                .vortex_expect("scalar_at should succeed in conformance test"),
336            result
337                .scalar_at(i)
338                .vortex_expect("scalar_at should succeed in conformance test")
339        );
340    }
341}
342
343#[allow(clippy::cast_possible_truncation)]
344fn test_take_mixed_repeated(array: &ArrayRef) {
345    let len = array.len();
346
347    // Create pattern with some repeated indices
348    let indices = vec![
349        0u64,
350        0,
351        1,
352        1,
353        len as u64 / 2,
354        len as u64 / 2,
355        len as u64 / 2,
356        (len - 1) as u64,
357    ];
358
359    let indices_array = PrimitiveArray::from_iter(indices.clone());
360    let result = array
361        .take(indices_array.to_array())
362        .vortex_expect("take should succeed in conformance test");
363
364    assert_eq!(result.len(), indices.len());
365
366    // Verify elements
367    for (i, &idx) in indices.iter().enumerate() {
368        assert_eq!(
369            array
370                .scalar_at(idx as usize)
371                .vortex_expect("scalar_at should succeed in conformance test"),
372            result
373                .scalar_at(i)
374                .vortex_expect("scalar_at should succeed in conformance test")
375        );
376    }
377}
378
379#[allow(clippy::cast_possible_truncation)]
380fn test_take_large_indices(array: &ArrayRef) {
381    // Test with a large number of indices to stress test performance
382    let len = array.len();
383    let num_indices = 10000.min(len * 3);
384
385    // Create many indices with a pattern
386    let indices: Vec<u64> = (0..num_indices)
387        .map(|i| ((i * 17 + 5) % len) as u64)
388        .collect();
389
390    let indices_array = PrimitiveArray::from_iter(indices.clone());
391    let result = array
392        .take(indices_array.to_array())
393        .vortex_expect("take should succeed in conformance test");
394
395    assert_eq!(result.len(), num_indices);
396
397    // Spot check a few elements
398    for i in (0..num_indices).step_by(1000) {
399        let expected_idx = indices[i] as usize;
400        assert_eq!(
401            array
402                .scalar_at(expected_idx)
403                .vortex_expect("scalar_at should succeed in conformance test"),
404            result
405                .scalar_at(i)
406                .vortex_expect("scalar_at should succeed in conformance test")
407        );
408    }
409}