vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexUnwrap;
7
8use crate::Array;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12use crate::compute::take;
13
14/// Test conformance of the take compute function for an array.
15///
16/// This function tests various scenarios including:
17/// - Taking all elements
18/// - Taking no elements
19/// - Taking selective elements
20/// - Taking with out-of-bounds indices (should panic)
21/// - Taking with nullable indices
22/// - Edge cases like empty arrays
23pub fn test_take_conformance(array: &dyn Array) {
24    let len = array.len();
25
26    if len > 0 {
27        test_take_all(array);
28        test_take_none(array);
29        test_take_selective(array);
30        test_take_first_and_last(array);
31        test_take_with_nullable_indices(array);
32        test_take_repeated_indices(array);
33    }
34
35    test_empty_indices(array);
36
37    // Additional edge cases for non-empty arrays
38    if len > 0 {
39        test_take_reverse(array);
40        test_take_single_middle(array);
41    }
42
43    if len > 3 {
44        test_take_random_unsorted(array);
45        test_take_contiguous_range(array);
46        test_take_mixed_repeated(array);
47    }
48
49    // Test for larger arrays
50    if len >= 1024 {
51        test_take_large_indices(array);
52    }
53}
54
55fn test_take_all(array: &dyn Array) {
56    let len = array.len();
57    let indices = PrimitiveArray::from_iter(0..len as u64);
58    let result = take(array, indices.as_ref()).vortex_unwrap();
59
60    assert_eq!(result.len(), len);
61    assert_eq!(result.dtype(), array.dtype());
62
63    // Verify elements match
64    match (&array.to_canonical(), &result.to_canonical()) {
65        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
66            assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer());
67        }
68        _ => {
69            // For non-primitive types, check scalar values
70            for i in 0..len {
71                assert_eq!(array.scalar_at(i), result.scalar_at(i));
72            }
73        }
74    }
75}
76
77fn test_take_none(array: &dyn Array) {
78    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
79    let result = take(array, indices.as_ref()).vortex_unwrap();
80
81    assert_eq!(result.len(), 0);
82    assert_eq!(result.dtype(), array.dtype());
83}
84
85#[allow(clippy::cast_possible_truncation)]
86fn test_take_selective(array: &dyn Array) {
87    let len = array.len();
88
89    // Take every other element
90    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
91    let expected_len = indices.len();
92    let indices_array = PrimitiveArray::from_iter(indices.clone());
93
94    let result = take(array, indices_array.as_ref()).vortex_unwrap();
95    assert_eq!(result.len(), expected_len);
96
97    // Verify the taken elements
98    for (result_idx, &original_idx) in indices.iter().enumerate() {
99        assert_eq!(
100            array.scalar_at(original_idx as usize),
101            result.scalar_at(result_idx)
102        );
103    }
104}
105
106fn test_take_first_and_last(array: &dyn Array) {
107    let len = array.len();
108    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
109    let result = take(array, indices.as_ref()).vortex_unwrap();
110
111    assert_eq!(result.len(), 2);
112    assert_eq!(array.scalar_at(0), result.scalar_at(0));
113    assert_eq!(array.scalar_at(len - 1), result.scalar_at(1));
114}
115
116#[allow(clippy::cast_possible_truncation)]
117fn test_take_with_nullable_indices(array: &dyn Array) {
118    let len = array.len();
119
120    // Create indices with some null values
121    let indices_vec: Vec<Option<u64>> = if len >= 3 {
122        vec![Some(0), None, Some((len - 1) as u64)]
123    } else if len >= 2 {
124        vec![Some(0), None]
125    } else {
126        vec![None]
127    };
128
129    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
130    let result = take(array, indices.as_ref()).vortex_unwrap();
131
132    assert_eq!(result.len(), indices_vec.len());
133    assert_eq!(
134        result.dtype(),
135        &array.dtype().with_nullability(Nullability::Nullable)
136    );
137
138    // Verify values
139    for (i, idx_opt) in indices_vec.iter().enumerate() {
140        match idx_opt {
141            Some(idx) => {
142                let expected = array.scalar_at(*idx as usize);
143                let actual = result.scalar_at(i);
144                assert_eq!(expected, actual);
145            }
146            None => {
147                assert!(result.scalar_at(i).is_null());
148            }
149        }
150    }
151}
152
153fn test_take_repeated_indices(array: &dyn Array) {
154    if array.is_empty() {
155        return;
156    }
157
158    // Take the first element multiple times
159    let indices = buffer![0u64, 0, 0].into_array();
160    let result = take(array, indices.as_ref()).vortex_unwrap();
161
162    assert_eq!(result.len(), 3);
163    let first_elem = array.scalar_at(0);
164    for i in 0..3 {
165        assert_eq!(result.scalar_at(i), first_elem);
166    }
167}
168
169fn test_empty_indices(array: &dyn Array) {
170    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
171    let result = take(array, indices.as_ref()).vortex_unwrap();
172
173    assert_eq!(result.len(), 0);
174    assert_eq!(result.dtype(), array.dtype());
175}
176
177fn test_take_reverse(array: &dyn Array) {
178    let len = array.len();
179    // Take elements in reverse order
180    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
181    let result = take(array, indices.as_ref()).vortex_unwrap();
182
183    assert_eq!(result.len(), len);
184
185    // Verify elements are in reverse order
186    for i in 0..len {
187        assert_eq!(array.scalar_at(len - 1 - i), result.scalar_at(i));
188    }
189}
190
191fn test_take_single_middle(array: &dyn Array) {
192    let len = array.len();
193    let middle_idx = len / 2;
194
195    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
196    let result = take(array, indices.as_ref()).vortex_unwrap();
197
198    assert_eq!(result.len(), 1);
199    assert_eq!(array.scalar_at(middle_idx), result.scalar_at(0));
200}
201
202#[allow(clippy::cast_possible_truncation)]
203fn test_take_random_unsorted(array: &dyn Array) {
204    let len = array.len();
205
206    // Create a pseudo-random but deterministic pattern
207    let mut indices = Vec::new();
208    let mut idx = 1u64;
209    for _ in 0..len.min(10) {
210        indices.push((idx * 7 + 3) % len as u64);
211        idx = (idx * 3 + 1) % len as u64;
212    }
213
214    let indices_array = PrimitiveArray::from_iter(indices.clone());
215    let result = take(array, indices_array.as_ref()).vortex_unwrap();
216
217    assert_eq!(result.len(), indices.len());
218
219    // Verify elements match
220    for (i, &idx) in indices.iter().enumerate() {
221        assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
222    }
223}
224
225fn test_take_contiguous_range(array: &dyn Array) {
226    let len = array.len();
227    let start = len / 4;
228    let end = len / 2;
229
230    // Take a contiguous range from the middle
231    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
232    let result = take(array, indices.as_ref()).vortex_unwrap();
233
234    assert_eq!(result.len(), end - start);
235
236    // Verify elements
237    for i in 0..(end - start) {
238        assert_eq!(array.scalar_at(start + i), result.scalar_at(i));
239    }
240}
241
242#[allow(clippy::cast_possible_truncation)]
243fn test_take_mixed_repeated(array: &dyn Array) {
244    let len = array.len();
245
246    // Create pattern with some repeated indices
247    let indices = vec![
248        0u64,
249        0,
250        1,
251        1,
252        len as u64 / 2,
253        len as u64 / 2,
254        len as u64 / 2,
255        (len - 1) as u64,
256    ];
257
258    let indices_array = PrimitiveArray::from_iter(indices.clone());
259    let result = take(array, indices_array.as_ref()).vortex_unwrap();
260
261    assert_eq!(result.len(), indices.len());
262
263    // Verify elements
264    for (i, &idx) in indices.iter().enumerate() {
265        assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
266    }
267}
268
269#[allow(clippy::cast_possible_truncation)]
270fn test_take_large_indices(array: &dyn Array) {
271    // Test with a large number of indices to stress test performance
272    let len = array.len();
273    let num_indices = 10000.min(len * 3);
274
275    // Create many indices with a pattern
276    let indices: Vec<u64> = (0..num_indices)
277        .map(|i| ((i * 17 + 5) % len) as u64)
278        .collect();
279
280    let indices_array = PrimitiveArray::from_iter(indices.clone());
281    let result = take(array, indices_array.as_ref()).vortex_unwrap();
282
283    assert_eq!(result.len(), num_indices);
284
285    // Spot check a few elements
286    for i in (0..num_indices).step_by(1000) {
287        let expected_idx = indices[i] as usize;
288        assert_eq!(array.scalar_at(expected_idx), result.scalar_at(i));
289    }
290}