Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::ArrayRef;
8use vortex_array::DeserializeMetadata;
9use vortex_array::ExecutionCtx;
10use vortex_array::IntoArray;
11use vortex_array::Precision;
12use vortex_array::ProstMetadata;
13use vortex_array::SerializeMetadata;
14use vortex_array::arrays::PrimitiveArray;
15use vortex_array::buffer::BufferHandle;
16use vortex_array::dtype::DType;
17use vortex_array::dtype::NativePType;
18use vortex_array::dtype::Nullability;
19use vortex_array::dtype::Nullability::NonNullable;
20use vortex_array::dtype::PType;
21use vortex_array::expr::stats::Precision as StatPrecision;
22use vortex_array::expr::stats::Stat;
23use vortex_array::match_each_integer_ptype;
24use vortex_array::match_each_native_ptype;
25use vortex_array::match_each_pvalue;
26use vortex_array::scalar::PValue;
27use vortex_array::scalar::Scalar;
28use vortex_array::scalar::ScalarValue;
29use vortex_array::serde::ArrayChildren;
30use vortex_array::stats::ArrayStats;
31use vortex_array::stats::StatsSet;
32use vortex_array::stats::StatsSetRef;
33use vortex_array::validity::Validity;
34use vortex_array::vtable;
35use vortex_array::vtable::ArrayId;
36use vortex_array::vtable::OperationsVTable;
37use vortex_array::vtable::VTable;
38use vortex_array::vtable::ValidityVTable;
39use vortex_buffer::BufferMut;
40use vortex_error::VortexExpect;
41use vortex_error::VortexResult;
42use vortex_error::vortex_bail;
43use vortex_error::vortex_ensure;
44use vortex_error::vortex_err;
45use vortex_error::vortex_panic;
46use vortex_session::VortexSession;
47
48use crate::kernel::PARENT_KERNELS;
49use crate::rules::RULES;
50
51vtable!(Sequence);
52
53#[derive(Debug, Clone, Copy)]
54pub struct SequenceMetadata {
55    base: PValue,
56    multiplier: PValue,
57}
58
59#[derive(Clone, prost::Message)]
60pub struct ProstSequenceMetadata {
61    #[prost(message, tag = "1")]
62    base: Option<vortex_proto::scalar::ScalarValue>,
63    #[prost(message, tag = "2")]
64    multiplier: Option<vortex_proto::scalar::ScalarValue>,
65}
66
67/// Components of [`SequenceArray`].
68pub struct SequenceArrayParts {
69    pub base: PValue,
70    pub multiplier: PValue,
71    pub len: usize,
72    pub ptype: PType,
73    pub nullability: Nullability,
74}
75
76#[derive(Clone, Debug)]
77/// An array representing the equation `A[i] = base + i * multiplier`.
78pub struct SequenceArray {
79    base: PValue,
80    multiplier: PValue,
81    dtype: DType,
82    pub(crate) len: usize,
83    stats_set: ArrayStats,
84}
85
86impl SequenceArray {
87    pub fn try_new_typed<T: NativePType + Into<PValue>>(
88        base: T,
89        multiplier: T,
90        nullability: Nullability,
91        length: usize,
92    ) -> VortexResult<Self> {
93        Self::try_new(
94            base.into(),
95            multiplier.into(),
96            T::PTYPE,
97            nullability,
98            length,
99        )
100    }
101
102    /// Constructs a sequence array using two integer values (with the same ptype).
103    pub fn try_new(
104        base: PValue,
105        multiplier: PValue,
106        ptype: PType,
107        nullability: Nullability,
108        length: usize,
109    ) -> VortexResult<Self> {
110        if !ptype.is_int() {
111            vortex_bail!("only integer ptype are supported in SequenceArray currently")
112        }
113
114        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
115            e.with_context(format!(
116                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
117            ))
118        })?;
119
120        // SAFETY: we just validated that `ptype` is an integer and that the final
121        // element is representable via `try_last`.
122        Ok(unsafe { Self::new_unchecked(base, multiplier, ptype, nullability, length) })
123    }
124
125    /// Constructs a [`SequenceArray`] without validating that the `ptype` is an integer
126    /// type or that the final element is representable.
127    ///
128    /// # Safety
129    ///
130    /// The caller must ensure that:
131    /// - `ptype` is an integer type (i.e., `ptype.is_int()` returns `true`).
132    /// - `base + (length - 1) * multiplier` does not overflow the range of `ptype`.
133    ///
134    /// Violating the first invariant will cause a panic. Violating the second will
135    /// cause silent wraparound when materializing elements, producing incorrect values.
136    pub(crate) unsafe fn new_unchecked(
137        base: PValue,
138        multiplier: PValue,
139        ptype: PType,
140        nullability: Nullability,
141        length: usize,
142    ) -> Self {
143        let dtype = DType::Primitive(ptype, nullability);
144
145        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
146        // and strictly sorted iff multiplier > 0.
147
148        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
149            multiplier,
150            uint: |v| { (true, v> 0) },
151            int: |v| { (v >= 0, v > 0) },
152            float: |_v| { unreachable!("float multiplier not supported") }
153        );
154
155        // SAFETY: we don't have duplicate stats
156        let stats_set = unsafe {
157            StatsSet::new_unchecked(vec![
158                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
159                (
160                    Stat::IsStrictSorted,
161                    StatPrecision::Exact(is_strict_sorted.into()),
162                ),
163            ])
164        };
165
166        Self {
167            base,
168            multiplier,
169            dtype,
170            len: length,
171            stats_set: ArrayStats::from(stats_set),
172        }
173    }
174
175    pub fn ptype(&self) -> PType {
176        self.dtype.as_ptype()
177    }
178
179    pub fn base(&self) -> PValue {
180        self.base
181    }
182
183    pub fn multiplier(&self) -> PValue {
184        self.multiplier
185    }
186
187    pub(crate) fn try_last(
188        base: PValue,
189        multiplier: PValue,
190        ptype: PType,
191        length: usize,
192    ) -> VortexResult<PValue> {
193        match_each_integer_ptype!(ptype, |P| {
194            let len_t = <P>::from_usize(length - 1)
195                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
196
197            let base = base.cast::<P>()?;
198            let multiplier = multiplier.cast::<P>()?;
199            let last = len_t
200                .checked_mul(multiplier)
201                .and_then(|offset| offset.checked_add(base))
202                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
203            Ok(PValue::from(last))
204        })
205    }
206
207    pub(crate) fn index_value(&self, idx: usize) -> PValue {
208        assert!(idx < self.len, "index_value({idx}): index out of bounds");
209
210        match_each_native_ptype!(self.ptype(), |P| {
211            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
212            let multiplier = self
213                .multiplier
214                .cast::<P>()
215                .vortex_expect("must be able to cast");
216            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
217
218            PValue::from(value)
219        })
220    }
221
222    /// Returns the validated final value of a sequence array
223    pub fn last(&self) -> PValue {
224        Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
225            .vortex_expect("validated array")
226    }
227
228    pub fn into_parts(self) -> SequenceArrayParts {
229        SequenceArrayParts {
230            base: self.base,
231            multiplier: self.multiplier,
232            len: self.len,
233            ptype: self.dtype.as_ptype(),
234            nullability: self.dtype.nullability(),
235        }
236    }
237}
238
239impl VTable for SequenceVTable {
240    type Array = SequenceArray;
241
242    type Metadata = SequenceMetadata;
243    type OperationsVTable = Self;
244    type ValidityVTable = Self;
245
246    fn id(_array: &Self::Array) -> ArrayId {
247        Self::ID
248    }
249
250    fn len(array: &SequenceArray) -> usize {
251        array.len
252    }
253
254    fn dtype(array: &SequenceArray) -> &DType {
255        &array.dtype
256    }
257
258    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
259        array.stats_set.to_ref(array.as_ref())
260    }
261
262    fn array_hash<H: std::hash::Hasher>(
263        array: &SequenceArray,
264        state: &mut H,
265        _precision: Precision,
266    ) {
267        array.base.hash(state);
268        array.multiplier.hash(state);
269        array.dtype.hash(state);
270        array.len.hash(state);
271    }
272
273    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
274        array.base == other.base
275            && array.multiplier == other.multiplier
276            && array.dtype == other.dtype
277            && array.len == other.len
278    }
279
280    fn nbuffers(_array: &SequenceArray) -> usize {
281        0
282    }
283
284    fn buffer(_array: &SequenceArray, idx: usize) -> BufferHandle {
285        vortex_panic!("SequenceArray buffer index {idx} out of bounds")
286    }
287
288    fn buffer_name(_array: &SequenceArray, idx: usize) -> Option<String> {
289        vortex_panic!("SequenceArray buffer_name index {idx} out of bounds")
290    }
291
292    fn nchildren(_array: &SequenceArray) -> usize {
293        0
294    }
295
296    fn child(_array: &SequenceArray, idx: usize) -> ArrayRef {
297        vortex_panic!("SequenceArray child index {idx} out of bounds")
298    }
299
300    fn child_name(_array: &SequenceArray, idx: usize) -> String {
301        vortex_panic!("SequenceArray child_name index {idx} out of bounds")
302    }
303
304    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
305        Ok(SequenceMetadata {
306            base: array.base(),
307            multiplier: array.multiplier(),
308        })
309    }
310
311    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
312        let prost = ProstMetadata(ProstSequenceMetadata {
313            base: Some((&metadata.base).into()),
314            multiplier: Some((&metadata.multiplier).into()),
315        });
316
317        Ok(Some(prost.serialize()))
318    }
319
320    fn deserialize(
321        bytes: &[u8],
322        dtype: &DType,
323        _len: usize,
324        _buffers: &[BufferHandle],
325        _session: &VortexSession,
326    ) -> VortexResult<Self::Metadata> {
327        let prost =
328            <ProstMetadata<ProstSequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?;
329
330        let ptype = dtype.as_ptype();
331
332        // We go via Scalar to validate that the value is valid for the ptype.
333        let base = Scalar::from_proto_value(
334            prost
335                .base
336                .as_ref()
337                .ok_or_else(|| vortex_err!("base required"))?,
338            &DType::Primitive(ptype, NonNullable),
339        )?
340        .as_primitive()
341        .pvalue()
342        .vortex_expect("sequence array base should be a non-nullable primitive");
343
344        let multiplier = Scalar::from_proto_value(
345            prost
346                .multiplier
347                .as_ref()
348                .ok_or_else(|| vortex_err!("multiplier required"))?,
349            &DType::Primitive(ptype, NonNullable),
350        )?
351        .as_primitive()
352        .pvalue()
353        .vortex_expect("sequence array multiplier should be a non-nullable primitive");
354
355        Ok(SequenceMetadata { base, multiplier })
356    }
357
358    fn build(
359        dtype: &DType,
360        len: usize,
361        metadata: &Self::Metadata,
362        _buffers: &[BufferHandle],
363        _children: &dyn ArrayChildren,
364    ) -> VortexResult<SequenceArray> {
365        SequenceArray::try_new(
366            metadata.base,
367            metadata.multiplier,
368            dtype.as_ptype(),
369            dtype.nullability(),
370            len,
371        )
372    }
373
374    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
375        vortex_ensure!(
376            children.is_empty(),
377            "SequenceArray expects 0 children, got {}",
378            children.len()
379        );
380        Ok(())
381    }
382
383    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
384        let prim = match_each_native_ptype!(array.ptype(), |P| {
385            let base = array.base().cast::<P>()?;
386            let multiplier = array.multiplier().cast::<P>()?;
387            let values = BufferMut::from_iter(
388                (0..array.len())
389                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
390            );
391            PrimitiveArray::new(values, array.dtype.nullability().into())
392        });
393
394        Ok(prim.into_array())
395    }
396
397    fn execute_parent(
398        array: &Self::Array,
399        parent: &ArrayRef,
400        child_idx: usize,
401        ctx: &mut ExecutionCtx,
402    ) -> VortexResult<Option<ArrayRef>> {
403        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
404    }
405
406    fn reduce_parent(
407        array: &SequenceArray,
408        parent: &ArrayRef,
409        child_idx: usize,
410    ) -> VortexResult<Option<ArrayRef>> {
411        RULES.evaluate(array, parent, child_idx)
412    }
413}
414
415impl OperationsVTable<SequenceVTable> for SequenceVTable {
416    fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
417        Scalar::try_new(
418            array.dtype().clone(),
419            Some(ScalarValue::Primitive(array.index_value(index))),
420        )
421    }
422}
423
424impl ValidityVTable<SequenceVTable> for SequenceVTable {
425    fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
426        Ok(Validity::AllValid)
427    }
428}
429
430#[derive(Debug)]
431pub struct SequenceVTable;
432
433impl SequenceVTable {
434    pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
435}
436
437#[cfg(test)]
438mod tests {
439    use vortex_array::arrays::PrimitiveArray;
440    use vortex_array::assert_arrays_eq;
441    use vortex_array::dtype::Nullability;
442    use vortex_array::expr::stats::Precision as StatPrecision;
443    use vortex_array::expr::stats::Stat;
444    use vortex_array::expr::stats::StatsProviderExt;
445    use vortex_array::scalar::Scalar;
446    use vortex_array::scalar::ScalarValue;
447    use vortex_error::VortexResult;
448
449    use crate::array::SequenceArray;
450
451    #[test]
452    fn test_sequence_canonical() {
453        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4).unwrap();
454
455        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
456
457        assert_arrays_eq!(arr, canon);
458    }
459
460    #[test]
461    fn test_sequence_slice_canonical() {
462        let arr = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
463            .unwrap()
464            .slice(2..3)
465            .unwrap();
466
467        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
468
469        assert_arrays_eq!(arr, canon);
470    }
471
472    #[test]
473    fn test_sequence_scalar_at() {
474        let scalar = SequenceArray::try_new_typed(2i64, 3, Nullability::NonNullable, 4)
475            .unwrap()
476            .scalar_at(2)
477            .unwrap();
478
479        assert_eq!(
480            scalar,
481            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
482        )
483    }
484
485    #[test]
486    fn test_sequence_min_max() {
487        assert!(SequenceArray::try_new_typed(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
488        assert!(SequenceArray::try_new_typed(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
489    }
490
491    #[test]
492    fn test_sequence_too_big() {
493        assert!(SequenceArray::try_new_typed(127i8, 1i8, Nullability::NonNullable, 2).is_err());
494        assert!(SequenceArray::try_new_typed(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
495    }
496
497    #[test]
498    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
499        let arr = SequenceArray::try_new_typed(0i64, 3, Nullability::NonNullable, 4)?;
500
501        let is_sorted = arr
502            .statistics()
503            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
504        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
505
506        let is_strict_sorted = arr
507            .statistics()
508            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
509        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
510        Ok(())
511    }
512
513    #[test]
514    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
515        let arr = SequenceArray::try_new_typed(5i64, 0, Nullability::NonNullable, 4)?;
516
517        let is_sorted = arr
518            .statistics()
519            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
520        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
521
522        let is_strict_sorted = arr
523            .statistics()
524            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
525        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
526        Ok(())
527    }
528
529    #[test]
530    fn negative_multiplier_not_sorted() -> VortexResult<()> {
531        let arr = SequenceArray::try_new_typed(10i64, -1, Nullability::NonNullable, 4)?;
532
533        let is_sorted = arr
534            .statistics()
535            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
536        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
537
538        let is_strict_sorted = arr
539            .statistics()
540            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
541        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
542        Ok(())
543    }
544
545    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
546    // multiplier > i64::MAX were unable to be constructed.
547    #[test]
548    fn test_large_multiplier_sorted() -> VortexResult<()> {
549        let large_multiplier = (i64::MAX as u64) + 1;
550        let arr = SequenceArray::try_new_typed(0, large_multiplier, Nullability::NonNullable, 2)?;
551
552        let is_sorted = arr
553            .statistics()
554            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
555
556        let is_strict_sorted = arr
557            .statistics()
558            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
559
560        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
561        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
562
563        Ok(())
564    }
565}