vortex_array/arrays/struct_/compute/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod cast;
5mod filter;
6mod mask;
7
8use itertools::Itertools;
9use vortex_dtype::Nullability::NonNullable;
10use vortex_error::VortexResult;
11use vortex_scalar::Scalar;
12
13use crate::arrays::StructVTable;
14use crate::arrays::struct_::StructArray;
15use crate::compute::{
16    IsConstantKernel, IsConstantKernelAdapter, IsConstantOpts, MinMaxKernel, MinMaxKernelAdapter,
17    MinMaxResult, TakeKernel, TakeKernelAdapter, fill_null, is_constant_opts, take,
18};
19use crate::validity::Validity;
20use crate::vtable::ValidityHelper;
21use crate::{Array, ArrayRef, IntoArray, register_kernel};
22
23impl TakeKernel for StructVTable {
24    fn take(&self, array: &StructArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
25        // If the struct array is empty then the indices must be all null, otherwise it will access
26        // an out of bounds element
27        if array.is_empty() {
28            return StructArray::try_new_with_dtype(
29                array.fields().to_vec(),
30                array.struct_fields().clone(),
31                indices.len(),
32                Validity::AllInvalid,
33            )
34            .map(StructArray::into_array);
35        }
36        // The validity is applied to the struct validity,
37        let inner_indices = &fill_null(
38            indices,
39            &Scalar::default_value(indices.dtype().with_nullability(NonNullable)),
40        )?;
41        StructArray::try_new_with_dtype(
42            array
43                .fields()
44                .iter()
45                .map(|field| take(field, inner_indices))
46                .try_collect()?,
47            array.struct_fields().clone(),
48            indices.len(),
49            array.validity().take(indices)?,
50        )
51        .map(|a| a.into_array())
52    }
53}
54
55register_kernel!(TakeKernelAdapter(StructVTable).lift());
56
57impl MinMaxKernel for StructVTable {
58    fn min_max(&self, _array: &StructArray) -> VortexResult<Option<MinMaxResult>> {
59        // TODO(joe): Implement struct min max
60        Ok(None)
61    }
62}
63
64register_kernel!(MinMaxKernelAdapter(StructVTable).lift());
65
66impl IsConstantKernel for StructVTable {
67    fn is_constant(
68        &self,
69        array: &StructArray,
70        opts: &IsConstantOpts,
71    ) -> VortexResult<Option<bool>> {
72        let children = array.children();
73        if children.is_empty() {
74            return Ok(Some(true));
75        }
76
77        for child in children.iter() {
78            match is_constant_opts(child, opts)? {
79                // Un-determined
80                None => return Ok(None),
81                Some(false) => return Ok(Some(false)),
82                Some(true) => {}
83            }
84        }
85
86        Ok(Some(true))
87    }
88}
89
90register_kernel!(IsConstantKernelAdapter(StructVTable).lift());
91
92#[cfg(test)]
93mod tests {
94    use Nullability::{NonNullable, Nullable};
95    use vortex_buffer::buffer;
96    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
97    use vortex_error::VortexUnwrap;
98    use vortex_mask::Mask;
99    use vortex_scalar::Scalar;
100
101    use crate::arrays::{BoolArray, BooleanBuffer, PrimitiveArray, StructArray, VarBinArray};
102    use crate::compute::conformance::mask::test_mask;
103    use crate::compute::{cast, filter, is_constant, take};
104    use crate::validity::Validity;
105    use crate::{Array, IntoArray as _};
106
107    #[test]
108    fn filter_empty_struct() {
109        let struct_arr =
110            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
111        let mask = vec![
112            false, true, false, true, false, true, false, true, false, true,
113        ];
114        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter(mask)).unwrap();
115        assert_eq!(filtered.len(), 5);
116    }
117
118    #[test]
119    fn take_empty_struct() {
120        let struct_arr =
121            StructArray::try_new(vec![].into(), vec![], 10, Validity::NonNullable).unwrap();
122        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
123        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
124        assert_eq!(taken.len(), 2);
125
126        assert_eq!(
127            taken.scalar_at(0).unwrap(),
128            Scalar::struct_(
129                DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable),
130                vec![]
131            )
132        );
133        assert_eq!(
134            taken.scalar_at(1).unwrap(),
135            Scalar::null(DType::Struct(
136                StructFields::new(FieldNames::default(), vec![]),
137                Nullable
138            ))
139        );
140    }
141
142    #[test]
143    fn take_field_struct() {
144        let struct_arr =
145            StructArray::from_fields(&[("a", PrimitiveArray::from_iter(0..10).to_array())])
146                .unwrap();
147        let indices = PrimitiveArray::from_option_iter([Some(1), None]);
148        let taken = take(struct_arr.as_ref(), indices.as_ref()).unwrap();
149        assert_eq!(taken.len(), 2);
150
151        assert_eq!(
152            taken.scalar_at(0).unwrap(),
153            Scalar::struct_(
154                struct_arr.dtype().union_nullability(Nullable),
155                vec![Scalar::primitive(1, NonNullable)],
156            )
157        );
158        assert_eq!(
159            taken.scalar_at(1).unwrap(),
160            Scalar::null(struct_arr.dtype().union_nullability(Nullable),)
161        );
162    }
163
164    #[test]
165    fn filter_empty_struct_with_empty_filter() {
166        let struct_arr =
167            StructArray::try_new(vec![].into(), vec![], 0, Validity::NonNullable).unwrap();
168        let filtered = filter(struct_arr.as_ref(), &Mask::from_iter::<[bool; 0]>([])).unwrap();
169        assert_eq!(filtered.len(), 0);
170    }
171
172    #[test]
173    fn test_mask_empty_struct() {
174        test_mask(
175            StructArray::try_new(vec![].into(), vec![], 5, Validity::NonNullable)
176                .unwrap()
177                .as_ref(),
178        );
179    }
180
181    #[test]
182    fn test_mask_complex_struct() {
183        let xs = buffer![0i64, 1, 2, 3, 4].into_array();
184        let ys = VarBinArray::from_iter(
185            [Some("a"), Some("b"), None, Some("d"), None],
186            DType::Utf8(Nullable),
187        )
188        .into_array();
189        let zs =
190            BoolArray::from_iter([Some(true), Some(true), None, None, Some(false)]).into_array();
191
192        test_mask(
193            StructArray::try_new(
194                ["xs", "ys", "zs"].into(),
195                vec![
196                    StructArray::try_new(
197                        ["left", "right"].into(),
198                        vec![xs.clone(), xs],
199                        5,
200                        Validity::NonNullable,
201                    )
202                    .unwrap()
203                    .into_array(),
204                    ys,
205                    zs,
206                ],
207                5,
208                Validity::NonNullable,
209            )
210            .unwrap()
211            .as_ref(),
212        );
213    }
214
215    #[test]
216    fn test_cast_empty_struct() {
217        let array = StructArray::try_new(FieldNames::default(), vec![], 5, Validity::NonNullable)
218            .unwrap()
219            .into_array();
220        let non_nullable_dtype = DType::Struct(
221            StructFields::new(FieldNames::default(), vec![]),
222            NonNullable,
223        );
224        let casted = cast(&array, &non_nullable_dtype).unwrap();
225        assert_eq!(casted.dtype(), &non_nullable_dtype);
226
227        let nullable_dtype =
228            DType::Struct(StructFields::new(FieldNames::default(), vec![]), Nullable);
229        let casted = cast(&array, &nullable_dtype).unwrap();
230        assert_eq!(casted.dtype(), &nullable_dtype);
231    }
232
233    #[test]
234    fn test_cast_cannot_change_name_order() {
235        let array = StructArray::try_new(
236            ["xs", "ys", "zs"].into(),
237            vec![
238                buffer![1u8].into_array(),
239                buffer![1u8].into_array(),
240                buffer![1u8].into_array(),
241            ],
242            1,
243            Validity::NonNullable,
244        )
245        .unwrap();
246
247        let tu8 = DType::Primitive(PType::U8, NonNullable);
248
249        let result = cast(
250            array.as_ref(),
251            &DType::Struct(
252                StructFields::new(
253                    FieldNames::from(["ys", "xs", "zs"]),
254                    vec![tu8.clone(), tu8.clone(), tu8],
255                ),
256                NonNullable,
257            ),
258        );
259        assert!(
260            result.as_ref().is_err_and(|err| {
261                err.to_string()
262                    .contains("cannot cast {xs=u8, ys=u8, zs=u8} to {ys=u8, xs=u8, zs=u8}")
263            }),
264            "{result:?}"
265        );
266    }
267
268    #[test]
269    fn test_cast_complex_struct() {
270        let xs = PrimitiveArray::from_option_iter([Some(0i64), Some(1), Some(2), Some(3), Some(4)]);
271        let ys = VarBinArray::from_vec(vec!["a", "b", "c", "d", "e"], DType::Utf8(Nullable));
272        let zs = BoolArray::new(
273            BooleanBuffer::from_iter([true, true, false, false, true]),
274            Validity::AllValid,
275        );
276        let fully_nullable_array = StructArray::try_new(
277            ["xs", "ys", "zs"].into(),
278            vec![
279                StructArray::try_new(
280                    ["left", "right"].into(),
281                    vec![xs.to_array(), xs.to_array()],
282                    5,
283                    Validity::AllValid,
284                )
285                .unwrap()
286                .into_array(),
287                ys.into_array(),
288                zs.into_array(),
289            ],
290            5,
291            Validity::AllValid,
292        )
293        .unwrap()
294        .into_array();
295
296        let top_level_non_nullable = fully_nullable_array.dtype().as_nonnullable();
297        let casted = cast(&fully_nullable_array, &top_level_non_nullable).unwrap();
298        assert_eq!(casted.dtype(), &top_level_non_nullable);
299
300        let non_null_xs_right = DType::Struct(
301            StructFields::new(
302                ["xs", "ys", "zs"].into(),
303                vec![
304                    DType::Struct(
305                        StructFields::new(
306                            ["left", "right"].into(),
307                            vec![
308                                DType::Primitive(PType::I64, NonNullable),
309                                DType::Primitive(PType::I64, Nullable),
310                            ],
311                        ),
312                        Nullable,
313                    ),
314                    DType::Utf8(Nullable),
315                    DType::Bool(Nullable),
316                ],
317            ),
318            Nullable,
319        );
320        let casted = cast(&fully_nullable_array, &non_null_xs_right).unwrap();
321        assert_eq!(casted.dtype(), &non_null_xs_right);
322
323        let non_null_xs = DType::Struct(
324            StructFields::new(
325                ["xs", "ys", "zs"].into(),
326                vec![
327                    DType::Struct(
328                        StructFields::new(
329                            ["left", "right"].into(),
330                            vec![
331                                DType::Primitive(PType::I64, Nullable),
332                                DType::Primitive(PType::I64, Nullable),
333                            ],
334                        ),
335                        NonNullable,
336                    ),
337                    DType::Utf8(Nullable),
338                    DType::Bool(Nullable),
339                ],
340            ),
341            Nullable,
342        );
343        let casted = cast(&fully_nullable_array, &non_null_xs).unwrap();
344        assert_eq!(casted.dtype(), &non_null_xs);
345    }
346
347    #[test]
348    fn test_empty_struct_is_constant() {
349        let array = StructArray::new_with_len(2);
350        let is_constant = is_constant(array.as_ref()).vortex_unwrap();
351        assert_eq!(is_constant, Some(true));
352    }
353}