Skip to main content

vortex_array/arrays/varbinview/compute/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::iter;
5use std::sync::Arc;
6
7use num_traits::AsPrimitive;
8use vortex_buffer::Buffer;
9use vortex_error::VortexResult;
10use vortex_mask::AllOr;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::VarBinView;
18use crate::arrays::VarBinViewArray;
19use crate::arrays::dict::TakeExecute;
20use crate::arrays::varbinview::BinaryView;
21use crate::buffer::BufferHandle;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24
25impl TakeExecute for VarBinView {
26    /// Take involves creating a new array that references the old array, just with the given set of views.
27    fn take(
28        array: ArrayView<'_, VarBinView>,
29        indices: &ArrayRef,
30        ctx: &mut ExecutionCtx,
31    ) -> VortexResult<Option<ArrayRef>> {
32        let validity = array.validity()?.take(indices)?;
33        let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
34
35        let indices_mask = indices
36            .as_ref()
37            .validity()?
38            .execute_mask(indices.as_ref().len(), ctx)?;
39        let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
40            take_views(array.views(), indices.as_slice::<I>(), &indices_mask)
41        });
42
43        // SAFETY: taking all components at same indices maintains invariants
44        unsafe {
45            Ok(Some(
46                VarBinViewArray::new_handle_unchecked(
47                    BufferHandle::new_host(views_buffer.into_byte_buffer()),
48                    Arc::clone(array.data_buffers()),
49                    array
50                        .dtype()
51                        .union_nullability(indices.dtype().nullability()),
52                    validity,
53                )
54                .into_array(),
55            ))
56        }
57    }
58}
59
60fn take_views<I: AsPrimitive<usize>>(
61    views_ref: &[BinaryView],
62    indices: &[I],
63    mask: &Mask,
64) -> Buffer<BinaryView> {
65    // NOTE(ngates): this deref is not actually trivial, so we run it once.
66    // We do not use iter_bools directly, since the resulting dyn iterator cannot
67    // implement TrustedLen.
68    match mask.bit_buffer() {
69        AllOr::All => {
70            Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
71        }
72        AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
73            BinaryView::default(),
74            indices.len(),
75        )),
76        AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
77            buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
78                if valid {
79                    views_ref[idx.as_()]
80                } else {
81                    BinaryView::default()
82                }
83            }),
84        ),
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use rstest::rstest;
91    use vortex_buffer::BitBuffer;
92    use vortex_buffer::buffer;
93    use vortex_error::VortexResult;
94
95    use crate::IntoArray;
96    use crate::VortexSessionExecute;
97    use crate::array_session;
98    use crate::arrays::VarBinViewArray;
99    use crate::arrays::varbinview::compute::take::PrimitiveArray;
100    use crate::compute::conformance::take::test_take_conformance;
101    use crate::dtype::DType;
102    use crate::dtype::Nullability::NonNullable;
103    use crate::validity::Validity;
104
105    #[test]
106    fn take_nullable() -> VortexResult<()> {
107        let arr = VarBinViewArray::from_iter_nullable_str([
108            Some("one"),
109            None,
110            Some("three"),
111            Some("four"),
112            None,
113            Some("six"),
114        ]);
115
116        let taken = arr.take(buffer![0, 3].into_array())?;
117
118        assert!(taken.dtype().is_nullable());
119        let mut ctx = array_session().create_execution_ctx();
120        let taken = taken.execute::<VarBinViewArray>(&mut ctx)?;
121        let mask = taken.validity()?.execute_mask(taken.len(), &mut ctx)?;
122        let result = (0..taken.len())
123            .map(|i| {
124                mask.value(i)
125                    .then(|| unsafe { String::from_utf8_unchecked(taken.bytes_at(i).to_vec()) })
126            })
127            .collect::<Vec<_>>();
128        assert_eq!(result, [Some("one".to_string()), Some("four".to_string())]);
129        Ok(())
130    }
131
132    #[test]
133    fn take_nullable_indices() -> VortexResult<()> {
134        let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
135
136        let indices = PrimitiveArray::new(
137            // Verify that garbage values at NULL indices are ignored.
138            buffer![1u64, 999],
139            Validity::from(BitBuffer::from(vec![true, false])),
140        );
141
142        let taken = arr.take(indices.into_array())?;
143
144        assert!(taken.dtype().is_nullable());
145        let mut ctx = array_session().create_execution_ctx();
146        let taken = taken.execute::<VarBinViewArray>(&mut ctx)?;
147        let mask = taken.validity()?.execute_mask(taken.len(), &mut ctx)?;
148        let result = (0..taken.len())
149            .map(|i| {
150                mask.value(i)
151                    .then(|| unsafe { String::from_utf8_unchecked(taken.bytes_at(i).to_vec()) })
152            })
153            .collect::<Vec<_>>();
154        assert_eq!(result, [Some("two".to_string()), None]);
155        Ok(())
156    }
157
158    #[rstest]
159    #[case(VarBinViewArray::from_iter(
160        ["hello", "world", "test", "data", "array"].map(Some),
161        DType::Utf8(NonNullable),
162    ))]
163    #[case(VarBinViewArray::from_iter_nullable_str([
164        Some("hello"),
165        None,
166        Some("test"),
167        Some("data"),
168        None,
169    ]))]
170    #[case(VarBinViewArray::from_iter(
171        [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
172        DType::Binary(NonNullable),
173    ))]
174    #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
175    fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
176        test_take_conformance(&array.into_array());
177    }
178}