Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::ArrayRef;
8use vortex_array::DeserializeMetadata;
9use vortex_array::ExecutionCtx;
10use vortex_array::ExecutionStep;
11use vortex_array::Precision;
12use vortex_array::ProstMetadata;
13use vortex_array::SerializeMetadata;
14use vortex_array::buffer::BufferHandle;
15use vortex_array::dtype::DType;
16use vortex_array::dtype::NativePType;
17use vortex_array::dtype::Nullability;
18use vortex_array::dtype::Nullability::NonNullable;
19use vortex_array::dtype::PType;
20use vortex_array::expr::stats::Precision as StatPrecision;
21use vortex_array::expr::stats::Stat;
22use vortex_array::match_each_integer_ptype;
23use vortex_array::match_each_native_ptype;
24use vortex_array::match_each_pvalue;
25use vortex_array::scalar::PValue;
26use vortex_array::scalar::Scalar;
27use vortex_array::scalar::ScalarValue;
28use vortex_array::serde::ArrayChildren;
29use vortex_array::stats::ArrayStats;
30use vortex_array::stats::StatsSet;
31use vortex_array::stats::StatsSetRef;
32use vortex_array::validity::Validity;
33use vortex_array::vtable;
34use vortex_array::vtable::ArrayId;
35use vortex_array::vtable::OperationsVTable;
36use vortex_array::vtable::VTable;
37use vortex_array::vtable::ValidityVTable;
38use vortex_error::VortexExpect;
39use vortex_error::VortexResult;
40use vortex_error::vortex_bail;
41use vortex_error::vortex_ensure;
42use vortex_error::vortex_err;
43use vortex_error::vortex_panic;
44use vortex_session::VortexSession;
45
46use crate::compress::sequence_decompress;
47use crate::kernel::PARENT_KERNELS;
48use crate::rules::RULES;
49
50vtable!(Sequence);
51
52#[derive(Debug, Clone, Copy)]
53pub struct SequenceMetadata {
54    base: PValue,
55    multiplier: PValue,
56}
57
58#[derive(Clone, prost::Message)]
59pub struct ProstSequenceMetadata {
60    #[prost(message, tag = "1")]
61    base: Option<vortex_proto::scalar::ScalarValue>,
62    #[prost(message, tag = "2")]
63    multiplier: Option<vortex_proto::scalar::ScalarValue>,
64}
65
66/// Components of [`SequenceArray`].
67pub struct SequenceArrayParts {
68    pub base: PValue,
69    pub multiplier: PValue,
70    pub len: usize,
71    pub ptype: PType,
72    pub nullability: Nullability,
73}
74
75#[derive(Clone, Debug)]
76/// An array representing the equation `A[i] = base + i * multiplier`.
77pub struct SequenceArray {
78    base: PValue,
79    multiplier: PValue,
80    dtype: DType,
81    pub(crate) len: usize,
82    stats_set: ArrayStats,
83}
84
85impl SequenceArray {
86    pub fn try_new_typed<T: NativePType + Into<PValue>>(
87        base: T,
88        multiplier: T,
89        nullability: Nullability,
90        length: usize,
91    ) -> VortexResult<Self> {
92        Self::try_new(
93            base.into(),
94            multiplier.into(),
95            T::PTYPE,
96            nullability,
97            length,
98        )
99    }
100
101    /// Constructs a sequence array using two integer values (with the same ptype).
102    pub fn try_new(
103        base: PValue,
104        multiplier: PValue,
105        ptype: PType,
106        nullability: Nullability,
107        length: usize,
108    ) -> VortexResult<Self> {
109        if !ptype.is_int() {
110            vortex_bail!("only integer ptype are supported in SequenceArray currently")
111        }
112
113        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
114            e.with_context(format!(
115                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
116            ))
117        })?;
118
119        // SAFETY: we just validated that `ptype` is an integer and that the final
120        // element is representable via `try_last`.
121        Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) })
122    }
123
124    /// Constructs a [`SequenceArray`] without validating that the `ptype` is an integer
125    /// type or that the final element is representable.
126    ///
127    /// # Safety
128    ///
129    /// The caller must ensure that:
130    /// - `ptype` is an integer type (i.e., `ptype.is_int()` returns `true`).
131    /// - `base + (length - 1) * multiplier` does not overflow the range of `ptype`.
132    ///
133    /// Violating the first invariant will cause a panic. Violating the second will
134    /// cause silent wraparound when materializing elements, producing incorrect values.
135    pub(crate) unsafe fn new_unchecked(
136        base: PValue,
137        multiplier: PValue,
138        ptype: PType,
139        nullability: Nullability,
140        length: usize,
141    ) -> Self {
142        let dtype = DType::Primitive(ptype, nullability);
143
144        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
145        // and strictly sorted iff multiplier > 0.
146
147        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
148            multiplier,
149            uint: |v| { (true, v> 0) },
150            int: |v| { (v >= 0, v > 0) },
151            float: |_v| { unreachable!("float multiplier not supported") }
152        );
153
154        // SAFETY: we don't have duplicate stats
155        let stats_set = unsafe {
156            StatsSet::new_unchecked(vec![
157                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
158                (
159                    Stat::IsStrictSorted,
160                    StatPrecision::Exact(is_strict_sorted.into()),
161                ),
162            ])
163        };
164
165        Self {
166            base,
167            multiplier,
168            dtype,
169            len: length,
170            stats_set: ArrayStats::from(stats_set),
171        }
172    }
173
174    pub fn ptype(&self) -> PType {
175        self.dtype.as_ptype()
176    }
177
178    pub fn base(&self) -> PValue {
179        self.base
180    }
181
182    pub fn multiplier(&self) -> PValue {
183        self.multiplier
184    }
185
186    pub(crate) fn try_last(
187        base: PValue,
188        multiplier: PValue,
189        ptype: PType,
190        length: usize,
191    ) -> VortexResult<PValue> {
192        match_each_integer_ptype!(ptype, |P| {
193            let len_t = <P>::from_usize(length - 1)
194                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
195
196            let base = base.cast::<P>()?;
197            let multiplier = multiplier.cast::<P>()?;
198            let last = len_t
199                .checked_mul(multiplier)
200                .and_then(|offset| offset.checked_add(base))
201                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
202            Ok(PValue::from(last))
203        })
204    }
205
206    pub(crate) fn index_value(&self, idx: usize) -> PValue {
207        assert!(idx < self.len, "index_value({idx}): index out of bounds");
208
209        match_each_native_ptype!(self.ptype(), |P| {
210            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
211            let multiplier = self
212                .multiplier
213                .cast::<P>()
214                .vortex_expect("must be able to cast");
215            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
216
217            PValue::from(value)
218        })
219    }
220
221    /// Returns the validated final value of a sequence array
222    pub fn last(&self) -> PValue {
223        Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
224            .vortex_expect("validated array")
225    }
226
227    pub fn into_parts(self) -> SequenceArrayParts {
228        SequenceArrayParts {
229            base: self.base,
230            multiplier: self.multiplier,
231            len: self.len,
232            ptype: self.dtype.as_ptype(),
233            nullability: self.dtype.nullability(),
234        }
235    }
236}
237
238impl VTable for Sequence {
239    type Array = SequenceArray;
240
241    type Metadata = SequenceMetadata;
242    type OperationsVTable = Self;
243    type ValidityVTable = Self;
244
245    fn id(_array: &Self::Array) -> ArrayId {
246        Self::ID
247    }
248
249    fn len(array: &SequenceArray) -> usize {
250        array.len
251    }
252
253    fn dtype(array: &SequenceArray) -> &DType {
254        &array.dtype
255    }
256
257    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
258        array.stats_set.to_ref(array.as_ref())
259    }
260
261    fn array_hash<H: std::hash::Hasher>(
262        array: &SequenceArray,
263        state: &mut H,
264        _precision: Precision,
265    ) {
266        array.base.hash(state);
267        array.multiplier.hash(state);
268        array.dtype.hash(state);
269        array.len.hash(state);
270    }
271
272    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
273        array.base == other.base
274            && array.multiplier == other.multiplier
275            && array.dtype == other.dtype
276            && array.len == other.len
277    }
278
279    fn nbuffers(_array: &SequenceArray) -> usize {
280        0
281    }
282
283    fn buffer(_array: &SequenceArray, idx: usize) -> BufferHandle {
284        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
285    }
286
287    fn buffer_name(_array: &SequenceArray, idx: usize) -> Option<String> {
288        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
289    }
290
291    fn nchildren(_array: &SequenceArray) -> usize {
292        0
293    }
294
295    fn child(_array: &SequenceArray, idx: usize) -> ArrayRef {
296        vortex_panic!("SequenceArray child index {idx} out of bounds")
297    }
298
299    fn child_name(_array: &SequenceArray, idx: usize) -> String {
300        vortex_panic!("SequenceArray child_name index {idx} out of bounds")
301    }
302
303    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
304        Ok(SequenceMetadata {
305            base: array.base(),
306            multiplier: array.multiplier(),
307        })
308    }
309
310    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
311        let prost = ProstMetadata(ProstSequenceMetadata {
312            base: Some((&metadata.base).into()),
313            multiplier: Some((&metadata.multiplier).into()),
314        });
315
316        Ok(Some(prost.serialize()))
317    }
318
319    fn deserialize(
320        bytes: &[u8],
321        dtype: &DType,
322        _len: usize,
323        _buffers: &[BufferHandle],
324        session: &VortexSession,
325    ) -> VortexResult<Self::Metadata> {
326        let prost =
327            <ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;
328
329        let ptype = dtype.as_ptype();
330
331        // We go via Scalar to validate that the value is valid for the ptype.
332        let base = Scalar::from_proto_value(
333            prost
334                .base
335                .as_ref()
336                .ok_or_else(|| vortex_err!("base required"))?,
337            &DType::Primitive(ptype, NonNullable),
338            session,
339        )?
340        .as_primitive()
341        .pvalue()
342        .vortex_expect("sequence array base should be a non-nullable primitive");
343
344        let multiplier = Scalar::from_proto_value(
345            prost
346                .multiplier
347                .as_ref()
348                .ok_or_else(|| vortex_err!("multiplier required"))?,
349            &DType::Primitive(ptype, NonNullable),
350            session,
351        )?
352        .as_primitive()
353        .pvalue()
354        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
355
356        Ok(SequenceMetadata { base, multiplier })
357    }
358
359    fn build(
360        dtype: &DType,
361        len: usize,
362        metadata: &Self::Metadata,
363        _buffers: &[BufferHandle],
364        _children: &dyn ArrayChildren,
365    ) -> VortexResult<SequenceArray> {
366        SequenceArray::try_new(
367            metadata.base,
368            metadata.multiplier,
369            dtype.as_ptype(),
370            dtype.nullability(),
371            len,
372        )
373    }
374
375    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
376        vortex_ensure!(
377            children.is_empty(),
378            "SequenceArray expects 0 children, got {}",
379            children.len()
380        );
381        Ok(())
382    }
383
384    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionStep> {
385        sequence_decompress(array).map(ExecutionStep::Done)
386    }
387
388    fn execute_parent(
389        array: &Self::Array,
390        parent: &ArrayRef,
391        child_idx: usize,
392        ctx: &mut ExecutionCtx,
393    ) -> VortexResult<Option<ArrayRef>> {
394        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
395    }
396
397    fn reduce_parent(
398        array: &SequenceArray,
399        parent: &ArrayRef,
400        child_idx: usize,
401    ) -> VortexResult<Option<ArrayRef>> {
402        RULES.evaluate(array, parent, child_idx)
403    }
404}
405
406impl OperationsVTable<Sequence> for Sequence {
407    fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
408        Scalar::try_new(
409            array.dtype().clone(),
410            Some(ScalarValue::Primitive(array.index_value(index))),
411        )
412    }
413}
414
415impl ValidityVTable<Sequence> for Sequence {
416    fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
417        Ok(Validity::AllValid)
418    }
419}
420
421#[derive(Debug)]
422pub struct Sequence;
423
424impl Sequence {
425    pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
426}
427
428#[cfg(test)]
429mod tests {
430    use vortex_array::arrays::PrimitiveArray;
431    use vortex_array::assert_arrays_eq;
432    use vortex_array::dtype::Nullability;
433    use vortex_array::expr::stats::Precision as StatPrecision;
434    use vortex_array::expr::stats::Stat;
435    use vortex_array::expr::stats::StatsProviderExt;
436    use vortex_array::scalar::Scalar;
437    use vortex_array::scalar::ScalarValue;
438    use vortex_error::VortexResult;
439
440    use crate::array::SequenceArray;
441
442    #[test]
443    fn test_sequence_canonical() {
444        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
445
446        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
447
448        assert_arrays_eq!(arr, canon);
449    }
450
451    #[test]
452    fn test_sequence_slice_canonical() {
453        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
454            .unwrap()
455            .slice(2..3)
456            .unwrap();
457
458        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
459
460        assert_arrays_eq!(arr, canon);
461    }
462
463    #[test]
464    fn test_sequence_scalar_at() {
465        let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
466            .unwrap()
467            .scalar_at(2)
468            .unwrap();
469
470        assert_eq!(
471            scalar,
472            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
473        )
474    }
475
476    #[test]
477    fn test_sequence_min_max() {
478        assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
479        assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
480    }
481
482    #[test]
483    fn test_sequence_too_big() {
484        assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
485        assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
486    }
487
488    #[test]
489    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
490        let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
491
492        let is_sorted = arr
493            .statistics()
494            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
495        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
496
497        let is_strict_sorted = arr
498            .statistics()
499            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
500        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
501        Ok(())
502    }
503
504    #[test]
505    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
506        let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
507
508        let is_sorted = arr
509            .statistics()
510            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
511        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
512
513        let is_strict_sorted = arr
514            .statistics()
515            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
516        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
517        Ok(())
518    }
519
520    #[test]
521    fn negative_multiplier_not_sorted() -> VortexResult<()> {
522        let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
523
524        let is_sorted = arr
525            .statistics()
526            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
527        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
528
529        let is_strict_sorted = arr
530            .statistics()
531            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
532        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
533        Ok(())
534    }
535
536    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
537    // multiplier > i64::MAX were unable to be constructed.
538    #[test]
539    fn test_large_multiplier_sorted() -> VortexResult<()> {
540        let large_multiplier = (i64::MAX as u64) + 1;
541        let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
542
543        let is_sorted = arr
544            .statistics()
545            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
546
547        let is_strict_sorted = arr
548            .statistics()
549            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
550
551        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
552        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
553
554        Ok(())
555    }
556}