1use std::hash::Hash;
5use std::ops::Range;
6
7use num_traits::One;
8use num_traits::cast::FromPrimitive;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayRef;
12use vortex_array::Canonical;
13use vortex_array::DeserializeMetadata;
14use vortex_array::Precision;
15use vortex_array::ProstMetadata;
16use vortex_array::SerializeMetadata;
17use vortex_array::arrays::PrimitiveArray;
18use vortex_array::execution::ExecutionCtx;
19use vortex_array::serde::ArrayChildren;
20use vortex_array::stats::ArrayStats;
21use vortex_array::stats::StatsSetRef;
22use vortex_array::vtable;
23use vortex_array::vtable::ArrayId;
24use vortex_array::vtable::ArrayVTable;
25use vortex_array::vtable::ArrayVTableExt;
26use vortex_array::vtable::BaseArrayVTable;
27use vortex_array::vtable::CanonicalVTable;
28use vortex_array::vtable::EncodeVTable;
29use vortex_array::vtable::NotSupported;
30use vortex_array::vtable::OperationsVTable;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityVTable;
33use vortex_array::vtable::VisitorVTable;
34use vortex_buffer::BufferHandle;
35use vortex_buffer::BufferMut;
36use vortex_dtype::DType;
37use vortex_dtype::NativePType;
38use vortex_dtype::Nullability;
39use vortex_dtype::Nullability::NonNullable;
40use vortex_dtype::PType;
41use vortex_dtype::match_each_integer_ptype;
42use vortex_dtype::match_each_native_ptype;
43use vortex_error::VortexExpect;
44use vortex_error::VortexResult;
45use vortex_error::vortex_bail;
46use vortex_error::vortex_err;
47use vortex_mask::Mask;
48use vortex_scalar::PValue;
49use vortex_scalar::Scalar;
50use vortex_scalar::ScalarValue;
51use vortex_vector::Vector;
52use vortex_vector::primitive::PVector;
53
54vtable!(Sequence);
55
56#[derive(Clone, prost::Message)]
57pub struct SequenceMetadata {
58 #[prost(message, tag = "1")]
59 base: Option<vortex_proto::scalar::ScalarValue>,
60 #[prost(message, tag = "2")]
61 multiplier: Option<vortex_proto::scalar::ScalarValue>,
62}
63
64#[derive(Clone, Debug)]
65pub struct SequenceArray {
67 base: PValue,
68 multiplier: PValue,
69 dtype: DType,
70 pub(crate) length: usize,
71 stats_set: ArrayStats,
72}
73
74impl SequenceArray {
75 pub fn typed_new<T: NativePType + Into<PValue>>(
76 base: T,
77 multiplier: T,
78 nullability: Nullability,
79 length: usize,
80 ) -> VortexResult<Self> {
81 Self::new(
82 base.into(),
83 multiplier.into(),
84 T::PTYPE,
85 nullability,
86 length,
87 )
88 }
89
90 pub fn new(
92 base: PValue,
93 multiplier: PValue,
94 ptype: PType,
95 nullability: Nullability,
96 length: usize,
97 ) -> VortexResult<Self> {
98 if !ptype.is_int() {
99 vortex_bail!("only integer ptype are supported in SequenceArray currently")
100 }
101
102 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
103 e.with_context(format!(
104 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
105 ))
106 })?;
107
108 Ok(Self::unchecked_new(
109 base,
110 multiplier,
111 ptype,
112 nullability,
113 length,
114 ))
115 }
116
117 pub(crate) fn unchecked_new(
118 base: PValue,
119 multiplier: PValue,
120 ptype: PType,
121 nullability: Nullability,
122 length: usize,
123 ) -> Self {
124 let dtype = DType::Primitive(ptype, nullability);
125 Self {
126 base,
127 multiplier,
128 dtype,
129 length,
130 stats_set: Default::default(),
132 }
133 }
134
135 pub fn ptype(&self) -> PType {
136 self.dtype.as_ptype()
137 }
138
139 pub fn base(&self) -> PValue {
140 self.base
141 }
142
143 pub fn multiplier(&self) -> PValue {
144 self.multiplier
145 }
146
147 pub(crate) fn try_last(
148 base: PValue,
149 multiplier: PValue,
150 ptype: PType,
151 length: usize,
152 ) -> VortexResult<PValue> {
153 match_each_integer_ptype!(ptype, |P| {
154 let len_t = <P>::from_usize(length - 1)
155 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
156
157 let base = base.cast::<P>();
158 let multiplier = multiplier.cast::<P>();
159
160 let last = len_t
161 .checked_mul(multiplier)
162 .and_then(|offset| offset.checked_add(base))
163 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
164 Ok(PValue::from(last))
165 })
166 }
167
168 pub(crate) fn index_value(&self, idx: usize) -> PValue {
169 assert!(idx < self.length, "index_value({idx}): index out of bounds");
170
171 match_each_native_ptype!(self.ptype(), |P| {
172 let base = self.base.cast::<P>();
173 let multiplier = self.multiplier.cast::<P>();
174 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
175
176 PValue::from(value)
177 })
178 }
179
180 pub fn last(&self) -> PValue {
182 Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
183 .vortex_expect("validated array")
184 }
185}
186
187impl VTable for SequenceVTable {
188 type Array = SequenceArray;
189
190 type Metadata = ProstMetadata<SequenceMetadata>;
191
192 type ArrayVTable = Self;
193 type CanonicalVTable = Self;
194 type OperationsVTable = Self;
195 type ValidityVTable = Self;
196 type VisitorVTable = Self;
197 type ComputeVTable = NotSupported;
198 type EncodeVTable = Self;
199
200 fn id(&self) -> ArrayId {
201 ArrayId::new_ref("vortex.sequence")
202 }
203
204 fn encoding(_array: &Self::Array) -> ArrayVTable {
205 SequenceVTable.as_vtable()
206 }
207
208 fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
209 Ok(ProstMetadata(SequenceMetadata {
210 base: Some((&array.base()).into()),
211 multiplier: Some((&array.multiplier()).into()),
212 }))
213 }
214
215 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
216 Ok(Some(metadata.serialize()))
217 }
218
219 fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
220 Ok(ProstMetadata(
221 <ProstMetadata<SequenceMetadata> as DeserializeMetadata>::deserialize(buffer)?,
222 ))
223 }
224
225 fn build(
226 &self,
227 dtype: &DType,
228 len: usize,
229 metadata: &Self::Metadata,
230 _buffers: &[BufferHandle],
231 _children: &dyn ArrayChildren,
232 ) -> VortexResult<SequenceArray> {
233 let ptype = dtype.as_ptype();
234
235 let base = Scalar::new(
237 DType::Primitive(ptype, NonNullable),
238 metadata
239 .0
240 .base
241 .as_ref()
242 .ok_or_else(|| vortex_err!("base required"))?
243 .try_into()?,
244 )
245 .as_primitive()
246 .pvalue()
247 .vortex_expect("non-nullable primitive");
248
249 let multiplier = Scalar::new(
250 DType::Primitive(ptype, NonNullable),
251 metadata
252 .0
253 .multiplier
254 .as_ref()
255 .ok_or_else(|| vortex_err!("base required"))?
256 .try_into()?,
257 )
258 .as_primitive()
259 .pvalue()
260 .vortex_expect("non-nullable primitive");
261
262 Ok(SequenceArray::unchecked_new(
263 base,
264 multiplier,
265 ptype,
266 dtype.nullability(),
267 len,
268 ))
269 }
270
271 fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
272 Ok(match_each_native_ptype!(array.ptype(), |P| {
273 let base = array.base().cast::<P>();
274 let multiplier = array.multiplier().cast::<P>();
275
276 let values = if multiplier == <P>::one() {
277 BufferMut::from_iter(
278 (0..array.len()).map(|i| base + <P>::from_usize(i).vortex_expect("must fit")),
279 )
280 } else {
281 BufferMut::from_iter(
282 (0..array.len())
283 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
284 )
285 };
286
287 PVector::<P>::new(values.freeze(), Mask::new_true(array.len())).into()
288 }))
289 }
290}
291
292impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
293 fn len(array: &SequenceArray) -> usize {
294 array.length
295 }
296
297 fn dtype(array: &SequenceArray) -> &DType {
298 &array.dtype
299 }
300
301 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
302 array.stats_set.to_ref(array.as_ref())
303 }
304
305 fn array_hash<H: std::hash::Hasher>(
306 array: &SequenceArray,
307 state: &mut H,
308 _precision: Precision,
309 ) {
310 array.base.hash(state);
311 array.multiplier.hash(state);
312 array.dtype.hash(state);
313 array.length.hash(state);
314 }
315
316 fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
317 array.base == other.base
318 && array.multiplier == other.multiplier
319 && array.dtype == other.dtype
320 && array.length == other.length
321 }
322}
323
324impl CanonicalVTable<SequenceVTable> for SequenceVTable {
325 fn canonicalize(array: &SequenceArray) -> Canonical {
326 let prim = match_each_native_ptype!(array.ptype(), |P| {
327 let base = array.base().cast::<P>();
328 let multiplier = array.multiplier().cast::<P>();
329 let values = BufferMut::from_iter(
330 (0..array.len())
331 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
332 );
333 PrimitiveArray::new(values, array.dtype.nullability().into())
334 });
335
336 Canonical::Primitive(prim)
337 }
338}
339
340impl OperationsVTable<SequenceVTable> for SequenceVTable {
341 fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
342 SequenceArray::unchecked_new(
343 array.index_value(range.start),
344 array.multiplier,
345 array.ptype(),
346 array.dtype().nullability(),
347 range.len(),
348 )
349 .to_array()
350 }
351
352 fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
353 Scalar::new(
354 array.dtype().clone(),
355 ScalarValue::from(array.index_value(index)),
356 )
357 }
358}
359
360impl ValidityVTable<SequenceVTable> for SequenceVTable {
361 fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
362 true
363 }
364
365 fn all_valid(_array: &SequenceArray) -> bool {
366 true
367 }
368
369 fn all_invalid(_array: &SequenceArray) -> bool {
370 false
371 }
372
373 fn validity_mask(array: &SequenceArray) -> Mask {
374 Mask::AllTrue(array.len())
375 }
376}
377
378impl VisitorVTable<SequenceVTable> for SequenceVTable {
379 fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
380 }
382
383 fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
384}
385
386#[derive(Debug)]
387pub struct SequenceVTable;
388
389impl EncodeVTable<SequenceVTable> for SequenceVTable {
390 fn encode(
391 _vtable: &SequenceVTable,
392 _canonical: &Canonical,
393 _like: Option<&SequenceArray>,
394 ) -> VortexResult<Option<SequenceArray>> {
395 Ok(None)
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use vortex_array::ToCanonical;
403 use vortex_array::arrays::PrimitiveArray;
404 use vortex_dtype::Nullability;
405 use vortex_scalar::Scalar;
406 use vortex_scalar::ScalarValue;
407
408 use crate::array::SequenceArray;
409
410 #[test]
411 fn test_sequence_canonical() {
412 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
413
414 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
415
416 assert_eq!(
417 arr.to_primitive().as_slice::<i64>(),
418 canon.as_slice::<i64>()
419 )
420 }
421
422 #[test]
423 fn test_sequence_slice_canonical() {
424 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
425 .unwrap()
426 .slice(2..3);
427
428 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
429
430 assert_eq!(
431 arr.to_primitive().as_slice::<i64>(),
432 canon.as_slice::<i64>()
433 )
434 }
435
436 #[test]
437 fn test_sequence_scalar_at() {
438 let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
439 .unwrap()
440 .scalar_at(2);
441
442 assert_eq!(
443 scalar,
444 Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
445 )
446 }
447
448 #[test]
449 fn test_sequence_min_max() {
450 assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
451 assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
452 }
453
454 #[test]
455 fn test_sequence_too_big() {
456 assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
457 assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
458 }
459}