vortex_array/arrays/varbinview/compute/
take.rs

1use std::ops::Deref;
2
3use num_traits::AsPrimitive;
4use vortex_buffer::Buffer;
5use vortex_dtype::match_each_integer_ptype;
6use vortex_error::VortexResult;
7
8use crate::arrays::{BinaryView, VarBinViewArray, VarBinViewVTable};
9use crate::compute::{TakeKernel, TakeKernelAdapter};
10use crate::vtable::ValidityHelper;
11use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
12
13/// Take involves creating a new array that references the old array, just with the given set of views.
14impl TakeKernel for VarBinViewVTable {
15    fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16        // Compute the new validity
17
18        // This is valid since all elements (of all arrays) even null values are inside must be the
19        // min-max valid range.
20        let validity = array.validity().take(indices)?;
21        let indices = indices.to_primitive()?;
22
23        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
24            // This is valid since all elements even null values are inside the min-max valid range.
25            take_views(array.views(), indices.as_slice::<I>())
26        });
27
28        Ok(VarBinViewArray::try_new(
29            views_buffer,
30            array.buffers().to_vec(),
31            array
32                .dtype()
33                .union_nullability(indices.dtype().nullability()),
34            validity,
35        )?
36        .into_array())
37    }
38}
39
40register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
41
42fn take_views<I: AsPrimitive<usize>>(
43    views: &Buffer<BinaryView>,
44    indices: &[I],
45) -> Buffer<BinaryView> {
46    // NOTE(ngates): this deref is not actually trivial, so we run it once.
47    let views_ref = views.deref();
48    Buffer::<BinaryView>::from_iter(indices.iter().map(|i| views_ref[i.as_()]))
49}
50
51#[cfg(test)]
52mod tests {
53    use vortex_buffer::buffer;
54    use vortex_dtype::DType;
55    use vortex_dtype::Nullability::NonNullable;
56
57    use crate::IntoArray;
58    use crate::accessor::ArrayAccessor;
59    use crate::array::Array;
60    use crate::arrays::{PrimitiveArray, VarBinViewArray};
61    use crate::canonical::ToCanonical;
62    use crate::compute::take;
63
64    #[test]
65    fn take_nullable() {
66        let arr = VarBinViewArray::from_iter_nullable_str([
67            Some("one"),
68            None,
69            Some("three"),
70            Some("four"),
71            None,
72            Some("six"),
73        ]);
74
75        let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
76
77        assert!(taken.dtype().is_nullable());
78        assert_eq!(
79            taken
80                .to_varbinview()
81                .unwrap()
82                .with_iterator(|it| it
83                    .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
84                    .collect::<Vec<_>>())
85                .unwrap(),
86            [Some("one".to_string()), Some("four".to_string())]
87        );
88    }
89
90    #[test]
91    fn take_nullable_indices() {
92        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
93
94        let taken = take(
95            arr.as_ref(),
96            PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
97        )
98        .unwrap();
99
100        assert!(taken.dtype().is_nullable());
101        assert_eq!(
102            taken
103                .to_varbinview()
104                .unwrap()
105                .with_iterator(|it| it
106                    .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
107                    .collect::<Vec<_>>())
108                .unwrap(),
109            [Some("two".to_string()), None]
110        );
111    }
112}