vortex_sequence/compute/
take.rs1use num_traits::cast::NumCast;
5use vortex_array::arrays::{ConstantArray, PrimitiveArray};
6use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
7use vortex_array::validity::Validity;
8use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
9use vortex_buffer::Buffer;
10use vortex_dtype::{
11 DType, IntegerPType, NativePType, Nullability, match_each_integer_ptype,
12 match_each_native_ptype,
13};
14use vortex_error::{VortexExpect, VortexResult};
15use vortex_mask::{AllOr, Mask};
16use vortex_scalar::Scalar;
17
18use crate::{SequenceArray, SequenceVTable};
19
20impl TakeKernel for SequenceVTable {
21 fn take(&self, array: &SequenceArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
22 let mask = indices.validity_mask();
23 let indices = indices.to_primitive();
24 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
25
26 Ok(match_each_integer_ptype!(indices.ptype(), |T| {
27 let indices = indices.as_slice::<T>();
28 match_each_native_ptype!(array.ptype(), |S| {
29 let mul = array.multiplier().as_primitive::<S>();
30 let base = array.base().as_primitive::<S>();
31 take(mul, base, indices, mask, result_nullability)
32 })
33 }))
34 }
35}
36
37fn take<T: IntegerPType, S: NativePType>(
38 mul: S,
39 base: S,
40 indices: &[T],
41 indices_mask: Mask,
42 result_nullability: Nullability,
43) -> ArrayRef {
44 match indices_mask.boolean_buffer() {
45 AllOr::All => PrimitiveArray::new(
46 Buffer::from_trusted_len_iter(indices.iter().map(|i| {
47 let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
48 base + i * mul
49 })),
50 Validity::from(result_nullability),
51 )
52 .into_array(),
53 AllOr::None => ConstantArray::new(
54 Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
55 indices.len(),
56 )
57 .into_array(),
58 AllOr::Some(b) => {
59 let buffer =
60 Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
61 if b.value(mask_index) {
62 let i =
63 <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
64 base + i * mul
65 } else {
66 S::zero()
67 }
68 }));
69 PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
70 }
71 }
72}
73
74register_kernel!(TakeKernelAdapter(SequenceVTable).lift());
75
76#[cfg(test)]
77mod test {
78 use rstest::rstest;
79 use vortex_dtype::Nullability;
80
81 use crate::SequenceArray;
82
83 #[rstest]
84 #[case::basic_sequence(SequenceArray::typed_new(
85 0i32,
86 1i32,
87 Nullability::NonNullable,
88 10
89 ).unwrap())]
90 #[case::sequence_with_multiplier(SequenceArray::typed_new(
91 10i32,
92 5i32,
93 Nullability::Nullable,
94 20
95 ).unwrap())]
96 #[case::sequence_i64(SequenceArray::typed_new(
97 100i64,
98 10i64,
99 Nullability::NonNullable,
100 50
101 ).unwrap())]
102 #[case::sequence_u32(SequenceArray::typed_new(
103 0u32,
104 2u32,
105 Nullability::NonNullable,
106 100
107 ).unwrap())]
108 #[case::sequence_negative_step(SequenceArray::typed_new(
109 1000i32,
110 -10i32,
111 Nullability::Nullable,
112 30
113 ).unwrap())]
114 #[case::sequence_constant(SequenceArray::typed_new(
115 42i32,
116 0i32, Nullability::Nullable,
118 15
119 ).unwrap())]
120 #[case::sequence_i16(SequenceArray::typed_new(
121 -100i16,
122 3i16,
123 Nullability::NonNullable,
124 25
125 ).unwrap())]
126 #[case::sequence_large(SequenceArray::typed_new(
127 0i64,
128 1i64,
129 Nullability::Nullable,
130 1000
131 ).unwrap())]
132 fn test_take_conformance(#[case] sequence: SequenceArray) {
133 use vortex_array::compute::conformance::take::test_take_conformance;
134 test_take_conformance(sequence.as_ref());
135 }
136}