vortex_array/arrays/varbinview/compute/
take.rs1use std::ops::Deref;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10
11use crate::arrays::binary_view::BinaryView;
12use crate::arrays::{VarBinViewArray, VarBinViewVTable};
13use crate::compute::{TakeKernel, TakeKernelAdapter};
14use crate::vtable::ValidityHelper;
15use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
16
17impl TakeKernel for VarBinViewVTable {
19 fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
20 let validity = array.validity().take(indices)?;
25 let indices = indices.to_primitive();
26
27 let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
28 take_views(array.views(), indices.as_slice::<I>())
30 });
31
32 unsafe {
34 Ok(VarBinViewArray::new_unchecked(
35 views_buffer,
36 array.buffers().clone(),
37 array
38 .dtype()
39 .union_nullability(indices.dtype().nullability()),
40 validity,
41 )
42 .into_array())
43 }
44 }
45}
46
47register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
48
49fn take_views<I: AsPrimitive<usize>>(
50 views: &Buffer<BinaryView>,
51 indices: &[I],
52) -> Buffer<BinaryView> {
53 let views_ref = views.deref();
55 Buffer::<BinaryView>::from_trusted_len_iter(indices.iter().map(|i| views_ref[i.as_()]))
56}
57
58#[cfg(test)]
59mod tests {
60 use rstest::rstest;
61 use vortex_buffer::buffer;
62 use vortex_dtype::DType;
63 use vortex_dtype::Nullability::NonNullable;
64
65 use crate::IntoArray;
66 use crate::accessor::ArrayAccessor;
67 use crate::array::Array;
68 use crate::arrays::{PrimitiveArray, VarBinViewArray};
69 use crate::canonical::ToCanonical;
70 use crate::compute::conformance::take::test_take_conformance;
71 use crate::compute::take;
72
73 #[test]
74 fn take_nullable() {
75 let arr = VarBinViewArray::from_iter_nullable_str([
76 Some("one"),
77 None,
78 Some("three"),
79 Some("four"),
80 None,
81 Some("six"),
82 ]);
83
84 let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
85
86 assert!(taken.dtype().is_nullable());
87 assert_eq!(
88 taken
89 .to_varbinview()
90 .with_iterator(|it| it
91 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
92 .collect::<Vec<_>>())
93 .unwrap(),
94 [Some("one".to_string()), Some("four".to_string())]
95 );
96 }
97
98 #[test]
99 fn take_nullable_indices() {
100 let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
101
102 let taken = take(
103 arr.as_ref(),
104 PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
105 )
106 .unwrap();
107
108 assert!(taken.dtype().is_nullable());
109 assert_eq!(
110 taken
111 .to_varbinview()
112 .with_iterator(|it| it
113 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
114 .collect::<Vec<_>>())
115 .unwrap(),
116 [Some("two".to_string()), None]
117 );
118 }
119
120 #[rstest]
121 #[case(VarBinViewArray::from_iter(
122 ["hello", "world", "test", "data", "array"].map(Some),
123 DType::Utf8(NonNullable),
124 ))]
125 #[case(VarBinViewArray::from_iter_nullable_str([
126 Some("hello"),
127 None,
128 Some("test"),
129 Some("data"),
130 None,
131 ]))]
132 #[case(VarBinViewArray::from_iter(
133 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
134 DType::Binary(NonNullable),
135 ))]
136 #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
137 fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
138 test_take_conformance(array.as_ref());
139 }
140}