Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::hash::Hasher;
8
9use num_traits::cast::FromPrimitive;
10use prost::Message;
11use vortex_array::Array;
12use vortex_array::ArrayEq;
13use vortex_array::ArrayHash;
14use vortex_array::ArrayId;
15use vortex_array::ArrayParts;
16use vortex_array::ArrayRef;
17use vortex_array::ArrayView;
18use vortex_array::ExecutionCtx;
19use vortex_array::ExecutionResult;
20use vortex_array::Precision;
21use vortex_array::buffer::BufferHandle;
22use vortex_array::dtype::DType;
23use vortex_array::dtype::NativePType;
24use vortex_array::dtype::Nullability;
25use vortex_array::dtype::Nullability::NonNullable;
26use vortex_array::dtype::PType;
27use vortex_array::expr::stats::Precision as StatPrecision;
28use vortex_array::expr::stats::Stat;
29use vortex_array::match_each_integer_ptype;
30use vortex_array::match_each_native_ptype;
31use vortex_array::match_each_pvalue;
32use vortex_array::scalar::PValue;
33use vortex_array::scalar::Scalar;
34use vortex_array::scalar::ScalarValue;
35use vortex_array::serde::ArrayChildren;
36use vortex_array::stats::StatsSet;
37use vortex_array::validity::Validity;
38use vortex_array::vtable::OperationsVTable;
39use vortex_array::vtable::VTable;
40use vortex_array::vtable::ValidityVTable;
41use vortex_error::VortexExpect;
42use vortex_error::VortexResult;
43use vortex_error::vortex_bail;
44use vortex_error::vortex_ensure;
45use vortex_error::vortex_err;
46use vortex_error::vortex_panic;
47use vortex_session::VortexSession;
48
49use crate::compress::sequence_decompress;
50use crate::kernel::PARENT_KERNELS;
51use crate::rules::RULES;
52
53/// A [`Sequence`]-encoded Vortex array.
54pub type SequenceArray = Array<Sequence>;
55
56#[derive(Clone, prost::Message)]
57pub struct SequenceMetadata {
58    #[prost(message, tag = "1")]
59    base: Option<vortex_proto::scalar::ScalarValue>,
60    #[prost(message, tag = "2")]
61    multiplier: Option<vortex_proto::scalar::ScalarValue>,
62}
63
64pub(super) const SLOT_NAMES: [&str; 0] = [];
65
66#[derive(Clone, Debug)]
67/// An array representing the equation `A[i] = base + i * multiplier`.
68pub struct SequenceData {
69    base: PValue,
70    multiplier: PValue,
71}
72
73impl Display for SequenceData {
74    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
75        write!(f, "base: {}, multiplier: {}", self.base, self.multiplier)
76    }
77}
78
79pub struct SequenceDataParts {
80    pub base: PValue,
81    pub multiplier: PValue,
82    pub ptype: PType,
83}
84
85impl SequenceData {
86    pub(crate) fn try_new_typed<T: NativePType + Into<PValue>>(
87        base: T,
88        multiplier: T,
89        nullability: Nullability,
90        length: usize,
91    ) -> VortexResult<Self> {
92        Self::try_new(
93            base.into(),
94            multiplier.into(),
95            T::PTYPE,
96            nullability,
97            length,
98        )
99    }
100
101    /// Constructs a sequence array using two integer values (with the same ptype).
102    pub(crate) fn try_new(
103        base: PValue,
104        multiplier: PValue,
105        ptype: PType,
106        nullability: Nullability,
107        length: usize,
108    ) -> VortexResult<Self> {
109        let dtype = DType::Primitive(ptype, nullability);
110        Self::validate(base, multiplier, &dtype, length)?;
111        let (base, multiplier) = Self::normalize(base, multiplier, ptype)?;
112
113        Ok(unsafe { Self::new_unchecked(base, multiplier) })
114    }
115
116    pub fn validate(
117        base: PValue,
118        multiplier: PValue,
119        dtype: &DType,
120        length: usize,
121    ) -> VortexResult<()> {
122        let DType::Primitive(ptype, _) = dtype else {
123            vortex_bail!("only primitive dtypes are supported in SequenceArray currently");
124        };
125
126        if !ptype.is_int() {
127            vortex_bail!("only integer ptype are supported in SequenceArray currently")
128        }
129
130        vortex_ensure!(length > 0, "SequenceArray length must be greater than zero");
131        Self::try_last(base, multiplier, *ptype, length).map_err(|e| {
132            e.with_context(format!(
133                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
134            ))
135        })?;
136
137        Ok(())
138    }
139
140    fn normalize(base: PValue, multiplier: PValue, ptype: PType) -> VortexResult<(PValue, PValue)> {
141        match_each_integer_ptype!(ptype, |P| {
142            Ok((
143                PValue::from(base.cast::<P>()?),
144                PValue::from(multiplier.cast::<P>()?),
145            ))
146        })
147    }
148
149    /// Constructs a [`SequenceArray`] payload without validation.
150    ///
151    /// # Safety
152    ///
153    /// The caller must ensure that:
154    /// - `base` and `multiplier` are both normalized to the same integer `ptype`.
155    /// - they are logically compatible with the outer dtype and len.
156    pub(crate) unsafe fn new_unchecked(base: PValue, multiplier: PValue) -> Self {
157        Self { base, multiplier }
158    }
159
160    pub fn ptype(&self) -> PType {
161        self.base.ptype()
162    }
163
164    pub fn base(&self) -> PValue {
165        self.base
166    }
167
168    pub fn multiplier(&self) -> PValue {
169        self.multiplier
170    }
171
172    pub fn into_parts(self) -> SequenceDataParts {
173        SequenceDataParts {
174            base: self.base,
175            multiplier: self.multiplier,
176            ptype: self.base.ptype(),
177        }
178    }
179
180    pub(crate) fn try_last(
181        base: PValue,
182        multiplier: PValue,
183        ptype: PType,
184        length: usize,
185    ) -> VortexResult<PValue> {
186        match_each_integer_ptype!(ptype, |P| {
187            let len_t = <P>::from_usize(length - 1)
188                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
189
190            let base = base.cast::<P>()?;
191            let multiplier = multiplier.cast::<P>()?;
192            let last = len_t
193                .checked_mul(multiplier)
194                .and_then(|offset| offset.checked_add(base))
195                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
196            Ok(PValue::from(last))
197        })
198    }
199
200    pub(crate) fn index_value(&self, idx: usize) -> PValue {
201        match_each_native_ptype!(self.ptype(), |P| {
202            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
203            let multiplier = self
204                .multiplier
205                .cast::<P>()
206                .vortex_expect("must be able to cast");
207            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
208
209            PValue::from(value)
210        })
211    }
212}
213
214impl ArrayHash for SequenceData {
215    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
216        self.base.hash(state);
217        self.multiplier.hash(state);
218    }
219}
220
221impl ArrayEq for SequenceData {
222    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
223        self.base == other.base && self.multiplier == other.multiplier
224    }
225}
226
227impl VTable for Sequence {
228    type ArrayData = SequenceData;
229
230    type OperationsVTable = Self;
231    type ValidityVTable = Self;
232
233    fn id(&self) -> ArrayId {
234        Self::ID
235    }
236
237    fn validate(
238        &self,
239        data: &Self::ArrayData,
240        dtype: &DType,
241        len: usize,
242        _slots: &[Option<ArrayRef>],
243    ) -> VortexResult<()> {
244        SequenceData::validate(data.base, data.multiplier, dtype, len)
245    }
246
247    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
248        0
249    }
250
251    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
252        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
253    }
254
255    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
256        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
257    }
258
259    fn serialize(
260        array: ArrayView<'_, Self>,
261        _session: &VortexSession,
262    ) -> VortexResult<Option<Vec<u8>>> {
263        let metadata = SequenceMetadata {
264            base: Some((&array.base()).into()),
265            multiplier: Some((&array.multiplier()).into()),
266        };
267
268        Ok(Some(metadata.encode_to_vec()))
269    }
270
271    fn deserialize(
272        &self,
273        dtype: &DType,
274        len: usize,
275        metadata: &[u8],
276        buffers: &[BufferHandle],
277        children: &dyn ArrayChildren,
278        session: &VortexSession,
279    ) -> VortexResult<ArrayParts<Self>> {
280        vortex_ensure!(
281            buffers.is_empty(),
282            "SequenceArray expects 0 buffers, got {}",
283            buffers.len()
284        );
285        vortex_ensure!(
286            children.is_empty(),
287            "SequenceArray expects 0 children, got {}",
288            children.len()
289        );
290        let metadata = SequenceMetadata::decode(metadata)?;
291
292        let ptype = dtype.as_ptype();
293
294        // We go via Scalar to validate that the value is valid for the ptype.
295        let base = Scalar::from_proto_value(
296            metadata
297                .base
298                .as_ref()
299                .ok_or_else(|| vortex_err!("base required"))?,
300            &DType::Primitive(ptype, NonNullable),
301            session,
302        )?
303        .as_primitive()
304        .pvalue()
305        .vortex_expect("sequence array base should be a non-nullable primitive");
306
307        let multiplier = Scalar::from_proto_value(
308            metadata
309                .multiplier
310                .as_ref()
311                .ok_or_else(|| vortex_err!("multiplier required"))?,
312            &DType::Primitive(ptype, NonNullable),
313            session,
314        )?
315        .as_primitive()
316        .pvalue()
317        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
318
319        let data = SequenceData::try_new(base, multiplier, ptype, dtype.nullability(), len)?;
320        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data))
321    }
322
323    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
324        SLOT_NAMES[idx].to_string()
325    }
326
327    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
328        sequence_decompress(&array).map(ExecutionResult::done)
329    }
330
331    fn execute_parent(
332        array: ArrayView<'_, Self>,
333        parent: &ArrayRef,
334        child_idx: usize,
335        ctx: &mut ExecutionCtx,
336    ) -> VortexResult<Option<ArrayRef>> {
337        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
338    }
339
340    fn reduce_parent(
341        array: ArrayView<'_, Self>,
342        parent: &ArrayRef,
343        child_idx: usize,
344    ) -> VortexResult<Option<ArrayRef>> {
345        RULES.evaluate(array, parent, child_idx)
346    }
347}
348
349impl OperationsVTable<Sequence> for Sequence {
350    fn scalar_at(
351        array: ArrayView<'_, Sequence>,
352        index: usize,
353        _ctx: &mut ExecutionCtx,
354    ) -> VortexResult<Scalar> {
355        Scalar::try_new(
356            array.dtype().clone(),
357            Some(ScalarValue::Primitive(array.index_value(index))),
358        )
359    }
360}
361
362impl ValidityVTable<Sequence> for Sequence {
363    fn validity(_array: ArrayView<'_, Sequence>) -> VortexResult<Validity> {
364        Ok(Validity::AllValid)
365    }
366}
367
368#[derive(Clone, Debug)]
369pub struct Sequence;
370
371impl Sequence {
372    pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
373
374    fn stats(multiplier: PValue) -> StatsSet {
375        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
376        // and strictly sorted iff multiplier > 0.
377        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
378            multiplier,
379            uint: |v| { (true, v > 0) },
380            int: |v| { (v >= 0, v > 0) },
381            float: |_v| { unreachable!("float multiplier not supported") }
382        );
383
384        // SAFETY: we don't have duplicate stats.
385        unsafe {
386            StatsSet::new_unchecked(vec![
387                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
388                (
389                    Stat::IsStrictSorted,
390                    StatPrecision::Exact(is_strict_sorted.into()),
391                ),
392            ])
393        }
394    }
395
396    /// Construct a new [`SequenceArray`] from pre-validated parts.
397    ///
398    /// # Safety
399    ///
400    /// Caller must ensure the sequence is logically compatible with the provided dtype and len.
401    pub(crate) unsafe fn new_unchecked(
402        base: PValue,
403        multiplier: PValue,
404        ptype: PType,
405        nullability: Nullability,
406        length: usize,
407    ) -> SequenceArray {
408        let dtype = DType::Primitive(ptype, nullability);
409        let (base, multiplier) = SequenceData::normalize(base, multiplier, ptype)
410            .vortex_expect("SequenceArray parts must be normalized to the target ptype");
411        let stats = Self::stats(multiplier);
412        let data = unsafe { SequenceData::new_unchecked(base, multiplier) };
413        unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
414            .with_stats_set(stats)
415    }
416
417    /// Construct a new [`SequenceArray`] from its components.
418    pub fn try_new(
419        base: PValue,
420        multiplier: PValue,
421        ptype: PType,
422        nullability: Nullability,
423        length: usize,
424    ) -> VortexResult<SequenceArray> {
425        let dtype = DType::Primitive(ptype, nullability);
426        let data = SequenceData::try_new(base, multiplier, ptype, nullability, length)?;
427        let stats = Self::stats(data.multiplier());
428        Ok(
429            unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
430                .with_stats_set(stats),
431        )
432    }
433
434    /// Construct a new typed [`SequenceArray`] from base/multiplier values.
435    pub fn try_new_typed<T: NativePType + Into<PValue>>(
436        base: T,
437        multiplier: T,
438        nullability: Nullability,
439        length: usize,
440    ) -> VortexResult<SequenceArray> {
441        let ptype = T::PTYPE;
442        let dtype = DType::Primitive(ptype, nullability);
443        let data = SequenceData::try_new_typed(base, multiplier, nullability, length)?;
444        let stats = Self::stats(data.multiplier());
445        Ok(
446            unsafe { Array::from_parts_unchecked(ArrayParts::new(Sequence, dtype, length, data)) }
447                .with_stats_set(stats),
448        )
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use vortex_array::arrays::PrimitiveArray;
455    use vortex_array::assert_arrays_eq;
456    use vortex_array::dtype::Nullability;
457    use vortex_array::expr::stats::Precision as StatPrecision;
458    use vortex_array::expr::stats::Stat;
459    use vortex_array::expr::stats::StatsProviderExt;
460    use vortex_array::scalar::Scalar;
461    use vortex_array::scalar::ScalarValue;
462    use vortex_error::VortexResult;
463
464    use crate::Sequence;
465
466    #[test]
467    fn test_sequence_canonical() {
468        let arr = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
469
470        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
471
472        assert_arrays_eq!(arr, canon);
473    }
474
475    #[test]
476    fn test_sequence_slice_canonical() {
477        let arr = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
478            .unwrap()
479            .slice(2..3)
480            .unwrap();
481
482        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
483
484        assert_arrays_eq!(arr, canon);
485    }
486
487    #[test]
488    fn test_sequence_scalar_at() {
489        let scalar = Sequence::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
490            .unwrap()
491            .scalar_at(2)
492            .unwrap();
493
494        assert_eq!(
495            scalar,
496            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
497        )
498    }
499
500    #[test]
501    fn test_sequence_min_max() {
502        assert!(Sequence::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
503        assert!(Sequence::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
504    }
505
506    #[test]
507    fn test_sequence_too_big() {
508        assert!(Sequence::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
509        assert!(Sequence::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
510    }
511
512    #[test]
513    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
514        let arr = Sequence::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
515
516        let is_sorted = arr
517            .statistics()
518            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
519        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
520
521        let is_strict_sorted = arr
522            .statistics()
523            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
524        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
525        Ok(())
526    }
527
528    #[test]
529    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
530        let arr = Sequence::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
531
532        let is_sorted = arr
533            .statistics()
534            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
535        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
536
537        let is_strict_sorted = arr
538            .statistics()
539            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
540        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
541        Ok(())
542    }
543
544    #[test]
545    fn negative_multiplier_not_sorted() -> VortexResult<()> {
546        let arr = Sequence::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
547
548        let is_sorted = arr
549            .statistics()
550            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
551        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
552
553        let is_strict_sorted = arr
554            .statistics()
555            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
556        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
557        Ok(())
558    }
559
560    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
561    // multiplier > i64::MAX were unable to be constructed.
562    #[test]
563    fn test_large_multiplier_sorted() -> VortexResult<()> {
564        let large_multiplier = (i64::MAX as u64) + 1;
565        let arr = Sequence::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
566
567        let is_sorted = arr
568            .statistics()
569            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
570
571        let is_strict_sorted = arr
572            .statistics()
573            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
574
575        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
576        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
577
578        Ok(())
579    }
580}