vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexExpect;
7
8use crate::Array;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12use crate::compute::take;
13
14/// Test conformance of the take compute function for an array.
15///
16/// This function tests various scenarios including:
17/// - Taking all elements
18/// - Taking no elements
19/// - Taking selective elements
20/// - Taking with out-of-bounds indices (should panic)
21/// - Taking with nullable indices
22/// - Edge cases like empty arrays
23pub fn test_take_conformance(array: &dyn Array) {
24    let len = array.len();
25
26    if len > 0 {
27        test_take_all(array);
28        test_take_none(array);
29        test_take_selective(array);
30        test_take_first_and_last(array);
31        test_take_with_nullable_indices(array);
32        test_take_repeated_indices(array);
33    }
34
35    test_empty_indices(array);
36
37    // Additional edge cases for non-empty arrays
38    if len > 0 {
39        test_take_reverse(array);
40        test_take_single_middle(array);
41    }
42
43    if len > 3 {
44        test_take_random_unsorted(array);
45        test_take_contiguous_range(array);
46        test_take_mixed_repeated(array);
47    }
48
49    // Test for larger arrays
50    if len >= 1024 {
51        test_take_large_indices(array);
52    }
53}
54
55fn test_take_all(array: &dyn Array) {
56    let len = array.len();
57    let indices = PrimitiveArray::from_iter(0..len as u64);
58    let result =
59        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
60
61    assert_eq!(result.len(), len);
62    assert_eq!(result.dtype(), array.dtype());
63
64    // Verify elements match
65    match (&array.to_canonical(), &result.to_canonical()) {
66        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
67            assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer());
68        }
69        _ => {
70            // For non-primitive types, check scalar values
71            for i in 0..len {
72                assert_eq!(array.scalar_at(i), result.scalar_at(i));
73            }
74        }
75    }
76}
77
78fn test_take_none(array: &dyn Array) {
79    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
80    let result =
81        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
82
83    assert_eq!(result.len(), 0);
84    assert_eq!(result.dtype(), array.dtype());
85}
86
87#[allow(clippy::cast_possible_truncation)]
88fn test_take_selective(array: &dyn Array) {
89    let len = array.len();
90
91    // Take every other element
92    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
93    let expected_len = indices.len();
94    let indices_array = PrimitiveArray::from_iter(indices.clone());
95
96    let result = take(array, indices_array.as_ref())
97        .vortex_expect("take should succeed in conformance test");
98    assert_eq!(result.len(), expected_len);
99
100    // Verify the taken elements
101    for (result_idx, &original_idx) in indices.iter().enumerate() {
102        assert_eq!(
103            array.scalar_at(original_idx as usize),
104            result.scalar_at(result_idx)
105        );
106    }
107}
108
109fn test_take_first_and_last(array: &dyn Array) {
110    let len = array.len();
111    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
112    let result =
113        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
114
115    assert_eq!(result.len(), 2);
116    assert_eq!(array.scalar_at(0), result.scalar_at(0));
117    assert_eq!(array.scalar_at(len - 1), result.scalar_at(1));
118}
119
120#[allow(clippy::cast_possible_truncation)]
121fn test_take_with_nullable_indices(array: &dyn Array) {
122    let len = array.len();
123
124    // Create indices with some null values
125    let indices_vec: Vec<Option<u64>> = if len >= 3 {
126        vec![Some(0), None, Some((len - 1) as u64)]
127    } else if len >= 2 {
128        vec![Some(0), None]
129    } else {
130        vec![None]
131    };
132
133    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
134    let result =
135        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
136
137    assert_eq!(result.len(), indices_vec.len());
138    assert_eq!(
139        result.dtype(),
140        &array.dtype().with_nullability(Nullability::Nullable)
141    );
142
143    // Verify values
144    for (i, idx_opt) in indices_vec.iter().enumerate() {
145        match idx_opt {
146            Some(idx) => {
147                let expected = array.scalar_at(*idx as usize);
148                let actual = result.scalar_at(i);
149                assert_eq!(expected, actual);
150            }
151            None => {
152                assert!(result.scalar_at(i).is_null());
153            }
154        }
155    }
156}
157
158fn test_take_repeated_indices(array: &dyn Array) {
159    if array.is_empty() {
160        return;
161    }
162
163    // Take the first element multiple times
164    let indices = buffer![0u64, 0, 0].into_array();
165    let result =
166        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
167
168    assert_eq!(result.len(), 3);
169    let first_elem = array.scalar_at(0);
170    for i in 0..3 {
171        assert_eq!(result.scalar_at(i), first_elem);
172    }
173}
174
175fn test_empty_indices(array: &dyn Array) {
176    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
177    let result =
178        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
179
180    assert_eq!(result.len(), 0);
181    assert_eq!(result.dtype(), array.dtype());
182}
183
184fn test_take_reverse(array: &dyn Array) {
185    let len = array.len();
186    // Take elements in reverse order
187    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
188    let result =
189        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
190
191    assert_eq!(result.len(), len);
192
193    // Verify elements are in reverse order
194    for i in 0..len {
195        assert_eq!(array.scalar_at(len - 1 - i), result.scalar_at(i));
196    }
197}
198
199fn test_take_single_middle(array: &dyn Array) {
200    let len = array.len();
201    let middle_idx = len / 2;
202
203    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
204    let result =
205        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
206
207    assert_eq!(result.len(), 1);
208    assert_eq!(array.scalar_at(middle_idx), result.scalar_at(0));
209}
210
211#[allow(clippy::cast_possible_truncation)]
212fn test_take_random_unsorted(array: &dyn Array) {
213    let len = array.len();
214
215    // Create a pseudo-random but deterministic pattern
216    let mut indices = Vec::new();
217    let mut idx = 1u64;
218    for _ in 0..len.min(10) {
219        indices.push((idx * 7 + 3) % len as u64);
220        idx = (idx * 3 + 1) % len as u64;
221    }
222
223    let indices_array = PrimitiveArray::from_iter(indices.clone());
224    let result = take(array, indices_array.as_ref())
225        .vortex_expect("take should succeed in conformance test");
226
227    assert_eq!(result.len(), indices.len());
228
229    // Verify elements match
230    for (i, &idx) in indices.iter().enumerate() {
231        assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
232    }
233}
234
235fn test_take_contiguous_range(array: &dyn Array) {
236    let len = array.len();
237    let start = len / 4;
238    let end = len / 2;
239
240    // Take a contiguous range from the middle
241    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
242    let result =
243        take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
244
245    assert_eq!(result.len(), end - start);
246
247    // Verify elements
248    for i in 0..(end - start) {
249        assert_eq!(array.scalar_at(start + i), result.scalar_at(i));
250    }
251}
252
253#[allow(clippy::cast_possible_truncation)]
254fn test_take_mixed_repeated(array: &dyn Array) {
255    let len = array.len();
256
257    // Create pattern with some repeated indices
258    let indices = vec![
259        0u64,
260        0,
261        1,
262        1,
263        len as u64 / 2,
264        len as u64 / 2,
265        len as u64 / 2,
266        (len - 1) as u64,
267    ];
268
269    let indices_array = PrimitiveArray::from_iter(indices.clone());
270    let result = take(array, indices_array.as_ref())
271        .vortex_expect("take should succeed in conformance test");
272
273    assert_eq!(result.len(), indices.len());
274
275    // Verify elements
276    for (i, &idx) in indices.iter().enumerate() {
277        assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
278    }
279}
280
281#[allow(clippy::cast_possible_truncation)]
282fn test_take_large_indices(array: &dyn Array) {
283    // Test with a large number of indices to stress test performance
284    let len = array.len();
285    let num_indices = 10000.min(len * 3);
286
287    // Create many indices with a pattern
288    let indices: Vec<u64> = (0..num_indices)
289        .map(|i| ((i * 17 + 5) % len) as u64)
290        .collect();
291
292    let indices_array = PrimitiveArray::from_iter(indices.clone());
293    let result = take(array, indices_array.as_ref())
294        .vortex_expect("take should succeed in conformance test");
295
296    assert_eq!(result.len(), num_indices);
297
298    // Spot check a few elements
299    for i in (0..num_indices).step_by(1000) {
300        let expected_idx = indices[i] as usize;
301        assert_eq!(array.scalar_at(expected_idx), result.scalar_at(i));
302    }
303}