Skip to main content

vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5use std::sync::Arc;
6
7use num_traits::AsPrimitive;
8use vortex_buffer::Buffer;
9use vortex_error::VortexResult;
10use vortex_mask::AllOr;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::VarBinView;
18use crate::arrays::VarBinViewArray;
19use crate::arrays::dict::TakeExecute;
20use crate::arrays::varbinview::BinaryView;
21use crate::buffer::BufferHandle;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24
25impl TakeExecute for VarBinView {
26    /// Take involves creating a new array that references the old array, just with the given set of views.
27    fn take(
28        array: ArrayView<'_, VarBinView>,
29        indices: &ArrayRef,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let validity = array.validity()?.take(indices)?;
33        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
34
35        let indices_mask = indices.validity_mask()?;
36        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
37            take_views(array.views(), indices.as_slice::<I>(), &indices_mask)
38        });
39
40        // SAFETY: taking all components at same indices maintains invariants
41        unsafe {
42            Ok(Some(
43                VarBinViewArray::new_handle_unchecked(
44                    BufferHandle::new_host(views_buffer.into_byte_buffer()),
45                    Arc::clone(array.data_buffers()),
46                    array
47                        .dtype()
48                        .union_nullability(indices.dtype().nullability()),
49                    validity,
50                )
51                .into_array(),
52            ))
53        }
54    }
55}
56
57fn take_views<I: AsPrimitive<usize>>(
58    views_ref: &[BinaryView],
59    indices: &[I],
60    mask: &Mask,
61) -> Buffer<BinaryView> {
62    // NOTE(ngates): this deref is not actually trivial, so we run it once.
63    // We do not use iter_bools directly, since the resulting dyn iterator cannot
64    // implement TrustedLen.
65    match mask.bit_buffer() {
66        AllOr::All => {
67            Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
68        }
69        AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
70            BinaryView::default(),
71            indices.len(),
72        )),
73        AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
74            buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
75                if valid {
76                    views_ref[idx.as_()]
77                } else {
78                    BinaryView::default()
79                }
80            }),
81        ),
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use rstest::rstest;
88    use vortex_buffer::BitBuffer;
89    use vortex_buffer::buffer;
90
91    use crate::IntoArray;
92    use crate::accessor::ArrayAccessor;
93    use crate::arrays::VarBinViewArray;
94    use crate::arrays::varbinview::compute::take::PrimitiveArray;
95    use crate::canonical::ToCanonical;
96    use crate::compute::conformance::take::test_take_conformance;
97    use crate::dtype::DType;
98    use crate::dtype::Nullability::NonNullable;
99    use crate::validity::Validity;
100
101    #[test]
102    fn take_nullable() {
103        let arr = VarBinViewArray::from_iter_nullable_str([
104            Some("one"),
105            None,
106            Some("three"),
107            Some("four"),
108            None,
109            Some("six"),
110        ]);
111
112        let taken = arr.take(buffer![0, 3].into_array()).unwrap();
113
114        assert!(taken.dtype().is_nullable());
115        assert_eq!(
116            taken.to_varbinview().with_iterator(|it| it
117                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
118                .collect::<Vec<_>>()),
119            [Some("one".to_string()), Some("four".to_string())]
120        );
121    }
122
123    #[test]
124    fn take_nullable_indices() {
125        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
126
127        let indices = PrimitiveArray::new(
128            // Verify that garbage values at NULL indices are ignored.
129            buffer![1u64, 999],
130            Validity::from(BitBuffer::from(vec![true, false])),
131        );
132
133        let taken = arr.take(indices.into_array()).unwrap();
134
135        assert!(taken.dtype().is_nullable());
136        assert_eq!(
137            taken.to_varbinview().with_iterator(|it| it
138                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
139                .collect::<Vec<_>>()),
140            [Some("two".to_string()), None]
141        );
142    }
143
144    #[rstest]
145    #[case(VarBinViewArray::from_iter(
146        ["hello", "world", "test", "data", "array"].map(Some),
147        DType::Utf8(NonNullable),
148    ))]
149    #[case(VarBinViewArray::from_iter_nullable_str([
150        Some("hello"),
151        None,
152        Some("test"),
153        Some("data"),
154        None,
155    ]))]
156    #[case(VarBinViewArray::from_iter(
157        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
158        DType::Binary(NonNullable),
159    ))]
160    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
161    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
162        test_take_conformance(&array.into_array());
163    }
164}