Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::hash::Hasher;
8
9use num_traits::cast::FromPrimitive;
10use prost::Message;
11use vortex_array::Array;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayId;
15use vortex_array::ArrayParts;
16use vortex_array::ArrayRef;
17use vortex_array::ArrayView;
18use vortex_array::ExecutionCtx;
19use vortex_array::ExecutionResult;
20use vortex_array::Precision;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::NativePType;
24use vortex_array::dtype::Nullability;
25use vortex_array::dtype::Nullability::NonNullable;
26use vortex_array::dtype::PType;
27use vortex_array::expr::stats::Precision as StatPrecision;
28use vortex_array::expr::stats::Stat;
29use vortex_array::match_each_integer_ptype;
30use vortex_array::match_each_native_ptype;
31use vortex_array::match_each_pvalue;
32use vortex_array::scalar::PValue;
33use vortex_array::scalar::Scalar;
34use vortex_array::scalar::ScalarValue;
35use vortex_array::serde::ArrayChildren;
36use vortex_array::stats::StatsSet;
37use vortex_array::validity::Validity;
38use vortex_array::vtable::OperationsVTable;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityVTable;
41use vortex_error::VortexExpect;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_err;
46use vortex_error::vortex_panic;
47use vortex_session::VortexSession;
48use vortex_session::registry::CachedId;
49
50use crate::compress::sequence_decompress;
51use crate::kernel::PARENT_KERNELS;
52use crate::rules::RULES;
53
54/// A [`Sequence`]-encoded Vortex array.
55pub type SequenceArray = Array<Sequence>;
56
57#[derive(Clone, prost::Message)]
58pub struct SequenceMetadata {
59    #[prost(message, tag = "1")]
60    base: Option<vortex_proto::scalar::ScalarValue>,
61    #[prost(message, tag = "2")]
62    multiplier: Option<vortex_proto::scalar::ScalarValue>,
63}
64
65pub(super) const SLOT_NAMES: [&str; 0] = [];
66
67#[derive(Clone, Debug)]
68/// An array representing the equation `A[i] = base + i * multiplier`.
69pub struct SequenceData {
70    base: PValue,
71    multiplier: PValue,
72}
73
74impl Display for SequenceData {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        write!(f, "base: {}, multiplier: {}", self.base, self.multiplier)
77    }
78}
79
80pub struct SequenceDataParts {
81    pub base: PValue,
82    pub multiplier: PValue,
83    pub ptype: PType,
84}
85
86impl SequenceData {
87    pub(crate) fn try_new_typed<T: NativePType + Into<PValue>>(
88        base: T,
89        multiplier: T,
90        nullability: Nullability,
91        length: usize,
92    ) -> VortexResult<Self> {
93        Self::try_new(
94            base.into(),
95            multiplier.into(),
96            T::PTYPE,
97            nullability,
98            length,
99        )
100    }
101
102    /// Constructs a sequence array using two integer values (with the same ptype).
103    pub(crate) fn try_new(
104        base: PValue,
105        multiplier: PValue,
106        ptype: PType,
107        nullability: Nullability,
108        length: usize,
109    ) -> VortexResult<Self> {
110        let dtype = DType::Primitive(ptype, nullability);
111        Self::validate(base, multiplier, &dtype, length)?;
112        let (base, multiplier) = Self::normalize(base, multiplier, ptype)?;
113
114        Ok(unsafe { Self::new_unchecked(base, multiplier) })
115    }
116
117    pub fn validate(
118        base: PValue,
119        multiplier: PValue,
120        dtype: &DType,
121        length: usize,
122    ) -> VortexResult<()> {
123        let DType::Primitive(ptype, _) = dtype else {
124            vortex_bail!("only primitive dtypes are supported in SequenceArray currently");
125        };
126
127        if !ptype.is_int() {
128            vortex_bail!("only integer ptype are supported in SequenceArray currently")
129        }
130
131        vortex_ensure!(length > 0, "SequenceArray length must be greater than zero");
132        Self::try_last(base, multiplier, *ptype, length).map_err(|e| {
133            e.with_context(format!(
134                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
135            ))
136        })?;
137
138        Ok(())
139    }
140
141    fn normalize(base: PValue, multiplier: PValue, ptype: PType) -> VortexResult<(PValue, PValue)> {
142        match_each_integer_ptype!(ptype, |P| {
143            Ok((
144                PValue::from(base.cast::<P>()?),
145                PValue::from(multiplier.cast::<P>()?),
146            ))
147        })
148    }
149
150    /// Constructs a [`SequenceArray`] payload without validation.
151    ///
152    /// # Safety
153    ///
154    /// The caller must ensure that:
155    /// - `base` and `multiplier` are both normalized to the same integer `ptype`.
156    /// - they are logically compatible with the outer dtype and len.
157    pub(crate) unsafe fn new_unchecked(base: PValue, multiplier: PValue) -> Self {
158        Self { base, multiplier }
159    }
160
161    pub fn ptype(&self) -> PType {
162        self.base.ptype()
163    }
164
165    pub fn base(&self) -> PValue {
166        self.base
167    }
168
169    pub fn multiplier(&self) -> PValue {
170        self.multiplier
171    }
172
173    pub fn into_parts(self) -> SequenceDataParts {
174        SequenceDataParts {
175            base: self.base,
176            multiplier: self.multiplier,
177            ptype: self.base.ptype(),
178        }
179    }
180
181    pub(crate) fn try_last(
182        base: PValue,
183        multiplier: PValue,
184        ptype: PType,
185        length: usize,
186    ) -> VortexResult<PValue> {
187        match_each_integer_ptype!(ptype, |P| {
188            let len_t = <P>::from_usize(length - 1)
189                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
190
191            let base = base.cast::<P>()?;
192            let multiplier = multiplier.cast::<P>()?;
193            let last = len_t
194                .checked_mul(multiplier)
195                .and_then(|offset| offset.checked_add(base))
196                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
197            Ok(PValue::from(last))
198        })
199    }
200
201    pub(crate) fn index_value(&self, idx: usize) -> PValue {
202        match_each_native_ptype!(self.ptype(), |P| {
203            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
204            let multiplier = self
205                .multiplier
206                .cast::<P>()
207                .vortex_expect("must be able to cast");
208            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
209
210            PValue::from(value)
211        })
212    }
213}
214
215impl ArrayHash for SequenceData {
216    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
217        self.base.hash(state);
218        self.multiplier.hash(state);
219    }
220}
221
222impl ArrayEq for SequenceData {
223    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
224        self.base == other.base && self.multiplier == other.multiplier
225    }
226}
227
228impl VTable for Sequence {
229    type ArrayData = SequenceData;
230
231    type OperationsVTable = Self;
232    type ValidityVTable = Self;
233
234    fn id(&self) -> ArrayId {
235        static ID: CachedId = CachedId::new("vortex.sequence");
236        *ID
237    }
238
239    fn validate(
240        &self,
241        data: &Self::ArrayData,
242        dtype: &DType,
243        len: usize,
244        _slots: &[Option<ArrayRef>],
245    ) -> VortexResult<()> {
246        SequenceData::validate(data.base, data.multiplier, dtype, len)
247    }
248
249    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
250        0
251    }
252
253    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
254        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
255    }
256
257    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
258        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
259    }
260
261    fn serialize(
262        array: ArrayView<'_, Self>,
263        _session: &VortexSession,
264    ) -> VortexResult<Option<Vec<u8>>> {
265        let metadata = SequenceMetadata {
266            base: Some((&array.base()).into()),
267            multiplier: Some((&array.multiplier()).into()),
268        };
269
270        Ok(Some(metadata.encode_to_vec()))
271    }
272
273    fn deserialize(
274        &self,
275        dtype: &DType,
276        len: usize,
277        metadata: &[u8],
278        buffers: &[BufferHandle],
279        children: &dyn ArrayChildren,
280        session: &VortexSession,
281    ) -> VortexResult<ArrayParts<Self>> {
282        vortex_ensure!(
283            buffers.is_empty(),
284            "SequenceArray expects 0 buffers, got {}",
285            buffers.len()
286        );
287        vortex_ensure!(
288            children.is_empty(),
289            "SequenceArray expects 0 children, got {}",
290            children.len()
291        );
292        let metadata = SequenceMetadata::decode(metadata)?;
293
294        let ptype = dtype.as_ptype();
295
296        // We go via Scalar to validate that the value is valid for the ptype.
297        let base = Scalar::from_proto_value(
298            metadata
299                .base
300                .as_ref()
301                .ok_or_else(|| vortex_err!("base required"))?,
302            &DType::Primitive(ptype, NonNullable),
303            session,
304        )?
305        .as_primitive()
306        .pvalue()
307        .vortex_expect("sequence array base should be a non-nullable primitive");
308
309        let multiplier = Scalar::from_proto_value(
310            metadata
311                .multiplier
312                .as_ref()
313                .ok_or_else(|| vortex_err!("multiplier required"))?,
314            &DType::Primitive(ptype, NonNullable),
315            session,
316        )?
317        .as_primitive()
318        .pvalue()
319        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
320
321        let data = SequenceData::try_new(base, multiplier, ptype, dtype.nullability(), len)?;
322        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data))
323    }
324
325    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
326        SLOT_NAMES[idx].to_string()
327    }
328
329    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
330        sequence_decompress(&array).map(ExecutionResult::done)
331    }
332
333    fn execute_parent(
334        array: ArrayView<'_, Self>,
335        parent: &ArrayRef,
336        child_idx: usize,
337        ctx: &mut ExecutionCtx,
338    ) -> VortexResult<Option<ArrayRef>> {
339        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
340    }
341
342    fn reduce_parent(
343        array: ArrayView<'_, Self>,
344        parent: &ArrayRef,
345        child_idx: usize,
346    ) -> VortexResult<Option<ArrayRef>> {
347        RULES.evaluate(array, parent, child_idx)
348    }
349}
350
351impl OperationsVTable<Sequence> for Sequence {
352    fn scalar_at(
353        array: ArrayView<'_, Sequence>,
354        index: usize,
355        _ctx: &mut ExecutionCtx,
356    ) -> VortexResult<Scalar> {
357        Scalar::try_new(
358            array.dtype().clone(),
359            Some(ScalarValue::Primitive(array.index_value(index))),
360        )
361    }
362}
363
364impl ValidityVTable<Sequence> for Sequence {
365    fn validity(_array: ArrayView<'_, Sequence>) -> VortexResult<Validity> {
366        Ok(Validity::AllValid)
367    }
368}
369
370#[derive(Clone, Debug)]
371pub struct Sequence;
372
373impl Sequence {
374    fn stats(multiplier: PValue) -> StatsSet {
375        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
376        // and strictly sorted iff multiplier > 0.
377        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
378            multiplier,
379            uint: |v| { (true, v > 0) },
380            int: |v| { (v >= 0, v > 0) },
381            float: |_v| { unreachable!("float multiplier not supported") }
382        );
383
384        // SAFETY: we don't have duplicate stats.
385        unsafe {
386            StatsSet::new_unchecked(vec![
387                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
388                (
389                    Stat::IsStrictSorted,
390                    StatPrecision::Exact(is_strict_sorted.into()),
391                ),
392            ])
393        }
394    }
395
396    /// Construct a new [`SequenceArray`] from pre-validated parts.
397    ///
398    /// # Safety
399    ///
400    /// Caller must ensure the sequence is logically compatible with the provided dtype and len.
401    pub(crate) unsafe fn new_unchecked(
402        base: PValue,
403        multiplier: PValue,
404        ptype: PType,
405        nullability: Nullability,
406        length: usize,
407    ) -> SequenceArray {
408        let dtype = DType::Primitive(ptype, nullability);
409        let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype)
410            .vortex_expect("SequenceArray parts must be normalized to the target ptype");
411        let stats = Self::stats(multiplier);
412        let data = unsafe { SequenceData::new_unchecked(base, multiplier) };
413        unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
414            .with_stats_set(stats)
415    }
416
417    /// Construct a new [`SequenceArray`] from its components.
418    pub fn try_new(
419        base: PValue,
420        multiplier: PValue,
421        ptype: PType,
422        nullability: Nullability,
423        length: usize,
424    ) -> VortexResult<SequenceArray> {
425        let dtype = DType::Primitive(ptype, nullability);
426        let data = SequenceData::try_new(base, multiplier, ptype, nullability, length)?;
427        let stats = Self::stats(data.multiplier());
428        Ok(
429            unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
430                .with_stats_set(stats),
431        )
432    }
433
434    /// Construct a new typed [`SequenceArray`] from base/multiplier values.
435    pub fn try_new_typed<T: NativePType + Into<PValue>>(
436        base: T,
437        multiplier: T,
438        nullability: Nullability,
439        length: usize,
440    ) -> VortexResult<SequenceArray> {
441        let ptype = T::PTYPE;
442        let dtype = DType::Primitive(ptype, nullability);
443        let data = SequenceData::try_new_typed(base, multiplier, nullability, length)?;
444        let stats = Self::stats(data.multiplier());
445        Ok(
446            unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
447                .with_stats_set(stats),
448        )
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use vortex_array::LEGACY_SESSION;
455    use vortex_array::VortexSessionExecute;
456    use vortex_array::arrays::PrimitiveArray;
457    use vortex_array::assert_arrays_eq;
458    use vortex_array::dtype::Nullability;
459    use vortex_array::expr::stats::Precision as StatPrecision;
460    use vortex_array::expr::stats::Stat;
461    use vortex_array::expr::stats::StatsProviderExt;
462    use vortex_array::scalar::Scalar;
463    use vortex_array::scalar::ScalarValue;
464    use vortex_error::VortexResult;
465
466    use crate::Sequence;
467
468    #[test]
469    fn test_sequence_canonical() {
470        let arr = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
471
472        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
473
474        assert_arrays_eq!(arr, canon);
475    }
476
477    #[test]
478    fn test_sequence_slice_canonical() {
479        let arr = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
480            .unwrap()
481            .slice(2..3)
482            .unwrap();
483
484        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
485
486        assert_arrays_eq!(arr, canon);
487    }
488
489    #[test]
490    fn test_sequence_scalar_at() {
491        let scalar = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
492            .unwrap()
493            .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
494            .unwrap();
495
496        assert_eq!(
497            scalar,
498            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
499        )
500    }
501
502    #[test]
503    fn test_sequence_min_max() {
504        assert!(Sequence::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
505        assert!(Sequence::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
506    }
507
508    #[test]
509    fn test_sequence_too_big() {
510        assert!(Sequence::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
511        assert!(Sequence::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
512    }
513
514    #[test]
515    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
516        let arr = Sequence::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
517
518        let is_sorted = arr
519            .statistics()
520            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
521        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
522
523        let is_strict_sorted = arr
524            .statistics()
525            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
526        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
527        Ok(())
528    }
529
530    #[test]
531    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
532        let arr = Sequence::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
533
534        let is_sorted = arr
535            .statistics()
536            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
537        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
538
539        let is_strict_sorted = arr
540            .statistics()
541            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
542        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
543        Ok(())
544    }
545
546    #[test]
547    fn negative_multiplier_not_sorted() -> VortexResult<()> {
548        let arr = Sequence::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
549
550        let is_sorted = arr
551            .statistics()
552            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
553        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
554
555        let is_strict_sorted = arr
556            .statistics()
557            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
558        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
559        Ok(())
560    }
561
562    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
563    // multiplier > i64::MAX were unable to be constructed.
564    #[test]
565    fn test_large_multiplier_sorted() -> VortexResult<()> {
566        let large_multiplier = (i64::MAX as u64) + 1;
567        let arr = Sequence::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
568
569        let is_sorted = arr
570            .statistics()
571            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
572
573        let is_strict_sorted = arr
574            .statistics()
575            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
576
577        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
578        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
579
580        Ok(())
581    }
582}