vortex_array/arrays/varbinview/compute/
take.rs1use std::iter;
5use std::ops::Deref;
6
7use num_traits::AsPrimitive;
8use vortex_buffer::Buffer;
9use vortex_dtype::match_each_integer_ptype;
10use vortex_error::VortexResult;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13use vortex_vector::binaryview::BinaryView;
14
15use crate::Array;
16use crate::ArrayRef;
17use crate::IntoArray;
18use crate::ToCanonical;
19use crate::arrays::VarBinViewArray;
20use crate::arrays::VarBinViewVTable;
21use crate::compute::TakeKernel;
22use crate::compute::TakeKernelAdapter;
23use crate::register_kernel;
24use crate::vtable::ValidityHelper;
25
26impl TakeKernel for VarBinViewVTable {
28 fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
29 let validity = array.validity().take(indices)?;
31 let indices = indices.to_primitive();
32
33 let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
34 take_views(
35 array.views(),
36 indices.as_slice::<I>(),
37 &indices.validity_mask(),
38 )
39 });
40
41 unsafe {
43 Ok(VarBinViewArray::new_unchecked(
44 views_buffer,
45 array.buffers().clone(),
46 array
47 .dtype()
48 .union_nullability(indices.dtype().nullability()),
49 validity,
50 )
51 .into_array())
52 }
53 }
54}
55
56register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
57
58fn take_views<I: AsPrimitive<usize>>(
59 views: &Buffer<BinaryView>,
60 indices: &[I],
61 mask: &Mask,
62) -> Buffer<BinaryView> {
63 let views_ref = views.deref();
65 match mask.bit_buffer() {
68 AllOr::All => {
69 Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
70 }
71 AllOr::None => Buffer::<BinaryView>::from_trusted_len_iter(iter::repeat_n(
72 BinaryView::default(),
73 indices.len(),
74 )),
75 AllOr::Some(buffer) => Buffer::<BinaryView>::from_trusted_len_iter(
76 buffer.iter().zip(indices.iter()).map(|(valid, idx)| {
77 if valid {
78 views_ref[idx.as_()]
79 } else {
80 BinaryView::default()
81 }
82 }),
83 ),
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use rstest::rstest;
90 use vortex_buffer::BitBuffer;
91 use vortex_buffer::buffer;
92 use vortex_dtype::DType;
93 use vortex_dtype::Nullability::NonNullable;
94
95 use crate::IntoArray;
96 use crate::accessor::ArrayAccessor;
97 use crate::array::Array;
98 use crate::arrays::PrimitiveArray;
99 use crate::arrays::VarBinViewArray;
100 use crate::canonical::ToCanonical;
101 use crate::compute::conformance::take::test_take_conformance;
102 use crate::compute::take;
103 use crate::validity::Validity;
104
105 #[test]
106 fn take_nullable() {
107 let arr = VarBinViewArray::from_iter_nullable_str([
108 Some("one"),
109 None,
110 Some("three"),
111 Some("four"),
112 None,
113 Some("six"),
114 ]);
115
116 let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
117
118 assert!(taken.dtype().is_nullable());
119 assert_eq!(
120 taken.to_varbinview().with_iterator(|it| it
121 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
122 .collect::<Vec<_>>()),
123 [Some("one".to_string()), Some("four".to_string())]
124 );
125 }
126
127 #[test]
128 fn take_nullable_indices() {
129 let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
130
131 let indices = PrimitiveArray::new(
132 buffer![1u64, 999],
134 Validity::from(BitBuffer::from(vec![true, false])),
135 );
136
137 let taken = take(arr.as_ref(), indices.as_ref()).unwrap();
138
139 assert!(taken.dtype().is_nullable());
140 assert_eq!(
141 taken.to_varbinview().with_iterator(|it| it
142 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
143 .collect::<Vec<_>>()),
144 [Some("two".to_string()), None]
145 );
146 }
147
148 #[rstest]
149 #[case(VarBinViewArray::from_iter(
150 ["hello", "world", "test", "data", "array"].map(Some),
151 DType::Utf8(NonNullable),
152 ))]
153 #[case(VarBinViewArray::from_iter_nullable_str([
154 Some("hello"),
155 None,
156 Some("test"),
157 Some("data"),
158 None,
159 ]))]
160 #[case(VarBinViewArray::from_iter(
161 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
162 DType::Binary(NonNullable),
163 ))]
164 #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
165 fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
166 test_take_conformance(array.as_ref());
167 }
168}