vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Deref;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10
11use crate::arrays::binary_view::BinaryView;
12use crate::arrays::{VarBinViewArray, VarBinViewVTable};
13use crate::compute::{TakeKernel, TakeKernelAdapter};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17/// Take involves creating a new array that references the old array, just with the given set of views.
18impl TakeKernel for VarBinViewVTable {
19    fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20        // Compute the new validity
21
22        // This is valid since all elements (of all arrays) even null values must be inside
23        // min-max valid range.
24        let validity = array.validity().take(indices)?;
25        let indices = indices.to_primitive();
26
27        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
28            // This is valid since all elements even null values are inside the min-max valid range.
29            take_views(array.views(), indices.as_slice::<I>())
30        });
31
32        // SAFETY: taking all components at same indices maintains invariants
33        unsafe {
34            Ok(VarBinViewArray::new_unchecked(
35                views_buffer,
36                array.buffers().clone(),
37                array
38                    .dtype()
39                    .union_nullability(indices.dtype().nullability()),
40                validity,
41            )
42            .into_array())
43        }
44    }
45}
46
47register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
48
49fn take_views<I: AsPrimitive<usize>>(
50    views: &Buffer<BinaryView>,
51    indices: &[I],
52) -> Buffer<BinaryView> {
53    // NOTE(ngates): this deref is not actually trivial, so we run it once.
54    let views_ref = views.deref();
55    Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
56}
57
58#[cfg(test)]
59mod tests {
60    use rstest::rstest;
61    use vortex_buffer::buffer;
62    use vortex_dtype::DType;
63    use vortex_dtype::Nullability::NonNullable;
64
65    use crate::IntoArray;
66    use crate::accessor::ArrayAccessor;
67    use crate::array::Array;
68    use crate::arrays::{PrimitiveArray, VarBinViewArray};
69    use crate::canonical::ToCanonical;
70    use crate::compute::conformance::take::test_take_conformance;
71    use crate::compute::take;
72
73    #[test]
74    fn take_nullable() {
75        let arr = VarBinViewArray::from_iter_nullable_str([
76            Some("one"),
77            None,
78            Some("three"),
79            Some("four"),
80            None,
81            Some("six"),
82        ]);
83
84        let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
85
86        assert!(taken.dtype().is_nullable());
87        assert_eq!(
88            taken
89                .to_varbinview()
90                .with_iterator(|it| it
91                    .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
92                    .collect::<Vec<_>>())
93                .unwrap(),
94            [Some("one".to_string()), Some("four".to_string())]
95        );
96    }
97
98    #[test]
99    fn take_nullable_indices() {
100        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
101
102        let taken = take(
103            arr.as_ref(),
104            PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
105        )
106        .unwrap();
107
108        assert!(taken.dtype().is_nullable());
109        assert_eq!(
110            taken
111                .to_varbinview()
112                .with_iterator(|it| it
113                    .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
114                    .collect::<Vec<_>>())
115                .unwrap(),
116            [Some("two".to_string()), None]
117        );
118    }
119
120    #[rstest]
121    #[case(VarBinViewArray::from_iter(
122        ["hello", "world", "test", "data", "array"].map(Some),
123        DType::Utf8(NonNullable),
124    ))]
125    #[case(VarBinViewArray::from_iter_nullable_str([
126        Some("hello"),
127        None,
128        Some("test"),
129        Some("data"),
130        None,
131    ]))]
132    #[case(VarBinViewArray::from_iter(
133        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
134        DType::Binary(NonNullable),
135    ))]
136    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
137    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
138        test_take_conformance(array.as_ref());
139    }
140}