1use std::hash::Hash;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::ArrayBufferVisitor;
8use vortex_array::ArrayChildVisitor;
9use vortex_array::ArrayRef;
10use vortex_array::DeserializeMetadata;
11use vortex_array::ExecutionCtx;
12use vortex_array::IntoArray;
13use vortex_array::Precision;
14use vortex_array::ProstMetadata;
15use vortex_array::SerializeMetadata;
16use vortex_array::arrays::PrimitiveArray;
17use vortex_array::buffer::BufferHandle;
18use vortex_array::expr::stats::Precision as StatPrecision;
19use vortex_array::expr::stats::Stat;
20use vortex_array::match_each_pvalue;
21use vortex_array::scalar::PValue;
22use vortex_array::scalar::Scalar;
23use vortex_array::scalar::ScalarValue;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::stats::ArrayStats;
26use vortex_array::stats::StatsSet;
27use vortex_array::stats::StatsSetRef;
28use vortex_array::validity::Validity;
29use vortex_array::vtable;
30use vortex_array::vtable::ArrayId;
31use vortex_array::vtable::BaseArrayVTable;
32use vortex_array::vtable::OperationsVTable;
33use vortex_array::vtable::VTable;
34use vortex_array::vtable::ValidityVTable;
35use vortex_array::vtable::VisitorVTable;
36use vortex_buffer::BufferMut;
37use vortex_dtype::DType;
38use vortex_dtype::NativePType;
39use vortex_dtype::Nullability;
40use vortex_dtype::Nullability::NonNullable;
41use vortex_dtype::PType;
42use vortex_dtype::match_each_integer_ptype;
43use vortex_dtype::match_each_native_ptype;
44use vortex_error::VortexExpect;
45use vortex_error::VortexResult;
46use vortex_error::vortex_bail;
47use vortex_error::vortex_ensure;
48use vortex_error::vortex_err;
49use vortex_session::VortexSession;
50
51use crate::kernel::PARENT_KERNELS;
52use crate::rules::RULES;
53
54vtable!(Sequence);
55
56#[derive(Clone, prost::Message)]
57pub struct SequenceMetadata {
58 #[prost(message, tag = "1")]
59 base: Option<vortex_proto::scalar::ScalarValue>,
60 #[prost(message, tag = "2")]
61 multiplier: Option<vortex_proto::scalar::ScalarValue>,
62}
63
64pub struct SequenceArrayParts {
66 pub base: PValue,
67 pub multiplier: PValue,
68 pub len: usize,
69 pub ptype: PType,
70 pub nullability: Nullability,
71}
72
73#[derive(Clone, Debug)]
74pub struct SequenceArray {
76 base: PValue,
77 multiplier: PValue,
78 dtype: DType,
79 pub(crate) len: usize,
80 stats_set: ArrayStats,
81}
82
83impl SequenceArray {
84 pub fn typed_new<T: NativePType + Into<PValue>>(
85 base: T,
86 multiplier: T,
87 nullability: Nullability,
88 length: usize,
89 ) -> VortexResult<Self> {
90 Self::new(
91 base.into(),
92 multiplier.into(),
93 T::PTYPE,
94 nullability,
95 length,
96 )
97 }
98
99 pub fn new(
101 base: PValue,
102 multiplier: PValue,
103 ptype: PType,
104 nullability: Nullability,
105 length: usize,
106 ) -> VortexResult<Self> {
107 if !ptype.is_int() {
108 vortex_bail!("only integer ptype are supported in SequenceArray currently")
109 }
110
111 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
112 e.with_context(format!(
113 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
114 ))
115 })?;
116
117 Ok(Self::unchecked_new(
118 base,
119 multiplier,
120 ptype,
121 nullability,
122 length,
123 ))
124 }
125
126 pub(crate) fn unchecked_new(
127 base: PValue,
128 multiplier: PValue,
129 ptype: PType,
130 nullability: Nullability,
131 length: usize,
132 ) -> Self {
133 let dtype = DType::Primitive(ptype, nullability);
134
135 let (is_sorted, is_strict_sorted) = match_each_pvalue!(
139 multiplier,
140 uint: |v| { (true, v> 0) },
141 int: |v| { (v >= 0, v > 0) },
142 float: |_v| { unreachable!("float multiplier not supported") }
143 );
144
145 let stats_set = unsafe {
147 StatsSet::new_unchecked(vec![
148 (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
149 (
150 Stat::IsStrictSorted,
151 StatPrecision::Exact(is_strict_sorted.into()),
152 ),
153 ])
154 };
155
156 Self {
157 base,
158 multiplier,
159 dtype,
160 len: length,
161 stats_set: ArrayStats::from(stats_set),
162 }
163 }
164
165 pub fn ptype(&self) -> PType {
166 self.dtype.as_ptype()
167 }
168
169 pub fn base(&self) -> PValue {
170 self.base
171 }
172
173 pub fn multiplier(&self) -> PValue {
174 self.multiplier
175 }
176
177 pub(crate) fn try_last(
178 base: PValue,
179 multiplier: PValue,
180 ptype: PType,
181 length: usize,
182 ) -> VortexResult<PValue> {
183 match_each_integer_ptype!(ptype, |P| {
184 let len_t = <P>::from_usize(length - 1)
185 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
186
187 let base = base.cast::<P>()?;
188 let multiplier = multiplier.cast::<P>()?;
189 let last = len_t
190 .checked_mul(multiplier)
191 .and_then(|offset| offset.checked_add(base))
192 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
193 Ok(PValue::from(last))
194 })
195 }
196
197 pub(crate) fn index_value(&self, idx: usize) -> PValue {
198 assert!(idx < self.len, "index_value({idx}): index out of bounds");
199
200 match_each_native_ptype!(self.ptype(), |P| {
201 let base = self.base.cast::<P>().vortex_expect("must be able to cast");
202 let multiplier = self
203 .multiplier
204 .cast::<P>()
205 .vortex_expect("must be able to cast");
206 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
207
208 PValue::from(value)
209 })
210 }
211
212 pub fn last(&self) -> PValue {
214 Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
215 .vortex_expect("validated array")
216 }
217
218 pub fn into_parts(self) -> SequenceArrayParts {
219 SequenceArrayParts {
220 base: self.base,
221 multiplier: self.multiplier,
222 len: self.len,
223 ptype: self.dtype.as_ptype(),
224 nullability: self.dtype.nullability(),
225 }
226 }
227}
228
229impl VTable for SequenceVTable {
230 type Array = SequenceArray;
231
232 type Metadata = ProstMetadata<SequenceMetadata>;
233
234 type ArrayVTable = Self;
235 type OperationsVTable = Self;
236 type ValidityVTable = Self;
237 type VisitorVTable = Self;
238
239 fn id(_array: &Self::Array) -> ArrayId {
240 Self::ID
241 }
242
243 fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
244 Ok(ProstMetadata(SequenceMetadata {
245 base: Some((&array.base()).into()),
246 multiplier: Some((&array.multiplier()).into()),
247 }))
248 }
249
250 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
251 Ok(Some(metadata.serialize()))
252 }
253
254 fn deserialize(
255 bytes: &[u8],
256 _dtype: &DType,
257 _len: usize,
258 _buffers: &[BufferHandle],
259 _session: &VortexSession,
260 ) -> VortexResult<Self::Metadata> {
261 Ok(ProstMetadata(
262 <ProstMetadata<SequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?,
263 ))
264 }
265
266 fn build(
267 dtype: &DType,
268 len: usize,
269 metadata: &Self::Metadata,
270 _buffers: &[BufferHandle],
271 _children: &dyn ArrayChildren,
272 ) -> VortexResult<SequenceArray> {
273 let ptype = dtype.as_ptype();
274
275 let base = Scalar::from_proto_value(
277 metadata
278 .0
279 .base
280 .as_ref()
281 .ok_or_else(|| vortex_err!("base required"))?,
282 &DType::Primitive(ptype, NonNullable),
283 )?
284 .as_primitive()
285 .pvalue()
286 .vortex_expect("non-nullable primitive");
287
288 let multiplier = Scalar::from_proto_value(
289 metadata
290 .0
291 .multiplier
292 .as_ref()
293 .ok_or_else(|| vortex_err!("multiplier required"))?,
294 &DType::Primitive(ptype, NonNullable),
295 )?
296 .as_primitive()
297 .pvalue()
298 .vortex_expect("non-nullable primitive");
299
300 Ok(SequenceArray::unchecked_new(
301 base,
302 multiplier,
303 ptype,
304 dtype.nullability(),
305 len,
306 ))
307 }
308
309 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
310 vortex_ensure!(
311 children.is_empty(),
312 "SequenceArray expects 0 children, got {}",
313 children.len()
314 );
315 Ok(())
316 }
317
318 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
319 let prim = match_each_native_ptype!(array.ptype(), |P| {
320 let base = array.base().cast::<P>()?;
321 let multiplier = array.multiplier().cast::<P>()?;
322 let values = BufferMut::from_iter(
323 (0..array.len())
324 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
325 );
326 PrimitiveArray::new(values, array.dtype.nullability().into())
327 });
328
329 Ok(prim.into_array())
330 }
331
332 fn execute_parent(
333 array: &Self::Array,
334 parent: &ArrayRef,
335 child_idx: usize,
336 ctx: &mut ExecutionCtx,
337 ) -> VortexResult<Option<ArrayRef>> {
338 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
339 }
340
341 fn reduce_parent(
342 array: &SequenceArray,
343 parent: &ArrayRef,
344 child_idx: usize,
345 ) -> VortexResult<Option<ArrayRef>> {
346 RULES.evaluate(array, parent, child_idx)
347 }
348}
349
350impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
351 fn len(array: &SequenceArray) -> usize {
352 array.len
353 }
354
355 fn dtype(array: &SequenceArray) -> &DType {
356 &array.dtype
357 }
358
359 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
360 array.stats_set.to_ref(array.as_ref())
361 }
362
363 fn array_hash<H: std::hash::Hasher>(
364 array: &SequenceArray,
365 state: &mut H,
366 _precision: Precision,
367 ) {
368 array.base.hash(state);
369 array.multiplier.hash(state);
370 array.dtype.hash(state);
371 array.len.hash(state);
372 }
373
374 fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
375 array.base == other.base
376 && array.multiplier == other.multiplier
377 && array.dtype == other.dtype
378 && array.len == other.len
379 }
380}
381
382impl OperationsVTable<SequenceVTable> for SequenceVTable {
383 fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
384 Scalar::try_new(
385 array.dtype().clone(),
386 Some(ScalarValue::Primitive(array.index_value(index))),
387 )
388 }
389}
390
391impl ValidityVTable<SequenceVTable> for SequenceVTable {
392 fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
393 Ok(Validity::AllValid)
394 }
395}
396
397impl VisitorVTable<SequenceVTable> for SequenceVTable {
398 fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
399 }
401
402 fn nbuffers(_array: &SequenceArray) -> usize {
403 0
404 }
405
406 fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
407
408 fn nchildren(_array: &SequenceArray) -> usize {
409 0
410 }
411}
412
413#[derive(Debug)]
414pub struct SequenceVTable;
415
416impl SequenceVTable {
417 pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
418}
419
420#[cfg(test)]
421mod tests {
422 use vortex_array::arrays::PrimitiveArray;
423 use vortex_array::assert_arrays_eq;
424 use vortex_array::expr::stats::Precision as StatPrecision;
425 use vortex_array::expr::stats::Stat;
426 use vortex_array::expr::stats::StatsProviderExt;
427 use vortex_array::scalar::Scalar;
428 use vortex_array::scalar::ScalarValue;
429 use vortex_dtype::Nullability;
430 use vortex_error::VortexResult;
431
432 use crate::array::SequenceArray;
433
434 #[test]
435 fn test_sequence_canonical() {
436 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
437
438 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
439
440 assert_arrays_eq!(arr, canon);
441 }
442
443 #[test]
444 fn test_sequence_slice_canonical() {
445 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
446 .unwrap()
447 .slice(2..3)
448 .unwrap();
449
450 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
451
452 assert_arrays_eq!(arr, canon);
453 }
454
455 #[test]
456 fn test_sequence_scalar_at() {
457 let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
458 .unwrap()
459 .scalar_at(2)
460 .unwrap();
461
462 assert_eq!(
463 scalar,
464 Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
465 )
466 }
467
468 #[test]
469 fn test_sequence_min_max() {
470 assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
471 assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
472 }
473
474 #[test]
475 fn test_sequence_too_big() {
476 assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
477 assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
478 }
479
480 #[test]
481 fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
482 let arr = SequenceArray::typed_new(0i64, 3, Nullability::NonNullable, 4)?;
483
484 let is_sorted = arr
485 .statistics()
486 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
487 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
488
489 let is_strict_sorted = arr
490 .statistics()
491 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
492 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
493 Ok(())
494 }
495
496 #[test]
497 fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
498 let arr = SequenceArray::typed_new(5i64, 0, Nullability::NonNullable, 4)?;
499
500 let is_sorted = arr
501 .statistics()
502 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
503 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
504
505 let is_strict_sorted = arr
506 .statistics()
507 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
508 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
509 Ok(())
510 }
511
512 #[test]
513 fn negative_multiplier_not_sorted() -> VortexResult<()> {
514 let arr = SequenceArray::typed_new(10i64, -1, Nullability::NonNullable, 4)?;
515
516 let is_sorted = arr
517 .statistics()
518 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
519 assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
520
521 let is_strict_sorted = arr
522 .statistics()
523 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
524 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
525 Ok(())
526 }
527
528 #[test]
531 fn test_large_multiplier_sorted() -> VortexResult<()> {
532 let large_multiplier = (i64::MAX as u64) + 1;
533 let arr = SequenceArray::typed_new(0, large_multiplier, Nullability::NonNullable, 2)?;
534
535 let is_sorted = arr
536 .statistics()
537 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
538
539 let is_strict_sorted = arr
540 .statistics()
541 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
542
543 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
544 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
545
546 Ok(())
547 }
548}