Skip to main content

vortex_array/arrays/struct_/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::Array;
7use crate::ArrayRef;
8use crate::ExecutionCtx;
9use crate::IntoArray;
10use crate::arrays::StructArray;
11use crate::arrays::StructVTable;
12use crate::arrays::TakeExecute;
13use crate::builtins::ArrayBuiltins;
14use crate::scalar::Scalar;
15use crate::validity::Validity;
16use crate::vtable::ValidityHelper;
17
18impl TakeExecute for StructVTable {
19    fn take(
20        array: &StructArray,
21        indices: &dyn Array,
22        _ctx: &mut ExecutionCtx,
23    ) -> VortexResult<Option<ArrayRef>> {
24        // If the struct array is empty then the indices must be all null, otherwise it will access
25        // an out of bounds element.
26        if array.is_empty() {
27            return StructArray::try_new_with_dtype(
28                array.unmasked_fields().clone(),
29                array.struct_fields().clone(),
30                indices.len(),
31                Validity::AllInvalid,
32            )
33            .map(StructArray::into_array)
34            .map(Some);
35        }
36
37        // TODO(connor): This could be bad for cache locality...
38
39        // Fill null indices with zero so they point at a valid row.
40        // Note that we strip nullability so that `Take::return_dtype` doesn't union nullable into
41        // each field's dtype (the struct-level validity already captures which rows are null).
42        let fill_scalar = Scalar::zero_value(&indices.dtype().as_nonnullable());
43        let inner_indices = &indices.to_array().fill_null(fill_scalar)?;
44
45        StructArray::try_new_with_dtype(
46            array
47                .unmasked_fields()
48                .iter()
49                .map(|field| field.take(inner_indices.to_array()))
50                .collect::<Result<Vec<_>, _>>()?,
51            array.struct_fields().clone(),
52            indices.len(),
53            array.validity().take(indices)?,
54        )
55        .map(|a| a.into_array())
56        .map(Some)
57    }
58}