vortex_sequence/compute/
take.rs1use num_traits::cast::NumCast;
5use vortex_array::ArrayRef;
6use vortex_array::ArrayView;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::arrays::PrimitiveArray;
11use vortex_array::arrays::dict::TakeExecute;
12use vortex_array::dtype::DType;
13use vortex_array::dtype::IntegerPType;
14use vortex_array::dtype::NativePType;
15use vortex_array::dtype::Nullability;
16use vortex_array::match_each_integer_ptype;
17use vortex_array::match_each_native_ptype;
18use vortex_array::scalar::Scalar;
19use vortex_array::validity::Validity;
20use vortex_buffer::Buffer;
21use vortex_error::VortexExpect;
22use vortex_error::VortexResult;
23use vortex_error::vortex_panic;
24use vortex_mask::AllOr;
25use vortex_mask::Mask;
26
27use crate::Sequence;
28
29fn take_inner<T: IntegerPType, S: NativePType>(
30 mul: S,
31 base: S,
32 indices: &[T],
33 indices_mask: Mask,
34 result_nullability: Nullability,
35 len: usize,
36) -> ArrayRef {
37 match indices_mask.bit_buffer() {
38 AllOr::All => PrimitiveArray::new(
39 Buffer::from_trusted_len_iter(indices.iter().map(|i| {
40 if i.as_() >= len {
41 vortex_panic!(OutOfBounds: i.as_(), 0, len);
42 }
43 let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
44 base + i * mul
45 })),
46 Validity::from(result_nullability),
47 )
48 .into_array(),
49 AllOr::None => ConstantArray::new(
50 Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
51 indices.len(),
52 )
53 .into_array(),
54 AllOr::Some(b) => {
55 let buffer =
56 Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
57 if b.value(mask_index) {
58 if i.as_() >= len {
59 vortex_panic!(OutOfBounds: i.as_(), 0, len);
60 }
61
62 let i =
63 <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
64 base + i * mul
65 } else {
66 S::zero()
67 }
68 }));
69 PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
70 }
71 }
72}
73
74impl TakeExecute for Sequence {
75 fn take(
76 array: ArrayView<'_, Self>,
77 indices: &ArrayRef,
78 ctx: &mut ExecutionCtx,
79 ) -> VortexResult<Option<ArrayRef>> {
80 let mask = indices.validity_mask()?;
81 let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
82 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
83
84 match_each_integer_ptype!(indices.ptype(), |T| {
85 let indices = indices.as_slice::<T>();
86 match_each_native_ptype!(array.ptype(), |S| {
87 let mul = array.multiplier().cast::<S>()?;
88 let base = array.base().cast::<S>()?;
89 Ok(Some(take_inner(
90 mul,
91 base,
92 indices,
93 mask,
94 result_nullability,
95 array.len(),
96 )))
97 })
98 })
99 }
100}
101
102#[cfg(test)]
103mod test {
104 use rstest::rstest;
105 use vortex_array::Canonical;
106 use vortex_array::IntoArray;
107 use vortex_array::LEGACY_SESSION;
108 use vortex_array::VortexSessionExecute;
109 use vortex_array::arrays::PrimitiveArray;
110 use vortex_array::dtype::Nullability;
111
112 use crate::Sequence;
113 use crate::SequenceArray;
114
115 #[rstest]
116 #[case::basic_sequence(Sequence::try_new_typed(
117 0i32,
118 1i32,
119 Nullability::NonNullable,
120 10
121 ).unwrap())]
122 #[case::sequence_with_multiplier(Sequence::try_new_typed(
123 10i32,
124 5i32,
125 Nullability::Nullable,
126 20
127 ).unwrap())]
128 #[case::sequence_i64(Sequence::try_new_typed(
129 100i64,
130 10i64,
131 Nullability::NonNullable,
132 50
133 ).unwrap())]
134 #[case::sequence_u32(Sequence::try_new_typed(
135 0u32,
136 2u32,
137 Nullability::NonNullable,
138 100
139 ).unwrap())]
140 #[case::sequence_negative_step(Sequence::try_new_typed(
141 1000i32,
142 -10i32,
143 Nullability::Nullable,
144 30
145 ).unwrap())]
146 #[case::sequence_constant(Sequence::try_new_typed(
147 42i32,
148 0i32, Nullability::Nullable,
150 15
151 ).unwrap())]
152 #[case::sequence_i16(Sequence::try_new_typed(
153 -100i16,
154 3i16,
155 Nullability::NonNullable,
156 25
157 ).unwrap())]
158 #[case::sequence_large(Sequence::try_new_typed(
159 0i64,
160 1i64,
161 Nullability::Nullable,
162 1000
163 ).unwrap())]
164 fn test_take_conformance(#[case] sequence: SequenceArray) {
165 use vortex_array::compute::conformance::take::test_take_conformance;
166 test_take_conformance(&sequence.into_array());
167 }
168
169 #[test]
170 #[should_panic(expected = "out of bounds")]
171 fn test_bounds_check() {
172 let array = Sequence::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
173 let indices = PrimitiveArray::from_iter([0i32, 20]);
174 let _array = array
175 .take(indices.into_array())
176 .unwrap()
177 .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
178 .unwrap();
179 }
180}