1use std::hash::Hash;
5use std::ops::Range;
6
7use num_traits::cast::FromPrimitive;
8use vortex_array::ArrayBufferVisitor;
9use vortex_array::ArrayChildVisitor;
10use vortex_array::ArrayRef;
11use vortex_array::Canonical;
12use vortex_array::DeserializeMetadata;
13use vortex_array::ExecutionCtx;
14use vortex_array::Precision;
15use vortex_array::ProstMetadata;
16use vortex_array::SerializeMetadata;
17use vortex_array::arrays::FilterVTable;
18use vortex_array::arrays::PrimitiveArray;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::serde::ArrayChildren;
21use vortex_array::stats::ArrayStats;
22use vortex_array::stats::StatsSetRef;
23use vortex_array::validity::Validity;
24use vortex_array::vtable;
25use vortex_array::vtable::ArrayId;
26use vortex_array::vtable::ArrayVTable;
27use vortex_array::vtable::ArrayVTableExt;
28use vortex_array::vtable::BaseArrayVTable;
29use vortex_array::vtable::CanonicalVTable;
30use vortex_array::vtable::EncodeVTable;
31use vortex_array::vtable::NotSupported;
32use vortex_array::vtable::OperationsVTable;
33use vortex_array::vtable::VTable;
34use vortex_array::vtable::ValidityVTable;
35use vortex_array::vtable::VisitorVTable;
36use vortex_buffer::BufferMut;
37use vortex_dtype::DType;
38use vortex_dtype::NativePType;
39use vortex_dtype::Nullability;
40use vortex_dtype::Nullability::NonNullable;
41use vortex_dtype::PType;
42use vortex_dtype::match_each_integer_ptype;
43use vortex_dtype::match_each_native_ptype;
44use vortex_error::VortexExpect;
45use vortex_error::VortexResult;
46use vortex_error::vortex_bail;
47use vortex_error::vortex_ensure;
48use vortex_error::vortex_err;
49use vortex_mask::AllOr;
50use vortex_mask::Mask;
51use vortex_scalar::PValue;
52use vortex_scalar::Scalar;
53use vortex_scalar::ScalarValue;
54use vortex_vector::Vector;
55use vortex_vector::VectorMut;
56use vortex_vector::VectorMutOps;
57use vortex_vector::primitive::PVector;
58
59vtable!(Sequence);
60
61#[derive(Clone, prost::Message)]
62pub struct SequenceMetadata {
63 #[prost(message, tag = "1")]
64 base: Option<vortex_proto::scalar::ScalarValue>,
65 #[prost(message, tag = "2")]
66 multiplier: Option<vortex_proto::scalar::ScalarValue>,
67}
68
69#[derive(Clone, Debug)]
70pub struct SequenceArray {
72 base: PValue,
73 multiplier: PValue,
74 dtype: DType,
75 pub(crate) length: usize,
76 stats_set: ArrayStats,
77}
78
79impl SequenceArray {
80 pub fn typed_new<T: NativePType + Into<PValue>>(
81 base: T,
82 multiplier: T,
83 nullability: Nullability,
84 length: usize,
85 ) -> VortexResult<Self> {
86 Self::new(
87 base.into(),
88 multiplier.into(),
89 T::PTYPE,
90 nullability,
91 length,
92 )
93 }
94
95 pub fn new(
97 base: PValue,
98 multiplier: PValue,
99 ptype: PType,
100 nullability: Nullability,
101 length: usize,
102 ) -> VortexResult<Self> {
103 if !ptype.is_int() {
104 vortex_bail!("only integer ptype are supported in SequenceArray currently")
105 }
106
107 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
108 e.with_context(format!(
109 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
110 ))
111 })?;
112
113 Ok(Self::unchecked_new(
114 base,
115 multiplier,
116 ptype,
117 nullability,
118 length,
119 ))
120 }
121
122 pub(crate) fn unchecked_new(
123 base: PValue,
124 multiplier: PValue,
125 ptype: PType,
126 nullability: Nullability,
127 length: usize,
128 ) -> Self {
129 let dtype = DType::Primitive(ptype, nullability);
130 Self {
131 base,
132 multiplier,
133 dtype,
134 length,
135 stats_set: Default::default(),
137 }
138 }
139
140 pub fn ptype(&self) -> PType {
141 self.dtype.as_ptype()
142 }
143
144 pub fn base(&self) -> PValue {
145 self.base
146 }
147
148 pub fn multiplier(&self) -> PValue {
149 self.multiplier
150 }
151
152 pub(crate) fn try_last(
153 base: PValue,
154 multiplier: PValue,
155 ptype: PType,
156 length: usize,
157 ) -> VortexResult<PValue> {
158 match_each_integer_ptype!(ptype, |P| {
159 let len_t = <P>::from_usize(length - 1)
160 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
161
162 let base = base.cast::<P>();
163 let multiplier = multiplier.cast::<P>();
164
165 let last = len_t
166 .checked_mul(multiplier)
167 .and_then(|offset| offset.checked_add(base))
168 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
169 Ok(PValue::from(last))
170 })
171 }
172
173 pub(crate) fn index_value(&self, idx: usize) -> PValue {
174 assert!(idx < self.length, "index_value({idx}): index out of bounds");
175
176 match_each_native_ptype!(self.ptype(), |P| {
177 let base = self.base.cast::<P>();
178 let multiplier = self.multiplier.cast::<P>();
179 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
180
181 PValue::from(value)
182 })
183 }
184
185 pub fn last(&self) -> PValue {
187 Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
188 .vortex_expect("validated array")
189 }
190}
191
192impl VTable for SequenceVTable {
193 type Array = SequenceArray;
194
195 type Metadata = ProstMetadata<SequenceMetadata>;
196
197 type ArrayVTable = Self;
198 type CanonicalVTable = Self;
199 type OperationsVTable = Self;
200 type ValidityVTable = Self;
201 type VisitorVTable = Self;
202 type ComputeVTable = NotSupported;
203 type EncodeVTable = Self;
204
205 fn id(&self) -> ArrayId {
206 ArrayId::new_ref("vortex.sequence")
207 }
208
209 fn encoding(_array: &Self::Array) -> ArrayVTable {
210 SequenceVTable.as_vtable()
211 }
212
213 fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
214 Ok(ProstMetadata(SequenceMetadata {
215 base: Some((&array.base()).into()),
216 multiplier: Some((&array.multiplier()).into()),
217 }))
218 }
219
220 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
221 Ok(Some(metadata.serialize()))
222 }
223
224 fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
225 Ok(ProstMetadata(
226 <ProstMetadata<SequenceMetadata> as DeserializeMetadata>::deserialize(buffer)?,
227 ))
228 }
229
230 fn build(
231 &self,
232 dtype: &DType,
233 len: usize,
234 metadata: &Self::Metadata,
235 _buffers: &[BufferHandle],
236 _children: &dyn ArrayChildren,
237 ) -> VortexResult<SequenceArray> {
238 let ptype = dtype.as_ptype();
239
240 let base = Scalar::new(
242 DType::Primitive(ptype, NonNullable),
243 metadata
244 .0
245 .base
246 .as_ref()
247 .ok_or_else(|| vortex_err!("base required"))?
248 .try_into()?,
249 )
250 .as_primitive()
251 .pvalue()
252 .vortex_expect("non-nullable primitive");
253
254 let multiplier = Scalar::new(
255 DType::Primitive(ptype, NonNullable),
256 metadata
257 .0
258 .multiplier
259 .as_ref()
260 .ok_or_else(|| vortex_err!("base required"))?
261 .try_into()?,
262 )
263 .as_primitive()
264 .pvalue()
265 .vortex_expect("non-nullable primitive");
266
267 Ok(SequenceArray::unchecked_new(
268 base,
269 multiplier,
270 ptype,
271 dtype.nullability(),
272 len,
273 ))
274 }
275
276 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
277 vortex_ensure!(
278 children.is_empty(),
279 "SequenceArray expects 0 children, got {}",
280 children.len()
281 );
282 Ok(())
283 }
284
285 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
286 let array = array.clone();
287
288 Ok(match_each_native_ptype!(array.ptype(), |P| {
289 let base = array.base().cast::<P>();
290 let multiplier = array.multiplier().cast::<P>();
291
292 execute_iter(base, multiplier, 0..array.len(), array.len()).into()
293 }))
294 }
295
296 fn execute_parent(
297 array: &Self::Array,
298 parent: &ArrayRef,
299 _child_idx: usize,
300 _ctx: &mut ExecutionCtx,
301 ) -> VortexResult<Option<Vector>> {
302 let Some(filter) = parent.as_opt::<FilterVTable>() else {
304 return Ok(None);
305 };
306
307 match filter.filter_mask().indices() {
308 AllOr::All => Ok(None),
309 AllOr::None => Ok(Some(VectorMut::with_capacity(array.dtype(), 0).freeze())),
310 AllOr::Some(indices) => Ok(Some(match_each_native_ptype!(array.ptype(), |P| {
311 let base = array.base().cast::<P>();
312 let multiplier = array.multiplier().cast::<P>();
313 execute_iter(base, multiplier, indices.iter().copied(), indices.len()).into()
314 }))),
315 }
316 }
317}
318
319fn execute_iter<P: NativePType, I: Iterator<Item = usize>>(
320 base: P,
321 multiplier: P,
322 iter: I,
323 len: usize,
324) -> PVector<P> {
325 let values = if multiplier == <P>::one() {
326 BufferMut::from_iter(iter.map(|i| base + <P>::from_usize(i).vortex_expect("must fit")))
327 } else {
328 BufferMut::from_iter(
329 iter.map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
330 )
331 };
332
333 PVector::<P>::new(values.freeze(), Mask::new_true(len))
334}
335
336impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
337 fn len(array: &SequenceArray) -> usize {
338 array.length
339 }
340
341 fn dtype(array: &SequenceArray) -> &DType {
342 &array.dtype
343 }
344
345 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
346 array.stats_set.to_ref(array.as_ref())
347 }
348
349 fn array_hash<H: std::hash::Hasher>(
350 array: &SequenceArray,
351 state: &mut H,
352 _precision: Precision,
353 ) {
354 array.base.hash(state);
355 array.multiplier.hash(state);
356 array.dtype.hash(state);
357 array.length.hash(state);
358 }
359
360 fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
361 array.base == other.base
362 && array.multiplier == other.multiplier
363 && array.dtype == other.dtype
364 && array.length == other.length
365 }
366}
367
368impl CanonicalVTable<SequenceVTable> for SequenceVTable {
369 fn canonicalize(array: &SequenceArray) -> Canonical {
370 let prim = match_each_native_ptype!(array.ptype(), |P| {
371 let base = array.base().cast::<P>();
372 let multiplier = array.multiplier().cast::<P>();
373 let values = BufferMut::from_iter(
374 (0..array.len())
375 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
376 );
377 PrimitiveArray::new(values, array.dtype.nullability().into())
378 });
379
380 Canonical::Primitive(prim)
381 }
382}
383
384impl OperationsVTable<SequenceVTable> for SequenceVTable {
385 fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
386 SequenceArray::unchecked_new(
387 array.index_value(range.start),
388 array.multiplier,
389 array.ptype(),
390 array.dtype().nullability(),
391 range.len(),
392 )
393 .to_array()
394 }
395
396 fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
397 Scalar::new(
398 array.dtype().clone(),
399 ScalarValue::from(array.index_value(index)),
400 )
401 }
402}
403
404impl ValidityVTable<SequenceVTable> for SequenceVTable {
405 fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
406 true
407 }
408
409 fn all_valid(_array: &SequenceArray) -> bool {
410 true
411 }
412
413 fn all_invalid(_array: &SequenceArray) -> bool {
414 false
415 }
416
417 fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
418 Ok(Validity::AllValid)
419 }
420
421 fn validity_mask(array: &SequenceArray) -> Mask {
422 Mask::AllTrue(array.len())
423 }
424}
425
426impl VisitorVTable<SequenceVTable> for SequenceVTable {
427 fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
428 }
430
431 fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
432}
433
434#[derive(Debug)]
435pub struct SequenceVTable;
436
437impl EncodeVTable<SequenceVTable> for SequenceVTable {
438 fn encode(
439 _vtable: &SequenceVTable,
440 _canonical: &Canonical,
441 _like: Option<&SequenceArray>,
442 ) -> VortexResult<Option<SequenceArray>> {
443 Ok(None)
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use vortex_array::ToCanonical;
451 use vortex_array::arrays::PrimitiveArray;
452 use vortex_dtype::Nullability;
453 use vortex_scalar::Scalar;
454 use vortex_scalar::ScalarValue;
455
456 use crate::array::SequenceArray;
457
458 #[test]
459 fn test_sequence_canonical() {
460 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
461
462 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
463
464 assert_eq!(
465 arr.to_primitive().as_slice::<i64>(),
466 canon.as_slice::<i64>()
467 )
468 }
469
470 #[test]
471 fn test_sequence_slice_canonical() {
472 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
473 .unwrap()
474 .slice(2..3);
475
476 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
477
478 assert_eq!(
479 arr.to_primitive().as_slice::<i64>(),
480 canon.as_slice::<i64>()
481 )
482 }
483
484 #[test]
485 fn test_sequence_scalar_at() {
486 let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
487 .unwrap()
488 .scalar_at(2);
489
490 assert_eq!(
491 scalar,
492 Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
493 )
494 }
495
496 #[test]
497 fn test_sequence_min_max() {
498 assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
499 assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
500 }
501
502 #[test]
503 fn test_sequence_too_big() {
504 assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
505 assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
506 }
507}