Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::hash::Hasher;
8
9use num_traits::cast::FromPrimitive;
10use prost::Message;
11use smallvec::smallvec;
12use vortex_array::Array;
13use vortex_array::ArrayEq;
14use vortex_array::ArrayHash;
15use vortex_array::ArrayId;
16use vortex_array::ArrayParts;
17use vortex_array::ArrayRef;
18use vortex_array::ArrayView;
19use vortex_array::ExecutionCtx;
20use vortex_array::ExecutionResult;
21use vortex_array::Precision;
22use vortex_array::buffer::BufferHandle;
23use vortex_array::dtype::DType;
24use vortex_array::dtype::NativePType;
25use vortex_array::dtype::Nullability;
26use vortex_array::dtype::Nullability::NonNullable;
27use vortex_array::dtype::PType;
28use vortex_array::expr::stats::Precision as StatPrecision;
29use vortex_array::expr::stats::Stat;
30use vortex_array::match_each_integer_ptype;
31use vortex_array::match_each_native_ptype;
32use vortex_array::match_each_pvalue;
33use vortex_array::scalar::PValue;
34use vortex_array::scalar::Scalar;
35use vortex_array::scalar::ScalarValue;
36use vortex_array::serde::ArrayChildren;
37use vortex_array::stats::StatsSet;
38use vortex_array::validity::Validity;
39use vortex_array::vtable::OperationsVTable;
40use vortex_array::vtable::VTable;
41use vortex_array::vtable::ValidityVTable;
42use vortex_error::VortexExpect;
43use vortex_error::VortexResult;
44use vortex_error::vortex_bail;
45use vortex_error::vortex_ensure;
46use vortex_error::vortex_err;
47use vortex_error::vortex_panic;
48use vortex_session::VortexSession;
49use vortex_session::registry::CachedId;
50
51use crate::compress::sequence_decompress;
52use crate::kernel::PARENT_KERNELS;
53use crate::rules::RULES;
54
55/// A [`Sequence`]-encoded Vortex array.
56pub type SequenceArray = Array<Sequence>;
57
58#[derive(Clone, prost::Message)]
59pub struct SequenceMetadata {
60    #[prost(message, tag = "1")]
61    base: Option<vortex_proto::scalar::ScalarValue>,
62    #[prost(message, tag = "2")]
63    multiplier: Option<vortex_proto::scalar::ScalarValue>,
64}
65
66pub(super) const SLOT_NAMES: [&str; 0] = [];
67
68#[derive(Clone, Debug)]
69/// An array representing the equation `A[i] = base + i * multiplier`.
70pub struct SequenceData {
71    base: PValue,
72    multiplier: PValue,
73}
74
75impl Display for SequenceData {
76    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
77        write!(f, "base: {}, multiplier: {}", self.base, self.multiplier)
78    }
79}
80
81pub struct SequenceDataParts {
82    pub base: PValue,
83    pub multiplier: PValue,
84    pub ptype: PType,
85}
86
87impl SequenceData {
88    pub(crate) fn try_new_typed<T: NativePType + Into<PValue>>(
89        base: T,
90        multiplier: T,
91        nullability: Nullability,
92        length: usize,
93    ) -> VortexResult<Self> {
94        Self::try_new(
95            base.into(),
96            multiplier.into(),
97            T::PTYPE,
98            nullability,
99            length,
100        )
101    }
102
103    /// Constructs a sequence array using two integer values (with the same ptype).
104    pub(crate) fn try_new(
105        base: PValue,
106        multiplier: PValue,
107        ptype: PType,
108        nullability: Nullability,
109        length: usize,
110    ) -> VortexResult<Self> {
111        let dtype = DType::Primitive(ptype, nullability);
112        Self::validate(base, multiplier, &dtype, length)?;
113        let (base, multiplier) = Self::normalize(base, multiplier, ptype)?;
114
115        Ok(unsafe { Self::new_unchecked(base, multiplier) })
116    }
117
118    pub fn validate(
119        base: PValue,
120        multiplier: PValue,
121        dtype: &DType,
122        length: usize,
123    ) -> VortexResult<()> {
124        let DType::Primitive(ptype, _) = dtype else {
125            vortex_bail!("only primitive dtypes are supported in SequenceArray currently");
126        };
127
128        if !ptype.is_int() {
129            vortex_bail!("only integer ptype are supported in SequenceArray currently")
130        }
131
132        vortex_ensure!(length > 0, "SequenceArray length must be greater than zero");
133        Self::try_last(base, multiplier, *ptype, length).map_err(|e| {
134            e.with_context(format!(
135                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
136            ))
137        })?;
138
139        Ok(())
140    }
141
142    fn normalize(base: PValue, multiplier: PValue, ptype: PType) -> VortexResult<(PValue, PValue)> {
143        match_each_integer_ptype!(ptype, |P| {
144            Ok((
145                PValue::from(base.cast::<P>()?),
146                PValue::from(multiplier.cast::<P>()?),
147            ))
148        })
149    }
150
151    /// Constructs a [`SequenceArray`] payload without validation.
152    ///
153    /// # Safety
154    ///
155    /// The caller must ensure that:
156    /// - `base` and `multiplier` are both normalized to the same integer `ptype`.
157    /// - they are logically compatible with the outer dtype and len.
158    pub(crate) unsafe fn new_unchecked(base: PValue, multiplier: PValue) -> Self {
159        Self { base, multiplier }
160    }
161
162    pub fn ptype(&self) -> PType {
163        self.base.ptype()
164    }
165
166    pub fn base(&self) -> PValue {
167        self.base
168    }
169
170    pub fn multiplier(&self) -> PValue {
171        self.multiplier
172    }
173
174    pub fn into_parts(self) -> SequenceDataParts {
175        SequenceDataParts {
176            base: self.base,
177            multiplier: self.multiplier,
178            ptype: self.base.ptype(),
179        }
180    }
181
182    pub(crate) fn try_last(
183        base: PValue,
184        multiplier: PValue,
185        ptype: PType,
186        length: usize,
187    ) -> VortexResult<PValue> {
188        match_each_integer_ptype!(ptype, |P| {
189            let len_t = <P>::from_usize(length - 1)
190                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
191
192            let base = base.cast::<P>()?;
193            let multiplier = multiplier.cast::<P>()?;
194            let last = len_t
195                .checked_mul(multiplier)
196                .and_then(|offset| offset.checked_add(base))
197                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
198            Ok(PValue::from(last))
199        })
200    }
201
202    pub(crate) fn index_value(&self, idx: usize) -> PValue {
203        match_each_native_ptype!(self.ptype(), |P| {
204            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
205            let multiplier = self
206                .multiplier
207                .cast::<P>()
208                .vortex_expect("must be able to cast");
209            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
210
211            PValue::from(value)
212        })
213    }
214}
215
216impl ArrayHash for SequenceData {
217    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
218        self.base.hash(state);
219        self.multiplier.hash(state);
220    }
221}
222
223impl ArrayEq for SequenceData {
224    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
225        self.base == other.base && self.multiplier == other.multiplier
226    }
227}
228
229impl VTable for Sequence {
230    type TypedArrayData = SequenceData;
231
232    type OperationsVTable = Self;
233    type ValidityVTable = Self;
234
235    fn id(&self) -> ArrayId {
236        static ID: CachedId = CachedId::new("vortex.sequence");
237        *ID
238    }
239
240    fn validate(
241        &self,
242        data: &Self::TypedArrayData,
243        dtype: &DType,
244        len: usize,
245        _slots: &[Option<ArrayRef>],
246    ) -> VortexResult<()> {
247        SequenceData::validate(data.base, data.multiplier, dtype, len)
248    }
249
250    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
251        0
252    }
253
254    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
255        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
256    }
257
258    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
259        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
260    }
261
262    fn serialize(
263        array: ArrayView<'_, Self>,
264        _session: &VortexSession,
265    ) -> VortexResult<Option<Vec<u8>>> {
266        let metadata = SequenceMetadata {
267            base: Some((&array.base()).into()),
268            multiplier: Some((&array.multiplier()).into()),
269        };
270
271        Ok(Some(metadata.encode_to_vec()))
272    }
273
274    fn deserialize(
275        &self,
276        dtype: &DType,
277        len: usize,
278        metadata: &[u8],
279        buffers: &[BufferHandle],
280        children: &dyn ArrayChildren,
281        session: &VortexSession,
282    ) -> VortexResult<ArrayParts<Self>> {
283        vortex_ensure!(
284            buffers.is_empty(),
285            "SequenceArray expects 0 buffers, got {}",
286            buffers.len()
287        );
288        vortex_ensure!(
289            children.is_empty(),
290            "SequenceArray expects 0 children, got {}",
291            children.len()
292        );
293        let metadata = SequenceMetadata::decode(metadata)?;
294
295        let ptype = dtype.as_ptype();
296
297        // We go via Scalar to validate that the value is valid for the ptype.
298        let base = Scalar::from_proto_value(
299            metadata
300                .base
301                .as_ref()
302                .ok_or_else(|| vortex_err!("base required"))?,
303            &DType::Primitive(ptype, NonNullable),
304            session,
305        )?
306        .as_primitive()
307        .pvalue()
308        .vortex_expect("sequence array base should be a non-nullable primitive");
309
310        let multiplier = Scalar::from_proto_value(
311            metadata
312                .multiplier
313                .as_ref()
314                .ok_or_else(|| vortex_err!("multiplier required"))?,
315            &DType::Primitive(ptype, NonNullable),
316            session,
317        )?
318        .as_primitive()
319        .pvalue()
320        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
321
322        let data = SequenceData::try_new(base, multiplier, ptype, dtype.nullability(), len)?;
323        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data))
324    }
325
326    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
327        SLOT_NAMES[idx].to_string()
328    }
329
330    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
331        sequence_decompress(&array).map(ExecutionResult::done)
332    }
333
334    fn execute_parent(
335        array: ArrayView<'_, Self>,
336        parent: &ArrayRef,
337        child_idx: usize,
338        ctx: &mut ExecutionCtx,
339    ) -> VortexResult<Option<ArrayRef>> {
340        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
341    }
342
343    fn reduce_parent(
344        array: ArrayView<'_, Self>,
345        parent: &ArrayRef,
346        child_idx: usize,
347    ) -> VortexResult<Option<ArrayRef>> {
348        RULES.evaluate(array, parent, child_idx)
349    }
350}
351
352impl OperationsVTable<Sequence> for Sequence {
353    fn scalar_at(
354        array: ArrayView<'_, Sequence>,
355        index: usize,
356        _ctx: &mut ExecutionCtx,
357    ) -> VortexResult<Scalar> {
358        Scalar::try_new(
359            array.dtype().clone(),
360            Some(ScalarValue::Primitive(array.index_value(index))),
361        )
362    }
363}
364
365impl ValidityVTable<Sequence> for Sequence {
366    fn validity(_array: ArrayView<'_, Sequence>) -> VortexResult<Validity> {
367        Ok(Validity::AllValid)
368    }
369}
370
371#[derive(Clone, Debug)]
372pub struct Sequence;
373
374impl Sequence {
375    fn stats(multiplier: PValue) -> StatsSet {
376        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
377        // and strictly sorted iff multiplier > 0.
378        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
379            multiplier,
380            uint: |v| { (true, v > 0) },
381            int: |v| { (v >= 0, v > 0) },
382            float: |_v| { unreachable!("float multiplier not supported") }
383        );
384
385        // SAFETY: we don't have duplicate stats.
386        unsafe {
387            StatsSet::new_unchecked(smallvec![
388                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
389                (
390                    Stat::IsStrictSorted,
391                    StatPrecision::Exact(is_strict_sorted.into()),
392                ),
393            ])
394        }
395    }
396
397    /// Construct a new [`SequenceArray`] from pre-validated parts.
398    ///
399    /// # Safety
400    ///
401    /// Caller must ensure the sequence is logically compatible with the provided dtype and len.
402    pub(crate) unsafe fn new_unchecked(
403        base: PValue,
404        multiplier: PValue,
405        ptype: PType,
406        nullability: Nullability,
407        length: usize,
408    ) -> SequenceArray {
409        let dtype = DType::Primitive(ptype, nullability);
410        let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype)
411            .vortex_expect("SequenceArray parts must be normalized to the target ptype");
412        let stats = Self::stats(multiplier);
413        let data = unsafe { SequenceData::new_unchecked(base, multiplier) };
414        unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
415            .with_stats_set(stats)
416    }
417
418    /// Construct a new [`SequenceArray`] from its components.
419    pub fn try_new(
420        base: PValue,
421        multiplier: PValue,
422        ptype: PType,
423        nullability: Nullability,
424        length: usize,
425    ) -> VortexResult<SequenceArray> {
426        let dtype = DType::Primitive(ptype, nullability);
427        let data = SequenceData::try_new(base, multiplier, ptype, nullability, length)?;
428        let stats = Self::stats(data.multiplier());
429        Ok(
430            unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
431                .with_stats_set(stats),
432        )
433    }
434
435    /// Construct a new typed [`SequenceArray`] from base/multiplier values.
436    pub fn try_new_typed<T: NativePType + Into<PValue>>(
437        base: T,
438        multiplier: T,
439        nullability: Nullability,
440        length: usize,
441    ) -> VortexResult<SequenceArray> {
442        let ptype = T::PTYPE;
443        let dtype = DType::Primitive(ptype, nullability);
444        let data = SequenceData::try_new_typed(base, multiplier, nullability, length)?;
445        let stats = Self::stats(data.multiplier());
446        Ok(
447            unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
448                .with_stats_set(stats),
449        )
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use vortex_array::LEGACY_SESSION;
456    use vortex_array::VortexSessionExecute;
457    use vortex_array::arrays::PrimitiveArray;
458    use vortex_array::assert_arrays_eq;
459    use vortex_array::dtype::Nullability;
460    use vortex_array::expr::stats::Precision as StatPrecision;
461    use vortex_array::expr::stats::Stat;
462    use vortex_array::expr::stats::StatsProviderExt;
463    use vortex_array::scalar::Scalar;
464    use vortex_array::scalar::ScalarValue;
465    use vortex_error::VortexResult;
466
467    use crate::Sequence;
468
469    #[test]
470    fn test_sequence_canonical() {
471        let arr = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
472
473        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
474
475        assert_arrays_eq!(arr, canon);
476    }
477
478    #[test]
479    fn test_sequence_slice_canonical() {
480        let arr = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
481            .unwrap()
482            .slice(2..3)
483            .unwrap();
484
485        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
486
487        assert_arrays_eq!(arr, canon);
488    }
489
490    #[test]
491    fn test_sequence_scalar_at() {
492        let scalar = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
493            .unwrap()
494            .execute_scalar(2, &mut LEGACY_SESSION.create_execution_ctx())
495            .unwrap();
496
497        assert_eq!(
498            scalar,
499            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
500        )
501    }
502
503    #[test]
504    fn test_sequence_min_max() {
505        assert!(Sequence::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
506        assert!(Sequence::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
507    }
508
509    #[test]
510    fn test_sequence_too_big() {
511        assert!(Sequence::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
512        assert!(Sequence::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
513    }
514
515    #[test]
516    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
517        let arr = Sequence::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
518
519        let is_sorted = arr
520            .statistics()
521            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
522        assert_eq!(is_sorted, StatPrecision::Exact(true));
523
524        let is_strict_sorted = arr
525            .statistics()
526            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
527        assert_eq!(is_strict_sorted, StatPrecision::Exact(true));
528        Ok(())
529    }
530
531    #[test]
532    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
533        let arr = Sequence::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
534
535        let is_sorted = arr
536            .statistics()
537            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
538        assert_eq!(is_sorted, StatPrecision::Exact(true));
539
540        let is_strict_sorted = arr
541            .statistics()
542            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
543        assert_eq!(is_strict_sorted, StatPrecision::Exact(false));
544        Ok(())
545    }
546
547    #[test]
548    fn negative_multiplier_not_sorted() -> VortexResult<()> {
549        let arr = Sequence::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
550
551        let is_sorted = arr
552            .statistics()
553            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
554        assert_eq!(is_sorted, StatPrecision::Exact(false));
555
556        let is_strict_sorted = arr
557            .statistics()
558            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
559        assert_eq!(is_strict_sorted, StatPrecision::Exact(false));
560        Ok(())
561    }
562
563    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
564    // multiplier > i64::MAX were unable to be constructed.
565    #[test]
566    fn test_large_multiplier_sorted() -> VortexResult<()> {
567        let large_multiplier = (i64::MAX as u64) + 1;
568        let arr = Sequence::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
569
570        let is_sorted = arr
571            .statistics()
572            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
573
574        let is_strict_sorted = arr
575            .statistics()
576            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
577
578        assert_eq!(is_sorted, StatPrecision::Exact(true));
579        assert_eq!(is_strict_sorted, StatPrecision::Exact(true));
580
581        Ok(())
582    }
583}