vortex_sequence/compute/
take.rs1use num_traits::cast::NumCast;
5use vortex_array::Array;
6use vortex_array::ArrayRef;
7use vortex_array::ExecutionCtx;
8use vortex_array::IntoArray;
9use vortex_array::ToCanonical;
10use vortex_array::arrays::ConstantArray;
11use vortex_array::arrays::PrimitiveArray;
12use vortex_array::arrays::TakeExecute;
13use vortex_array::dtype::DType;
14use vortex_array::dtype::IntegerPType;
15use vortex_array::dtype::NativePType;
16use vortex_array::dtype::Nullability;
17use vortex_array::match_each_integer_ptype;
18use vortex_array::match_each_native_ptype;
19use vortex_array::scalar::Scalar;
20use vortex_array::validity::Validity;
21use vortex_buffer::Buffer;
22use vortex_error::VortexExpect;
23use vortex_error::VortexResult;
24use vortex_error::vortex_panic;
25use vortex_mask::AllOr;
26use vortex_mask::Mask;
27
28use crate::SequenceArray;
29use crate::SequenceVTable;
30
31fn take_inner<T: IntegerPType, S: NativePType>(
32 mul: S,
33 base: S,
34 indices: &[T],
35 indices_mask: Mask,
36 result_nullability: Nullability,
37 len: usize,
38) -> ArrayRef {
39 match indices_mask.bit_buffer() {
40 AllOr::All => PrimitiveArray::new(
41 Buffer::from_trusted_len_iter(indices.iter().map(|i| {
42 if i.as_() >= len {
43 vortex_panic!(OutOfBounds: i.as_(), 0, len);
44 }
45 let i = <S as NumCast>::from::<T>(*i).vortex_expect("all indices fit");
46 base + i * mul
47 })),
48 Validity::from(result_nullability),
49 )
50 .into_array(),
51 AllOr::None => ConstantArray::new(
52 Scalar::null(DType::Primitive(S::PTYPE, Nullability::Nullable)),
53 indices.len(),
54 )
55 .into_array(),
56 AllOr::Some(b) => {
57 let buffer =
58 Buffer::from_trusted_len_iter(indices.iter().enumerate().map(|(mask_index, i)| {
59 if b.value(mask_index) {
60 if i.as_() >= len {
61 vortex_panic!(OutOfBounds: i.as_(), 0, len);
62 }
63
64 let i =
65 <S as NumCast>::from::<T>(*i).vortex_expect("all valid indices fit");
66 base + i * mul
67 } else {
68 S::zero()
69 }
70 }));
71 PrimitiveArray::new(buffer, Validity::from(b.clone())).into_array()
72 }
73 }
74}
75
76impl TakeExecute for SequenceVTable {
77 fn take(
78 array: &SequenceArray,
79 indices: &ArrayRef,
80 _ctx: &mut ExecutionCtx,
81 ) -> VortexResult<Option<ArrayRef>> {
82 let mask = indices.validity_mask()?;
83 let indices = indices.to_primitive();
84 let result_nullability = array.dtype().nullability() | indices.dtype().nullability();
85
86 match_each_integer_ptype!(indices.ptype(), |T| {
87 let indices = indices.as_slice::<T>();
88 match_each_native_ptype!(array.ptype(), |S| {
89 let mul = array.multiplier().cast::<S>()?;
90 let base = array.base().cast::<S>()?;
91 Ok(Some(take_inner(
92 mul,
93 base,
94 indices,
95 mask,
96 result_nullability,
97 array.len(),
98 )))
99 })
100 })
101 }
102}
103
104#[cfg(test)]
105mod test {
106 use rstest::rstest;
107 use vortex_array::Canonical;
108 use vortex_array::LEGACY_SESSION;
109 use vortex_array::VortexSessionExecute;
110 use vortex_array::dtype::Nullability;
111
112 use crate::SequenceArray;
113
114 #[rstest]
115 #[case::basic_sequence(SequenceArray::try_new_typed(
116 0i32,
117 1i32,
118 Nullability::NonNullable,
119 10
120 ).unwrap())]
121 #[case::sequence_with_multiplier(SequenceArray::try_new_typed(
122 10i32,
123 5i32,
124 Nullability::Nullable,
125 20
126 ).unwrap())]
127 #[case::sequence_i64(SequenceArray::try_new_typed(
128 100i64,
129 10i64,
130 Nullability::NonNullable,
131 50
132 ).unwrap())]
133 #[case::sequence_u32(SequenceArray::try_new_typed(
134 0u32,
135 2u32,
136 Nullability::NonNullable,
137 100
138 ).unwrap())]
139 #[case::sequence_negative_step(SequenceArray::try_new_typed(
140 1000i32,
141 -10i32,
142 Nullability::Nullable,
143 30
144 ).unwrap())]
145 #[case::sequence_constant(SequenceArray::try_new_typed(
146 42i32,
147 0i32, Nullability::Nullable,
149 15
150 ).unwrap())]
151 #[case::sequence_i16(SequenceArray::try_new_typed(
152 -100i16,
153 3i16,
154 Nullability::NonNullable,
155 25
156 ).unwrap())]
157 #[case::sequence_large(SequenceArray::try_new_typed(
158 0i64,
159 1i64,
160 Nullability::Nullable,
161 1000
162 ).unwrap())]
163 fn test_take_conformance(#[case] sequence: SequenceArray) {
164 use vortex_array::compute::conformance::take::test_take_conformance;
165 test_take_conformance(&sequence.to_array());
166 }
167
168 #[test]
169 #[should_panic(expected = "out of bounds")]
170 fn test_bounds_check() {
171 let array = SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 10).unwrap();
172 let indices = vortex_array::arrays::PrimitiveArray::from_iter([0i32, 20]);
173 let _array = array
174 .take(indices.to_array())
175 .unwrap()
176 .execute::<Canonical>(&mut LEGACY_SESSION.create_execution_ctx())
177 .unwrap();
178 }
179}