Skip to main content

vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10use vortex_mask::AllOr;
11use vortex_mask::Mask;
12
13use crate::Array;
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::ToCanonical;
17use crate::arrays::BinaryView;
18use crate::arrays::TakeExecute;
19use crate::arrays::VarBinViewArray;
20use crate::arrays::VarBinViewVTable;
21use crate::buffer::BufferHandle;
22use crate::executor::ExecutionCtx;
23use crate::vtable::ValidityHelper;
24
25impl TakeExecute for VarBinViewVTable {
26    /// Take involves creating a new array that references the old array, just with the given set of views.
27    fn take(
28        array: &VarBinViewArray,
29        indices: &dyn Array,
30        _ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let validity = array.validity().take(indices)?;
33        let indices = indices.to_primitive();
34
35        let indices_mask = indices.validity_mask()?;
36        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
37            take_views(array.views(), indices.as_slice::<I>(), &indices_mask)
38        });
39
40        // SAFETY: taking all components at same indices maintains invariants
41        unsafe {
42            Ok(Some(
43                VarBinViewArray::new_handle_unchecked(
44                    BufferHandle::new_host(views_buffer.into_byte_buffer()),
45                    array.buffers().clone(),
46                    array
47                        .dtype()
48                        .union_nullability(indices.dtype().nullability()),
49                    validity,
50                )
51                .into_array(),
52            ))
53        }
54    }
55}
56
57fn take_views<I: AsPrimitive<usize>>(
58    views_ref: &[BinaryView],
59    indices: &[I],
60    mask: &Mask,
61) -> Buffer<BinaryView> {
62    // NOTE(ngates): this deref is not actually trivial, so we run it once.
63    // We do not use iter_bools directly, since the resulting dyn iterator cannot
64    // implement TrustedLen.
65    match mask.bit_buffer() {
66        AllOr::All => {
67            Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
68        }
69        AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
70            BinaryView::default(),
71            indices.len(),
72        )),
73        AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
74            buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
75                if valid {
76                    views_ref[idx.as_()]
77                } else {
78                    BinaryView::default()
79                }
80            }),
81        ),
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use rstest::rstest;
88    use vortex_buffer::BitBuffer;
89    use vortex_buffer::buffer;
90    use vortex_dtype::DType;
91    use vortex_dtype::Nullability::NonNullable;
92
93    use crate::IntoArray;
94    use crate::accessor::ArrayAccessor;
95    use crate::array::Array;
96    use crate::arrays::PrimitiveArray;
97    use crate::arrays::VarBinViewArray;
98    use crate::canonical::ToCanonical;
99    use crate::compute::conformance::take::test_take_conformance;
100    use crate::validity::Validity;
101
102    #[test]
103    fn take_nullable() {
104        let arr = VarBinViewArray::from_iter_nullable_str([
105            Some("one"),
106            None,
107            Some("three"),
108            Some("four"),
109            None,
110            Some("six"),
111        ]);
112
113        let taken = arr.take(buffer![0, 3].into_array()).unwrap();
114
115        assert!(taken.dtype().is_nullable());
116        assert_eq!(
117            taken.to_varbinview().with_iterator(|it| it
118                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
119                .collect::<Vec<_>>()),
120            [Some("one".to_string()), Some("four".to_string())]
121        );
122    }
123
124    #[test]
125    fn take_nullable_indices() {
126        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
127
128        let indices = PrimitiveArray::new(
129            // Verify that garbage values at NULL indices are ignored.
130            buffer![1u64, 999],
131            Validity::from(BitBuffer::from(vec![true, false])),
132        );
133
134        let taken = arr.take(indices.to_array()).unwrap();
135
136        assert!(taken.dtype().is_nullable());
137        assert_eq!(
138            taken.to_varbinview().with_iterator(|it| it
139                .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
140                .collect::<Vec<_>>()),
141            [Some("two".to_string()), None]
142        );
143    }
144
145    #[rstest]
146    #[case(VarBinViewArray::from_iter(
147        ["hello", "world", "test", "data", "array"].map(Some),
148        DType::Utf8(NonNullable),
149    ))]
150    #[case(VarBinViewArray::from_iter_nullable_str([
151        Some("hello"),
152        None,
153        Some("test"),
154        Some("data"),
155        None,
156    ]))]
157    #[case(VarBinViewArray::from_iter(
158        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
159        DType::Binary(NonNullable),
160    ))]
161    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
162    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
163        test_take_conformance(array.as_ref());
164    }
165}