vortex_compute/take/vector/
struct_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_dtype::UnsignedPType;
7use vortex_vector::Vector;
8use vortex_vector::VectorOps;
9use vortex_vector::primitive::PVector;
10use vortex_vector::struct_::StructVector;
11
12use crate::take::Take;
13
14impl<I: UnsignedPType> Take<PVector<I>> for &StructVector {
15    type Output = StructVector;
16
17    fn take(self, indices: &PVector<I>) -> StructVector {
18        if indices.validity().all_true() {
19            self.take(indices.elements().as_slice())
20        } else {
21            take_nullable(self, indices)
22        }
23    }
24}
25
26impl<I: UnsignedPType> Take<[I]> for &StructVector {
27    type Output = StructVector;
28
29    fn take(self, indices: &[I]) -> StructVector {
30        let taken_fields: Box<[Vector]> = self
31            .fields()
32            .iter()
33            .map(|field| field.take(indices))
34            .collect();
35        let taken_validity = self.validity().take(indices);
36
37        // SAFETY: We called take on all fields and validity with the same indices, so all fields
38        // must have the same length as each other and as the validity.
39        unsafe { StructVector::new_unchecked(Arc::new(taken_fields), taken_validity) }
40    }
41}
42
43fn take_nullable<I: UnsignedPType>(svector: &StructVector, indices: &PVector<I>) -> StructVector {
44    // We ignore nullability when taking the fields since we can let the `Mask` implementation
45    // determine which elements are null.
46    let taken_fields: Box<[Vector]> = svector
47        .fields()
48        .iter()
49        .map(|field| field.take(indices.elements().as_slice()))
50        .collect();
51
52    // NB: This is the nullable version of `take`, so this is not the same as the `take`
53    // implementation `indices: &[I]` above.
54    let taken_validity = svector.validity().take(indices);
55
56    // SAFETY: We called take on all fields and validity with the same indices, so all fields must
57    // have the same length as each other and as the validity.
58    unsafe { StructVector::new_unchecked(Arc::new(taken_fields), taken_validity) }
59}