vortex_array/compute/conformance/
take.rs1use vortex_buffer::buffer;
5use vortex_dtype::Nullability;
6use vortex_error::VortexUnwrap;
7
8use crate::Array;
9use crate::Canonical;
10use crate::IntoArray as _;
11use crate::arrays::PrimitiveArray;
12use crate::compute::take;
13
14pub fn test_take_conformance(array: &dyn Array) {
24 let len = array.len();
25
26 if len > 0 {
27 test_take_all(array);
28 test_take_none(array);
29 test_take_selective(array);
30 test_take_first_and_last(array);
31 test_take_with_nullable_indices(array);
32 test_take_repeated_indices(array);
33 }
34
35 test_empty_indices(array);
36
37 if len > 0 {
39 test_take_reverse(array);
40 test_take_single_middle(array);
41 }
42
43 if len > 3 {
44 test_take_random_unsorted(array);
45 test_take_contiguous_range(array);
46 test_take_mixed_repeated(array);
47 }
48
49 if len >= 1024 {
51 test_take_large_indices(array);
52 }
53}
54
55fn test_take_all(array: &dyn Array) {
56 let len = array.len();
57 let indices = PrimitiveArray::from_iter(0..len as u64);
58 let result = take(array, indices.as_ref()).vortex_unwrap();
59
60 assert_eq!(result.len(), len);
61 assert_eq!(result.dtype(), array.dtype());
62
63 match (&array.to_canonical(), &result.to_canonical()) {
65 (Canonical::Primitive(orig_prim), Canonical::Primitive(result_prim)) => {
66 assert_eq!(orig_prim.byte_buffer(), result_prim.byte_buffer());
67 }
68 _ => {
69 for i in 0..len {
71 assert_eq!(array.scalar_at(i), result.scalar_at(i));
72 }
73 }
74 }
75}
76
77fn test_take_none(array: &dyn Array) {
78 let indices: PrimitiveArray = PrimitiveArray::from_iter::<[u64; 0]>([]);
79 let result = take(array, indices.as_ref()).vortex_unwrap();
80
81 assert_eq!(result.len(), 0);
82 assert_eq!(result.dtype(), array.dtype());
83}
84
85#[allow(clippy::cast_possible_truncation)]
86fn test_take_selective(array: &dyn Array) {
87 let len = array.len();
88
89 let indices: Vec<u64> = (0..len as u64).step_by(2).collect();
91 let expected_len = indices.len();
92 let indices_array = PrimitiveArray::from_iter(indices.clone());
93
94 let result = take(array, indices_array.as_ref()).vortex_unwrap();
95 assert_eq!(result.len(), expected_len);
96
97 for (result_idx, &original_idx) in indices.iter().enumerate() {
99 assert_eq!(
100 array.scalar_at(original_idx as usize),
101 result.scalar_at(result_idx)
102 );
103 }
104}
105
106fn test_take_first_and_last(array: &dyn Array) {
107 let len = array.len();
108 let indices = PrimitiveArray::from_iter([0u64, (len - 1) as u64]);
109 let result = take(array, indices.as_ref()).vortex_unwrap();
110
111 assert_eq!(result.len(), 2);
112 assert_eq!(array.scalar_at(0), result.scalar_at(0));
113 assert_eq!(array.scalar_at(len - 1), result.scalar_at(1));
114}
115
116#[allow(clippy::cast_possible_truncation)]
117fn test_take_with_nullable_indices(array: &dyn Array) {
118 let len = array.len();
119
120 let indices_vec: Vec<Option<u64>> = if len >= 3 {
122 vec![Some(0), None, Some((len - 1) as u64)]
123 } else if len >= 2 {
124 vec![Some(0), None]
125 } else {
126 vec![None]
127 };
128
129 let indices = PrimitiveArray::from_option_iter(indices_vec.clone());
130 let result = take(array, indices.as_ref()).vortex_unwrap();
131
132 assert_eq!(result.len(), indices_vec.len());
133 assert_eq!(
134 result.dtype(),
135 &array.dtype().with_nullability(Nullability::Nullable)
136 );
137
138 for (i, idx_opt) in indices_vec.iter().enumerate() {
140 match idx_opt {
141 Some(idx) => {
142 let expected = array.scalar_at(*idx as usize);
143 let actual = result.scalar_at(i);
144 assert_eq!(expected, actual);
145 }
146 None => {
147 assert!(result.scalar_at(i).is_null());
148 }
149 }
150 }
151}
152
153fn test_take_repeated_indices(array: &dyn Array) {
154 if array.is_empty() {
155 return;
156 }
157
158 let indices = buffer![0u64, 0, 0].into_array();
160 let result = take(array, indices.as_ref()).vortex_unwrap();
161
162 assert_eq!(result.len(), 3);
163 let first_elem = array.scalar_at(0);
164 for i in 0..3 {
165 assert_eq!(result.scalar_at(i), first_elem);
166 }
167}
168
169fn test_empty_indices(array: &dyn Array) {
170 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
171 let result = take(array, indices.as_ref()).vortex_unwrap();
172
173 assert_eq!(result.len(), 0);
174 assert_eq!(result.dtype(), array.dtype());
175}
176
177fn test_take_reverse(array: &dyn Array) {
178 let len = array.len();
179 let indices = PrimitiveArray::from_iter((0..len as u64).rev());
181 let result = take(array, indices.as_ref()).vortex_unwrap();
182
183 assert_eq!(result.len(), len);
184
185 for i in 0..len {
187 assert_eq!(array.scalar_at(len - 1 - i), result.scalar_at(i));
188 }
189}
190
191fn test_take_single_middle(array: &dyn Array) {
192 let len = array.len();
193 let middle_idx = len / 2;
194
195 let indices = PrimitiveArray::from_iter([middle_idx as u64]);
196 let result = take(array, indices.as_ref()).vortex_unwrap();
197
198 assert_eq!(result.len(), 1);
199 assert_eq!(array.scalar_at(middle_idx), result.scalar_at(0));
200}
201
202#[allow(clippy::cast_possible_truncation)]
203fn test_take_random_unsorted(array: &dyn Array) {
204 let len = array.len();
205
206 let mut indices = Vec::new();
208 let mut idx = 1u64;
209 for _ in 0..len.min(10) {
210 indices.push((idx * 7 + 3) % len as u64);
211 idx = (idx * 3 + 1) % len as u64;
212 }
213
214 let indices_array = PrimitiveArray::from_iter(indices.clone());
215 let result = take(array, indices_array.as_ref()).vortex_unwrap();
216
217 assert_eq!(result.len(), indices.len());
218
219 for (i, &idx) in indices.iter().enumerate() {
221 assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
222 }
223}
224
225fn test_take_contiguous_range(array: &dyn Array) {
226 let len = array.len();
227 let start = len / 4;
228 let end = len / 2;
229
230 let indices = PrimitiveArray::from_iter(start as u64..end as u64);
232 let result = take(array, indices.as_ref()).vortex_unwrap();
233
234 assert_eq!(result.len(), end - start);
235
236 for i in 0..(end - start) {
238 assert_eq!(array.scalar_at(start + i), result.scalar_at(i));
239 }
240}
241
242#[allow(clippy::cast_possible_truncation)]
243fn test_take_mixed_repeated(array: &dyn Array) {
244 let len = array.len();
245
246 let indices = vec![
248 0u64,
249 0,
250 1,
251 1,
252 len as u64 / 2,
253 len as u64 / 2,
254 len as u64 / 2,
255 (len - 1) as u64,
256 ];
257
258 let indices_array = PrimitiveArray::from_iter(indices.clone());
259 let result = take(array, indices_array.as_ref()).vortex_unwrap();
260
261 assert_eq!(result.len(), indices.len());
262
263 for (i, &idx) in indices.iter().enumerate() {
265 assert_eq!(array.scalar_at(idx as usize), result.scalar_at(i));
266 }
267}
268
269#[allow(clippy::cast_possible_truncation)]
270fn test_take_large_indices(array: &dyn Array) {
271 let len = array.len();
273 let num_indices = 10000.min(len * 3);
274
275 let indices: Vec<u64> = (0..num_indices)
277 .map(|i| ((i * 17 + 5) % len) as u64)
278 .collect();
279
280 let indices_array = PrimitiveArray::from_iter(indices.clone());
281 let result = take(array, indices_array.as_ref()).vortex_unwrap();
282
283 assert_eq!(result.len(), num_indices);
284
285 for i in (0..num_indices).step_by(1000) {
287 let expected_idx = indices[i] as usize;
288 assert_eq!(array.scalar_at(expected_idx), result.scalar_at(i));
289 }
290}