vortex_array/arrays/varbinview/compute/
take.rs1use std::iter;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_error::VortexResult;
9use vortex_mask::AllOr;
10use vortex_mask::Mask;
11
12use crate::ArrayRef;
13use crate::IntoArray;
14use crate::arrays::BinaryView;
15use crate::arrays::PrimitiveArray;
16use crate::arrays::TakeExecute;
17use crate::arrays::VarBinViewArray;
18use crate::arrays::VarBinViewVTable;
19use crate::buffer::BufferHandle;
20use crate::executor::ExecutionCtx;
21use crate::match_each_integer_ptype;
22use crate::vtable::ValidityHelper;
23
24impl TakeExecute for VarBinViewVTable {
25 fn take(
27 array: &VarBinViewArray,
28 indices: &ArrayRef,
29 ctx: &mut ExecutionCtx,
30 ) -> VortexResult<Option<ArrayRef>> {
31 let validity = array.validity().take(indices)?;
32 let indices = indices.to_array().execute::<PrimitiveArray>(ctx)?;
33
34 let indices_mask = indices.validity_mask()?;
35 let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
36 take_views(array.views(), indices.as_slice::<I>(), &indices_mask)
37 });
38
39 unsafe {
41 Ok(Some(
42 VarBinViewArray::new_handle_unchecked(
43 BufferHandle::new_host(views_buffer.into_byte_buffer()),
44 array.buffers().clone(),
45 array
46 .dtype()
47 .union_nullability(indices.dtype().nullability()),
48 validity,
49 )
50 .into_array(),
51 ))
52 }
53 }
54}
55
56fn take_views<I: AsPrimitive<usize>>(
57 views_ref: &[BinaryView],
58 indices: &[I],
59 mask: &Mask,
60) -> Buffer<BinaryView> {
61 match mask.bit_buffer() {
65 AllOr::All => {
66 Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
67 }
68 AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
69 BinaryView::default(),
70 indices.len(),
71 )),
72 AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
73 buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
74 if valid {
75 views_ref[idx.as_()]
76 } else {
77 BinaryView::default()
78 }
79 }),
80 ),
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use rstest::rstest;
87 use vortex_buffer::BitBuffer;
88 use vortex_buffer::buffer;
89
90 use crate::IntoArray;
91 use crate::accessor::ArrayAccessor;
92 use crate::array::Array;
93 use crate::arrays::PrimitiveArray;
94 use crate::arrays::VarBinViewArray;
95 use crate::canonical::ToCanonical;
96 use crate::compute::conformance::take::test_take_conformance;
97 use crate::dtype::DType;
98 use crate::dtype::Nullability::NonNullable;
99 use crate::validity::Validity;
100
101 #[test]
102 fn take_nullable() {
103 let arr = VarBinViewArray::from_iter_nullable_str([
104 Some("one"),
105 None,
106 Some("three"),
107 Some("four"),
108 None,
109 Some("six"),
110 ]);
111
112 let taken = arr.take(buffer![0, 3].into_array()).unwrap();
113
114 assert!(taken.dtype().is_nullable());
115 assert_eq!(
116 taken.to_varbinview().with_iterator(|it| it
117 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
118 .collect::<Vec<_>>()),
119 [Some("one".to_string()), Some("four".to_string())]
120 );
121 }
122
123 #[test]
124 fn take_nullable_indices() {
125 let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
126
127 let indices = PrimitiveArray::new(
128 buffer![1u64, 999],
130 Validity::from(BitBuffer::from(vec![true, false])),
131 );
132
133 let taken = arr.take(indices.to_array()).unwrap();
134
135 assert!(taken.dtype().is_nullable());
136 assert_eq!(
137 taken.to_varbinview().with_iterator(|it| it
138 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
139 .collect::<Vec<_>>()),
140 [Some("two".to_string()), None]
141 );
142 }
143
144 #[rstest]
145 #[case(VarBinViewArray::from_iter(
146 ["hello", "world", "test", "data", "array"].map(Some),
147 DType::Utf8(NonNullable),
148 ))]
149 #[case(VarBinViewArray::from_iter_nullable_str([
150 Some("hello"),
151 None,
152 Some("test"),
153 Some("data"),
154 None,
155 ]))]
156 #[case(VarBinViewArray::from_iter(
157 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
158 DType::Binary(NonNullable),
159 ))]
160 #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
161 fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
162 test_take_conformance(&array.to_array());
163 }
164}