Skip to main content

vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_error::VortexResult;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::BinaryView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::TakeExecute;
17use crate::arrays::VarBinViewArray;
18use crate::arrays::VarBinViewVTable;
19use crate::buffer::BufferHandle;
20use crate::executor::ExecutionCtx;
21use crate::match_each_integer_ptype;
22use crate::vtable::ValidityHelper;
23
24impl TakeExecute for VarBinViewVTable {
25    /// Take involves creating a new array that references the old array, just with the given set of views.
26    fn take(
27        array: &VarBinViewArray,
28        indices: &ArrayRef,
29        ctx: &mut ExecutionCtx,
30    ) -> VortexResult<Option<ArrayRef>> {
31        let validity = array.validity().take(indices)?;
32        let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
33
34        let indices_mask = indices.validity_mask()?;
35        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
36            take_views(array.views(), indices.as_slice::<I>(), &indices_mask)
37        });
38
39        // SAFETY: taking all components at same indices maintains invariants
40        unsafe {
41            Ok(Some(
42                VarBinViewArray::new_handle_unchecked(
43                    BufferHandle::new_host(views_buffer.into_byte_buffer()),
44                    array.buffers().clone(),
45                    array
46                        .dtype()
47                        .union_nullability(indices.dtype().nullability()),
48                    validity,
49                )
50                .into_array(),
51            ))
52        }
53    }
54}
55
56fn take_views<I: AsPrimitive<usize>>(
57    views_ref: &[BinaryView],
58    indices: &[I],
59    mask: &Mask,
60) -> Buffer<BinaryView> {
61    // NOTE(ngates): this deref is not actually trivial, so we run it once.
62    // We do not use iter_bools directly, since the resulting dyn iterator cannot
63    // implement TrustedLen.
64    match mask.bit_buffer() {
65        AllOr::All => {
66            Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
67        }
68        AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
69            BinaryView::default(),
70            indices.len(),
71        )),
72        AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
73            buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
74                if valid {
75                    views_ref[idx.as_()]
76                } else {
77                    BinaryView::default()
78                }
79            }),
80        ),
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use rstest::rstest;
87    use vortex_buffer::BitBuffer;
88    use vortex_buffer::buffer;
89
90    use crate::IntoArray;
91    use crate::accessor::ArrayAccessor;
92    use crate::array::Array;
93    use crate::arrays::PrimitiveArray;
94    use crate::arrays::VarBinViewArray;
95    use crate::canonical::ToCanonical;
96    use crate::compute::conformance::take::test_take_conformance;
97    use crate::dtype::DType;
98    use crate::dtype::Nullability::NonNullable;
99    use crate::validity::Validity;
100
101    #[test]
102    fn take_nullable() {
103        let arr = VarBinViewArray::from_iter_nullable_str([
104            Some("one"),
105            None,
106            Some("three"),
107            Some("four"),
108            None,
109            Some("six"),
110        ]);
111
112        let taken = arr.take(buffer![0, 3].into_array()).unwrap();
113
114        assert!(taken.dtype().is_nullable());
115        assert_eq!(
116            taken.to_varbinview().with_iterator(|it| it
117                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
118                .collect::<Vec<_>>()),
119            [Some("one".to_string()), Some("four".to_string())]
120        );
121    }
122
123    #[test]
124    fn take_nullable_indices() {
125        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
126
127        let indices = PrimitiveArray::new(
128            // Verify that garbage values at NULL indices are ignored.
129            buffer![1u64, 999],
130            Validity::from(BitBuffer::from(vec![true, false])),
131        );
132
133        let taken = arr.take(indices.to_array()).unwrap();
134
135        assert!(taken.dtype().is_nullable());
136        assert_eq!(
137            taken.to_varbinview().with_iterator(|it| it
138                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
139                .collect::<Vec<_>>()),
140            [Some("two".to_string()), None]
141        );
142    }
143
144    #[rstest]
145    #[case(VarBinViewArray::from_iter(
146        ["hello", "world", "test", "data", "array"].map(Some),
147        DType::Utf8(NonNullable),
148    ))]
149    #[case(VarBinViewArray::from_iter_nullable_str([
150        Some("hello"),
151        None,
152        Some("test"),
153        Some("data"),
154        None,
155    ]))]
156    #[case(VarBinViewArray::from_iter(
157        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
158        DType::Binary(NonNullable),
159    ))]
160    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
161    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
162        test_take_conformance(&array.to_array());
163    }
164}