1use std::hash::Hash;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::ArrayRef;
8use vortex_array::DeserializeMetadata;
9use vortex_array::ExecutionCtx;
10use vortex_array::IntoArray;
11use vortex_array::Precision;
12use vortex_array::ProstMetadata;
13use vortex_array::SerializeMetadata;
14use vortex_array::arrays::PrimitiveArray;
15use vortex_array::buffer::BufferHandle;
16use vortex_array::dtype::DType;
17use vortex_array::dtype::NativePType;
18use vortex_array::dtype::Nullability;
19use vortex_array::dtype::Nullability::NonNullable;
20use vortex_array::dtype::PType;
21use vortex_array::expr::stats::Precision as StatPrecision;
22use vortex_array::expr::stats::Stat;
23use vortex_array::match_each_integer_ptype;
24use vortex_array::match_each_native_ptype;
25use vortex_array::match_each_pvalue;
26use vortex_array::scalar::PValue;
27use vortex_array::scalar::Scalar;
28use vortex_array::scalar::ScalarValue;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSet;
32use vortex_array::stats::StatsSetRef;
33use vortex_array::validity::Validity;
34use vortex_array::vtable;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::OperationsVTable;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_buffer::BufferMut;
40use vortex_error::VortexExpect;
41use vortex_error::VortexResult;
42use vortex_error::vortex_bail;
43use vortex_error::vortex_ensure;
44use vortex_error::vortex_err;
45use vortex_error::vortex_panic;
46use vortex_session::VortexSession;
47
48use crate::kernel::PARENT_KERNELS;
49use crate::rules::RULES;
50
51vtable!(Sequence);
52
53#[derive(Debug, Clone, Copy)]
54pub struct SequenceMetadata {
55 base: PValue,
56 multiplier: PValue,
57}
58
59#[derive(Clone, prost::Message)]
60pub struct ProstSequenceMetadata {
61 #[prost(message, tag = "1")]
62 base: Option<vortex_proto::scalar::ScalarValue>,
63 #[prost(message, tag = "2")]
64 multiplier: Option<vortex_proto::scalar::ScalarValue>,
65}
66
67pub struct SequenceArrayParts {
69 pub base: PValue,
70 pub multiplier: PValue,
71 pub len: usize,
72 pub ptype: PType,
73 pub nullability: Nullability,
74}
75
76#[derive(Clone, Debug)]
77pub struct SequenceArray {
79 base: PValue,
80 multiplier: PValue,
81 dtype: DType,
82 pub(crate) len: usize,
83 stats_set: ArrayStats,
84}
85
86impl SequenceArray {
87 pub fn try_new_typed<T: NativePType + Into<PValue>>(
88 base: T,
89 multiplier: T,
90 nullability: Nullability,
91 length: usize,
92 ) -> VortexResult<Self> {
93 Self::try_new(
94 base.into(),
95 multiplier.into(),
96 T::PTYPE,
97 nullability,
98 length,
99 )
100 }
101
102 pub fn try_new(
104 base: PValue,
105 multiplier: PValue,
106 ptype: PType,
107 nullability: Nullability,
108 length: usize,
109 ) -> VortexResult<Self> {
110 if !ptype.is_int() {
111 vortex_bail!("only integer ptype are supported in SequenceArray currently")
112 }
113
114 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
115 e.with_context(format!(
116 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
117 ))
118 })?;
119
120 Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) })
123 }
124
125 pub(crate) unsafe fn new_unchecked(
137 base: PValue,
138 multiplier: PValue,
139 ptype: PType,
140 nullability: Nullability,
141 length: usize,
142 ) -> Self {
143 let dtype = DType::Primitive(ptype, nullability);
144
145 let (is_sorted, is_strict_sorted) = match_each_pvalue!(
149 multiplier,
150 uint: |v| { (true, v> 0) },
151 int: |v| { (v >= 0, v > 0) },
152 float: |_v| { unreachable!("float multiplier not supported") }
153 );
154
155 let stats_set = unsafe {
157 StatsSet::new_unchecked(vec![
158 (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
159 (
160 Stat::IsStrictSorted,
161 StatPrecision::Exact(is_strict_sorted.into()),
162 ),
163 ])
164 };
165
166 Self {
167 base,
168 multiplier,
169 dtype,
170 len: length,
171 stats_set: ArrayStats::from(stats_set),
172 }
173 }
174
175 pub fn ptype(&self) -> PType {
176 self.dtype.as_ptype()
177 }
178
179 pub fn base(&self) -> PValue {
180 self.base
181 }
182
183 pub fn multiplier(&self) -> PValue {
184 self.multiplier
185 }
186
187 pub(crate) fn try_last(
188 base: PValue,
189 multiplier: PValue,
190 ptype: PType,
191 length: usize,
192 ) -> VortexResult<PValue> {
193 match_each_integer_ptype!(ptype, |P| {
194 let len_t = <P>::from_usize(length - 1)
195 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
196
197 let base = base.cast::<P>()?;
198 let multiplier = multiplier.cast::<P>()?;
199 let last = len_t
200 .checked_mul(multiplier)
201 .and_then(|offset| offset.checked_add(base))
202 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
203 Ok(PValue::from(last))
204 })
205 }
206
207 pub(crate) fn index_value(&self, idx: usize) -> PValue {
208 assert!(idx < self.len, "index_value({idx}): index out of bounds");
209
210 match_each_native_ptype!(self.ptype(), |P| {
211 let base = self.base.cast::<P>().vortex_expect("must be able to cast");
212 let multiplier = self
213 .multiplier
214 .cast::<P>()
215 .vortex_expect("must be able to cast");
216 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
217
218 PValue::from(value)
219 })
220 }
221
222 pub fn last(&self) -> PValue {
224 Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
225 .vortex_expect("validated array")
226 }
227
228 pub fn into_parts(self) -> SequenceArrayParts {
229 SequenceArrayParts {
230 base: self.base,
231 multiplier: self.multiplier,
232 len: self.len,
233 ptype: self.dtype.as_ptype(),
234 nullability: self.dtype.nullability(),
235 }
236 }
237}
238
239impl VTable for SequenceVTable {
240 type Array = SequenceArray;
241
242 type Metadata = SequenceMetadata;
243 type OperationsVTable = Self;
244 type ValidityVTable = Self;
245
246 fn id(_array: &Self::Array) -> ArrayId {
247 Self::ID
248 }
249
250 fn len(array: &SequenceArray) -> usize {
251 array.len
252 }
253
254 fn dtype(array: &SequenceArray) -> &DType {
255 &array.dtype
256 }
257
258 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
259 array.stats_set.to_ref(array.as_ref())
260 }
261
262 fn array_hash<H: std::hash::Hasher>(
263 array: &SequenceArray,
264 state: &mut H,
265 _precision: Precision,
266 ) {
267 array.base.hash(state);
268 array.multiplier.hash(state);
269 array.dtype.hash(state);
270 array.len.hash(state);
271 }
272
273 fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
274 array.base == other.base
275 && array.multiplier == other.multiplier
276 && array.dtype == other.dtype
277 && array.len == other.len
278 }
279
280 fn nbuffers(_array: &SequenceArray) -> usize {
281 0
282 }
283
284 fn buffer(_array: &SequenceArray, idx: usize) -> BufferHandle {
285 vortex_panic!("SequenceArray buffer index {idx} out of bounds")
286 }
287
288 fn buffer_name(_array: &SequenceArray, idx: usize) -> Option<String> {
289 vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
290 }
291
292 fn nchildren(_array: &SequenceArray) -> usize {
293 0
294 }
295
296 fn child(_array: &SequenceArray, idx: usize) -> ArrayRef {
297 vortex_panic!("SequenceArray child index {idx} out of bounds")
298 }
299
300 fn child_name(_array: &SequenceArray, idx: usize) -> String {
301 vortex_panic!("SequenceArray child_name index {idx} out of bounds")
302 }
303
304 fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
305 Ok(SequenceMetadata {
306 base: array.base(),
307 multiplier: array.multiplier(),
308 })
309 }
310
311 fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
312 let prost = ProstMetadata(ProstSequenceMetadata {
313 base: Some((&metadata.base).into()),
314 multiplier: Some((&metadata.multiplier).into()),
315 });
316
317 Ok(Some(prost.serialize()))
318 }
319
320 fn deserialize(
321 bytes: &[u8],
322 dtype: &DType,
323 _len: usize,
324 _buffers: &[BufferHandle],
325 _session: &VortexSession,
326 ) -> VortexResult<Self::Metadata> {
327 let prost =
328 <ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;
329
330 let ptype = dtype.as_ptype();
331
332 let base = Scalar::from_proto_value(
334 prost
335 .base
336 .as_ref()
337 .ok_or_else(|| vortex_err!("base required"))?,
338 &DType::Primitive(ptype, NonNullable),
339 )?
340 .as_primitive()
341 .pvalue()
342 .vortex_expect("sequence array base should be a non-nullable primitive");
343
344 let multiplier = Scalar::from_proto_value(
345 prost
346 .multiplier
347 .as_ref()
348 .ok_or_else(|| vortex_err!("multiplier required"))?,
349 &DType::Primitive(ptype, NonNullable),
350 )?
351 .as_primitive()
352 .pvalue()
353 .vortex_expect("sequence array multiplier should be a non-nullable primitive");
354
355 Ok(SequenceMetadata { base, multiplier })
356 }
357
358 fn build(
359 dtype: &DType,
360 len: usize,
361 metadata: &Self::Metadata,
362 _buffers: &[BufferHandle],
363 _children: &dyn ArrayChildren,
364 ) -> VortexResult<SequenceArray> {
365 SequenceArray::try_new(
366 metadata.base,
367 metadata.multiplier,
368 dtype.as_ptype(),
369 dtype.nullability(),
370 len,
371 )
372 }
373
374 fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
375 vortex_ensure!(
376 children.is_empty(),
377 "SequenceArray expects 0 children, got {}",
378 children.len()
379 );
380 Ok(())
381 }
382
383 fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
384 let prim = match_each_native_ptype!(array.ptype(), |P| {
385 let base = array.base().cast::<P>()?;
386 let multiplier = array.multiplier().cast::<P>()?;
387 let values = BufferMut::from_iter(
388 (0..array.len())
389 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
390 );
391 PrimitiveArray::new(values, array.dtype.nullability().into())
392 });
393
394 Ok(prim.into_array())
395 }
396
397 fn execute_parent(
398 array: &Self::Array,
399 parent: &ArrayRef,
400 child_idx: usize,
401 ctx: &mut ExecutionCtx,
402 ) -> VortexResult<Option<ArrayRef>> {
403 PARENT_KERNELS.execute(array, parent, child_idx, ctx)
404 }
405
406 fn reduce_parent(
407 array: &SequenceArray,
408 parent: &ArrayRef,
409 child_idx: usize,
410 ) -> VortexResult<Option<ArrayRef>> {
411 RULES.evaluate(array, parent, child_idx)
412 }
413}
414
415impl OperationsVTable<SequenceVTable> for SequenceVTable {
416 fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
417 Scalar::try_new(
418 array.dtype().clone(),
419 Some(ScalarValue::Primitive(array.index_value(index))),
420 )
421 }
422}
423
424impl ValidityVTable<SequenceVTable> for SequenceVTable {
425 fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
426 Ok(Validity::AllValid)
427 }
428}
429
430#[derive(Debug)]
431pub struct SequenceVTable;
432
433impl SequenceVTable {
434 pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
435}
436
437#[cfg(test)]
438mod tests {
439 use vortex_array::arrays::PrimitiveArray;
440 use vortex_array::assert_arrays_eq;
441 use vortex_array::dtype::Nullability;
442 use vortex_array::expr::stats::Precision as StatPrecision;
443 use vortex_array::expr::stats::Stat;
444 use vortex_array::expr::stats::StatsProviderExt;
445 use vortex_array::scalar::Scalar;
446 use vortex_array::scalar::ScalarValue;
447 use vortex_error::VortexResult;
448
449 use crate::array::SequenceArray;
450
451 #[test]
452 fn test_sequence_canonical() {
453 let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
454
455 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
456
457 assert_arrays_eq!(arr, canon);
458 }
459
460 #[test]
461 fn test_sequence_slice_canonical() {
462 let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
463 .unwrap()
464 .slice(2..3)
465 .unwrap();
466
467 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
468
469 assert_arrays_eq!(arr, canon);
470 }
471
472 #[test]
473 fn test_sequence_scalar_at() {
474 let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
475 .unwrap()
476 .scalar_at(2)
477 .unwrap();
478
479 assert_eq!(
480 scalar,
481 Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
482 )
483 }
484
485 #[test]
486 fn test_sequence_min_max() {
487 assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
488 assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
489 }
490
491 #[test]
492 fn test_sequence_too_big() {
493 assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
494 assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
495 }
496
497 #[test]
498 fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
499 let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
500
501 let is_sorted = arr
502 .statistics()
503 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
504 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
505
506 let is_strict_sorted = arr
507 .statistics()
508 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
509 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
510 Ok(())
511 }
512
513 #[test]
514 fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
515 let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
516
517 let is_sorted = arr
518 .statistics()
519 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
520 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
521
522 let is_strict_sorted = arr
523 .statistics()
524 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
525 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
526 Ok(())
527 }
528
529 #[test]
530 fn negative_multiplier_not_sorted() -> VortexResult<()> {
531 let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
532
533 let is_sorted = arr
534 .statistics()
535 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
536 assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
537
538 let is_strict_sorted = arr
539 .statistics()
540 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
541 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
542 Ok(())
543 }
544
545 #[test]
548 fn test_large_multiplier_sorted() -> VortexResult<()> {
549 let large_multiplier = (i64::MAX as u64) + 1;
550 let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
551
552 let is_sorted = arr
553 .statistics()
554 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
555
556 let is_strict_sorted = arr
557 .statistics()
558 .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
559
560 assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
561 assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
562
563 Ok(())
564 }
565}