vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Deref;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_vector::binaryview::BinaryView;
11
12use crate::Array;
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::ToCanonical;
16use crate::arrays::VarBinViewArray;
17use crate::arrays::VarBinViewVTable;
18use crate::compute::TakeKernel;
19use crate::compute::TakeKernelAdapter;
20use crate::register_kernel;
21use crate::vtable::ValidityHelper;
22
23/// Take involves creating a new array that references the old array, just with the given set of views.
24impl TakeKernel for VarBinViewVTable {
25    fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
26        // Compute the new validity
27
28        // This is valid since all elements (of all arrays) even null values must be inside
29        // min-max valid range.
30        let validity = array.validity().take(indices)?;
31        let indices = indices.to_primitive();
32
33        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
34            // This is valid since all elements even null values are inside the min-max valid range.
35            take_views(array.views(), indices.as_slice::<I>())
36        });
37
38        // SAFETY: taking all components at same indices maintains invariants
39        unsafe {
40            Ok(VarBinViewArray::new_unchecked(
41                views_buffer,
42                array.buffers().clone(),
43                array
44                    .dtype()
45                    .union_nullability(indices.dtype().nullability()),
46                validity,
47            )
48            .into_array())
49        }
50    }
51}
52
53register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
54
55fn take_views<I: AsPrimitive<usize>>(
56    views: &Buffer<BinaryView>,
57    indices: &[I],
58) -> Buffer<BinaryView> {
59    // NOTE(ngates): this deref is not actually trivial, so we run it once.
60    let views_ref = views.deref();
61    Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
62}
63
64#[cfg(test)]
65mod tests {
66    use rstest::rstest;
67    use vortex_buffer::buffer;
68    use vortex_dtype::DType;
69    use vortex_dtype::Nullability::NonNullable;
70
71    use crate::IntoArray;
72    use crate::accessor::ArrayAccessor;
73    use crate::array::Array;
74    use crate::arrays::PrimitiveArray;
75    use crate::arrays::VarBinViewArray;
76    use crate::canonical::ToCanonical;
77    use crate::compute::conformance::take::test_take_conformance;
78    use crate::compute::take;
79
80    #[test]
81    fn take_nullable() {
82        let arr = VarBinViewArray::from_iter_nullable_str([
83            Some("one"),
84            None,
85            Some("three"),
86            Some("four"),
87            None,
88            Some("six"),
89        ]);
90
91        let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
92
93        assert!(taken.dtype().is_nullable());
94        assert_eq!(
95            taken.to_varbinview().with_iterator(|it| it
96                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
97                .collect::<Vec<_>>()),
98            [Some("one".to_string()), Some("four".to_string())]
99        );
100    }
101
102    #[test]
103    fn take_nullable_indices() {
104        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
105
106        let taken = take(
107            arr.as_ref(),
108            PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
109        )
110        .unwrap();
111
112        assert!(taken.dtype().is_nullable());
113        assert_eq!(
114            taken.to_varbinview().with_iterator(|it| it
115                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
116                .collect::<Vec<_>>()),
117            [Some("two".to_string()), None]
118        );
119    }
120
121    #[rstest]
122    #[case(VarBinViewArray::from_iter(
123        ["hello", "world", "test", "data", "array"].map(Some),
124        DType::Utf8(NonNullable),
125    ))]
126    #[case(VarBinViewArray::from_iter_nullable_str([
127        Some("hello"),
128        None,
129        Some("test"),
130        Some("data"),
131        None,
132    ]))]
133    #[case(VarBinViewArray::from_iter(
134        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
135        DType::Binary(NonNullable),
136    ))]
137    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
138    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
139        test_take_conformance(array.as_ref());
140    }
141}