vortex_array/arrays/struct_/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod filter;
6mod mask;
7
8use itertools::Itertools;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::VortexResult;
11use vortex_scalar::Scalar;
12
13use crate::arrays::StructVTable;
14use crate::arrays::struct_::StructArray;
15use crate::compute::{
16    IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
17    MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
18};
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21use crate::{Array, ArrayRef, IntoArray, register_kernel};
22
23impl TakeKernel for StructVTable {
24    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25        // If the struct array is empty then the indices must be all null, otherwise it will access
26        // an out of bounds element
27        if array.is_empty() {
28            return StructArray::try_new_with_dtype(
29                array.fields().to_vec(),
30                array.struct_fields().clone(),
31                indices.len(),
32                Validity::AllInvalid,
33            )
34            .map(StructArray::into_array);
35        }
36        // The validity is applied to the struct validity,
37        let inner_indices = &fill_null(
38            indices,
39            &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
40        )?;
41        StructArray::try_new_with_dtype(
42            array
43                .fields()
44                .iter()
45                .map(|field| take(field, inner_indices))
46                .try_collect()?,
47            array.struct_fields().clone(),
48            indices.len(),
49            array.validity().take(indices)?,
50        )
51        .map(|a| a.into_array())
52    }
53}
54
55register_kernel!(TakeKernelAdapter(StructVTable).lift());
56
57impl MinMaxKernel for StructVTable {
58    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
59        // TODO(joe): Implement struct min max
60        Ok(None)
61    }
62}
63
64register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
65
66impl IsConstantKernel for StructVTable {
67    fn is_constant(
68        &self,
69        array: &StructArray,
70        opts: &IsConstantOpts,
71    ) -> VortexResult<Option<bool>> {
72        let children = array.children();
73        if children.is_empty() {
74            return Ok(Some(true));
75        }
76
77        for child in children.iter() {
78            match is_constant_opts(child, opts)? {
79                // Un-determined
80                None => return Ok(None),
81                Some(false) => return Ok(Some(false)),
82                Some(true) => {}
83            }
84        }
85
86        Ok(Some(true))
87    }
88}
89
90register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
91
92#[cfg(test)]
93mod tests {
94    use Nullability::{NonNullable, Nullable};
95    use rstest::rstest;
96    use vortex_buffer::buffer;
97    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
98    use vortex_error::VortexUnwrap;
99    use vortex_mask::Mask;
100    use vortex_scalar::Scalar;
101
102    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
103    use crate::compute::conformance::consistency::test_array_consistency;
104    use crate::compute::conformance::filter::test_filter_conformance;
105    use crate::compute::conformance::mask::test_mask_conformance;
106    use crate::compute::conformance::take::test_take_conformance;
107    use crate::compute::{cast, filter, is_constant, take};
108    use crate::validity::Validity;
109    use crate::{Array, IntoArray as _};
110
111    #[test]
112    fn filter_empty_struct() {
113        let struct_arr =
114            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
115        let mask = vec![
116            false, true, false, true, false, true, false, true, false, true,
117        ];
118        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
119        assert_eq!(filtered.len(), 5);
120    }
121
122    #[test]
123    fn take_empty_struct() {
124        let struct_arr =
125            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
126        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
127        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
128        assert_eq!(taken.len(), 2);
129
130        assert_eq!(
131            taken.scalar_at(0),
132            Scalar::struct_(
133                DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
134                vec![]
135            )
136        );
137        assert_eq!(
138            taken.scalar_at(1),
139            Scalar::null(DType::Struct(
140                StructFields::new(FieldNames::default(), vec![]),
141                Nullable
142            ))
143        );
144    }
145
146    #[test]
147    fn take_field_struct() {
148        let struct_arr =
149            StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
150                .unwrap();
151        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
152        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
153        assert_eq!(taken.len(), 2);
154
155        assert_eq!(
156            taken.scalar_at(0),
157            Scalar::struct_(
158                struct_arr.dtype().union_nullability(Nullable),
159                vec![Scalar::primitive(1, NonNullable)],
160            )
161        );
162        assert_eq!(
163            taken.scalar_at(1),
164            Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
165        );
166    }
167
168    #[test]
169    fn filter_empty_struct_with_empty_filter() {
170        let struct_arr =
171            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
172        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
173        assert_eq!(filtered.len(), 0);
174    }
175
176    #[test]
177    fn test_mask_empty_struct() {
178        test_mask_conformance(
179            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
180                .unwrap()
181                .as_ref(),
182        );
183    }
184
185    #[test]
186    fn test_mask_complex_struct() {
187        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
188        let ys = VarBinArray::from_iter(
189            [Some("a"), Some("b"), None, Some("d"), None],
190            DType::Utf8(Nullable),
191        )
192        .into_array();
193        let zs =
194            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
195
196        test_mask_conformance(
197            StructArray::try_new(
198                ["xs", "ys", "zs"].into(),
199                vec![
200                    StructArray::try_new(
201                        ["left", "right"].into(),
202                        vec![xs.clone(), xs],
203                        5,
204                        Validity::NonNullable,
205                    )
206                    .unwrap()
207                    .into_array(),
208                    ys,
209                    zs,
210                ],
211                5,
212                Validity::NonNullable,
213            )
214            .unwrap()
215            .as_ref(),
216        );
217    }
218
219    #[test]
220    fn test_filter_empty_struct() {
221        test_filter_conformance(
222            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
223                .unwrap()
224                .as_ref(),
225        );
226    }
227
228    #[test]
229    fn test_filter_complex_struct() {
230        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
231        let ys = VarBinArray::from_iter(
232            [Some("a"), Some("b"), None, Some("d"), None],
233            DType::Utf8(Nullable),
234        )
235        .into_array();
236        let zs =
237            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
238
239        test_filter_conformance(
240            StructArray::try_new(
241                ["xs", "ys", "zs"].into(),
242                vec![
243                    StructArray::try_new(
244                        ["left", "right"].into(),
245                        vec![xs.clone(), xs],
246                        5,
247                        Validity::NonNullable,
248                    )
249                    .unwrap()
250                    .into_array(),
251                    ys,
252                    zs,
253                ],
254                5,
255                Validity::NonNullable,
256            )
257            .unwrap()
258            .as_ref(),
259        );
260    }
261
262    #[test]
263    fn test_cast_empty_struct() {
264        let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
265            .unwrap()
266            .into_array();
267        let non_nullable_dtype = DType::Struct(
268            StructFields::new(FieldNames::default(), vec![]),
269            NonNullable,
270        );
271        let casted = cast(&array, &non_nullable_dtype).unwrap();
272        assert_eq!(casted.dtype(), &non_nullable_dtype);
273
274        let nullable_dtype =
275            DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
276        let casted = cast(&array, &nullable_dtype).unwrap();
277        assert_eq!(casted.dtype(), &nullable_dtype);
278    }
279
280    #[test]
281    fn test_cast_cannot_change_name_order() {
282        let array = StructArray::try_new(
283            ["xs", "ys", "zs"].into(),
284            vec![
285                buffer![1u8].into_array(),
286                buffer![1u8].into_array(),
287                buffer![1u8].into_array(),
288            ],
289            1,
290            Validity::NonNullable,
291        )
292        .unwrap();
293
294        let tu8 = DType::Primitive(PType::U8, NonNullable);
295
296        let result = cast(
297            array.as_ref(),
298            &DType::Struct(
299                StructFields::new(
300                    FieldNames::from(["ys", "xs", "zs"]),
301                    vec![tu8.clone(), tu8.clone(), tu8],
302                ),
303                NonNullable,
304            ),
305        );
306        assert!(
307            result.as_ref().is_err_and(|err| {
308                err.to_string()
309                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
310            }),
311            "{result:?}"
312        );
313    }
314
315    #[test]
316    fn test_cast_complex_struct() {
317        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
318        let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
319        let zs = BoolArray::new(
320            BooleanBuffer::from_iter([true, true, false, false, true]),
321            Validity::AllValid,
322        );
323        let fully_nullable_array = StructArray::try_new(
324            ["xs", "ys", "zs"].into(),
325            vec![
326                StructArray::try_new(
327                    ["left", "right"].into(),
328                    vec![xs.to_array(), xs.to_array()],
329                    5,
330                    Validity::AllValid,
331                )
332                .unwrap()
333                .into_array(),
334                ys.into_array(),
335                zs.into_array(),
336            ],
337            5,
338            Validity::AllValid,
339        )
340        .unwrap()
341        .into_array();
342
343        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
344        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
345        assert_eq!(casted.dtype(), &top_level_non_nullable);
346
347        let non_null_xs_right = DType::Struct(
348            StructFields::new(
349                ["xs", "ys", "zs"].into(),
350                vec![
351                    DType::Struct(
352                        StructFields::new(
353                            ["left", "right"].into(),
354                            vec![
355                                DType::Primitive(PType::I64, NonNullable),
356                                DType::Primitive(PType::I64, Nullable),
357                            ],
358                        ),
359                        Nullable,
360                    ),
361                    DType::Utf8(Nullable),
362                    DType::Bool(Nullable),
363                ],
364            ),
365            Nullable,
366        );
367        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
368        assert_eq!(casted.dtype(), &non_null_xs_right);
369
370        let non_null_xs = DType::Struct(
371            StructFields::new(
372                ["xs", "ys", "zs"].into(),
373                vec![
374                    DType::Struct(
375                        StructFields::new(
376                            ["left", "right"].into(),
377                            vec![
378                                DType::Primitive(PType::I64, Nullable),
379                                DType::Primitive(PType::I64, Nullable),
380                            ],
381                        ),
382                        NonNullable,
383                    ),
384                    DType::Utf8(Nullable),
385                    DType::Bool(Nullable),
386                ],
387            ),
388            Nullable,
389        );
390        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
391        assert_eq!(casted.dtype(), &non_null_xs);
392    }
393
394    #[test]
395    fn test_empty_struct_is_constant() {
396        let array = StructArray::new_with_len(2);
397        let is_constant = is_constant(array.as_ref()).vortex_unwrap();
398        assert_eq!(is_constant, Some(true));
399    }
400
401    #[test]
402    fn test_take_empty_struct_conformance() {
403        test_take_conformance(
404            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
405                .unwrap()
406                .as_ref(),
407        );
408    }
409
410    #[test]
411    fn test_take_simple_struct_conformance() {
412        let xs = buffer![1i64, 2, 3, 4, 5].into_array();
413        let ys = VarBinArray::from_iter(
414            ["a", "b", "c", "d", "e"].map(Some),
415            DType::Utf8(NonNullable),
416        )
417        .into_array();
418
419        test_take_conformance(
420            StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 5, Validity::NonNullable)
421                .unwrap()
422                .as_ref(),
423        );
424    }
425
426    #[test]
427    fn test_take_nullable_struct_conformance() {
428        // Test struct with nullable fields
429        let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
430        let ys = VarBinArray::from_iter(
431            [Some("a"), Some("b"), None, Some("d"), None],
432            DType::Utf8(Nullable),
433        );
434
435        test_take_conformance(
436            StructArray::try_new(
437                ["xs", "ys"].into(),
438                vec![xs.into_array(), ys.into_array()],
439                5,
440                Validity::NonNullable,
441            )
442            .unwrap()
443            .as_ref(),
444        );
445    }
446
447    #[test]
448    fn test_take_nested_struct_conformance() {
449        // Test nested struct
450        let inner_xs = buffer![10i32, 20, 30, 40, 50].into_array();
451        let inner_ys = buffer![100i32, 200, 300, 400, 500].into_array();
452        let inner_struct = StructArray::try_new(
453            ["x", "y"].into(),
454            vec![inner_xs, inner_ys],
455            5,
456            Validity::NonNullable,
457        )
458        .unwrap()
459        .into_array();
460
461        let outer_zs = BoolArray::from_iter([true, false, true, false, true]).into_array();
462
463        test_take_conformance(
464            StructArray::try_new(
465                ["inner", "z"].into(),
466                vec![inner_struct, outer_zs],
467                5,
468                Validity::NonNullable,
469            )
470            .unwrap()
471            .as_ref(),
472        );
473    }
474
475    #[test]
476    fn test_take_single_element_struct_conformance() {
477        let xs = buffer![42i64].into_array();
478        let ys = VarBinArray::from_iter(["hello"].map(Some), DType::Utf8(NonNullable)).into_array();
479
480        test_take_conformance(
481            StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 1, Validity::NonNullable)
482                .unwrap()
483                .as_ref(),
484        );
485    }
486
487    #[test]
488    fn test_take_large_struct_conformance() {
489        // Test with larger array for additional edge cases
490        let xs = PrimitiveArray::from_iter(0i64..100).into_array();
491        let ys = VarBinArray::from_iter(
492            (0..100).map(|i| format!("str_{i}")).map(Some),
493            DType::Utf8(NonNullable),
494        )
495        .into_array();
496        let zs = BoolArray::from_iter((0..100).map(|i| i % 2 == 0)).into_array();
497
498        test_take_conformance(
499            StructArray::try_new(
500                ["xs", "ys", "zs"].into(),
501                vec![xs, ys, zs],
502                100,
503                Validity::NonNullable,
504            )
505            .unwrap()
506            .as_ref(),
507        );
508    }
509
510    // Consistency tests
511    #[rstest]
512    // From test_all_consistency
513    #[case::struct_simple({
514        let xs = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]);
515        let ys = VarBinArray::from_iter(
516            ["a", "b", "c", "d", "e"].map(Some),
517            DType::Utf8(NonNullable),
518        );
519        StructArray::try_new(
520            ["xs", "ys"].into(),
521            vec![xs.into_array(), ys.into_array()],
522            5,
523            Validity::NonNullable,
524        )
525        .unwrap()
526    })]
527    #[case::struct_nullable({
528        let xs = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None]);
529        let ys = VarBinArray::from_iter(
530            [Some("a"), Some("b"), None, Some("d"), None],
531            DType::Utf8(Nullable),
532        );
533        StructArray::try_new(
534            ["xs", "ys"].into(),
535            vec![xs.into_array(), ys.into_array()],
536            5,
537            Validity::NonNullable,
538        )
539        .unwrap()
540    })]
541    // Additional test cases
542    #[case::empty_struct(StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable).unwrap())]
543    #[case::single_field({
544        let xs = buffer![42i64].into_array();
545        StructArray::try_new(["xs"].into(), vec![xs], 1, Validity::NonNullable).unwrap()
546    })]
547    #[case::large_struct({
548        let xs = PrimitiveArray::from_iter(0..100i64).into_array();
549        let ys = VarBinArray::from_iter(
550            (0..100).map(|i| format!("value_{i}")).map(Some),
551            DType::Utf8(NonNullable),
552        ).into_array();
553        StructArray::try_new(["xs", "ys"].into(), vec![xs, ys], 100, Validity::NonNullable).unwrap()
554    })]
555    fn test_struct_consistency(#[case] array: StructArray) {
556        test_array_consistency(array.as_ref());
557    }
558}