Skip to main content

vortex_array/arrays/struct_/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::arrays::StructArray;
10use crate::arrays::StructVTable;
11use crate::arrays::TakeReduce;
12use crate::builtins::ArrayBuiltins;
13use crate::scalar::Scalar;
14use crate::validity::Validity;
15use crate::vtable::ValidityHelper;
16
17impl TakeReduce for StructVTable {
18    fn take(array: &StructArray, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
19        // If the struct array is empty then the indices must be all null, otherwise it will access
20        // an out of bounds element.
21        if array.is_empty() {
22            return StructArray::try_new_with_dtype(
23                array.unmasked_fields().clone(),
24                array.struct_fields().clone(),
25                indices.len(),
26                Validity::AllInvalid,
27            )
28            .map(StructArray::into_array)
29            .map(Some);
30        }
31
32        // TODO(connor): This could be bad for cache locality...
33
34        // Fill null indices with zero so they point at a valid row.
35        // Note that we strip nullability so that `Take::return_dtype` doesn't union nullable into
36        // each field's dtype (the struct-level validity already captures which rows are null).
37        let fill_scalar = Scalar::zero_value(&indices.dtype().as_nonnullable());
38        let inner_indices = indices.to_array().fill_null(fill_scalar)?;
39
40        StructArray::try_new_with_dtype(
41            array
42                .unmasked_fields()
43                .iter()
44                .map(|field| field.take(inner_indices.clone()))
45                .collect::<Result<Vec<_>, _>>()?,
46            array.struct_fields().clone(),
47            indices.len(),
48            array.validity().take(indices)?,
49        )
50        .map(|a| a.into_array())
51        .map(Some)
52    }
53}