Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::sync::Arc;
6
7use num_traits::cast::FromPrimitive;
8use vortex_array::ArrayRef;
9use vortex_array::DeserializeMetadata;
10use vortex_array::ExecutionCtx;
11use vortex_array::ExecutionResult;
12use vortex_array::Precision;
13use vortex_array::ProstMetadata;
14use vortex_array::SerializeMetadata;
15use vortex_array::buffer::BufferHandle;
16use vortex_array::dtype::DType;
17use vortex_array::dtype::NativePType;
18use vortex_array::dtype::Nullability;
19use vortex_array::dtype::Nullability::NonNullable;
20use vortex_array::dtype::PType;
21use vortex_array::expr::stats::Precision as StatPrecision;
22use vortex_array::expr::stats::Stat;
23use vortex_array::match_each_integer_ptype;
24use vortex_array::match_each_native_ptype;
25use vortex_array::match_each_pvalue;
26use vortex_array::scalar::PValue;
27use vortex_array::scalar::Scalar;
28use vortex_array::scalar::ScalarValue;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSet;
32use vortex_array::stats::StatsSetRef;
33use vortex_array::validity::Validity;
34use vortex_array::vtable;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::OperationsVTable;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_error::VortexExpect;
40use vortex_error::VortexResult;
41use vortex_error::vortex_bail;
42use vortex_error::vortex_ensure;
43use vortex_error::vortex_err;
44use vortex_error::vortex_panic;
45use vortex_session::VortexSession;
46
47use crate::compress::sequence_decompress;
48use crate::kernel::PARENT_KERNELS;
49use crate::rules::RULES;
50
51vtable!(Sequence);
52
53#[derive(Debug, Clone, Copy)]
54pub struct SequenceMetadata {
55    base: PValue,
56    multiplier: PValue,
57}
58
59#[derive(Clone, prost::Message)]
60pub struct ProstSequenceMetadata {
61    #[prost(message, tag = "1")]
62    base: Option<vortex_proto::scalar::ScalarValue>,
63    #[prost(message, tag = "2")]
64    multiplier: Option<vortex_proto::scalar::ScalarValue>,
65}
66
67/// Components of [`SequenceArray`].
68pub struct SequenceArrayParts {
69    pub base: PValue,
70    pub multiplier: PValue,
71    pub len: usize,
72    pub ptype: PType,
73    pub nullability: Nullability,
74}
75
76#[derive(Clone, Debug)]
77/// An array representing the equation `A[i] = base + i * multiplier`.
78pub struct SequenceArray {
79    base: PValue,
80    multiplier: PValue,
81    dtype: DType,
82    pub(crate) len: usize,
83    stats_set: ArrayStats,
84}
85
86impl SequenceArray {
87    pub fn try_new_typed<T: NativePType + Into<PValue>>(
88        base: T,
89        multiplier: T,
90        nullability: Nullability,
91        length: usize,
92    ) -> VortexResult<Self> {
93        Self::try_new(
94            base.into(),
95            multiplier.into(),
96            T::PTYPE,
97            nullability,
98            length,
99        )
100    }
101
102    /// Constructs a sequence array using two integer values (with the same ptype).
103    pub fn try_new(
104        base: PValue,
105        multiplier: PValue,
106        ptype: PType,
107        nullability: Nullability,
108        length: usize,
109    ) -> VortexResult<Self> {
110        if !ptype.is_int() {
111            vortex_bail!("only integer ptype are supported in SequenceArray currently")
112        }
113
114        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
115            e.with_context(format!(
116                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
117            ))
118        })?;
119
120        // SAFETY: we just validated that `ptype` is an integer and that the final
121        // element is representable via `try_last`.
122        Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) })
123    }
124
125    /// Constructs a [`SequenceArray`] without validating that the `ptype` is an integer
126    /// type or that the final element is representable.
127    ///
128    /// # Safety
129    ///
130    /// The caller must ensure that:
131    /// - `ptype` is an integer type (i.e., `ptype.is_int()` returns `true`).
132    /// - `base + (length - 1) * multiplier` does not overflow the range of `ptype`.
133    ///
134    /// Violating the first invariant will cause a panic. Violating the second will
135    /// cause silent wraparound when materializing elements, producing incorrect values.
136    pub(crate) unsafe fn new_unchecked(
137        base: PValue,
138        multiplier: PValue,
139        ptype: PType,
140        nullability: Nullability,
141        length: usize,
142    ) -> Self {
143        let dtype = DType::Primitive(ptype, nullability);
144
145        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
146        // and strictly sorted iff multiplier > 0.
147
148        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
149            multiplier,
150            uint: |v| { (true, v> 0) },
151            int: |v| { (v >= 0, v > 0) },
152            float: |_v| { unreachable!("float multiplier not supported") }
153        );
154
155        // SAFETY: we don't have duplicate stats
156        let stats_set = unsafe {
157            StatsSet::new_unchecked(vec![
158                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
159                (
160                    Stat::IsStrictSorted,
161                    StatPrecision::Exact(is_strict_sorted.into()),
162                ),
163            ])
164        };
165
166        Self {
167            base,
168            multiplier,
169            dtype,
170            len: length,
171            stats_set: ArrayStats::from(stats_set),
172        }
173    }
174
175    pub fn ptype(&self) -> PType {
176        self.dtype.as_ptype()
177    }
178
179    pub fn base(&self) -> PValue {
180        self.base
181    }
182
183    pub fn multiplier(&self) -> PValue {
184        self.multiplier
185    }
186
187    pub(crate) fn try_last(
188        base: PValue,
189        multiplier: PValue,
190        ptype: PType,
191        length: usize,
192    ) -> VortexResult<PValue> {
193        match_each_integer_ptype!(ptype, |P| {
194            let len_t = <P>::from_usize(length - 1)
195                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
196
197            let base = base.cast::<P>()?;
198            let multiplier = multiplier.cast::<P>()?;
199            let last = len_t
200                .checked_mul(multiplier)
201                .and_then(|offset| offset.checked_add(base))
202                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
203            Ok(PValue::from(last))
204        })
205    }
206
207    pub(crate) fn index_value(&self, idx: usize) -> PValue {
208        assert!(idx < self.len, "index_value({idx}): index out of bounds");
209
210        match_each_native_ptype!(self.ptype(), |P| {
211            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
212            let multiplier = self
213                .multiplier
214                .cast::<P>()
215                .vortex_expect("must be able to cast");
216            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
217
218            PValue::from(value)
219        })
220    }
221
222    /// Returns the validated final value of a sequence array
223    pub fn last(&self) -> PValue {
224        Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
225            .vortex_expect("validated array")
226    }
227
228    pub fn into_parts(self) -> SequenceArrayParts {
229        SequenceArrayParts {
230            base: self.base,
231            multiplier: self.multiplier,
232            len: self.len,
233            ptype: self.dtype.as_ptype(),
234            nullability: self.dtype.nullability(),
235        }
236    }
237}
238
239impl VTable for Sequence {
240    type Array = SequenceArray;
241
242    type Metadata = SequenceMetadata;
243    type OperationsVTable = Self;
244    type ValidityVTable = Self;
245
246    fn vtable(_array: &Self::Array) -> &Self {
247        &Sequence
248    }
249
250    fn id(&self) -> ArrayId {
251        Self::ID
252    }
253
254    fn len(array: &SequenceArray) -> usize {
255        array.len
256    }
257
258    fn dtype(array: &SequenceArray) -> &DType {
259        &array.dtype
260    }
261
262    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
263        array.stats_set.to_ref(array.as_ref())
264    }
265
266    fn array_hash<H: std::hash::Hasher>(
267        array: &SequenceArray,
268        state: &mut H,
269        _precision: Precision,
270    ) {
271        array.base.hash(state);
272        array.multiplier.hash(state);
273        array.dtype.hash(state);
274        array.len.hash(state);
275    }
276
277    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
278        array.base == other.base
279            && array.multiplier == other.multiplier
280            && array.dtype == other.dtype
281            && array.len == other.len
282    }
283
284    fn nbuffers(_array: &SequenceArray) -> usize {
285        0
286    }
287
288    fn buffer(_array: &SequenceArray, idx: usize) -> BufferHandle {
289        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
290    }
291
292    fn buffer_name(_array: &SequenceArray, idx: usize) -> Option<String> {
293        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
294    }
295
296    fn nchildren(_array: &SequenceArray) -> usize {
297        0
298    }
299
300    fn child(_array: &SequenceArray, idx: usize) -> ArrayRef {
301        vortex_panic!("SequenceArray child index {idx} out of bounds")
302    }
303
304    fn child_name(_array: &SequenceArray, idx: usize) -> String {
305        vortex_panic!("SequenceArray child_name index {idx} out of bounds")
306    }
307
308    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
309        Ok(SequenceMetadata {
310            base: array.base(),
311            multiplier: array.multiplier(),
312        })
313    }
314
315    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
316        let prost = ProstMetadata(ProstSequenceMetadata {
317            base: Some((&metadata.base).into()),
318            multiplier: Some((&metadata.multiplier).into()),
319        });
320
321        Ok(Some(prost.serialize()))
322    }
323
324    fn deserialize(
325        bytes: &[u8],
326        dtype: &DType,
327        _len: usize,
328        _buffers: &[BufferHandle],
329        session: &VortexSession,
330    ) -> VortexResult<Self::Metadata> {
331        let prost =
332            <ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;
333
334        let ptype = dtype.as_ptype();
335
336        // We go via Scalar to validate that the value is valid for the ptype.
337        let base = Scalar::from_proto_value(
338            prost
339                .base
340                .as_ref()
341                .ok_or_else(|| vortex_err!("base required"))?,
342            &DType::Primitive(ptype, NonNullable),
343            session,
344        )?
345        .as_primitive()
346        .pvalue()
347        .vortex_expect("sequence array base should be a non-nullable primitive");
348
349        let multiplier = Scalar::from_proto_value(
350            prost
351                .multiplier
352                .as_ref()
353                .ok_or_else(|| vortex_err!("multiplier required"))?,
354            &DType::Primitive(ptype, NonNullable),
355            session,
356        )?
357        .as_primitive()
358        .pvalue()
359        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
360
361        Ok(SequenceMetadata { base, multiplier })
362    }
363
364    fn build(
365        dtype: &DType,
366        len: usize,
367        metadata: &Self::Metadata,
368        _buffers: &[BufferHandle],
369        _children: &dyn ArrayChildren,
370    ) -> VortexResult<SequenceArray> {
371        SequenceArray::try_new(
372            metadata.base,
373            metadata.multiplier,
374            dtype.as_ptype(),
375            dtype.nullability(),
376            len,
377        )
378    }
379
380    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
381        vortex_ensure!(
382            children.is_empty(),
383            "SequenceArray expects 0 children, got {}",
384            children.len()
385        );
386        Ok(())
387    }
388
389    fn execute(array: Arc<Self::Array>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
390        sequence_decompress(&array).map(ExecutionResult::done)
391    }
392
393    fn execute_parent(
394        array: &Self::Array,
395        parent: &ArrayRef,
396        child_idx: usize,
397        ctx: &mut ExecutionCtx,
398    ) -> VortexResult<Option<ArrayRef>> {
399        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
400    }
401
402    fn reduce_parent(
403        array: &SequenceArray,
404        parent: &ArrayRef,
405        child_idx: usize,
406    ) -> VortexResult<Option<ArrayRef>> {
407        RULES.evaluate(array, parent, child_idx)
408    }
409}
410
411impl OperationsVTable<Sequence> for Sequence {
412    fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
413        Scalar::try_new(
414            array.dtype().clone(),
415            Some(ScalarValue::Primitive(array.index_value(index))),
416        )
417    }
418}
419
420impl ValidityVTable<Sequence> for Sequence {
421    fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
422        Ok(Validity::AllValid)
423    }
424}
425
426#[derive(Clone, Debug)]
427pub struct Sequence;
428
429impl Sequence {
430    pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
431}
432
433#[cfg(test)]
434mod tests {
435    use vortex_array::arrays::PrimitiveArray;
436    use vortex_array::assert_arrays_eq;
437    use vortex_array::dtype::Nullability;
438    use vortex_array::expr::stats::Precision as StatPrecision;
439    use vortex_array::expr::stats::Stat;
440    use vortex_array::expr::stats::StatsProviderExt;
441    use vortex_array::scalar::Scalar;
442    use vortex_array::scalar::ScalarValue;
443    use vortex_error::VortexResult;
444
445    use crate::array::SequenceArray;
446
447    #[test]
448    fn test_sequence_canonical() {
449        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
450
451        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
452
453        assert_arrays_eq!(arr, canon);
454    }
455
456    #[test]
457    fn test_sequence_slice_canonical() {
458        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
459            .unwrap()
460            .slice(2..3)
461            .unwrap();
462
463        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
464
465        assert_arrays_eq!(arr, canon);
466    }
467
468    #[test]
469    fn test_sequence_scalar_at() {
470        let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
471            .unwrap()
472            .scalar_at(2)
473            .unwrap();
474
475        assert_eq!(
476            scalar,
477            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
478        )
479    }
480
481    #[test]
482    fn test_sequence_min_max() {
483        assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
484        assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
485    }
486
487    #[test]
488    fn test_sequence_too_big() {
489        assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
490        assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
491    }
492
493    #[test]
494    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
495        let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
496
497        let is_sorted = arr
498            .statistics()
499            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
500        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
501
502        let is_strict_sorted = arr
503            .statistics()
504            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
505        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
506        Ok(())
507    }
508
509    #[test]
510    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
511        let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
512
513        let is_sorted = arr
514            .statistics()
515            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
516        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
517
518        let is_strict_sorted = arr
519            .statistics()
520            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
521        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
522        Ok(())
523    }
524
525    #[test]
526    fn negative_multiplier_not_sorted() -> VortexResult<()> {
527        let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
528
529        let is_sorted = arr
530            .statistics()
531            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
532        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
533
534        let is_strict_sorted = arr
535            .statistics()
536            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
537        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
538        Ok(())
539    }
540
541    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
542    // multiplier > i64::MAX were unable to be constructed.
543    #[test]
544    fn test_large_multiplier_sorted() -> VortexResult<()> {
545        let large_multiplier = (i64::MAX as u64) + 1;
546        let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
547
548        let is_sorted = arr
549            .statistics()
550            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
551
552        let is_strict_sorted = arr
553            .statistics()
554            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
555
556        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
557        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
558
559        Ok(())
560    }
561}