vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5use std::ops::Deref;
6
7use num_traits::AsPrimitive;
8use vortex_buffer::Buffer;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::VortexResult;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_vector::binaryview::BinaryView;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::VarBinViewArray;
20use crate::arrays::VarBinViewVTable;
21use crate::compute::TakeKernel;
22use crate::compute::TakeKernelAdapter;
23use crate::register_kernel;
24use crate::vtable::ValidityHelper;
25
26/// Take involves creating a new array that references the old array, just with the given set of views.
27impl TakeKernel for VarBinViewVTable {
28    fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
29        // Compute the new validity.
30        let validity = array.validity().take(indices)?;
31        let indices = indices.to_primitive();
32
33        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
34            take_views(
35                array.views(),
36                indices.as_slice::<I>(),
37                &indices.validity_mask(),
38            )
39        });
40
41        // SAFETY: taking all components at same indices maintains invariants
42        unsafe {
43            Ok(VarBinViewArray::new_unchecked(
44                views_buffer,
45                array.buffers().clone(),
46                array
47                    .dtype()
48                    .union_nullability(indices.dtype().nullability()),
49                validity,
50            )
51            .into_array())
52        }
53    }
54}
55
56register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
57
58fn take_views<I: AsPrimitive<usize>>(
59    views: &Buffer<BinaryView>,
60    indices: &[I],
61    mask: &Mask,
62) -> Buffer<BinaryView> {
63    // NOTE(ngates): this deref is not actually trivial, so we run it once.
64    let views_ref = views.deref();
65    // We do not use iter_bools directly, since the resulting dyn iterator cannot
66    // implement TrustedLen.
67    match mask.bit_buffer() {
68        AllOr::All => {
69            Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
70        }
71        AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
72            BinaryView::default(),
73            indices.len(),
74        )),
75        AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
76            buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
77                if valid {
78                    views_ref[idx.as_()]
79                } else {
80                    BinaryView::default()
81                }
82            }),
83        ),
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use rstest::rstest;
90    use vortex_buffer::BitBuffer;
91    use vortex_buffer::buffer;
92    use vortex_dtype::DType;
93    use vortex_dtype::Nullability::NonNullable;
94
95    use crate::IntoArray;
96    use crate::accessor::ArrayAccessor;
97    use crate::array::Array;
98    use crate::arrays::PrimitiveArray;
99    use crate::arrays::VarBinViewArray;
100    use crate::canonical::ToCanonical;
101    use crate::compute::conformance::take::test_take_conformance;
102    use crate::compute::take;
103    use crate::validity::Validity;
104
105    #[test]
106    fn take_nullable() {
107        let arr = VarBinViewArray::from_iter_nullable_str([
108            Some("one"),
109            None,
110            Some("three"),
111            Some("four"),
112            None,
113            Some("six"),
114        ]);
115
116        let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
117
118        assert!(taken.dtype().is_nullable());
119        assert_eq!(
120            taken.to_varbinview().with_iterator(|it| it
121                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
122                .collect::<Vec<_>>()),
123            [Some("one".to_string()), Some("four".to_string())]
124        );
125    }
126
127    #[test]
128    fn take_nullable_indices() {
129        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
130
131        let indices = PrimitiveArray::new(
132            // Verify that garbage values at NULL indices are ignored.
133            buffer![1u64, 999],
134            Validity::from(BitBuffer::from(vec![true, false])),
135        );
136
137        let taken = take(arr.as_ref(), indices.as_ref()).unwrap();
138
139        assert!(taken.dtype().is_nullable());
140        assert_eq!(
141            taken.to_varbinview().with_iterator(|it| it
142                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
143                .collect::<Vec<_>>()),
144            [Some("two".to_string()), None]
145        );
146    }
147
148    #[rstest]
149    #[case(VarBinViewArray::from_iter(
150        ["hello", "world", "test", "data", "array"].map(Some),
151        DType::Utf8(NonNullable),
152    ))]
153    #[case(VarBinViewArray::from_iter_nullable_str([
154        Some("hello"),
155        None,
156        Some("test"),
157        Some("data"),
158        None,
159    ]))]
160    #[case(VarBinViewArray::from_iter(
161        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
162        DType::Binary(NonNullable),
163    ))]
164    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
165    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
166        test_take_conformance(array.as_ref());
167    }
168}