1mod cast;
5mod filter;
6mod mask;
7
8use itertools::Itertools;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::VortexResult;
11use vortex_scalar::Scalar;
12
13use crate::arrays::StructVTable;
14use crate::arrays::struct_::StructArray;
15use crate::compute::{
16 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
17 MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
18};
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21use crate::{Array, ArrayRef, IntoArray, register_kernel};
22
23impl TakeKernel for StructVTable {
24 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25 if array.is_empty() {
28 return StructArray::try_new_with_dtype(
29 array.fields().to_vec(),
30 array.struct_fields().clone(),
31 indices.len(),
32 Validity::AllInvalid,
33 )
34 .map(StructArray::into_array);
35 }
36 let inner_indices = &fill_null(
38 indices,
39 &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
40 )?;
41 StructArray::try_new_with_dtype(
42 array
43 .fields()
44 .iter()
45 .map(|field| take(field, inner_indices))
46 .try_collect()?,
47 array.struct_fields().clone(),
48 indices.len(),
49 array.validity().take(indices)?,
50 )
51 .map(|a| a.into_array())
52 }
53}
54
55register_kernel!(TakeKernelAdapter(StructVTable).lift());
56
57impl MinMaxKernel for StructVTable {
58 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
59 Ok(None)
61 }
62}
63
64register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
65
66impl IsConstantKernel for StructVTable {
67 fn is_constant(
68 &self,
69 array: &StructArray,
70 opts: &IsConstantOpts,
71 ) -> VortexResult<Option<bool>> {
72 let children = array.children();
73 if children.is_empty() {
74 return Ok(Some(true));
75 }
76
77 for child in children.iter() {
78 match is_constant_opts(child, opts)? {
79 None => return Ok(None),
81 Some(false) => return Ok(Some(false)),
82 Some(true) => {}
83 }
84 }
85
86 Ok(Some(true))
87 }
88}
89
90register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
91
92#[cfg(test)]
93mod tests {
94 use Nullability::{NonNullable, Nullable};
95 use rstest::rstest;
96 use vortex_buffer::buffer;
97 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
98 use vortex_error::VortexUnwrap;
99 use vortex_mask::Mask;
100 use vortex_scalar::Scalar;
101
102 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
103 use crate::compute::conformance::consistency::test_array_consistency;
104 use crate::compute::conformance::filter::test_filter_conformance;
105 use crate::compute::conformance::mask::test_mask_conformance;
106 use crate::compute::conformance::take::test_take_conformance;
107 use crate::compute::{cast, filter, is_constant, take};
108 use crate::validity::Validity;
109 use crate::{Array, IntoArray as _};
110
111 #[test]
112 fn filter_empty_struct() {
113 let struct_arr =
114 StructArray::try_new(FieldNames::empty(), vec![], 10, Validity::NonNullable).unwrap();
115 let mask = vec![
116 false, true, false, true, false, true, false, true, false, true,
117 ];
118 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
119 assert_eq!(filtered.len(), 5);
120 }
121
122 #[test]
123 fn take_empty_struct() {
124 let struct_arr =
125 StructArray::try_new(FieldNames::empty(), vec![], 10, Validity::NonNullable).unwrap();
126 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
127 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
128 assert_eq!(taken.len(), 2);
129
130 assert_eq!(
131 taken.scalar_at(0),
132 Scalar::struct_(
133 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
134 vec![]
135 )
136 );
137 assert_eq!(
138 taken.scalar_at(1),
139 Scalar::null(DType::Struct(
140 StructFields::new(FieldNames::default(), vec![]),
141 Nullable
142 ))
143 );
144 }
145
146 #[test]
147 fn take_field_struct() {
148 let struct_arr = StructArray::from_fields(&[("a", buffer![0..10].into_array())]).unwrap();
149 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
150 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
151 assert_eq!(taken.len(), 2);
152
153 assert_eq!(
154 taken.scalar_at(0),
155 Scalar::struct_(
156 struct_arr.dtype().union_nullability(Nullable),
157 vec![Scalar::primitive(1, NonNullable)],
158 )
159 );
160 assert_eq!(
161 taken.scalar_at(1),
162 Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
163 );
164 }
165
166 #[test]
167 fn filter_empty_struct_with_empty_filter() {
168 let struct_arr =
169 StructArray::try_new(FieldNames::empty(), vec![], 0, Validity::NonNullable).unwrap();
170 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
171 assert_eq!(filtered.len(), 0);
172 }
173
174 #[test]
175 fn test_mask_empty_struct() {
176 test_mask_conformance(
177 StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable)
178 .unwrap()
179 .as_ref(),
180 );
181 }
182
183 #[test]
184 fn test_mask_complex_struct() {
185 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
186 let ys = VarBinArray::from_iter(
187 [Some("a"), Some("b"), None, Some("d"), None],
188 DType::Utf8(Nullable),
189 )
190 .into_array();
191 let zs =
192 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
193
194 test_mask_conformance(
195 StructArray::try_new(
196 ["xs", "ys", "zs"].into(),
197 vec![
198 StructArray::try_new(
199 ["left", "right"].into(),
200 vec![xs.clone(), xs],
201 5,
202 Validity::NonNullable,
203 )
204 .unwrap()
205 .into_array(),
206 ys,
207 zs,
208 ],
209 5,
210 Validity::NonNullable,
211 )
212 .unwrap()
213 .as_ref(),
214 );
215 }
216
217 #[test]
218 fn test_filter_empty_struct() {
219 test_filter_conformance(
220 StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable)
221 .unwrap()
222 .as_ref(),
223 );
224 }
225
226 #[test]
227 fn test_filter_complex_struct() {
228 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
229 let ys = VarBinArray::from_iter(
230 [Some("a"), Some("b"), None, Some("d"), None],
231 DType::Utf8(Nullable),
232 )
233 .into_array();
234 let zs =
235 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
236
237 test_filter_conformance(
238 StructArray::try_new(
239 ["xs", "ys", "zs"].into(),
240 vec![
241 StructArray::try_new(
242 ["left", "right"].into(),
243 vec![xs.clone(), xs],
244 5,
245 Validity::NonNullable,
246 )
247 .unwrap()
248 .into_array(),
249 ys,
250 zs,
251 ],
252 5,
253 Validity::NonNullable,
254 )
255 .unwrap()
256 .as_ref(),
257 );
258 }
259
260 #[test]
261 fn test_cast_empty_struct() {
262 let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
263 .unwrap()
264 .into_array();
265 let non_nullable_dtype = DType::Struct(
266 StructFields::new(FieldNames::default(), vec![]),
267 NonNullable,
268 );
269 let casted = cast(&array, &non_nullable_dtype).unwrap();
270 assert_eq!(casted.dtype(), &non_nullable_dtype);
271
272 let nullable_dtype =
273 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
274 let casted = cast(&array, &nullable_dtype).unwrap();
275 assert_eq!(casted.dtype(), &nullable_dtype);
276 }
277
278 #[test]
279 fn test_cast_cannot_change_name_order() {
280 let array = StructArray::try_new(
281 ["xs", "ys", "zs"].into(),
282 vec![
283 buffer![1u8].into_array(),
284 buffer![1u8].into_array(),
285 buffer![1u8].into_array(),
286 ],
287 1,
288 Validity::NonNullable,
289 )
290 .unwrap();
291
292 let tu8 = DType::Primitive(PType::U8, NonNullable);
293
294 let result = cast(
295 array.as_ref(),
296 &DType::Struct(
297 StructFields::new(
298 FieldNames::from(["ys", "xs", "zs"]),
299 vec![tu8.clone(), tu8.clone(), tu8],
300 ),
301 NonNullable,
302 ),
303 );
304 assert!(
305 result.as_ref().is_err_and(|err| {
306 err.to_string()
307 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
308 }),
309 "{result:?}"
310 );
311 }
312
313 #[test]
314 fn test_cast_complex_struct() {
315 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
316 let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
317 let zs = BoolArray::from_bool_buffer(
318 BooleanBuffer::from_iter([true, true, false, false, true]),
319 Validity::AllValid,
320 );
321 let fully_nullable_array = StructArray::try_new(
322 ["xs", "ys", "zs"].into(),
323 vec![
324 StructArray::try_new(
325 ["left", "right"].into(),
326 vec![xs.to_array(), xs.to_array()],
327 5,
328 Validity::AllValid,
329 )
330 .unwrap()
331 .into_array(),
332 ys.into_array(),
333 zs.into_array(),
334 ],
335 5,
336 Validity::AllValid,
337 )
338 .unwrap()
339 .into_array();
340
341 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
342 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
343 assert_eq!(casted.dtype(), &top_level_non_nullable);
344
345 let non_null_xs_right = DType::Struct(
346 StructFields::new(
347 ["xs", "ys", "zs"].into(),
348 vec![
349 DType::Struct(
350 StructFields::new(
351 ["left", "right"].into(),
352 vec![
353 DType::Primitive(PType::I64, NonNullable),
354 DType::Primitive(PType::I64, Nullable),
355 ],
356 ),
357 Nullable,
358 ),
359 DType::Utf8(Nullable),
360 DType::Bool(Nullable),
361 ],
362 ),
363 Nullable,
364 );
365 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
366 assert_eq!(casted.dtype(), &non_null_xs_right);
367
368 let non_null_xs = DType::Struct(
369 StructFields::new(
370 ["xs", "ys", "zs"].into(),
371 vec![
372 DType::Struct(
373 StructFields::new(
374 ["left", "right"].into(),
375 vec![
376 DType::Primitive(PType::I64, Nullable),
377 DType::Primitive(PType::I64, Nullable),
378 ],
379 ),
380 NonNullable,
381 ),
382 DType::Utf8(Nullable),
383 DType::Bool(Nullable),
384 ],
385 ),
386 Nullable,
387 );
388 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
389 assert_eq!(casted.dtype(), &non_null_xs);
390 }
391
392 #[test]
393 fn test_empty_struct_is_constant() {
394 let array = StructArray::new_fieldless_with_len(2);
395 let is_constant = is_constant(array.as_ref()).vortex_unwrap();
396 assert_eq!(is_constant, Some(true));
397 }
398
399 #[test]
400 fn test_take_empty_struct_conformance() {
401 test_take_conformance(
402 StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable)
403 .unwrap()
404 .as_ref(),
405 );
406 }
407
408 #[test]
409 fn test_take_simple_struct_conformance() {
410 let xs = buffer![1i64, 2, 3, 4, 5].into_array();
411 let ys = VarBinArray::from_iter(
412 ["a", "b", "c", "d", "e"].map(Some),
413 DType::Utf8(NonNullable),
414 )
415 .into_array();
416
417 test_take_conformance(
418 StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 5, Validity::NonNullable)
419 .unwrap()
420 .as_ref(),
421 );
422 }
423
424 #[test]
425 fn test_take_nullable_struct_conformance() {
426 let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
428 let ys = VarBinArray::from_iter(
429 [Some("a"), Some("b"), None, Some("d"), None],
430 DType::Utf8(Nullable),
431 );
432
433 test_take_conformance(
434 StructArray::try_new(
435 ["xs", "ys"].into(),
436 vec![xs.into_array(), ys.into_array()],
437 5,
438 Validity::NonNullable,
439 )
440 .unwrap()
441 .as_ref(),
442 );
443 }
444
445 #[test]
446 fn test_take_nested_struct_conformance() {
447 let inner_xs = buffer![10i32, 20, 30, 40, 50].into_array();
449 let inner_ys = buffer![100i32, 200, 300, 400, 500].into_array();
450 let inner_struct = StructArray::try_new(
451 ["x", "y"].into(),
452 vec![inner_xs, inner_ys],
453 5,
454 Validity::NonNullable,
455 )
456 .unwrap()
457 .into_array();
458
459 let outer_zs = BoolArray::from_iter([true, false, true, false, true]).into_array();
460
461 test_take_conformance(
462 StructArray::try_new(
463 ["inner", "z"].into(),
464 vec![inner_struct, outer_zs],
465 5,
466 Validity::NonNullable,
467 )
468 .unwrap()
469 .as_ref(),
470 );
471 }
472
473 #[test]
474 fn test_take_single_element_struct_conformance() {
475 let xs = buffer![42i64].into_array();
476 let ys = VarBinArray::from_iter(["hello"].map(Some), DType::Utf8(NonNullable)).into_array();
477
478 test_take_conformance(
479 StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 1, Validity::NonNullable)
480 .unwrap()
481 .as_ref(),
482 );
483 }
484
485 #[test]
486 fn test_take_large_struct_conformance() {
487 let xs = buffer![0i64..100].into_array();
489 let ys = VarBinArray::from_iter(
490 (0..100).map(|i| format!("str_{i}")).map(Some),
491 DType::Utf8(NonNullable),
492 )
493 .into_array();
494 let zs = BoolArray::from_iter((0..100).map(|i| i % 2 == 0)).into_array();
495
496 test_take_conformance(
497 StructArray::try_new(
498 ["xs", "ys", "zs"].into(),
499 vec![xs, ys, zs],
500 100,
501 Validity::NonNullable,
502 )
503 .unwrap()
504 .as_ref(),
505 );
506 }
507
508 #[rstest]
510 #[case::struct_simple({
512 let xs = buffer![1i32, 2, 3, 4, 5].into_array();
513 let ys = VarBinArray::from_iter(
514 ["a", "b", "c", "d", "e"].map(Some),
515 DType::Utf8(NonNullable),
516 );
517 StructArray::try_new(
518 ["xs", "ys"].into(),
519 vec![xs.into_array(), ys.into_array()],
520 5,
521 Validity::NonNullable,
522 )
523 .unwrap()
524 })]
525 #[case::struct_nullable({
526 let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
527 let ys = VarBinArray::from_iter(
528 [Some("a"), Some("b"), None, Some("d"), None],
529 DType::Utf8(Nullable),
530 );
531 StructArray::try_new(
532 ["xs", "ys"].into(),
533 vec![xs.into_array(), ys.into_array()],
534 5,
535 Validity::NonNullable,
536 )
537 .unwrap()
538 })]
539 #[case::empty_struct(StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable).unwrap())]
541 #[case::single_field({
542 let xs = buffer![42i64].into_array();
543 StructArray::try_new(["xs"].into(), vec![xs], 1, Validity::NonNullable).unwrap()
544 })]
545 #[case::large_struct({
546 let xs = buffer![0..100i64].into_array();
547 let ys = VarBinArray::from_iter(
548 (0..100).map(|i| format!("value_{i}")).map(Some),
549 DType::Utf8(NonNullable),
550 ).into_array();
551 StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 100, Validity::NonNullable).unwrap()
552 })]
553 fn test_struct_consistency(#[case] array: StructArray) {
554 test_array_consistency(array.as_ref());
555 }
556}