vortex_array/arrays/varbinview/compute/
take.rs1use std::ops::Deref;
5
6use num_traits::AsPrimitive;
7use vortex_buffer::Buffer;
8use vortex_dtype::match_each_integer_ptype;
9use vortex_error::VortexResult;
10
11use crate::arrays::{BinaryView, VarBinViewArray, VarBinViewVTable};
12use crate::compute::{TakeKernel, TakeKernelAdapter};
13use crate::vtable::ValidityHelper;
14use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
15
16impl TakeKernel for VarBinViewVTable {
18 fn take(&self, array: &VarBinViewArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
19 let validity = array.validity().take(indices)?;
24 let indices = indices.to_primitive()?;
25
26 let views_buffer = match_each_integer_ptype!(indices.ptype(), |I| {
27 take_views(array.views(), indices.as_slice::<I>())
29 });
30
31 unsafe {
33 Ok(VarBinViewArray::new_unchecked(
34 views_buffer,
35 array.buffers().clone(),
36 array
37 .dtype()
38 .union_nullability(indices.dtype().nullability()),
39 validity,
40 )
41 .into_array())
42 }
43 }
44}
45
46register_kernel!(TakeKernelAdapter(VarBinViewVTable).lift());
47
48fn take_views<I: AsPrimitive<usize>>(
49 views: &Buffer<BinaryView>,
50 indices: &[I],
51) -> Buffer<BinaryView> {
52 let views_ref = views.deref();
54 Buffer::<BinaryView>::from_iter(indices.iter().map(|i| views_ref[i.as_()]))
55}
56
57#[cfg(test)]
58mod tests {
59 use rstest::rstest;
60 use vortex_buffer::buffer;
61 use vortex_dtype::DType;
62 use vortex_dtype::Nullability::NonNullable;
63
64 use crate::IntoArray;
65 use crate::accessor::ArrayAccessor;
66 use crate::array::Array;
67 use crate::arrays::{PrimitiveArray, VarBinViewArray};
68 use crate::canonical::ToCanonical;
69 use crate::compute::conformance::take::test_take_conformance;
70 use crate::compute::take;
71
72 #[test]
73 fn take_nullable() {
74 let arr = VarBinViewArray::from_iter_nullable_str([
75 Some("one"),
76 None,
77 Some("three"),
78 Some("four"),
79 None,
80 Some("six"),
81 ]);
82
83 let taken = take(arr.as_ref(), &buffer![0, 3].into_array()).unwrap();
84
85 assert!(taken.dtype().is_nullable());
86 assert_eq!(
87 taken
88 .to_varbinview()
89 .unwrap()
90 .with_iterator(|it| it
91 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
92 .collect::<Vec<_>>())
93 .unwrap(),
94 [Some("one".to_string()), Some("four".to_string())]
95 );
96 }
97
98 #[test]
99 fn take_nullable_indices() {
100 let arr = VarBinViewArray::from_iter(["one", "two"].map(Some), DType::Utf8(NonNullable));
101
102 let taken = take(
103 arr.as_ref(),
104 PrimitiveArray::from_option_iter(vec![Some(1), None]).as_ref(),
105 )
106 .unwrap();
107
108 assert!(taken.dtype().is_nullable());
109 assert_eq!(
110 taken
111 .to_varbinview()
112 .unwrap()
113 .with_iterator(|it| it
114 .map(|v| v.map(|b| unsafe { String::from_utf8_unchecked(b.to_vec()) }))
115 .collect::<Vec<_>>())
116 .unwrap(),
117 [Some("two".to_string()), None]
118 );
119 }
120
121 #[rstest]
122 #[case(VarBinViewArray::from_iter(
123 ["hello", "world", "test", "data", "array"].map(Some),
124 DType::Utf8(NonNullable),
125 ))]
126 #[case(VarBinViewArray::from_iter_nullable_str([
127 Some("hello"),
128 None,
129 Some("test"),
130 Some("data"),
131 None,
132 ]))]
133 #[case(VarBinViewArray::from_iter(
134 [b"hello".as_slice(), b"world", b"test", b"data", b"array"].map(Some),
135 DType::Binary(NonNullable),
136 ))]
137 #[case(VarBinViewArray::from_iter(["single"].map(Some), DType::Utf8(NonNullable)))]
138 fn test_take_varbinview_conformance(#[case] array: VarBinViewArray) {
139 test_take_conformance(array.as_ref());
140 }
141}