Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use num_traits::cast::FromPrimitive;
8use vortex_array::ArrayRef;
9use vortex_array::DeserializeMetadata;
10use vortex_array::ExecutionCtx;
11use vortex_array::ExecutionResult;
12use vortex_array::Precision;
13use vortex_array::ProstMetadata;
14use vortex_array::SerializeMetadata;
15use vortex_array::buffer::BufferHandle;
16use vortex_array::dtype::DType;
17use vortex_array::dtype::NativePType;
18use vortex_array::dtype::Nullability;
19use vortex_array::dtype::Nullability::NonNullable;
20use vortex_array::dtype::PType;
21use vortex_array::expr::stats::Precision as StatPrecision;
22use vortex_array::expr::stats::Stat;
23use vortex_array::match_each_integer_ptype;
24use vortex_array::match_each_native_ptype;
25use vortex_array::match_each_pvalue;
26use vortex_array::scalar::PValue;
27use vortex_array::scalar::Scalar;
28use vortex_array::scalar::ScalarValue;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSet;
32use vortex_array::stats::StatsSetRef;
33use vortex_array::validity::Validity;
34use vortex_array::vtable;
35use vortex_array::vtable::Array;
36use vortex_array::vtable::ArrayId;
37use vortex_array::vtable::OperationsVTable;
38use vortex_array::vtable::VTable;
39use vortex_array::vtable::ValidityVTable;
40use vortex_error::VortexExpect;
41use vortex_error::VortexResult;
42use vortex_error::vortex_bail;
43use vortex_error::vortex_ensure;
44use vortex_error::vortex_err;
45use vortex_error::vortex_panic;
46use vortex_session::VortexSession;
47
48use crate::compress::sequence_decompress;
49use crate::kernel::PARENT_KERNELS;
50use crate::rules::RULES;
51
52vtable!(Sequence);
53
54#[derive(Debug, Clone, Copy)]
55pub struct SequenceMetadata {
56    base: PValue,
57    multiplier: PValue,
58}
59
60#[derive(Clone, prost::Message)]
61pub struct ProstSequenceMetadata {
62    #[prost(message, tag = "1")]
63    base: Option<vortex_proto::scalar::ScalarValue>,
64    #[prost(message, tag = "2")]
65    multiplier: Option<vortex_proto::scalar::ScalarValue>,
66}
67
68/// Components of [`SequenceArray`].
69pub struct SequenceArrayParts {
70    pub base: PValue,
71    pub multiplier: PValue,
72    pub len: usize,
73    pub ptype: PType,
74    pub nullability: Nullability,
75}
76
77#[derive(Clone, Debug)]
78/// An array representing the equation `A[i] = base + i * multiplier`.
79pub struct SequenceArray {
80    base: PValue,
81    multiplier: PValue,
82    dtype: DType,
83    pub(crate) len: usize,
84    stats_set: ArrayStats,
85}
86
87impl SequenceArray {
88    pub fn try_new_typed<T: NativePType + Into<PValue>>(
89        base: T,
90        multiplier: T,
91        nullability: Nullability,
92        length: usize,
93    ) -> VortexResult<Self> {
94        Self::try_new(
95            base.into(),
96            multiplier.into(),
97            T::PTYPE,
98            nullability,
99            length,
100        )
101    }
102
103    /// Constructs a sequence array using two integer values (with the same ptype).
104    pub fn try_new(
105        base: PValue,
106        multiplier: PValue,
107        ptype: PType,
108        nullability: Nullability,
109        length: usize,
110    ) -> VortexResult<Self> {
111        if !ptype.is_int() {
112            vortex_bail!("only integer ptype are supported in SequenceArray currently")
113        }
114
115        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
116            e.with_context(format!(
117                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
118            ))
119        })?;
120
121        // SAFETY: we just validated that `ptype` is an integer and that the final
122        // element is representable via `try_last`.
123        Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) })
124    }
125
126    /// Constructs a [`SequenceArray`] without validating that the `ptype` is an integer
127    /// type or that the final element is representable.
128    ///
129    /// # Safety
130    ///
131    /// The caller must ensure that:
132    /// - `ptype` is an integer type (i.e., `ptype.is_int()` returns `true`).
133    /// - `base + (length - 1) * multiplier` does not overflow the range of `ptype`.
134    ///
135    /// Violating the first invariant will cause a panic. Violating the second will
136    /// cause silent wraparound when materializing elements, producing incorrect values.
137    pub(crate) unsafe fn new_unchecked(
138        base: PValue,
139        multiplier: PValue,
140        ptype: PType,
141        nullability: Nullability,
142        length: usize,
143    ) -> Self {
144        let dtype = DType::Primitive(ptype, nullability);
145
146        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
147        // and strictly sorted iff multiplier > 0.
148
149        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
150            multiplier,
151            uint: |v| { (true, v> 0) },
152            int: |v| { (v >= 0, v > 0) },
153            float: |_v| { unreachable!("float multiplier not supported") }
154        );
155
156        // SAFETY: we don't have duplicate stats
157        let stats_set = unsafe {
158            StatsSet::new_unchecked(vec![
159                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
160                (
161                    Stat::IsStrictSorted,
162                    StatPrecision::Exact(is_strict_sorted.into()),
163                ),
164            ])
165        };
166
167        Self {
168            base,
169            multiplier,
170            dtype,
171            len: length,
172            stats_set: ArrayStats::from(stats_set),
173        }
174    }
175
176    pub fn ptype(&self) -> PType {
177        self.dtype.as_ptype()
178    }
179
180    pub fn base(&self) -> PValue {
181        self.base
182    }
183
184    pub fn multiplier(&self) -> PValue {
185        self.multiplier
186    }
187
188    pub(crate) fn try_last(
189        base: PValue,
190        multiplier: PValue,
191        ptype: PType,
192        length: usize,
193    ) -> VortexResult<PValue> {
194        match_each_integer_ptype!(ptype, |P| {
195            let len_t = <P>::from_usize(length - 1)
196                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
197
198            let base = base.cast::<P>()?;
199            let multiplier = multiplier.cast::<P>()?;
200            let last = len_t
201                .checked_mul(multiplier)
202                .and_then(|offset| offset.checked_add(base))
203                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
204            Ok(PValue::from(last))
205        })
206    }
207
208    pub(crate) fn index_value(&self, idx: usize) -> PValue {
209        assert!(idx < self.len, "index_value({idx}): index out of bounds");
210
211        match_each_native_ptype!(self.ptype(), |P| {
212            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
213            let multiplier = self
214                .multiplier
215                .cast::<P>()
216                .vortex_expect("must be able to cast");
217            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
218
219            PValue::from(value)
220        })
221    }
222
223    /// Returns the validated final value of a sequence array
224    pub fn last(&self) -> PValue {
225        Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
226            .vortex_expect("validated array")
227    }
228
229    pub fn into_parts(self) -> SequenceArrayParts {
230        SequenceArrayParts {
231            base: self.base,
232            multiplier: self.multiplier,
233            len: self.len,
234            ptype: self.dtype.as_ptype(),
235            nullability: self.dtype.nullability(),
236        }
237    }
238}
239
240impl VTable for Sequence {
241    type Array = SequenceArray;
242
243    type Metadata = SequenceMetadata;
244    type OperationsVTable = Self;
245    type ValidityVTable = Self;
246
247    fn vtable(_array: &Self::Array) -> &Self {
248        &Sequence
249    }
250
251    fn id(&self) -> ArrayId {
252        Self::ID
253    }
254
255    fn len(array: &SequenceArray) -> usize {
256        array.len
257    }
258
259    fn dtype(array: &SequenceArray) -> &DType {
260        &array.dtype
261    }
262
263    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
264        array.stats_set.to_ref(array.as_ref())
265    }
266
267    fn array_hash<H: std::hash::Hasher>(
268        array: &SequenceArray,
269        state: &mut H,
270        _precision: Precision,
271    ) {
272        array.base.hash(state);
273        array.multiplier.hash(state);
274        array.dtype.hash(state);
275        array.len.hash(state);
276    }
277
278    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
279        array.base == other.base
280            && array.multiplier == other.multiplier
281            && array.dtype == other.dtype
282            && array.len == other.len
283    }
284
285    fn nbuffers(_array: &SequenceArray) -> usize {
286        0
287    }
288
289    fn buffer(_array: &SequenceArray, idx: usize) -> BufferHandle {
290        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
291    }
292
293    fn buffer_name(_array: &SequenceArray, idx: usize) -> Option<String> {
294        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
295    }
296
297    fn nchildren(_array: &SequenceArray) -> usize {
298        0
299    }
300
301    fn child(_array: &SequenceArray, idx: usize) -> ArrayRef {
302        vortex_panic!("SequenceArray child index {idx} out of bounds")
303    }
304
305    fn child_name(_array: &SequenceArray, idx: usize) -> String {
306        vortex_panic!("SequenceArray child_name index {idx} out of bounds")
307    }
308
309    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
310        Ok(SequenceMetadata {
311            base: array.base(),
312            multiplier: array.multiplier(),
313        })
314    }
315
316    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
317        let prost = ProstMetadata(ProstSequenceMetadata {
318            base: Some((&metadata.base).into()),
319            multiplier: Some((&metadata.multiplier).into()),
320        });
321
322        Ok(Some(prost.serialize()))
323    }
324
325    fn deserialize(
326        bytes: &[u8],
327        dtype: &DType,
328        _len: usize,
329        _buffers: &[BufferHandle],
330        session: &VortexSession,
331    ) -> VortexResult<Self::Metadata> {
332        let prost =
333            <ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;
334
335        let ptype = dtype.as_ptype();
336
337        // We go via Scalar to validate that the value is valid for the ptype.
338        let base = Scalar::from_proto_value(
339            prost
340                .base
341                .as_ref()
342                .ok_or_else(|| vortex_err!("base required"))?,
343            &DType::Primitive(ptype, NonNullable),
344            session,
345        )?
346        .as_primitive()
347        .pvalue()
348        .vortex_expect("sequence array base should be a non-nullable primitive");
349
350        let multiplier = Scalar::from_proto_value(
351            prost
352                .multiplier
353                .as_ref()
354                .ok_or_else(|| vortex_err!("multiplier required"))?,
355            &DType::Primitive(ptype, NonNullable),
356            session,
357        )?
358        .as_primitive()
359        .pvalue()
360        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
361
362        Ok(SequenceMetadata { base, multiplier })
363    }
364
365    fn build(
366        dtype: &DType,
367        len: usize,
368        metadata: &Self::Metadata,
369        _buffers: &[BufferHandle],
370        _children: &dyn ArrayChildren,
371    ) -> VortexResult<SequenceArray> {
372        SequenceArray::try_new(
373            metadata.base,
374            metadata.multiplier,
375            dtype.as_ptype(),
376            dtype.nullability(),
377            len,
378        )
379    }
380
381    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
382        vortex_ensure!(
383            children.is_empty(),
384            "SequenceArray expects 0 children, got {}",
385            children.len()
386        );
387        Ok(())
388    }
389
390    fn execute(array: Arc<Array<Self>>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
391        sequence_decompress(&array).map(ExecutionResult::done)
392    }
393
394    fn execute_parent(
395        array: &Array<Self>,
396        parent: &ArrayRef,
397        child_idx: usize,
398        ctx: &mut ExecutionCtx,
399    ) -> VortexResult<Option<ArrayRef>> {
400        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
401    }
402
403    fn reduce_parent(
404        array: &Array<Self>,
405        parent: &ArrayRef,
406        child_idx: usize,
407    ) -> VortexResult<Option<ArrayRef>> {
408        RULES.evaluate(array, parent, child_idx)
409    }
410}
411
412impl OperationsVTable<Sequence> for Sequence {
413    fn scalar_at(
414        array: &SequenceArray,
415        index: usize,
416        _ctx: &mut ExecutionCtx,
417    ) -> VortexResult<Scalar> {
418        Scalar::try_new(
419            array.dtype().clone(),
420            Some(ScalarValue::Primitive(array.index_value(index))),
421        )
422    }
423}
424
425impl ValidityVTable<Sequence> for Sequence {
426    fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
427        Ok(Validity::AllValid)
428    }
429}
430
431#[derive(Clone, Debug)]
432pub struct Sequence;
433
434impl Sequence {
435    pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
436}
437
438#[cfg(test)]
439mod tests {
440    use vortex_array::arrays::PrimitiveArray;
441    use vortex_array::assert_arrays_eq;
442    use vortex_array::dtype::Nullability;
443    use vortex_array::expr::stats::Precision as StatPrecision;
444    use vortex_array::expr::stats::Stat;
445    use vortex_array::expr::stats::StatsProviderExt;
446    use vortex_array::scalar::Scalar;
447    use vortex_array::scalar::ScalarValue;
448    use vortex_error::VortexResult;
449
450    use crate::array::SequenceArray;
451
452    #[test]
453    fn test_sequence_canonical() {
454        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
455
456        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
457
458        assert_arrays_eq!(arr, canon);
459    }
460
461    #[test]
462    fn test_sequence_slice_canonical() {
463        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
464            .unwrap()
465            .slice(2..3)
466            .unwrap();
467
468        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
469
470        assert_arrays_eq!(arr, canon);
471    }
472
473    #[test]
474    fn test_sequence_scalar_at() {
475        let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
476            .unwrap()
477            .scalar_at(2)
478            .unwrap();
479
480        assert_eq!(
481            scalar,
482            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
483        )
484    }
485
486    #[test]
487    fn test_sequence_min_max() {
488        assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
489        assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
490    }
491
492    #[test]
493    fn test_sequence_too_big() {
494        assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
495        assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
496    }
497
498    #[test]
499    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
500        let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
501
502        let is_sorted = arr
503            .statistics()
504            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
505        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
506
507        let is_strict_sorted = arr
508            .statistics()
509            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
510        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
511        Ok(())
512    }
513
514    #[test]
515    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
516        let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
517
518        let is_sorted = arr
519            .statistics()
520            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
521        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
522
523        let is_strict_sorted = arr
524            .statistics()
525            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
526        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
527        Ok(())
528    }
529
530    #[test]
531    fn negative_multiplier_not_sorted() -> VortexResult<()> {
532        let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
533
534        let is_sorted = arr
535            .statistics()
536            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
537        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
538
539        let is_strict_sorted = arr
540            .statistics()
541            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
542        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
543        Ok(())
544    }
545
546    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
547    // multiplier > i64::MAX were unable to be constructed.
548    #[test]
549    fn test_large_multiplier_sorted() -> VortexResult<()> {
550        let large_multiplier = (i64::MAX as u64) + 1;
551        let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
552
553        let is_sorted = arr
554            .statistics()
555            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
556
557        let is_strict_sorted = arr
558            .statistics()
559            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
560
561        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
562        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
563
564        Ok(())
565    }
566}