vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexUnwrap;
7
8use crate::arrays::PrimitiveArray;
9use crate::compute::take;
10use crate::{Array, Canonical, IntoArray as _};
11
12/// Test conformance of the take compute function for an array.
13///
14/// This function tests various scenarios including:
15/// - Taking all elements
16/// - Taking no elements
17/// - Taking selective elements
18/// - Taking with out-of-bounds indices (should panic)
19/// - Taking with nullable indices
20/// - Edge cases like empty arrays
21pub fn test_take_conformance(array: &dyn Array) {
22    let len = array.len();
23
24    if len > 0 {
25        test_take_all(array);
26        test_take_none(array);
27        test_take_selective(array);
28        test_take_first_and_last(array);
29        test_take_with_nullable_indices(array);
30        test_take_repeated_indices(array);
31    }
32
33    test_empty_indices(array);
34
35    // Additional edge cases for non-empty arrays
36    if len > 0 {
37        test_take_reverse(array);
38        test_take_single_middle(array);
39    }
40
41    if len > 3 {
42        test_take_random_unsorted(array);
43        test_take_contiguous_range(array);
44        test_take_mixed_repeated(array);
45    }
46
47    // Test for larger arrays
48    if len >= 1024 {
49        test_take_large_indices(array);
50    }
51}
52
53fn test_take_all(array: &dyn Array) {
54    let len = array.len();
55    let indices = PrimitiveArray::from_iter(0..len as u64);
56    let result = take(array, indices.as_ref()).vortex_unwrap();
57
58    assert_eq!(result.len(), len);
59    assert_eq!(result.dtype(), array.dtype());
60
61    // Verify elements match
62    match (&array.to_canonical(), &result.to_canonical()) {
63        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
64            assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer());
65        }
66        _ => {
67            // For non-primitive types, check scalar values
68            for i in 0..len {
69                assert_eq!(array.scalar_at(i), result.scalar_at(i));
70            }
71        }
72    }
73}
74
75fn test_take_none(array: &dyn Array) {
76    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
77    let result = take(array, indices.as_ref()).vortex_unwrap();
78
79    assert_eq!(result.len(), 0);
80    assert_eq!(result.dtype(), array.dtype());
81}
82
83#[allow(clippy::cast_possible_truncation)]
84fn test_take_selective(array: &dyn Array) {
85    let len = array.len();
86
87    // Take every other element
88    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
89    let expected_len = indices.len();
90    let indices_array = PrimitiveArray::from_iter(indices.clone());
91
92    let result = take(array, indices_array.as_ref()).vortex_unwrap();
93    assert_eq!(result.len(), expected_len);
94
95    // Verify the taken elements
96    for (result_idx, &original_idx) in indices.iter().enumerate() {
97        assert_eq!(
98            array.scalar_at(original_idx as usize),
99            result.scalar_at(result_idx)
100        );
101    }
102}
103
104fn test_take_first_and_last(array: &dyn Array) {
105    let len = array.len();
106    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
107    let result = take(array, indices.as_ref()).vortex_unwrap();
108
109    assert_eq!(result.len(), 2);
110    assert_eq!(array.scalar_at(0), result.scalar_at(0));
111    assert_eq!(array.scalar_at(len - 1), result.scalar_at(1));
112}
113
114#[allow(clippy::cast_possible_truncation)]
115fn test_take_with_nullable_indices(array: &dyn Array) {
116    let len = array.len();
117
118    // Create indices with some null values
119    let indices_vec: Vec<Option<u64>> = if len >= 3 {
120        vec![Some(0), None, Some((len - 1) as u64)]
121    } else if len >= 2 {
122        vec![Some(0), None]
123    } else {
124        vec![None]
125    };
126
127    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
128    let result = take(array, indices.as_ref()).vortex_unwrap();
129
130    assert_eq!(result.len(), indices_vec.len());
131    assert_eq!(
132        result.dtype(),
133        &array.dtype().with_nullability(Nullability::Nullable)
134    );
135
136    // Verify values
137    for (i, idx_opt) in indices_vec.iter().enumerate() {
138        match idx_opt {
139            Some(idx) => {
140                let expected = array.scalar_at(*idx as usize);
141                let actual = result.scalar_at(i);
142                assert_eq!(expected, actual);
143            }
144            None => {
145                assert!(result.scalar_at(i).is_null());
146            }
147        }
148    }
149}
150
151fn test_take_repeated_indices(array: &dyn Array) {
152    if array.is_empty() {
153        return;
154    }
155
156    // Take the first element multiple times
157    let indices = buffer![0u64, 0, 0].into_array();
158    let result = take(array, indices.as_ref()).vortex_unwrap();
159
160    assert_eq!(result.len(), 3);
161    let first_elem = array.scalar_at(0);
162    for i in 0..3 {
163        assert_eq!(result.scalar_at(i), first_elem);
164    }
165}
166
167fn test_empty_indices(array: &dyn Array) {
168    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
169    let result = take(array, indices.as_ref()).vortex_unwrap();
170
171    assert_eq!(result.len(), 0);
172    assert_eq!(result.dtype(), array.dtype());
173}
174
175fn test_take_reverse(array: &dyn Array) {
176    let len = array.len();
177    // Take elements in reverse order
178    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
179    let result = take(array, indices.as_ref()).vortex_unwrap();
180
181    assert_eq!(result.len(), len);
182
183    // Verify elements are in reverse order
184    for i in 0..len {
185        assert_eq!(array.scalar_at(len - 1 - i), result.scalar_at(i));
186    }
187}
188
189fn test_take_single_middle(array: &dyn Array) {
190    let len = array.len();
191    let middle_idx = len / 2;
192
193    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
194    let result = take(array, indices.as_ref()).vortex_unwrap();
195
196    assert_eq!(result.len(), 1);
197    assert_eq!(array.scalar_at(middle_idx), result.scalar_at(0));
198}
199
200#[allow(clippy::cast_possible_truncation)]
201fn test_take_random_unsorted(array: &dyn Array) {
202    let len = array.len();
203
204    // Create a pseudo-random but deterministic pattern
205    let mut indices = Vec::new();
206    let mut idx = 1u64;
207    for _ in 0..len.min(10) {
208        indices.push((idx * 7 + 3) % len as u64);
209        idx = (idx * 3 + 1) % len as u64;
210    }
211
212    let indices_array = PrimitiveArray::from_iter(indices.clone());
213    let result = take(array, indices_array.as_ref()).vortex_unwrap();
214
215    assert_eq!(result.len(), indices.len());
216
217    // Verify elements match
218    for (i, &idx) in indices.iter().enumerate() {
219        assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
220    }
221}
222
223fn test_take_contiguous_range(array: &dyn Array) {
224    let len = array.len();
225    let start = len / 4;
226    let end = len / 2;
227
228    // Take a contiguous range from the middle
229    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
230    let result = take(array, indices.as_ref()).vortex_unwrap();
231
232    assert_eq!(result.len(), end - start);
233
234    // Verify elements
235    for i in 0..(end - start) {
236        assert_eq!(array.scalar_at(start + i), result.scalar_at(i));
237    }
238}
239
240#[allow(clippy::cast_possible_truncation)]
241fn test_take_mixed_repeated(array: &dyn Array) {
242    let len = array.len();
243
244    // Create pattern with some repeated indices
245    let indices = vec![
246        0u64,
247        0,
248        1,
249        1,
250        len as u64 / 2,
251        len as u64 / 2,
252        len as u64 / 2,
253        (len - 1) as u64,
254    ];
255
256    let indices_array = PrimitiveArray::from_iter(indices.clone());
257    let result = take(array, indices_array.as_ref()).vortex_unwrap();
258
259    assert_eq!(result.len(), indices.len());
260
261    // Verify elements
262    for (i, &idx) in indices.iter().enumerate() {
263        assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
264    }
265}
266
267#[allow(clippy::cast_possible_truncation)]
268fn test_take_large_indices(array: &dyn Array) {
269    // Test with a large number of indices to stress test performance
270    let len = array.len();
271    let num_indices = 10000.min(len * 3);
272
273    // Create many indices with a pattern
274    let indices: Vec<u64> = (0..num_indices)
275        .map(|i| ((i * 17 + 5) % len) as u64)
276        .collect();
277
278    let indices_array = PrimitiveArray::from_iter(indices.clone());
279    let result = take(array, indices_array.as_ref()).vortex_unwrap();
280
281    assert_eq!(result.len(), num_indices);
282
283    // Spot check a few elements
284    for i in (0..num_indices).step_by(1000) {
285        let expected_idx = indices[i] as usize;
286        assert_eq!(array.scalar_at(expected_idx), result.scalar_at(i));
287    }
288}