Skip to main content

vortex_array/compute/conformance/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexExpect;
7
8use crate::Array;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12
13/// Test conformance of the take compute function for an array.
14///
15/// This function tests various scenarios including:
16/// - Taking all elements
17/// - Taking no elements
18/// - Taking selective elements
19/// - Taking with out-of-bounds indices (should panic)
20/// - Taking with nullable indices
21/// - Edge cases like empty arrays
22pub fn test_take_conformance(array: &dyn Array) {
23    let len = array.len();
24
25    if len > 0 {
26        test_take_all(array);
27        test_take_none(array);
28        test_take_selective(array);
29        test_take_first_and_last(array);
30        test_take_with_nullable_indices(array);
31        test_take_repeated_indices(array);
32    }
33
34    test_empty_indices(array);
35
36    // Additional edge cases for non-empty arrays
37    if len > 0 {
38        test_take_reverse(array);
39        test_take_single_middle(array);
40    }
41
42    if len > 3 {
43        test_take_random_unsorted(array);
44        test_take_contiguous_range(array);
45        test_take_mixed_repeated(array);
46    }
47
48    // Test for larger arrays
49    if len >= 1024 {
50        test_take_large_indices(array);
51    }
52}
53
54fn test_take_all(array: &dyn Array) {
55    let len = array.len();
56    let indices = PrimitiveArray::from_iter(0..len as u64);
57    let result = array
58        .take(indices.to_array())
59        .vortex_expect("take should succeed in conformance test");
60
61    assert_eq!(result.len(), len);
62    assert_eq!(result.dtype(), array.dtype());
63
64    // Verify elements match
65    match (
66        array
67            .to_canonical()
68            .vortex_expect("to_canonical failed on array"),
69        result
70            .to_canonical()
71            .vortex_expect("to_canonical failed on result"),
72    ) {
73        (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
74            assert_eq!(
75                orig_prim.buffer_handle().to_host_sync(),
76                result_prim.buffer_handle().to_host_sync()
77            );
78        }
79        _ => {
80            // For non-primitive types, check scalar values
81            for i in 0..len {
82                assert_eq!(
83                    array
84                        .scalar_at(i)
85                        .vortex_expect("scalar_at should succeed in conformance test"),
86                    result
87                        .scalar_at(i)
88                        .vortex_expect("scalar_at should succeed in conformance test")
89                );
90            }
91        }
92    }
93}
94
95fn test_take_none(array: &dyn Array) {
96    let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
97    let result = array
98        .take(indices.to_array())
99        .vortex_expect("take should succeed in conformance test");
100
101    assert_eq!(result.len(), 0);
102    assert_eq!(result.dtype(), array.dtype());
103}
104
105#[allow(clippy::cast_possible_truncation)]
106fn test_take_selective(array: &dyn Array) {
107    let len = array.len();
108
109    // Take every other element
110    let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
111    let expected_len = indices.len();
112    let indices_array = PrimitiveArray::from_iter(indices.clone());
113
114    let result = array
115        .take(indices_array.to_array())
116        .vortex_expect("take should succeed in conformance test");
117    assert_eq!(result.len(), expected_len);
118
119    // Verify the taken elements
120    for (result_idx, &original_idx) in indices.iter().enumerate() {
121        assert_eq!(
122            array
123                .scalar_at(original_idx as usize)
124                .vortex_expect("scalar_at should succeed in conformance test"),
125            result
126                .scalar_at(result_idx)
127                .vortex_expect("scalar_at should succeed in conformance test")
128        );
129    }
130}
131
132fn test_take_first_and_last(array: &dyn Array) {
133    let len = array.len();
134    let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
135    let result = array
136        .take(indices.to_array())
137        .vortex_expect("take should succeed in conformance test");
138
139    assert_eq!(result.len(), 2);
140    assert_eq!(
141        array
142            .scalar_at(0)
143            .vortex_expect("scalar_at should succeed in conformance test"),
144        result
145            .scalar_at(0)
146            .vortex_expect("scalar_at should succeed in conformance test")
147    );
148    assert_eq!(
149        array
150            .scalar_at(len - 1)
151            .vortex_expect("scalar_at should succeed in conformance test"),
152        result
153            .scalar_at(1)
154            .vortex_expect("scalar_at should succeed in conformance test")
155    );
156}
157
158#[allow(clippy::cast_possible_truncation)]
159fn test_take_with_nullable_indices(array: &dyn Array) {
160    let len = array.len();
161
162    // Create indices with some null values
163    let indices_vec: Vec<Option<u64>> = if len >= 3 {
164        vec![Some(0), None, Some((len - 1) as u64)]
165    } else if len >= 2 {
166        vec![Some(0), None]
167    } else {
168        vec![None]
169    };
170
171    let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
172    let result = array
173        .take(indices.to_array())
174        .vortex_expect("take should succeed in conformance test");
175
176    assert_eq!(result.len(), indices_vec.len());
177    assert_eq!(
178        result.dtype(),
179        &array.dtype().with_nullability(Nullability::Nullable)
180    );
181
182    // Verify values
183    for (i, idx_opt) in indices_vec.iter().enumerate() {
184        match idx_opt {
185            Some(idx) => {
186                let expected = array
187                    .scalar_at(*idx as usize)
188                    .vortex_expect("scalar_at should succeed in conformance test");
189                let actual = result
190                    .scalar_at(i)
191                    .vortex_expect("scalar_at should succeed in conformance test");
192                assert_eq!(expected, actual);
193            }
194            None => {
195                assert!(
196                    result
197                        .scalar_at(i)
198                        .vortex_expect("scalar_at should succeed in conformance test")
199                        .is_null()
200                );
201            }
202        }
203    }
204}
205
206fn test_take_repeated_indices(array: &dyn Array) {
207    if array.is_empty() {
208        return;
209    }
210
211    // Take the first element multiple times
212    let indices = buffer![0u64, 0, 0].into_array();
213    let result = array
214        .take(indices.to_array())
215        .vortex_expect("take should succeed in conformance test");
216
217    assert_eq!(result.len(), 3);
218    let first_elem = array
219        .scalar_at(0)
220        .vortex_expect("scalar_at should succeed in conformance test");
221    for i in 0..3 {
222        assert_eq!(
223            result
224                .scalar_at(i)
225                .vortex_expect("scalar_at should succeed in conformance test"),
226            first_elem
227        );
228    }
229}
230
231fn test_empty_indices(array: &dyn Array) {
232    let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
233    let result = array
234        .take(indices.to_array())
235        .vortex_expect("take should succeed in conformance test");
236
237    assert_eq!(result.len(), 0);
238    assert_eq!(result.dtype(), array.dtype());
239}
240
241fn test_take_reverse(array: &dyn Array) {
242    let len = array.len();
243    // Take elements in reverse order
244    let indices = PrimitiveArray::from_iter((0..len as u64).rev());
245    let result = array
246        .take(indices.to_array())
247        .vortex_expect("take should succeed in conformance test");
248
249    assert_eq!(result.len(), len);
250
251    // Verify elements are in reverse order
252    for i in 0..len {
253        assert_eq!(
254            array
255                .scalar_at(len - 1 - i)
256                .vortex_expect("scalar_at should succeed in conformance test"),
257            result
258                .scalar_at(i)
259                .vortex_expect("scalar_at should succeed in conformance test")
260        );
261    }
262}
263
264fn test_take_single_middle(array: &dyn Array) {
265    let len = array.len();
266    let middle_idx = len / 2;
267
268    let indices = PrimitiveArray::from_iter([middle_idx as u64]);
269    let result = array
270        .take(indices.to_array())
271        .vortex_expect("take should succeed in conformance test");
272
273    assert_eq!(result.len(), 1);
274    assert_eq!(
275        array
276            .scalar_at(middle_idx)
277            .vortex_expect("scalar_at should succeed in conformance test"),
278        result
279            .scalar_at(0)
280            .vortex_expect("scalar_at should succeed in conformance test")
281    );
282}
283
284#[allow(clippy::cast_possible_truncation)]
285fn test_take_random_unsorted(array: &dyn Array) {
286    let len = array.len();
287
288    // Create a pseudo-random but deterministic pattern
289    let mut indices = Vec::new();
290    let mut idx = 1u64;
291    for _ in 0..len.min(10) {
292        indices.push((idx * 7 + 3) % len as u64);
293        idx = (idx * 3 + 1) % len as u64;
294    }
295
296    let indices_array = PrimitiveArray::from_iter(indices.clone());
297    let result = array
298        .take(indices_array.to_array())
299        .vortex_expect("take should succeed in conformance test");
300
301    assert_eq!(result.len(), indices.len());
302
303    // Verify elements match
304    for (i, &idx) in indices.iter().enumerate() {
305        assert_eq!(
306            array
307                .scalar_at(idx as usize)
308                .vortex_expect("scalar_at should succeed in conformance test"),
309            result
310                .scalar_at(i)
311                .vortex_expect("scalar_at should succeed in conformance test")
312        );
313    }
314}
315
316fn test_take_contiguous_range(array: &dyn Array) {
317    let len = array.len();
318    let start = len / 4;
319    let end = len / 2;
320
321    // Take a contiguous range from the middle
322    let indices = PrimitiveArray::from_iter(start as u64..end as u64);
323    let result = array
324        .take(indices.to_array())
325        .vortex_expect("take should succeed in conformance test");
326
327    assert_eq!(result.len(), end - start);
328
329    // Verify elements
330    for i in 0..(end - start) {
331        assert_eq!(
332            array
333                .scalar_at(start + i)
334                .vortex_expect("scalar_at should succeed in conformance test"),
335            result
336                .scalar_at(i)
337                .vortex_expect("scalar_at should succeed in conformance test")
338        );
339    }
340}
341
342#[allow(clippy::cast_possible_truncation)]
343fn test_take_mixed_repeated(array: &dyn Array) {
344    let len = array.len();
345
346    // Create pattern with some repeated indices
347    let indices = vec![
348        0u64,
349        0,
350        1,
351        1,
352        len as u64 / 2,
353        len as u64 / 2,
354        len as u64 / 2,
355        (len - 1) as u64,
356    ];
357
358    let indices_array = PrimitiveArray::from_iter(indices.clone());
359    let result = array
360        .take(indices_array.to_array())
361        .vortex_expect("take should succeed in conformance test");
362
363    assert_eq!(result.len(), indices.len());
364
365    // Verify elements
366    for (i, &idx) in indices.iter().enumerate() {
367        assert_eq!(
368            array
369                .scalar_at(idx as usize)
370                .vortex_expect("scalar_at should succeed in conformance test"),
371            result
372                .scalar_at(i)
373                .vortex_expect("scalar_at should succeed in conformance test")
374        );
375    }
376}
377
378#[allow(clippy::cast_possible_truncation)]
379fn test_take_large_indices(array: &dyn Array) {
380    // Test with a large number of indices to stress test performance
381    let len = array.len();
382    let num_indices = 10000.min(len * 3);
383
384    // Create many indices with a pattern
385    let indices: Vec<u64> = (0..num_indices)
386        .map(|i| ((i * 17 + 5) % len) as u64)
387        .collect();
388
389    let indices_array = PrimitiveArray::from_iter(indices.clone());
390    let result = array
391        .take(indices_array.to_array())
392        .vortex_expect("take should succeed in conformance test");
393
394    assert_eq!(result.len(), num_indices);
395
396    // Spot check a few elements
397    for i in (0..num_indices).step_by(1000) {
398        let expected_idx = indices[i] as usize;
399        assert_eq!(
400            array
401                .scalar_at(expected_idx)
402                .vortex_expect("scalar_at should succeed in conformance test"),
403            result
404                .scalar_at(i)
405                .vortex_expect("scalar_at should succeed in conformance test")
406        );
407    }
408}