vortex_array/compute/conformance/
take.rs1use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexUnwrap;
7
8use crate::arrays::PrimitiveArray;
9use crate::compute::take;
10use crate::{Array, Canonical, IntoArray as _};
11
12pub fn test_take_conformance(array: &dyn Array) {
22 let len = array.len();
23
24 if len > 0 {
25 test_take_all(array);
26 test_take_none(array);
27 test_take_selective(array);
28 test_take_first_and_last(array);
29 test_take_with_nullable_indices(array);
30 test_take_repeated_indices(array);
31 }
32
33 test_empty_indices(array);
34
35 if len > 0 {
37 test_take_reverse(array);
38 test_take_single_middle(array);
39 }
40
41 if len > 3 {
42 test_take_random_unsorted(array);
43 test_take_contiguous_range(array);
44 test_take_mixed_repeated(array);
45 }
46
47 if len >= 1024 {
49 test_take_large_indices(array);
50 }
51}
52
53fn test_take_all(array: &dyn Array) {
54 let len = array.len();
55 let indices = PrimitiveArray::from_iter(0..len as u64);
56 let result = take(array, indices.as_ref()).vortex_unwrap();
57
58 assert_eq!(result.len(), len);
59 assert_eq!(result.dtype(), array.dtype());
60
61 match (&array.to_canonical(), &result.to_canonical()) {
63 (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
64 assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer());
65 }
66 _ => {
67 for i in 0..len {
69 assert_eq!(array.scalar_at(i), result.scalar_at(i));
70 }
71 }
72 }
73}
74
75fn test_take_none(array: &dyn Array) {
76 let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
77 let result = take(array, indices.as_ref()).vortex_unwrap();
78
79 assert_eq!(result.len(), 0);
80 assert_eq!(result.dtype(), array.dtype());
81}
82
83#[allow(clippy::cast_possible_truncation)]
84fn test_take_selective(array: &dyn Array) {
85 let len = array.len();
86
87 let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
89 let expected_len = indices.len();
90 let indices_array = PrimitiveArray::from_iter(indices.clone());
91
92 let result = take(array, indices_array.as_ref()).vortex_unwrap();
93 assert_eq!(result.len(), expected_len);
94
95 for (result_idx, &original_idx) in indices.iter().enumerate() {
97 assert_eq!(
98 array.scalar_at(original_idx as usize),
99 result.scalar_at(result_idx)
100 );
101 }
102}
103
104fn test_take_first_and_last(array: &dyn Array) {
105 let len = array.len();
106 let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
107 let result = take(array, indices.as_ref()).vortex_unwrap();
108
109 assert_eq!(result.len(), 2);
110 assert_eq!(array.scalar_at(0), result.scalar_at(0));
111 assert_eq!(array.scalar_at(len - 1), result.scalar_at(1));
112}
113
114#[allow(clippy::cast_possible_truncation)]
115fn test_take_with_nullable_indices(array: &dyn Array) {
116 let len = array.len();
117
118 let indices_vec: Vec<Option<u64>> = if len >= 3 {
120 vec![Some(0), None, Some((len - 1) as u64)]
121 } else if len >= 2 {
122 vec![Some(0), None]
123 } else {
124 vec![None]
125 };
126
127 let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
128 let result = take(array, indices.as_ref()).vortex_unwrap();
129
130 assert_eq!(result.len(), indices_vec.len());
131 assert_eq!(
132 result.dtype(),
133 &array.dtype().with_nullability(Nullability::Nullable)
134 );
135
136 for (i, idx_opt) in indices_vec.iter().enumerate() {
138 match idx_opt {
139 Some(idx) => {
140 let expected = array.scalar_at(*idx as usize);
141 let actual = result.scalar_at(i);
142 assert_eq!(expected, actual);
143 }
144 None => {
145 assert!(result.scalar_at(i).is_null());
146 }
147 }
148 }
149}
150
151fn test_take_repeated_indices(array: &dyn Array) {
152 if array.is_empty() {
153 return;
154 }
155
156 let indices = buffer![0u64, 0, 0].into_array();
158 let result = take(array, indices.as_ref()).vortex_unwrap();
159
160 assert_eq!(result.len(), 3);
161 let first_elem = array.scalar_at(0);
162 for i in 0..3 {
163 assert_eq!(result.scalar_at(i), first_elem);
164 }
165}
166
167fn test_empty_indices(array: &dyn Array) {
168 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
169 let result = take(array, indices.as_ref()).vortex_unwrap();
170
171 assert_eq!(result.len(), 0);
172 assert_eq!(result.dtype(), array.dtype());
173}
174
175fn test_take_reverse(array: &dyn Array) {
176 let len = array.len();
177 let indices = PrimitiveArray::from_iter((0..len as u64).rev());
179 let result = take(array, indices.as_ref()).vortex_unwrap();
180
181 assert_eq!(result.len(), len);
182
183 for i in 0..len {
185 assert_eq!(array.scalar_at(len - 1 - i), result.scalar_at(i));
186 }
187}
188
189fn test_take_single_middle(array: &dyn Array) {
190 let len = array.len();
191 let middle_idx = len / 2;
192
193 let indices = PrimitiveArray::from_iter([middle_idx as u64]);
194 let result = take(array, indices.as_ref()).vortex_unwrap();
195
196 assert_eq!(result.len(), 1);
197 assert_eq!(array.scalar_at(middle_idx), result.scalar_at(0));
198}
199
200#[allow(clippy::cast_possible_truncation)]
201fn test_take_random_unsorted(array: &dyn Array) {
202 let len = array.len();
203
204 let mut indices = Vec::new();
206 let mut idx = 1u64;
207 for _ in 0..len.min(10) {
208 indices.push((idx * 7 + 3) % len as u64);
209 idx = (idx * 3 + 1) % len as u64;
210 }
211
212 let indices_array = PrimitiveArray::from_iter(indices.clone());
213 let result = take(array, indices_array.as_ref()).vortex_unwrap();
214
215 assert_eq!(result.len(), indices.len());
216
217 for (i, &idx) in indices.iter().enumerate() {
219 assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
220 }
221}
222
223fn test_take_contiguous_range(array: &dyn Array) {
224 let len = array.len();
225 let start = len / 4;
226 let end = len / 2;
227
228 let indices = PrimitiveArray::from_iter(start as u64..end as u64);
230 let result = take(array, indices.as_ref()).vortex_unwrap();
231
232 assert_eq!(result.len(), end - start);
233
234 for i in 0..(end - start) {
236 assert_eq!(array.scalar_at(start + i), result.scalar_at(i));
237 }
238}
239
240#[allow(clippy::cast_possible_truncation)]
241fn test_take_mixed_repeated(array: &dyn Array) {
242 let len = array.len();
243
244 let indices = vec![
246 0u64,
247 0,
248 1,
249 1,
250 len as u64 / 2,
251 len as u64 / 2,
252 len as u64 / 2,
253 (len - 1) as u64,
254 ];
255
256 let indices_array = PrimitiveArray::from_iter(indices.clone());
257 let result = take(array, indices_array.as_ref()).vortex_unwrap();
258
259 assert_eq!(result.len(), indices.len());
260
261 for (i, &idx) in indices.iter().enumerate() {
263 assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
264 }
265}
266
267#[allow(clippy::cast_possible_truncation)]
268fn test_take_large_indices(array: &dyn Array) {
269 let len = array.len();
271 let num_indices = 10000.min(len * 3);
272
273 let indices: Vec<u64> = (0..num_indices)
275 .map(|i| ((i * 17 + 5) % len) as u64)
276 .collect();
277
278 let indices_array = PrimitiveArray::from_iter(indices.clone());
279 let result = take(array, indices_array.as_ref()).vortex_unwrap();
280
281 assert_eq!(result.len(), num_indices);
282
283 for i in (0..num_indices).step_by(1000) {
285 let expected_idx = indices[i] as usize;
286 assert_eq!(array.scalar_at(expected_idx), result.scalar_at(i));
287 }
288}