1use std::hash::Hash;
5use std::ops::Range;
6
7use num_traits::cast::FromPrimitive;
8use vortex_array::arrays::PrimitiveArray;
9use vortex_array::stats::{ArrayStats, StatsSetRef};
10use vortex_array::vtable::{
11 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
12 VisitorVTable,
13};
14use vortex_array::{
15 ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, Precision,
16 vtable,
17};
18use vortex_buffer::BufferMut;
19use vortex_dtype::{
20 DType, NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
21};
22use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
23use vortex_mask::Mask;
24use vortex_scalar::{PValue, Scalar, ScalarValue};
25
26vtable!(Sequence);
27
28#[derive(Clone, Debug)]
29pub struct SequenceArray {
31 base: PValue,
32 multiplier: PValue,
33 dtype: DType,
34 pub(crate) length: usize,
35 stats_set: ArrayStats,
36}
37
38impl SequenceArray {
39 pub fn typed_new<T: NativePType + Into<PValue>>(
40 base: T,
41 multiplier: T,
42 nullability: Nullability,
43 length: usize,
44 ) -> VortexResult<Self> {
45 Self::new(
46 base.into(),
47 multiplier.into(),
48 T::PTYPE,
49 nullability,
50 length,
51 )
52 }
53
54 pub fn new(
56 base: PValue,
57 multiplier: PValue,
58 ptype: PType,
59 nullability: Nullability,
60 length: usize,
61 ) -> VortexResult<Self> {
62 if !ptype.is_int() {
63 vortex_bail!("only integer ptype are supported in SequenceArray currently")
64 }
65
66 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
67 e.with_context(format!(
68 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
69 ))
70 })?;
71
72 Ok(Self::unchecked_new(
73 base,
74 multiplier,
75 ptype,
76 nullability,
77 length,
78 ))
79 }
80
81 pub(crate) fn unchecked_new(
82 base: PValue,
83 multiplier: PValue,
84 ptype: PType,
85 nullability: Nullability,
86 length: usize,
87 ) -> Self {
88 let dtype = DType::Primitive(ptype, nullability);
89 Self {
90 base,
91 multiplier,
92 dtype,
93 length,
94 stats_set: Default::default(),
96 }
97 }
98
99 pub fn ptype(&self) -> PType {
100 self.dtype.as_ptype()
101 }
102
103 pub fn base(&self) -> PValue {
104 self.base
105 }
106
107 pub fn multiplier(&self) -> PValue {
108 self.multiplier
109 }
110
111 pub(crate) fn try_last(
112 base: PValue,
113 multiplier: PValue,
114 ptype: PType,
115 length: usize,
116 ) -> VortexResult<PValue> {
117 match_each_integer_ptype!(ptype, |P| {
118 let len_t = <P>::from_usize(length - 1)
119 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
120
121 let base = base.cast::<P>();
122 let multiplier = multiplier.cast::<P>();
123
124 let last = len_t
125 .checked_mul(multiplier)
126 .and_then(|offset| offset.checked_add(base))
127 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
128 Ok(PValue::from(last))
129 })
130 }
131
132 pub(crate) fn index_value(&self, idx: usize) -> PValue {
133 assert!(idx < self.length, "index_value({idx}): index out of bounds");
134
135 match_each_native_ptype!(self.ptype(), |P| {
136 let base = self.base.cast::<P>();
137 let multiplier = self.multiplier.cast::<P>();
138 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
139
140 PValue::from(value)
141 })
142 }
143
144 pub fn last(&self) -> PValue {
146 Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
147 .vortex_expect("validated array")
148 }
149}
150
151impl VTable for SequenceVTable {
152 type Array = SequenceArray;
153 type Encoding = SequenceEncoding;
154
155 type ArrayVTable = Self;
156 type CanonicalVTable = Self;
157 type OperationsVTable = Self;
158 type ValidityVTable = Self;
159 type VisitorVTable = Self;
160 type ComputeVTable = NotSupported;
161 type EncodeVTable = Self;
162 type SerdeVTable = Self;
163 type OperatorVTable = Self;
164
165 fn id(_encoding: &Self::Encoding) -> EncodingId {
166 EncodingId::new_ref("vortex.sequence")
167 }
168
169 fn encoding(_array: &Self::Array) -> EncodingRef {
170 EncodingRef::new_ref(SequenceEncoding.as_ref())
171 }
172}
173
174impl ArrayVTable<SequenceVTable> for SequenceVTable {
175 fn len(array: &SequenceArray) -> usize {
176 array.length
177 }
178
179 fn dtype(array: &SequenceArray) -> &DType {
180 &array.dtype
181 }
182
183 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
184 array.stats_set.to_ref(array.as_ref())
185 }
186
187 fn array_hash<H: std::hash::Hasher>(
188 array: &SequenceArray,
189 state: &mut H,
190 _precision: Precision,
191 ) {
192 array.base.hash(state);
193 array.multiplier.hash(state);
194 array.dtype.hash(state);
195 array.length.hash(state);
196 }
197
198 fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
199 array.base == other.base
200 && array.multiplier == other.multiplier
201 && array.dtype == other.dtype
202 && array.length == other.length
203 }
204}
205
206impl CanonicalVTable<SequenceVTable> for SequenceVTable {
207 fn canonicalize(array: &SequenceArray) -> Canonical {
208 let prim = match_each_native_ptype!(array.ptype(), |P| {
209 let base = array.base().cast::<P>();
210 let multiplier = array.multiplier().cast::<P>();
211 let values = BufferMut::from_iter(
212 (0..array.len())
213 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
214 );
215 PrimitiveArray::new(values, array.dtype.nullability().into())
216 });
217
218 Canonical::Primitive(prim)
219 }
220}
221
222impl OperationsVTable<SequenceVTable> for SequenceVTable {
223 fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
224 SequenceArray::unchecked_new(
225 array.index_value(range.start),
226 array.multiplier,
227 array.ptype(),
228 array.dtype().nullability(),
229 range.len(),
230 )
231 .to_array()
232 }
233
234 fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
235 Scalar::new(
236 array.dtype().clone(),
237 ScalarValue::from(array.index_value(index)),
238 )
239 }
240}
241
242impl ValidityVTable<SequenceVTable> for SequenceVTable {
243 fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
244 true
245 }
246
247 fn all_valid(_array: &SequenceArray) -> bool {
248 true
249 }
250
251 fn all_invalid(_array: &SequenceArray) -> bool {
252 false
253 }
254
255 fn validity_mask(array: &SequenceArray) -> Mask {
256 Mask::AllTrue(array.len())
257 }
258}
259
260impl VisitorVTable<SequenceVTable> for SequenceVTable {
261 fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
262 }
264
265 fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
266}
267
268#[derive(Clone, Debug)]
269pub struct SequenceEncoding;
270
271#[cfg(test)]
272mod tests {
273 use vortex_array::ToCanonical;
274 use vortex_array::arrays::PrimitiveArray;
275 use vortex_dtype::Nullability;
276 use vortex_scalar::{Scalar, ScalarValue};
277
278 use crate::array::SequenceArray;
279
280 #[test]
281 fn test_sequence_canonical() {
282 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
283
284 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
285
286 assert_eq!(
287 arr.to_primitive().as_slice::<i64>(),
288 canon.as_slice::<i64>()
289 )
290 }
291
292 #[test]
293 fn test_sequence_slice_canonical() {
294 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
295 .unwrap()
296 .slice(2..3);
297
298 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
299
300 assert_eq!(
301 arr.to_primitive().as_slice::<i64>(),
302 canon.as_slice::<i64>()
303 )
304 }
305
306 #[test]
307 fn test_sequence_scalar_at() {
308 let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
309 .unwrap()
310 .scalar_at(2);
311
312 assert_eq!(
313 scalar,
314 Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
315 )
316 }
317
318 #[test]
319 fn test_sequence_min_max() {
320 assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
321 assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
322 }
323
324 #[test]
325 fn test_sequence_too_big() {
326 assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
327 assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
328 }
329}