vortex_array/arrays/varbinview/compute/
take.rs1use std::iter;
5use std::sync::Arc;
6
7use num_traits::AsPrimitive;
8use vortex_buffer::Buffer;
9use vortex_error::VortexResult;
10use vortex_mask::AllOr;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::IntoArray;
15use crate::array::ArrayView;
16use crate::arrays::PrimitiveArray;
17use crate::arrays::VarBinView;
18use crate::arrays::VarBinViewArray;
19use crate::arrays::dict::TakeExecute;
20use crate::arrays::varbinview::BinaryView;
21use crate::buffer::BufferHandle;
22use crate::executor::ExecutionCtx;
23use crate::match_each_integer_ptype;
24
25impl TakeExecute for VarBinView {
26 fn take(
28 array: ArrayView<'_, VarBinView>,
29 indices: &ArrayRef,
30 ctx: &mut ExecutionCtx,
31 ) -> VortexResult<Option<ArrayRef>> {
32 let validity = array.validity()?.take(indices)?;
33 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
34
35 let indices_mask = indices
36 .as_ref()
37 .validity()?
38 .to_mask(indices.as_ref().len(), ctx)?;
39 let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
40 take_views(array.views(), indices.as_slice::<I>(), &indices_mask)
41 });
42
43 unsafe {
45 Ok(Some(
46 VarBinViewArray::new_handle_unchecked(
47 BufferHandle::new_host(views_buffer.into_byte_buffer()),
48 Arc::clone(array.data_buffers()),
49 array
50 .dtype()
51 .union_nullability(indices.dtype().nullability()),
52 validity,
53 )
54 .into_array(),
55 ))
56 }
57 }
58}
59
60fn take_views<I: AsPrimitive<usize>>(
61 views_ref: &[BinaryView],
62 indices: &[I],
63 mask: &Mask,
64) -> Buffer<BinaryView> {
65 match mask.bit_buffer() {
69 AllOr::All => {
70 Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
71 }
72 AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
73 BinaryView::default(),
74 indices.len(),
75 )),
76 AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
77 buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
78 if valid {
79 views_ref[idx.as_()]
80 } else {
81 BinaryView::default()
82 }
83 }),
84 ),
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use rstest::rstest;
91 use vortex_buffer::BitBuffer;
92 use vortex_buffer::buffer;
93
94 use crate::IntoArray;
95 use crate::accessor::ArrayAccessor;
96 use crate::arrays::VarBinViewArray;
97 use crate::arrays::varbinview::compute::take::PrimitiveArray;
98 use crate::canonical::ToCanonical;
99 use crate::compute::conformance::take::test_take_conformance;
100 use crate::dtype::DType;
101 use crate::dtype::Nullability::NonNullable;
102 use crate::validity::Validity;
103
104 #[test]
105 fn take_nullable() {
106 let arr = VarBinViewArray::from_iter_nullable_str([
107 Some("one"),
108 None,
109 Some("three"),
110 Some("four"),
111 None,
112 Some("six"),
113 ]);
114
115 let taken = arr.take(buffer![0, 3].into_array()).unwrap();
116
117 assert!(taken.dtype().is_nullable());
118 assert_eq!(
119 taken.to_varbinview().with_iterator(|it| it
120 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
121 .collect::<Vec<_>>()),
122 [Some("one".to_string()), Some("four".to_string())]
123 );
124 }
125
126 #[test]
127 fn take_nullable_indices() {
128 let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
129
130 let indices = PrimitiveArray::new(
131 buffer![1u64, 999],
133 Validity::from(BitBuffer::from(vec![true, false])),
134 );
135
136 let taken = arr.take(indices.into_array()).unwrap();
137
138 assert!(taken.dtype().is_nullable());
139 assert_eq!(
140 taken.to_varbinview().with_iterator(|it| it
141 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
142 .collect::<Vec<_>>()),
143 [Some("two".to_string()), None]
144 );
145 }
146
147 #[rstest]
148 #[case(VarBinViewArray::from_iter(
149 ["hello", "world", "test", "data", "array"].map(Some),
150 DType::Utf8(NonNullable),
151 ))]
152 #[case(VarBinViewArray::from_iter_nullable_str([
153 Some("hello"),
154 None,
155 Some("test"),
156 Some("data"),
157 None,
158 ]))]
159 #[case(VarBinViewArray::from_iter(
160 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
161 DType::Binary(NonNullable),
162 ))]
163 #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
164 fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
165 test_take_conformance(&array.into_array());
166 }
167}