vortex_sequence/compute/
take.rs1use num_traits::cast::NumCast;
5use vortex_array::ArrayRef;
6use vortex_array::DynArray;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::dict::TakeExecute;
12use vortex_array::dtype::DType;
13use vortex_array::dtype::IntegerPType;
14use vortex_array::dtype::NativePType;
15use vortex_array::dtype::Nullability;
16use vortex_array::match_each_integer_ptype;
17use vortex_array::match_each_native_ptype;
18use vortex_array::scalar::Scalar;
19use vortex_array::validity::Validity;
20use vortex_buffer::Buffer;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_panic;
24use vortex_mask::AllOr;
25use vortex_mask::Mask;
26
27use crate::Sequence;
28use crate::SequenceArray;
29
30fn take_inner<T: IntegerPType, S: NativePType>(
31 mul: S,
32 base: S,
33 indices: &[T],
34 indices_mask: Mask,
35 result_nullability: Nullability,
36 len: usize,
37) -> ArrayRef {
38 match indices_mask.bit_buffer() {
39 AllOr::All => PrimitiveArray::new(
40 Buffer::from_trusted_len_iter(indices.iter().map(|i| {
41 if i.as_() >= len {
42 vortex_panic!(OutOfBounds: i.as_(), 0, len);
43 }
44 let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
45 base + i * mul
46 })),
47 Validity::from(result_nullability),
48 )
49 .into_array(),
50 AllOr::None => ConstantArray::new(
51 Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
52 indices.len(),
53 )
54 .into_array(),
55 AllOr::Some(b) => {
56 let buffer =
57 Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
58 if b.value(mask_index) {
59 if i.as_() >= len {
60 vortex_panic!(OutOfBounds: i.as_(), 0, len);
61 }
62
63 let i =
64 <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
65 base + i * mul
66 } else {
67 S::zero()
68 }
69 }));
70 PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
71 }
72 }
73}
74
75impl TakeExecute for Sequence {
76 fn take(
77 array: &SequenceArray,
78 indices: &ArrayRef,
79 ctx: &mut ExecutionCtx,
80 ) -> VortexResult<Option<ArrayRef>> {
81 let mask = indices.validity_mask()?;
82 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
83 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
84
85 match_each_integer_ptype!(indices.ptype(), |T| {
86 let indices = indices.as_slice::<T>();
87 match_each_native_ptype!(array.ptype(), |S| {
88 let mul = array.multiplier().cast::<S>()?;
89 let base = array.base().cast::<S>()?;
90 Ok(Some(take_inner(
91 mul,
92 base,
93 indices,
94 mask,
95 result_nullability,
96 array.len(),
97 )))
98 })
99 })
100 }
101}
102
103#[cfg(test)]
104mod test {
105 use rstest::rstest;
106 use vortex_array::Canonical;
107 use vortex_array::IntoArray;
108 use vortex_array::LEGACY_SESSION;
109 use vortex_array::VortexSessionExecute;
110 use vortex_array::arrays::PrimitiveArray;
111 use vortex_array::dtype::Nullability;
112
113 use crate::SequenceArray;
114
115 #[rstest]
116 #[case::basic_sequence(SequenceArray::try_new_typed(
117 0i32,
118 1i32,
119 Nullability::NonNullable,
120 10
121 ).unwrap())]
122 #[case::sequence_with_multiplier(SequenceArray::try_new_typed(
123 10i32,
124 5i32,
125 Nullability::Nullable,
126 20
127 ).unwrap())]
128 #[case::sequence_i64(SequenceArray::try_new_typed(
129 100i64,
130 10i64,
131 Nullability::NonNullable,
132 50
133 ).unwrap())]
134 #[case::sequence_u32(SequenceArray::try_new_typed(
135 0u32,
136 2u32,
137 Nullability::NonNullable,
138 100
139 ).unwrap())]
140 #[case::sequence_negative_step(SequenceArray::try_new_typed(
141 1000i32,
142 -10i32,
143 Nullability::Nullable,
144 30
145 ).unwrap())]
146 #[case::sequence_constant(SequenceArray::try_new_typed(
147 42i32,
148 0i32, Nullability::Nullable,
150 15
151 ).unwrap())]
152 #[case::sequence_i16(SequenceArray::try_new_typed(
153 -100i16,
154 3i16,
155 Nullability::NonNullable,
156 25
157 ).unwrap())]
158 #[case::sequence_large(SequenceArray::try_new_typed(
159 0i64,
160 1i64,
161 Nullability::Nullable,
162 1000
163 ).unwrap())]
164 fn test_take_conformance(#[case] sequence: SequenceArray) {
165 use vortex_array::compute::conformance::take::test_take_conformance;
166 test_take_conformance(&sequence.into_array());
167 }
168
169 #[test]
170 #[should_panic(expected = "out of bounds")]
171 fn test_bounds_check() {
172 let array = SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
173 let indices = PrimitiveArray::from_iter([0i32, 20]);
174 let _array = array
175 .take(indices.into_array())
176 .unwrap()
177 .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
178 .unwrap();
179 }
180}