vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Deref;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10
11use crate::arrays::{BinaryView, VarBinViewArray, VarBinViewVTable};
12use crate::compute::{TakeKernel, TakeKernelAdapter};
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16/// Take involves creating a new array that references the old array, just with the given set of views.
17impl TakeKernel for VarBinViewVTable {
18    fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19        // Compute the new validity
20
21        // This is valid since all elements (of all arrays) even null values are inside must be the
22        // min-max valid range.
23        let validity = array.validity().take(indices)?;
24        let indices = indices.to_primitive()?;
25
26        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
27            // This is valid since all elements even null values are inside the min-max valid range.
28            take_views(array.views(), indices.as_slice::<I>())
29        });
30
31        Ok(VarBinViewArray::try_new(
32            views_buffer,
33            array.buffers().to_vec(),
34            array
35                .dtype()
36                .union_nullability(indices.dtype().nullability()),
37            validity,
38        )?
39        .into_array())
40    }
41}
42
43register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
44
45fn take_views<I: AsPrimitive<usize>>(
46    views: &Buffer<BinaryView>,
47    indices: &[I],
48) -> Buffer<BinaryView> {
49    // NOTE(ngates): this deref is not actually trivial, so we run it once.
50    let views_ref = views.deref();
51    Buffer::<BinaryView>::from_iter(indices.iter().map(|i| views_ref[i.as_()]))
52}
53
54#[cfg(test)]
55mod tests {
56    use vortex_buffer::buffer;
57    use vortex_dtype::DType;
58    use vortex_dtype::Nullability::NonNullable;
59
60    use crate::IntoArray;
61    use crate::accessor::ArrayAccessor;
62    use crate::array::Array;
63    use crate::arrays::{PrimitiveArray, VarBinViewArray};
64    use crate::canonical::ToCanonical;
65    use crate::compute::take;
66
67    #[test]
68    fn take_nullable() {
69        let arr = VarBinViewArray::from_iter_nullable_str([
70            Some("one"),
71            None,
72            Some("three"),
73            Some("four"),
74            None,
75            Some("six"),
76        ]);
77
78        let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
79
80        assert!(taken.dtype().is_nullable());
81        assert_eq!(
82            taken
83                .to_varbinview()
84                .unwrap()
85                .with_iterator(|it| it
86                    .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
87                    .collect::<Vec<_>>())
88                .unwrap(),
89            [Some("one".to_string()), Some("four".to_string())]
90        );
91    }
92
93    #[test]
94    fn take_nullable_indices() {
95        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
96
97        let taken = take(
98            arr.as_ref(),
99            PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
100        )
101        .unwrap();
102
103        assert!(taken.dtype().is_nullable());
104        assert_eq!(
105            taken
106                .to_varbinview()
107                .unwrap()
108                .with_iterator(|it| it
109                    .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
110                    .collect::<Vec<_>>())
111                .unwrap(),
112            [Some("two".to_string()), None]
113        );
114    }
115}