1mod cast;
5mod filter;
6mod mask;
7
8use itertools::Itertools;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::VortexResult;
11use vortex_scalar::Scalar;
12
13use crate::arrays::StructVTable;
14use crate::arrays::struct_::StructArray;
15use crate::compute::{
16 IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
17 MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
18};
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21use crate::{Array, ArrayRef, IntoArray, register_kernel};
22
23impl TakeKernel for StructVTable {
24 fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25 if array.is_empty() {
28 return StructArray::try_new_with_dtype(
29 array.fields().to_vec(),
30 array.struct_fields().clone(),
31 indices.len(),
32 Validity::AllInvalid,
33 )
34 .map(StructArray::into_array);
35 }
36 let inner_indices = &fill_null(
38 indices,
39 &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
40 )?;
41 StructArray::try_new_with_dtype(
42 array
43 .fields()
44 .iter()
45 .map(|field| take(field, inner_indices))
46 .try_collect()?,
47 array.struct_fields().clone(),
48 indices.len(),
49 array.validity().take(indices)?,
50 )
51 .map(|a| a.into_array())
52 }
53}
54
55register_kernel!(TakeKernelAdapter(StructVTable).lift());
56
57impl MinMaxKernel for StructVTable {
58 fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
59 Ok(None)
61 }
62}
63
64register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
65
66impl IsConstantKernel for StructVTable {
67 fn is_constant(
68 &self,
69 array: &StructArray,
70 opts: &IsConstantOpts,
71 ) -> VortexResult<Option<bool>> {
72 let children = array.children();
73 if children.is_empty() {
74 return Ok(Some(true));
75 }
76
77 for child in children.iter() {
78 match is_constant_opts(child, opts)? {
79 None => return Ok(None),
81 Some(false) => return Ok(Some(false)),
82 Some(true) => {}
83 }
84 }
85
86 Ok(Some(true))
87 }
88}
89
90register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
91
92#[cfg(test)]
93mod tests {
94 use Nullability::{NonNullable, Nullable};
95 use rstest::rstest;
96 use vortex_buffer::buffer;
97 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
98 use vortex_error::VortexUnwrap;
99 use vortex_mask::Mask;
100 use vortex_scalar::Scalar;
101
102 use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
103 use crate::compute::conformance::consistency::test_array_consistency;
104 use crate::compute::conformance::filter::test_filter_conformance;
105 use crate::compute::conformance::mask::test_mask_conformance;
106 use crate::compute::conformance::take::test_take_conformance;
107 use crate::compute::{cast, filter, is_constant, take};
108 use crate::validity::Validity;
109 use crate::{Array, IntoArray as _};
110
111 #[test]
112 fn filter_empty_struct() {
113 let struct_arr =
114 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
115 let mask = vec![
116 false, true, false, true, false, true, false, true, false, true,
117 ];
118 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
119 assert_eq!(filtered.len(), 5);
120 }
121
122 #[test]
123 fn take_empty_struct() {
124 let struct_arr =
125 StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
126 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
127 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
128 assert_eq!(taken.len(), 2);
129
130 assert_eq!(
131 taken.scalar_at(0),
132 Scalar::struct_(
133 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
134 vec![]
135 )
136 );
137 assert_eq!(
138 taken.scalar_at(1),
139 Scalar::null(DType::Struct(
140 StructFields::new(FieldNames::default(), vec![]),
141 Nullable
142 ))
143 );
144 }
145
146 #[test]
147 fn take_field_struct() {
148 let struct_arr =
149 StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
150 .unwrap();
151 let indices = PrimitiveArray::from_option_iter([Some(1), None]);
152 let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
153 assert_eq!(taken.len(), 2);
154
155 assert_eq!(
156 taken.scalar_at(0),
157 Scalar::struct_(
158 struct_arr.dtype().union_nullability(Nullable),
159 vec![Scalar::primitive(1, NonNullable)],
160 )
161 );
162 assert_eq!(
163 taken.scalar_at(1),
164 Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
165 );
166 }
167
168 #[test]
169 fn filter_empty_struct_with_empty_filter() {
170 let struct_arr =
171 StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
172 let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
173 assert_eq!(filtered.len(), 0);
174 }
175
176 #[test]
177 fn test_mask_empty_struct() {
178 test_mask_conformance(
179 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
180 .unwrap()
181 .as_ref(),
182 );
183 }
184
185 #[test]
186 fn test_mask_complex_struct() {
187 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
188 let ys = VarBinArray::from_iter(
189 [Some("a"), Some("b"), None, Some("d"), None],
190 DType::Utf8(Nullable),
191 )
192 .into_array();
193 let zs =
194 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
195
196 test_mask_conformance(
197 StructArray::try_new(
198 ["xs", "ys", "zs"].into(),
199 vec![
200 StructArray::try_new(
201 ["left", "right"].into(),
202 vec![xs.clone(), xs],
203 5,
204 Validity::NonNullable,
205 )
206 .unwrap()
207 .into_array(),
208 ys,
209 zs,
210 ],
211 5,
212 Validity::NonNullable,
213 )
214 .unwrap()
215 .as_ref(),
216 );
217 }
218
219 #[test]
220 fn test_filter_empty_struct() {
221 test_filter_conformance(
222 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
223 .unwrap()
224 .as_ref(),
225 );
226 }
227
228 #[test]
229 fn test_filter_complex_struct() {
230 let xs = buffer![0i64, 1, 2, 3, 4].into_array();
231 let ys = VarBinArray::from_iter(
232 [Some("a"), Some("b"), None, Some("d"), None],
233 DType::Utf8(Nullable),
234 )
235 .into_array();
236 let zs =
237 BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
238
239 test_filter_conformance(
240 StructArray::try_new(
241 ["xs", "ys", "zs"].into(),
242 vec![
243 StructArray::try_new(
244 ["left", "right"].into(),
245 vec![xs.clone(), xs],
246 5,
247 Validity::NonNullable,
248 )
249 .unwrap()
250 .into_array(),
251 ys,
252 zs,
253 ],
254 5,
255 Validity::NonNullable,
256 )
257 .unwrap()
258 .as_ref(),
259 );
260 }
261
262 #[test]
263 fn test_cast_empty_struct() {
264 let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
265 .unwrap()
266 .into_array();
267 let non_nullable_dtype = DType::Struct(
268 StructFields::new(FieldNames::default(), vec![]),
269 NonNullable,
270 );
271 let casted = cast(&array, &non_nullable_dtype).unwrap();
272 assert_eq!(casted.dtype(), &non_nullable_dtype);
273
274 let nullable_dtype =
275 DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
276 let casted = cast(&array, &nullable_dtype).unwrap();
277 assert_eq!(casted.dtype(), &nullable_dtype);
278 }
279
280 #[test]
281 fn test_cast_cannot_change_name_order() {
282 let array = StructArray::try_new(
283 ["xs", "ys", "zs"].into(),
284 vec![
285 buffer![1u8].into_array(),
286 buffer![1u8].into_array(),
287 buffer![1u8].into_array(),
288 ],
289 1,
290 Validity::NonNullable,
291 )
292 .unwrap();
293
294 let tu8 = DType::Primitive(PType::U8, NonNullable);
295
296 let result = cast(
297 array.as_ref(),
298 &DType::Struct(
299 StructFields::new(
300 FieldNames::from(["ys", "xs", "zs"]),
301 vec![tu8.clone(), tu8.clone(), tu8],
302 ),
303 NonNullable,
304 ),
305 );
306 assert!(
307 result.as_ref().is_err_and(|err| {
308 err.to_string()
309 .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
310 }),
311 "{result:?}"
312 );
313 }
314
315 #[test]
316 fn test_cast_complex_struct() {
317 let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
318 let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
319 let zs = BoolArray::new(
320 BooleanBuffer::from_iter([true, true, false, false, true]),
321 Validity::AllValid,
322 );
323 let fully_nullable_array = StructArray::try_new(
324 ["xs", "ys", "zs"].into(),
325 vec![
326 StructArray::try_new(
327 ["left", "right"].into(),
328 vec![xs.to_array(), xs.to_array()],
329 5,
330 Validity::AllValid,
331 )
332 .unwrap()
333 .into_array(),
334 ys.into_array(),
335 zs.into_array(),
336 ],
337 5,
338 Validity::AllValid,
339 )
340 .unwrap()
341 .into_array();
342
343 let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
344 let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
345 assert_eq!(casted.dtype(), &top_level_non_nullable);
346
347 let non_null_xs_right = DType::Struct(
348 StructFields::new(
349 ["xs", "ys", "zs"].into(),
350 vec![
351 DType::Struct(
352 StructFields::new(
353 ["left", "right"].into(),
354 vec![
355 DType::Primitive(PType::I64, NonNullable),
356 DType::Primitive(PType::I64, Nullable),
357 ],
358 ),
359 Nullable,
360 ),
361 DType::Utf8(Nullable),
362 DType::Bool(Nullable),
363 ],
364 ),
365 Nullable,
366 );
367 let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
368 assert_eq!(casted.dtype(), &non_null_xs_right);
369
370 let non_null_xs = DType::Struct(
371 StructFields::new(
372 ["xs", "ys", "zs"].into(),
373 vec![
374 DType::Struct(
375 StructFields::new(
376 ["left", "right"].into(),
377 vec![
378 DType::Primitive(PType::I64, Nullable),
379 DType::Primitive(PType::I64, Nullable),
380 ],
381 ),
382 NonNullable,
383 ),
384 DType::Utf8(Nullable),
385 DType::Bool(Nullable),
386 ],
387 ),
388 Nullable,
389 );
390 let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
391 assert_eq!(casted.dtype(), &non_null_xs);
392 }
393
394 #[test]
395 fn test_empty_struct_is_constant() {
396 let array = StructArray::new_with_len(2);
397 let is_constant = is_constant(array.as_ref()).vortex_unwrap();
398 assert_eq!(is_constant, Some(true));
399 }
400
401 #[test]
402 fn test_take_empty_struct_conformance() {
403 test_take_conformance(
404 StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
405 .unwrap()
406 .as_ref(),
407 );
408 }
409
410 #[test]
411 fn test_take_simple_struct_conformance() {
412 let xs = buffer![1i64, 2, 3, 4, 5].into_array();
413 let ys = VarBinArray::from_iter(
414 ["a", "b", "c", "d", "e"].map(Some),
415 DType::Utf8(NonNullable),
416 )
417 .into_array();
418
419 test_take_conformance(
420 StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 5, Validity::NonNullable)
421 .unwrap()
422 .as_ref(),
423 );
424 }
425
426 #[test]
427 fn test_take_nullable_struct_conformance() {
428 let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
430 let ys = VarBinArray::from_iter(
431 [Some("a"), Some("b"), None, Some("d"), None],
432 DType::Utf8(Nullable),
433 );
434
435 test_take_conformance(
436 StructArray::try_new(
437 ["xs", "ys"].into(),
438 vec![xs.into_array(), ys.into_array()],
439 5,
440 Validity::NonNullable,
441 )
442 .unwrap()
443 .as_ref(),
444 );
445 }
446
447 #[test]
448 fn test_take_nested_struct_conformance() {
449 let inner_xs = buffer![10i32, 20, 30, 40, 50].into_array();
451 let inner_ys = buffer![100i32, 200, 300, 400, 500].into_array();
452 let inner_struct = StructArray::try_new(
453 ["x", "y"].into(),
454 vec![inner_xs, inner_ys],
455 5,
456 Validity::NonNullable,
457 )
458 .unwrap()
459 .into_array();
460
461 let outer_zs = BoolArray::from_iter([true, false, true, false, true]).into_array();
462
463 test_take_conformance(
464 StructArray::try_new(
465 ["inner", "z"].into(),
466 vec![inner_struct, outer_zs],
467 5,
468 Validity::NonNullable,
469 )
470 .unwrap()
471 .as_ref(),
472 );
473 }
474
475 #[test]
476 fn test_take_single_element_struct_conformance() {
477 let xs = buffer![42i64].into_array();
478 let ys = VarBinArray::from_iter(["hello"].map(Some), DType::Utf8(NonNullable)).into_array();
479
480 test_take_conformance(
481 StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 1, Validity::NonNullable)
482 .unwrap()
483 .as_ref(),
484 );
485 }
486
487 #[test]
488 fn test_take_large_struct_conformance() {
489 let xs = PrimitiveArray::from_iter(0i64..100).into_array();
491 let ys = VarBinArray::from_iter(
492 (0..100).map(|i| format!("str_{i}")).map(Some),
493 DType::Utf8(NonNullable),
494 )
495 .into_array();
496 let zs = BoolArray::from_iter((0..100).map(|i| i % 2 == 0)).into_array();
497
498 test_take_conformance(
499 StructArray::try_new(
500 ["xs", "ys", "zs"].into(),
501 vec![xs, ys, zs],
502 100,
503 Validity::NonNullable,
504 )
505 .unwrap()
506 .as_ref(),
507 );
508 }
509
510 #[rstest]
512 #[case::struct_simple({
514 let xs = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
515 let ys = VarBinArray::from_iter(
516 ["a", "b", "c", "d", "e"].map(Some),
517 DType::Utf8(NonNullable),
518 );
519 StructArray::try_new(
520 ["xs", "ys"].into(),
521 vec![xs.into_array(), ys.into_array()],
522 5,
523 Validity::NonNullable,
524 )
525 .unwrap()
526 })]
527 #[case::struct_nullable({
528 let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
529 let ys = VarBinArray::from_iter(
530 [Some("a"), Some("b"), None, Some("d"), None],
531 DType::Utf8(Nullable),
532 );
533 StructArray::try_new(
534 ["xs", "ys"].into(),
535 vec![xs.into_array(), ys.into_array()],
536 5,
537 Validity::NonNullable,
538 )
539 .unwrap()
540 })]
541 #[case::empty_struct(StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap())]
543 #[case::single_field({
544 let xs = buffer![42i64].into_array();
545 StructArray::try_new(["xs"].into(), vec![xs], 1, Validity::NonNullable).unwrap()
546 })]
547 #[case::large_struct({
548 let xs = PrimitiveArray::from_iter(0..100i64).into_array();
549 let ys = VarBinArray::from_iter(
550 (0..100).map(|i| format!("value_{i}")).map(Some),
551 DType::Utf8(NonNullable),
552 ).into_array();
553 StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 100, Validity::NonNullable).unwrap()
554 })]
555 fn test_struct_consistency(#[case] array: StructArray) {
556 test_array_consistency(array.as_ref());
557 }
558}