vortex_array/arrays/struct_/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod filter;
6mod mask;
7
8use itertools::Itertools;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::VortexResult;
11use vortex_scalar::Scalar;
12
13use crate::arrays::StructVTable;
14use crate::arrays::struct_::StructArray;
15use crate::compute::{
16    IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
17    MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
18};
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21use crate::{Array, ArrayRef, IntoArray, register_kernel};
22
23impl TakeKernel for StructVTable {
24    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25        // If the struct array is empty then the indices must be all null, otherwise it will access
26        // an out of bounds element
27        if array.is_empty() {
28            return StructArray::try_new_with_dtype(
29                array.fields().to_vec(),
30                array.struct_fields().clone(),
31                indices.len(),
32                Validity::AllInvalid,
33            )
34            .map(StructArray::into_array);
35        }
36        // The validity is applied to the struct validity,
37        let inner_indices = &fill_null(
38            indices,
39            &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
40        )?;
41        StructArray::try_new_with_dtype(
42            array
43                .fields()
44                .iter()
45                .map(|field| take(field, inner_indices))
46                .try_collect()?,
47            array.struct_fields().clone(),
48            indices.len(),
49            array.validity().take(indices)?,
50        )
51        .map(|a| a.into_array())
52    }
53}
54
55register_kernel!(TakeKernelAdapter(StructVTable).lift());
56
57impl MinMaxKernel for StructVTable {
58    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
59        // TODO(joe): Implement struct min max
60        Ok(None)
61    }
62}
63
64register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
65
66impl IsConstantKernel for StructVTable {
67    fn is_constant(
68        &self,
69        array: &StructArray,
70        opts: &IsConstantOpts,
71    ) -> VortexResult<Option<bool>> {
72        let children = array.children();
73        if children.is_empty() {
74            return Ok(Some(true));
75        }
76
77        for child in children.iter() {
78            match is_constant_opts(child, opts)? {
79                // Un-determined
80                None => return Ok(None),
81                Some(false) => return Ok(Some(false)),
82                Some(true) => {}
83            }
84        }
85
86        Ok(Some(true))
87    }
88}
89
90register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
91
92#[cfg(test)]
93mod tests {
94    use Nullability::{NonNullable, Nullable};
95    use rstest::rstest;
96    use vortex_buffer::buffer;
97    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
98    use vortex_error::VortexUnwrap;
99    use vortex_mask::Mask;
100    use vortex_scalar::Scalar;
101
102    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
103    use crate::compute::conformance::consistency::test_array_consistency;
104    use crate::compute::conformance::filter::test_filter_conformance;
105    use crate::compute::conformance::mask::test_mask_conformance;
106    use crate::compute::conformance::take::test_take_conformance;
107    use crate::compute::{cast, filter, is_constant, take};
108    use crate::validity::Validity;
109    use crate::{Array, IntoArray as _};
110
111    #[test]
112    fn filter_empty_struct() {
113        let struct_arr =
114            StructArray::try_new(FieldNames::empty(), vec![], 10, Validity::NonNullable).unwrap();
115        let mask = vec![
116            false, true, false, true, false, true, false, true, false, true,
117        ];
118        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
119        assert_eq!(filtered.len(), 5);
120    }
121
122    #[test]
123    fn take_empty_struct() {
124        let struct_arr =
125            StructArray::try_new(FieldNames::empty(), vec![], 10, Validity::NonNullable).unwrap();
126        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
127        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
128        assert_eq!(taken.len(), 2);
129
130        assert_eq!(
131            taken.scalar_at(0),
132            Scalar::struct_(
133                DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
134                vec![]
135            )
136        );
137        assert_eq!(
138            taken.scalar_at(1),
139            Scalar::null(DType::Struct(
140                StructFields::new(FieldNames::default(), vec![]),
141                Nullable
142            ))
143        );
144    }
145
146    #[test]
147    fn take_field_struct() {
148        let struct_arr = StructArray::from_fields(&[("a", buffer![0..10].into_array())]).unwrap();
149        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
150        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
151        assert_eq!(taken.len(), 2);
152
153        assert_eq!(
154            taken.scalar_at(0),
155            Scalar::struct_(
156                struct_arr.dtype().union_nullability(Nullable),
157                vec![Scalar::primitive(1, NonNullable)],
158            )
159        );
160        assert_eq!(
161            taken.scalar_at(1),
162            Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
163        );
164    }
165
166    #[test]
167    fn filter_empty_struct_with_empty_filter() {
168        let struct_arr =
169            StructArray::try_new(FieldNames::empty(), vec![], 0, Validity::NonNullable).unwrap();
170        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
171        assert_eq!(filtered.len(), 0);
172    }
173
174    #[test]
175    fn test_mask_empty_struct() {
176        test_mask_conformance(
177            StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable)
178                .unwrap()
179                .as_ref(),
180        );
181    }
182
183    #[test]
184    fn test_mask_complex_struct() {
185        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
186        let ys = VarBinArray::from_iter(
187            [Some("a"), Some("b"), None, Some("d"), None],
188            DType::Utf8(Nullable),
189        )
190        .into_array();
191        let zs =
192            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
193
194        test_mask_conformance(
195            StructArray::try_new(
196                ["xs", "ys", "zs"].into(),
197                vec![
198                    StructArray::try_new(
199                        ["left", "right"].into(),
200                        vec![xs.clone(), xs],
201                        5,
202                        Validity::NonNullable,
203                    )
204                    .unwrap()
205                    .into_array(),
206                    ys,
207                    zs,
208                ],
209                5,
210                Validity::NonNullable,
211            )
212            .unwrap()
213            .as_ref(),
214        );
215    }
216
217    #[test]
218    fn test_filter_empty_struct() {
219        test_filter_conformance(
220            StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable)
221                .unwrap()
222                .as_ref(),
223        );
224    }
225
226    #[test]
227    fn test_filter_complex_struct() {
228        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
229        let ys = VarBinArray::from_iter(
230            [Some("a"), Some("b"), None, Some("d"), None],
231            DType::Utf8(Nullable),
232        )
233        .into_array();
234        let zs =
235            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
236
237        test_filter_conformance(
238            StructArray::try_new(
239                ["xs", "ys", "zs"].into(),
240                vec![
241                    StructArray::try_new(
242                        ["left", "right"].into(),
243                        vec![xs.clone(), xs],
244                        5,
245                        Validity::NonNullable,
246                    )
247                    .unwrap()
248                    .into_array(),
249                    ys,
250                    zs,
251                ],
252                5,
253                Validity::NonNullable,
254            )
255            .unwrap()
256            .as_ref(),
257        );
258    }
259
260    #[test]
261    fn test_cast_empty_struct() {
262        let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
263            .unwrap()
264            .into_array();
265        let non_nullable_dtype = DType::Struct(
266            StructFields::new(FieldNames::default(), vec![]),
267            NonNullable,
268        );
269        let casted = cast(&array, &non_nullable_dtype).unwrap();
270        assert_eq!(casted.dtype(), &non_nullable_dtype);
271
272        let nullable_dtype =
273            DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
274        let casted = cast(&array, &nullable_dtype).unwrap();
275        assert_eq!(casted.dtype(), &nullable_dtype);
276    }
277
278    #[test]
279    fn test_cast_cannot_change_name_order() {
280        let array = StructArray::try_new(
281            ["xs", "ys", "zs"].into(),
282            vec![
283                buffer![1u8].into_array(),
284                buffer![1u8].into_array(),
285                buffer![1u8].into_array(),
286            ],
287            1,
288            Validity::NonNullable,
289        )
290        .unwrap();
291
292        let tu8 = DType::Primitive(PType::U8, NonNullable);
293
294        let result = cast(
295            array.as_ref(),
296            &DType::Struct(
297                StructFields::new(
298                    FieldNames::from(["ys", "xs", "zs"]),
299                    vec![tu8.clone(), tu8.clone(), tu8],
300                ),
301                NonNullable,
302            ),
303        );
304        assert!(
305            result.as_ref().is_err_and(|err| {
306                err.to_string()
307                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
308            }),
309            "{result:?}"
310        );
311    }
312
313    #[test]
314    fn test_cast_complex_struct() {
315        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
316        let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
317        let zs = BoolArray::from_bool_buffer(
318            BooleanBuffer::from_iter([true, true, false, false, true]),
319            Validity::AllValid,
320        );
321        let fully_nullable_array = StructArray::try_new(
322            ["xs", "ys", "zs"].into(),
323            vec![
324                StructArray::try_new(
325                    ["left", "right"].into(),
326                    vec![xs.to_array(), xs.to_array()],
327                    5,
328                    Validity::AllValid,
329                )
330                .unwrap()
331                .into_array(),
332                ys.into_array(),
333                zs.into_array(),
334            ],
335            5,
336            Validity::AllValid,
337        )
338        .unwrap()
339        .into_array();
340
341        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
342        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
343        assert_eq!(casted.dtype(), &top_level_non_nullable);
344
345        let non_null_xs_right = DType::Struct(
346            StructFields::new(
347                ["xs", "ys", "zs"].into(),
348                vec![
349                    DType::Struct(
350                        StructFields::new(
351                            ["left", "right"].into(),
352                            vec![
353                                DType::Primitive(PType::I64, NonNullable),
354                                DType::Primitive(PType::I64, Nullable),
355                            ],
356                        ),
357                        Nullable,
358                    ),
359                    DType::Utf8(Nullable),
360                    DType::Bool(Nullable),
361                ],
362            ),
363            Nullable,
364        );
365        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
366        assert_eq!(casted.dtype(), &non_null_xs_right);
367
368        let non_null_xs = DType::Struct(
369            StructFields::new(
370                ["xs", "ys", "zs"].into(),
371                vec![
372                    DType::Struct(
373                        StructFields::new(
374                            ["left", "right"].into(),
375                            vec![
376                                DType::Primitive(PType::I64, Nullable),
377                                DType::Primitive(PType::I64, Nullable),
378                            ],
379                        ),
380                        NonNullable,
381                    ),
382                    DType::Utf8(Nullable),
383                    DType::Bool(Nullable),
384                ],
385            ),
386            Nullable,
387        );
388        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
389        assert_eq!(casted.dtype(), &non_null_xs);
390    }
391
392    #[test]
393    fn test_empty_struct_is_constant() {
394        let array = StructArray::new_fieldless_with_len(2);
395        let is_constant = is_constant(array.as_ref()).vortex_unwrap();
396        assert_eq!(is_constant, Some(true));
397    }
398
399    #[test]
400    fn test_take_empty_struct_conformance() {
401        test_take_conformance(
402            StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable)
403                .unwrap()
404                .as_ref(),
405        );
406    }
407
408    #[test]
409    fn test_take_simple_struct_conformance() {
410        let xs = buffer![1i64, 2, 3, 4, 5].into_array();
411        let ys = VarBinArray::from_iter(
412            ["a", "b", "c", "d", "e"].map(Some),
413            DType::Utf8(NonNullable),
414        )
415        .into_array();
416
417        test_take_conformance(
418            StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 5, Validity::NonNullable)
419                .unwrap()
420                .as_ref(),
421        );
422    }
423
424    #[test]
425    fn test_take_nullable_struct_conformance() {
426        // Test struct with nullable fields
427        let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
428        let ys = VarBinArray::from_iter(
429            [Some("a"), Some("b"), None, Some("d"), None],
430            DType::Utf8(Nullable),
431        );
432
433        test_take_conformance(
434            StructArray::try_new(
435                ["xs", "ys"].into(),
436                vec![xs.into_array(), ys.into_array()],
437                5,
438                Validity::NonNullable,
439            )
440            .unwrap()
441            .as_ref(),
442        );
443    }
444
445    #[test]
446    fn test_take_nested_struct_conformance() {
447        // Test nested struct
448        let inner_xs = buffer![10i32, 20, 30, 40, 50].into_array();
449        let inner_ys = buffer![100i32, 200, 300, 400, 500].into_array();
450        let inner_struct = StructArray::try_new(
451            ["x", "y"].into(),
452            vec![inner_xs, inner_ys],
453            5,
454            Validity::NonNullable,
455        )
456        .unwrap()
457        .into_array();
458
459        let outer_zs = BoolArray::from_iter([true, false, true, false, true]).into_array();
460
461        test_take_conformance(
462            StructArray::try_new(
463                ["inner", "z"].into(),
464                vec![inner_struct, outer_zs],
465                5,
466                Validity::NonNullable,
467            )
468            .unwrap()
469            .as_ref(),
470        );
471    }
472
473    #[test]
474    fn test_take_single_element_struct_conformance() {
475        let xs = buffer![42i64].into_array();
476        let ys = VarBinArray::from_iter(["hello"].map(Some), DType::Utf8(NonNullable)).into_array();
477
478        test_take_conformance(
479            StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 1, Validity::NonNullable)
480                .unwrap()
481                .as_ref(),
482        );
483    }
484
485    #[test]
486    fn test_take_large_struct_conformance() {
487        // Test with larger array for additional edge cases
488        let xs = buffer![0i64..100].into_array();
489        let ys = VarBinArray::from_iter(
490            (0..100).map(|i| format!("str_{i}")).map(Some),
491            DType::Utf8(NonNullable),
492        )
493        .into_array();
494        let zs = BoolArray::from_iter((0..100).map(|i| i % 2 == 0)).into_array();
495
496        test_take_conformance(
497            StructArray::try_new(
498                ["xs", "ys", "zs"].into(),
499                vec![xs, ys, zs],
500                100,
501                Validity::NonNullable,
502            )
503            .unwrap()
504            .as_ref(),
505        );
506    }
507
508    // Consistency tests
509    #[rstest]
510    // From test_all_consistency
511    #[case::struct_simple({
512        let xs = buffer![1i32, 2, 3, 4, 5].into_array();
513        let ys = VarBinArray::from_iter(
514            ["a", "b", "c", "d", "e"].map(Some),
515            DType::Utf8(NonNullable),
516        );
517        StructArray::try_new(
518            ["xs", "ys"].into(),
519            vec![xs.into_array(), ys.into_array()],
520            5,
521            Validity::NonNullable,
522        )
523        .unwrap()
524    })]
525    #[case::struct_nullable({
526        let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
527        let ys = VarBinArray::from_iter(
528            [Some("a"), Some("b"), None, Some("d"), None],
529            DType::Utf8(Nullable),
530        );
531        StructArray::try_new(
532            ["xs", "ys"].into(),
533            vec![xs.into_array(), ys.into_array()],
534            5,
535            Validity::NonNullable,
536        )
537        .unwrap()
538    })]
539    // Additional test cases
540    #[case::empty_struct(StructArray::try_new(FieldNames::empty(), vec![], 5, Validity::NonNullable).unwrap())]
541    #[case::single_field({
542        let xs = buffer![42i64].into_array();
543        StructArray::try_new(["xs"].into(), vec![xs], 1, Validity::NonNullable).unwrap()
544    })]
545    #[case::large_struct({
546        let xs = buffer![0..100i64].into_array();
547        let ys = VarBinArray::from_iter(
548            (0..100).map(|i| format!("value_{i}")).map(Some),
549            DType::Utf8(NonNullable),
550        ).into_array();
551        StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 100, Validity::NonNullable).unwrap()
552    })]
553    fn test_struct_consistency(#[case] array: StructArray) {
554        test_array_consistency(array.as_ref());
555    }
556}