vortex_array/compute/conformance/
take.rs1use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexExpect;
7
8use crate::Array;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12use crate::compute::take;
13
14pub fn test_take_conformance(array: &dyn Array) {
24 let len = array.len();
25
26 if len > 0 {
27 test_take_all(array);
28 test_take_none(array);
29 test_take_selective(array);
30 test_take_first_and_last(array);
31 test_take_with_nullable_indices(array);
32 test_take_repeated_indices(array);
33 }
34
35 test_empty_indices(array);
36
37 if len > 0 {
39 test_take_reverse(array);
40 test_take_single_middle(array);
41 }
42
43 if len > 3 {
44 test_take_random_unsorted(array);
45 test_take_contiguous_range(array);
46 test_take_mixed_repeated(array);
47 }
48
49 if len >= 1024 {
51 test_take_large_indices(array);
52 }
53}
54
55fn test_take_all(array: &dyn Array) {
56 let len = array.len();
57 let indices = PrimitiveArray::from_iter(0..len as u64);
58 let result =
59 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
60
61 assert_eq!(result.len(), len);
62 assert_eq!(result.dtype(), array.dtype());
63
64 match (&array.to_canonical(), &result.to_canonical()) {
66 (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
67 assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer());
68 }
69 _ => {
70 for i in 0..len {
72 assert_eq!(array.scalar_at(i), result.scalar_at(i));
73 }
74 }
75 }
76}
77
78fn test_take_none(array: &dyn Array) {
79 let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
80 let result =
81 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
82
83 assert_eq!(result.len(), 0);
84 assert_eq!(result.dtype(), array.dtype());
85}
86
87#[allow(clippy::cast_possible_truncation)]
88fn test_take_selective(array: &dyn Array) {
89 let len = array.len();
90
91 let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
93 let expected_len = indices.len();
94 let indices_array = PrimitiveArray::from_iter(indices.clone());
95
96 let result = take(array, indices_array.as_ref())
97 .vortex_expect("take should succeed in conformance test");
98 assert_eq!(result.len(), expected_len);
99
100 for (result_idx, &original_idx) in indices.iter().enumerate() {
102 assert_eq!(
103 array.scalar_at(original_idx as usize),
104 result.scalar_at(result_idx)
105 );
106 }
107}
108
109fn test_take_first_and_last(array: &dyn Array) {
110 let len = array.len();
111 let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
112 let result =
113 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
114
115 assert_eq!(result.len(), 2);
116 assert_eq!(array.scalar_at(0), result.scalar_at(0));
117 assert_eq!(array.scalar_at(len - 1), result.scalar_at(1));
118}
119
120#[allow(clippy::cast_possible_truncation)]
121fn test_take_with_nullable_indices(array: &dyn Array) {
122 let len = array.len();
123
124 let indices_vec: Vec<Option<u64>> = if len >= 3 {
126 vec![Some(0), None, Some((len - 1) as u64)]
127 } else if len >= 2 {
128 vec![Some(0), None]
129 } else {
130 vec![None]
131 };
132
133 let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
134 let result =
135 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
136
137 assert_eq!(result.len(), indices_vec.len());
138 assert_eq!(
139 result.dtype(),
140 &array.dtype().with_nullability(Nullability::Nullable)
141 );
142
143 for (i, idx_opt) in indices_vec.iter().enumerate() {
145 match idx_opt {
146 Some(idx) => {
147 let expected = array.scalar_at(*idx as usize);
148 let actual = result.scalar_at(i);
149 assert_eq!(expected, actual);
150 }
151 None => {
152 assert!(result.scalar_at(i).is_null());
153 }
154 }
155 }
156}
157
158fn test_take_repeated_indices(array: &dyn Array) {
159 if array.is_empty() {
160 return;
161 }
162
163 let indices = buffer![0u64, 0, 0].into_array();
165 let result =
166 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
167
168 assert_eq!(result.len(), 3);
169 let first_elem = array.scalar_at(0);
170 for i in 0..3 {
171 assert_eq!(result.scalar_at(i), first_elem);
172 }
173}
174
175fn test_empty_indices(array: &dyn Array) {
176 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
177 let result =
178 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
179
180 assert_eq!(result.len(), 0);
181 assert_eq!(result.dtype(), array.dtype());
182}
183
184fn test_take_reverse(array: &dyn Array) {
185 let len = array.len();
186 let indices = PrimitiveArray::from_iter((0..len as u64).rev());
188 let result =
189 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
190
191 assert_eq!(result.len(), len);
192
193 for i in 0..len {
195 assert_eq!(array.scalar_at(len - 1 - i), result.scalar_at(i));
196 }
197}
198
199fn test_take_single_middle(array: &dyn Array) {
200 let len = array.len();
201 let middle_idx = len / 2;
202
203 let indices = PrimitiveArray::from_iter([middle_idx as u64]);
204 let result =
205 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
206
207 assert_eq!(result.len(), 1);
208 assert_eq!(array.scalar_at(middle_idx), result.scalar_at(0));
209}
210
211#[allow(clippy::cast_possible_truncation)]
212fn test_take_random_unsorted(array: &dyn Array) {
213 let len = array.len();
214
215 let mut indices = Vec::new();
217 let mut idx = 1u64;
218 for _ in 0..len.min(10) {
219 indices.push((idx * 7 + 3) % len as u64);
220 idx = (idx * 3 + 1) % len as u64;
221 }
222
223 let indices_array = PrimitiveArray::from_iter(indices.clone());
224 let result = take(array, indices_array.as_ref())
225 .vortex_expect("take should succeed in conformance test");
226
227 assert_eq!(result.len(), indices.len());
228
229 for (i, &idx) in indices.iter().enumerate() {
231 assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
232 }
233}
234
235fn test_take_contiguous_range(array: &dyn Array) {
236 let len = array.len();
237 let start = len / 4;
238 let end = len / 2;
239
240 let indices = PrimitiveArray::from_iter(start as u64..end as u64);
242 let result =
243 take(array, indices.as_ref()).vortex_expect("take should succeed in conformance test");
244
245 assert_eq!(result.len(), end - start);
246
247 for i in 0..(end - start) {
249 assert_eq!(array.scalar_at(start + i), result.scalar_at(i));
250 }
251}
252
253#[allow(clippy::cast_possible_truncation)]
254fn test_take_mixed_repeated(array: &dyn Array) {
255 let len = array.len();
256
257 let indices = vec![
259 0u64,
260 0,
261 1,
262 1,
263 len as u64 / 2,
264 len as u64 / 2,
265 len as u64 / 2,
266 (len - 1) as u64,
267 ];
268
269 let indices_array = PrimitiveArray::from_iter(indices.clone());
270 let result = take(array, indices_array.as_ref())
271 .vortex_expect("take should succeed in conformance test");
272
273 assert_eq!(result.len(), indices.len());
274
275 for (i, &idx) in indices.iter().enumerate() {
277 assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
278 }
279}
280
281#[allow(clippy::cast_possible_truncation)]
282fn test_take_large_indices(array: &dyn Array) {
283 let len = array.len();
285 let num_indices = 10000.min(len * 3);
286
287 let indices: Vec<u64> = (0..num_indices)
289 .map(|i| ((i * 17 + 5) % len) as u64)
290 .collect();
291
292 let indices_array = PrimitiveArray::from_iter(indices.clone());
293 let result = take(array, indices_array.as_ref())
294 .vortex_expect("take should succeed in conformance test");
295
296 assert_eq!(result.len(), num_indices);
297
298 for i in (0..num_indices).step_by(1000) {
300 let expected_idx = indices[i] as usize;
301 assert_eq!(array.scalar_at(expected_idx), result.scalar_at(i));
302 }
303}