1use std::hash::Hash;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::ArrayRef;
8use vortex_array::DeserializeMetadata;
9use vortex_array::ExecutionCtx;
10use vortex_array::ExecutionStep;
11use vortex_array::Precision;
12use vortex_array::ProstMetadata;
13use vortex_array::SerializeMetadata;
14use vortex_array::buffer::BufferHandle;
15use vortex_array::dtype::DType;
16use vortex_array::dtype::NativePType;
17use vortex_array::dtype::Nullability;
18use vortex_array::dtype::Nullability::NonNullable;
19use vortex_array::dtype::PType;
20use vortex_array::expr::stats::Precision as StatPrecision;
21use vortex_array::expr::stats::Stat;
22use vortex_array::match_each_integer_ptype;
23use vortex_array::match_each_native_ptype;
24use vortex_array::match_each_pvalue;
25use vortex_array::scalar::PValue;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSet;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::OperationsVTable;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityVTable;
38use vortex_error::VortexExpect;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_err;
43use vortex_error::vortex_panic;
44use vortex_session::VortexSession;
45
46use crate::compress::sequence_decompress;
47use crate::kernel::PARENT_KERNELS;
48use crate::rules::RULES;
49
50vtable!(Sequence);
51
52#[derive(Debug, Clone, Copy)]
53pub struct SequenceMetadata {
54 base: PValue,
55 multiplier: PValue,
56}
57
58#[derive(Clone, prost::Message)]
59pub struct ProstSequenceMetadata {
60 #[prost(message, tag = "1")]
61 base: Option<vortex_proto::scalar::ScalarValue>,
62 #[prost(message, tag = "2")]
63 multiplier: Option<vortex_proto::scalar::ScalarValue>,
64}
65
66pub struct SequenceArrayParts {
68 pub base: PValue,
69 pub multiplier: PValue,
70 pub len: usize,
71 pub ptype: PType,
72 pub nullability: Nullability,
73}
74
75#[derive(Clone, Debug)]
76pub struct SequenceArray {
78 base: PValue,
79 multiplier: PValue,
80 dtype: DType,
81 pub(crate) len: usize,
82 stats_set: ArrayStats,
83}
84
85impl SequenceArray {
86 pub fn try_new_typed<T: NativePType + Into<PValue>>(
87 base: T,
88 multiplier: T,
89 nullability: Nullability,
90 length: usize,
91 ) -> VortexResult<Self> {
92 Self::try_new(
93 base.into(),
94 multiplier.into(),
95 T::PTYPE,
96 nullability,
97 length,
98 )
99 }
100
101 pub fn try_new(
103 base: PValue,
104 multiplier: PValue,
105 ptype: PType,
106 nullability: Nullability,
107 length: usize,
108 ) -> VortexResult<Self> {
109 if !ptype.is_int() {
110 vortex_bail!("only integer ptype are supported in SequenceArray currently")
111 }
112
113 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
114 e.with_context(format!(
115 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
116 ))
117 })?;
118
119 Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) })
122 }
123
124 pub(crate) unsafe fn new_unchecked(
136 base: PValue,
137 multiplier: PValue,
138 ptype: PType,
139 nullability: Nullability,
140 length: usize,
141 ) -> Self {
142 let dtype = DType::Primitive(ptype, nullability);
143
144 let (is_sorted, is_strict_sorted) = match_each_pvalue!(
148 multiplier,
149 uint: |v| { (true, v> 0) },
150 int: |v| { (v >= 0, v > 0) },
151 float: |_v| { unreachable!("float multiplier not supported") }
152 );
153
154 let stats_set = unsafe {
156 StatsSet::new_unchecked(vec![
157 (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
158 (
159 Stat::IsStrictSorted,
160 StatPrecision::Exact(is_strict_sorted.into()),
161 ),
162 ])
163 };
164
165 Self {
166 base,
167 multiplier,
168 dtype,
169 len: length,
170 stats_set: ArrayStats::from(stats_set),
171 }
172 }
173
174 pub fn ptype(&self) -> PType {
175 self.dtype.as_ptype()
176 }
177
178 pub fn base(&self) -> PValue {
179 self.base
180 }
181
182 pub fn multiplier(&self) -> PValue {
183 self.multiplier
184 }
185
186 pub(crate) fn try_last(
187 base: PValue,
188 multiplier: PValue,
189 ptype: PType,
190 length: usize,
191 ) -> VortexResult<PValue> {
192 match_each_integer_ptype!(ptype, |P| {
193 let len_t = <P>::from_usize(length - 1)
194 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
195
196 let base = base.cast::<P>()?;
197 let multiplier = multiplier.cast::<P>()?;
198 let last = len_t
199 .checked_mul(multiplier)
200 .and_then(|offset| offset.checked_add(base))
201 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
202 Ok(PValue::from(last))
203 })
204 }
205
206 pub(crate) fn index_value(&self, idx: usize) -> PValue {
207 assert!(idx < self.len, "index_value({idx}): index out of bounds");
208
209 match_each_native_ptype!(self.ptype(), |P| {
210 let base = self.base.cast::<P>().vortex_expect("must be able to cast");
211 let multiplier = self
212 .multiplier
213 .cast::<P>()
214 .vortex_expect("must be able to cast");
215 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
216
217 PValue::from(value)
218 })
219 }
220
221 pub fn last(&self) -> PValue {
223 Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
224 .vortex_expect("validated array")
225 }
226
227 pub fn into_parts(self) -> SequenceArrayParts {
228 SequenceArrayParts {
229 base: self.base,
230 multiplier: self.multiplier,
231 len: self.len,
232 ptype: self.dtype.as_ptype(),
233 nullability: self.dtype.nullability(),
234 }
235 }
236}
237
238impl VTable for Sequence {
239 type Array = SequenceArray;
240
241 type Metadata = SequenceMetadata;
242 type OperationsVTable = Self;
243 type ValidityVTable = Self;
244
245 fn id(_array: &Self::Array) -> ArrayId {
246 Self::ID
247 }
248
249 fn len(array: &SequenceArray) -> usize {
250 array.len
251 }
252
253 fn dtype(array: &SequenceArray) -> &DType {
254 &array.dtype
255 }
256
257 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
258 array.stats_set.to_ref(array.as_ref())
259 }
260
261 fn array_hash<H: std::hash::Hasher>(
262 array: &SequenceArray,
263 state: &mut H,
264 _precision: Precision,
265 ) {
266 array.base.hash(state);
267 array.multiplier.hash(state);
268 array.dtype.hash(state);
269 array.len.hash(state);
270 }
271
272 fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
273 array.base == other.base
274 && array.multiplier == other.multiplier
275 && array.dtype == other.dtype
276 && array.len == other.len
277 }
278
279 fn nbuffers(_array: &SequenceArray) -> usize {
280 0
281 }
282
283 fn buffer(_array: &SequenceArray, idx: usize) -> BufferHandle {
284 vortex_panic!("SequenceArray buffer index {idx} out of bounds")
285 }
286
287 fn buffer_name(_array: &SequenceArray, idx: usize) -> Option<String> {
288 vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
289 }
290
291 fn nchildren(_array: &SequenceArray) -> usize {
292 0
293 }
294
295 fn child(_array: &SequenceArray, idx: usize) -> ArrayRef {
296 vortex_panic!("SequenceArray child index {idx} out of bounds")
297 }
298
299 fn child_name(_array: &SequenceArray, idx: usize) -> String {
300 vortex_panic!("SequenceArray child_name index {idx} out of bounds")
301 }
302
303 fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
304 Ok(SequenceMetadata {
305 base: array.base(),
306 multiplier: array.multiplier(),
307 })
308 }
309
310 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
311 let prost = ProstMetadata(ProstSequenceMetadata {
312 base: Some((&metadata.base).into()),
313 multiplier: Some((&metadata.multiplier).into()),
314 });
315
316 Ok(Some(prost.serialize()))
317 }
318
319 fn deserialize(
320 bytes: &[u8],
321 dtype: &DType,
322 _len: usize,
323 _buffers: &[BufferHandle],
324 session: &VortexSession,
325 ) -> VortexResult<Self::Metadata> {
326 let prost =
327 <ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;
328
329 let ptype = dtype.as_ptype();
330
331 let base = Scalar::from_proto_value(
333 prost
334 .base
335 .as_ref()
336 .ok_or_else(|| vortex_err!("base required"))?,
337 &DType::Primitive(ptype, NonNullable),
338 session,
339 )?
340 .as_primitive()
341 .pvalue()
342 .vortex_expect("sequence array base should be a non-nullable primitive");
343
344 let multiplier = Scalar::from_proto_value(
345 prost
346 .multiplier
347 .as_ref()
348 .ok_or_else(|| vortex_err!("multiplier required"))?,
349 &DType::Primitive(ptype, NonNullable),
350 session,
351 )?
352 .as_primitive()
353 .pvalue()
354 .vortex_expect("sequence array multiplier should be a non-nullable primitive");
355
356 Ok(SequenceMetadata { base, multiplier })
357 }
358
359 fn build(
360 dtype: &DType,
361 len: usize,
362 metadata: &Self::Metadata,
363 _buffers: &[BufferHandle],
364 _children: &dyn ArrayChildren,
365 ) -> VortexResult<SequenceArray> {
366 SequenceArray::try_new(
367 metadata.base,
368 metadata.multiplier,
369 dtype.as_ptype(),
370 dtype.nullability(),
371 len,
372 )
373 }
374
375 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
376 vortex_ensure!(
377 children.is_empty(),
378 "SequenceArray expects 0 children, got {}",
379 children.len()
380 );
381 Ok(())
382 }
383
384 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
385 sequence_decompress(array).map(ExecutionStep::Done)
386 }
387
388 fn execute_parent(
389 array: &Self::Array,
390 parent: &ArrayRef,
391 child_idx: usize,
392 ctx: &mut ExecutionCtx,
393 ) -> VortexResult<Option<ArrayRef>> {
394 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
395 }
396
397 fn reduce_parent(
398 array: &SequenceArray,
399 parent: &ArrayRef,
400 child_idx: usize,
401 ) -> VortexResult<Option<ArrayRef>> {
402 RULES.evaluate(array, parent, child_idx)
403 }
404}
405
406impl OperationsVTable<Sequence> for Sequence {
407 fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
408 Scalar::try_new(
409 array.dtype().clone(),
410 Some(ScalarValue::Primitive(array.index_value(index))),
411 )
412 }
413}
414
415impl ValidityVTable<Sequence> for Sequence {
416 fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
417 Ok(Validity::AllValid)
418 }
419}
420
421#[derive(Debug)]
422pub struct Sequence;
423
424impl Sequence {
425 pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
426}
427
428#[cfg(test)]
429mod tests {
430 use vortex_array::arrays::PrimitiveArray;
431 use vortex_array::assert_arrays_eq;
432 use vortex_array::dtype::Nullability;
433 use vortex_array::expr::stats::Precision as StatPrecision;
434 use vortex_array::expr::stats::Stat;
435 use vortex_array::expr::stats::StatsProviderExt;
436 use vortex_array::scalar::Scalar;
437 use vortex_array::scalar::ScalarValue;
438 use vortex_error::VortexResult;
439
440 use crate::array::SequenceArray;
441
442 #[test]
443 fn test_sequence_canonical() {
444 let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
445
446 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
447
448 assert_arrays_eq!(arr, canon);
449 }
450
451 #[test]
452 fn test_sequence_slice_canonical() {
453 let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
454 .unwrap()
455 .slice(2..3)
456 .unwrap();
457
458 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
459
460 assert_arrays_eq!(arr, canon);
461 }
462
463 #[test]
464 fn test_sequence_scalar_at() {
465 let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
466 .unwrap()
467 .scalar_at(2)
468 .unwrap();
469
470 assert_eq!(
471 scalar,
472 Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
473 )
474 }
475
476 #[test]
477 fn test_sequence_min_max() {
478 assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
479 assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
480 }
481
482 #[test]
483 fn test_sequence_too_big() {
484 assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
485 assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
486 }
487
488 #[test]
489 fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
490 let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
491
492 let is_sorted = arr
493 .statistics()
494 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
495 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
496
497 let is_strict_sorted = arr
498 .statistics()
499 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
500 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
501 Ok(())
502 }
503
504 #[test]
505 fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
506 let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
507
508 let is_sorted = arr
509 .statistics()
510 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
511 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
512
513 let is_strict_sorted = arr
514 .statistics()
515 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
516 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
517 Ok(())
518 }
519
520 #[test]
521 fn negative_multiplier_not_sorted() -> VortexResult<()> {
522 let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
523
524 let is_sorted = arr
525 .statistics()
526 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
527 assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
528
529 let is_strict_sorted = arr
530 .statistics()
531 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
532 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
533 Ok(())
534 }
535
536 #[test]
539 fn test_large_multiplier_sorted() -> VortexResult<()> {
540 let large_multiplier = (i64::MAX as u64) + 1;
541 let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
542
543 let is_sorted = arr
544 .statistics()
545 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
546
547 let is_strict_sorted = arr
548 .statistics()
549 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
550
551 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
552 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
553
554 Ok(())
555 }
556}