Skip to main content

vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::ArrayBufferVisitor;
8use vortex_array::ArrayChildVisitor;
9use vortex_array::ArrayRef;
10use vortex_array::DeserializeMetadata;
11use vortex_array::ExecutionCtx;
12use vortex_array::IntoArray;
13use vortex_array::Precision;
14use vortex_array::ProstMetadata;
15use vortex_array::SerializeMetadata;
16use vortex_array::arrays::PrimitiveArray;
17use vortex_array::buffer::BufferHandle;
18use vortex_array::expr::stats::Precision as StatPrecision;
19use vortex_array::expr::stats::Stat;
20use vortex_array::match_each_pvalue;
21use vortex_array::scalar::PValue;
22use vortex_array::scalar::Scalar;
23use vortex_array::scalar::ScalarValue;
24use vortex_array::serde::ArrayChildren;
25use vortex_array::stats::ArrayStats;
26use vortex_array::stats::StatsSet;
27use vortex_array::stats::StatsSetRef;
28use vortex_array::validity::Validity;
29use vortex_array::vtable;
30use vortex_array::vtable::ArrayId;
31use vortex_array::vtable::BaseArrayVTable;
32use vortex_array::vtable::OperationsVTable;
33use vortex_array::vtable::VTable;
34use vortex_array::vtable::ValidityVTable;
35use vortex_array::vtable::VisitorVTable;
36use vortex_buffer::BufferMut;
37use vortex_dtype::DType;
38use vortex_dtype::NativePType;
39use vortex_dtype::Nullability;
40use vortex_dtype::Nullability::NonNullable;
41use vortex_dtype::PType;
42use vortex_dtype::match_each_integer_ptype;
43use vortex_dtype::match_each_native_ptype;
44use vortex_error::VortexExpect;
45use vortex_error::VortexResult;
46use vortex_error::vortex_bail;
47use vortex_error::vortex_ensure;
48use vortex_error::vortex_err;
49use vortex_session::VortexSession;
50
51use crate::kernel::PARENT_KERNELS;
52use crate::rules::RULES;
53
54vtable!(Sequence);
55
56#[derive(Clone, prost::Message)]
57pub struct SequenceMetadata {
58    #[prost(message, tag = "1")]
59    base: Option<vortex_proto::scalar::ScalarValue>,
60    #[prost(message, tag = "2")]
61    multiplier: Option<vortex_proto::scalar::ScalarValue>,
62}
63
64/// Components of [`SequenceArray`].
65pub struct SequenceArrayParts {
66    pub base: PValue,
67    pub multiplier: PValue,
68    pub len: usize,
69    pub ptype: PType,
70    pub nullability: Nullability,
71}
72
73#[derive(Clone, Debug)]
74/// An array representing the equation `A[i] = base + i * multiplier`.
75pub struct SequenceArray {
76    base: PValue,
77    multiplier: PValue,
78    dtype: DType,
79    pub(crate) len: usize,
80    stats_set: ArrayStats,
81}
82
83impl SequenceArray {
84    pub fn typed_new<T: NativePType + Into<PValue>>(
85        base: T,
86        multiplier: T,
87        nullability: Nullability,
88        length: usize,
89    ) -> VortexResult<Self> {
90        Self::new(
91            base.into(),
92            multiplier.into(),
93            T::PTYPE,
94            nullability,
95            length,
96        )
97    }
98
99    /// Constructs a sequence array using two integer values (with the same ptype).
100    pub fn new(
101        base: PValue,
102        multiplier: PValue,
103        ptype: PType,
104        nullability: Nullability,
105        length: usize,
106    ) -> VortexResult<Self> {
107        if !ptype.is_int() {
108            vortex_bail!("only integer ptype are supported in SequenceArray currently")
109        }
110
111        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
112            e.with_context(format!(
113                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
114            ))
115        })?;
116
117        Ok(Self::unchecked_new(
118            base,
119            multiplier,
120            ptype,
121            nullability,
122            length,
123        ))
124    }
125
126    pub(crate) fn unchecked_new(
127        base: PValue,
128        multiplier: PValue,
129        ptype: PType,
130        nullability: Nullability,
131        length: usize,
132    ) -> Self {
133        let dtype = DType::Primitive(ptype, nullability);
134
135        // A sequence A[i] = base + i * multiplier is sorted iff multiplier >= 0,
136        // and strictly sorted iff multiplier > 0.
137
138        let (is_sorted, is_strict_sorted) = match_each_pvalue!(
139            multiplier,
140            uint: |v| { (true, v> 0) },
141            int: |v| { (v >= 0, v > 0) },
142            float: |_v| { unreachable!("float multiplier not supported") }
143        );
144
145        // SAFETY: we don't have duplicate stats
146        let stats_set = unsafe {
147            StatsSet::new_unchecked(vec![
148                (Stat::IsSorted, StatPrecision::Exact(is_sorted.into())),
149                (
150                    Stat::IsStrictSorted,
151                    StatPrecision::Exact(is_strict_sorted.into()),
152                ),
153            ])
154        };
155
156        Self {
157            base,
158            multiplier,
159            dtype,
160            len: length,
161            stats_set: ArrayStats::from(stats_set),
162        }
163    }
164
165    pub fn ptype(&self) -> PType {
166        self.dtype.as_ptype()
167    }
168
169    pub fn base(&self) -> PValue {
170        self.base
171    }
172
173    pub fn multiplier(&self) -> PValue {
174        self.multiplier
175    }
176
177    pub(crate) fn try_last(
178        base: PValue,
179        multiplier: PValue,
180        ptype: PType,
181        length: usize,
182    ) -> VortexResult<PValue> {
183        match_each_integer_ptype!(ptype, |P| {
184            let len_t = <P>::from_usize(length - 1)
185                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
186
187            let base = base.cast::<P>()?;
188            let multiplier = multiplier.cast::<P>()?;
189            let last = len_t
190                .checked_mul(multiplier)
191                .and_then(|offset| offset.checked_add(base))
192                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
193            Ok(PValue::from(last))
194        })
195    }
196
197    pub(crate) fn index_value(&self, idx: usize) -> PValue {
198        assert!(idx < self.len, "index_value({idx}): index out of bounds");
199
200        match_each_native_ptype!(self.ptype(), |P| {
201            let base = self.base.cast::<P>().vortex_expect("must be able to cast");
202            let multiplier = self
203                .multiplier
204                .cast::<P>()
205                .vortex_expect("must be able to cast");
206            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
207
208            PValue::from(value)
209        })
210    }
211
212    /// Returns the validated final value of a sequence array
213    pub fn last(&self) -> PValue {
214        Self::try_last(self.base, self.multiplier, self.ptype(), self.len)
215            .vortex_expect("validated array")
216    }
217
218    pub fn into_parts(self) -> SequenceArrayParts {
219        SequenceArrayParts {
220            base: self.base,
221            multiplier: self.multiplier,
222            len: self.len,
223            ptype: self.dtype.as_ptype(),
224            nullability: self.dtype.nullability(),
225        }
226    }
227}
228
229impl VTable for SequenceVTable {
230    type Array = SequenceArray;
231
232    type Metadata = ProstMetadata<SequenceMetadata>;
233
234    type ArrayVTable = Self;
235    type OperationsVTable = Self;
236    type ValidityVTable = Self;
237    type VisitorVTable = Self;
238
239    fn id(_array: &Self::Array) -> ArrayId {
240        Self::ID
241    }
242
243    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
244        Ok(ProstMetadata(SequenceMetadata {
245            base: Some((&array.base()).into()),
246            multiplier: Some((&array.multiplier()).into()),
247        }))
248    }
249
250    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
251        Ok(Some(metadata.serialize()))
252    }
253
254    fn deserialize(
255        bytes: &[u8],
256        _dtype: &DType,
257        _len: usize,
258        _buffers: &[BufferHandle],
259        _session: &VortexSession,
260    ) -> VortexResult<Self::Metadata> {
261        Ok(ProstMetadata(
262            <ProstMetadata<SequenceMetadata> as DeserializeMetadata>::deserialize(bytes)?,
263        ))
264    }
265
266    fn build(
267        dtype: &DType,
268        len: usize,
269        metadata: &Self::Metadata,
270        _buffers: &[BufferHandle],
271        _children: &dyn ArrayChildren,
272    ) -> VortexResult<SequenceArray> {
273        let ptype = dtype.as_ptype();
274
275        // We go via scalar to cast the scalar values into the correct PType
276        let base = Scalar::from_proto_value(
277            metadata
278                .0
279                .base
280                .as_ref()
281                .ok_or_else(|| vortex_err!("base required"))?,
282            &DType::Primitive(ptype, NonNullable),
283        )?
284        .as_primitive()
285        .pvalue()
286        .vortex_expect("non-nullable primitive");
287
288        let multiplier = Scalar::from_proto_value(
289            metadata
290                .0
291                .multiplier
292                .as_ref()
293                .ok_or_else(|| vortex_err!("multiplier required"))?,
294            &DType::Primitive(ptype, NonNullable),
295        )?
296        .as_primitive()
297        .pvalue()
298        .vortex_expect("non-nullable primitive");
299
300        Ok(SequenceArray::unchecked_new(
301            base,
302            multiplier,
303            ptype,
304            dtype.nullability(),
305            len,
306        ))
307    }
308
309    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
310        vortex_ensure!(
311            children.is_empty(),
312            "SequenceArray expects 0 children, got {}",
313            children.len()
314        );
315        Ok(())
316    }
317
318    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
319        let prim = match_each_native_ptype!(array.ptype(), |P| {
320            let base = array.base().cast::<P>()?;
321            let multiplier = array.multiplier().cast::<P>()?;
322            let values = BufferMut::from_iter(
323                (0..array.len())
324                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
325            );
326            PrimitiveArray::new(values, array.dtype.nullability().into())
327        });
328
329        Ok(prim.into_array())
330    }
331
332    fn execute_parent(
333        array: &Self::Array,
334        parent: &ArrayRef,
335        child_idx: usize,
336        ctx: &mut ExecutionCtx,
337    ) -> VortexResult<Option<ArrayRef>> {
338        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
339    }
340
341    fn reduce_parent(
342        array: &SequenceArray,
343        parent: &ArrayRef,
344        child_idx: usize,
345    ) -> VortexResult<Option<ArrayRef>> {
346        RULES.evaluate(array, parent, child_idx)
347    }
348}
349
350impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
351    fn len(array: &SequenceArray) -> usize {
352        array.len
353    }
354
355    fn dtype(array: &SequenceArray) -> &DType {
356        &array.dtype
357    }
358
359    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
360        array.stats_set.to_ref(array.as_ref())
361    }
362
363    fn array_hash<H: std::hash::Hasher>(
364        array: &SequenceArray,
365        state: &mut H,
366        _precision: Precision,
367    ) {
368        array.base.hash(state);
369        array.multiplier.hash(state);
370        array.dtype.hash(state);
371        array.len.hash(state);
372    }
373
374    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
375        array.base == other.base
376            && array.multiplier == other.multiplier
377            && array.dtype == other.dtype
378            && array.len == other.len
379    }
380}
381
382impl OperationsVTable<SequenceVTable> for SequenceVTable {
383    fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
384        Scalar::try_new(
385            array.dtype().clone(),
386            Some(ScalarValue::Primitive(array.index_value(index))),
387        )
388    }
389}
390
391impl ValidityVTable<SequenceVTable> for SequenceVTable {
392    fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
393        Ok(Validity::AllValid)
394    }
395}
396
397impl VisitorVTable<SequenceVTable> for SequenceVTable {
398    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
399        // TODO(joe): expose scalar values
400    }
401
402    fn nbuffers(_array: &SequenceArray) -> usize {
403        0
404    }
405
406    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
407
408    fn nchildren(_array: &SequenceArray) -> usize {
409        0
410    }
411}
412
413#[derive(Debug)]
414pub struct SequenceVTable;
415
416impl SequenceVTable {
417    pub const ID: ArrayId = ArrayId::new_ref("vortex.sequence");
418}
419
420#[cfg(test)]
421mod tests {
422    use vortex_array::arrays::PrimitiveArray;
423    use vortex_array::assert_arrays_eq;
424    use vortex_array::expr::stats::Precision as StatPrecision;
425    use vortex_array::expr::stats::Stat;
426    use vortex_array::expr::stats::StatsProviderExt;
427    use vortex_array::scalar::Scalar;
428    use vortex_array::scalar::ScalarValue;
429    use vortex_dtype::Nullability;
430    use vortex_error::VortexResult;
431
432    use crate::array::SequenceArray;
433
434    #[test]
435    fn test_sequence_canonical() {
436        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
437
438        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
439
440        assert_arrays_eq!(arr, canon);
441    }
442
443    #[test]
444    fn test_sequence_slice_canonical() {
445        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
446            .unwrap()
447            .slice(2..3)
448            .unwrap();
449
450        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
451
452        assert_arrays_eq!(arr, canon);
453    }
454
455    #[test]
456    fn test_sequence_scalar_at() {
457        let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
458            .unwrap()
459            .scalar_at(2)
460            .unwrap();
461
462        assert_eq!(
463            scalar,
464            Scalar::try_new(scalar.dtype().clone(), Some(ScalarValue::from(8i64))).unwrap()
465        )
466    }
467
468    #[test]
469    fn test_sequence_min_max() {
470        assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
471        assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
472    }
473
474    #[test]
475    fn test_sequence_too_big() {
476        assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
477        assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
478    }
479
480    #[test]
481    fn positive_multiplier_is_strict_sorted() -> VortexResult<()> {
482        let arr = SequenceArray::typed_new(0i64, 3, Nullability::NonNullable, 4)?;
483
484        let is_sorted = arr
485            .statistics()
486            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
487        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
488
489        let is_strict_sorted = arr
490            .statistics()
491            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
492        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
493        Ok(())
494    }
495
496    #[test]
497    fn zero_multiplier_is_sorted_not_strict() -> VortexResult<()> {
498        let arr = SequenceArray::typed_new(5i64, 0, Nullability::NonNullable, 4)?;
499
500        let is_sorted = arr
501            .statistics()
502            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
503        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
504
505        let is_strict_sorted = arr
506            .statistics()
507            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
508        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
509        Ok(())
510    }
511
512    #[test]
513    fn negative_multiplier_not_sorted() -> VortexResult<()> {
514        let arr = SequenceArray::typed_new(10i64, -1, Nullability::NonNullable, 4)?;
515
516        let is_sorted = arr
517            .statistics()
518            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
519        assert_eq!(is_sorted, Some(StatPrecision::Exact(false)));
520
521        let is_strict_sorted = arr
522            .statistics()
523            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
524        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(false)));
525        Ok(())
526    }
527
528    // This is regression test for an issue caught by the fuzzer, where SequenceArrays with
529    // multiplier > i64::MAX were unable to be constructed.
530    #[test]
531    fn test_large_multiplier_sorted() -> VortexResult<()> {
532        let large_multiplier = (i64::MAX as u64) + 1;
533        let arr = SequenceArray::typed_new(0, large_multiplier, Nullability::NonNullable, 2)?;
534
535        let is_sorted = arr
536            .statistics()
537            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsSorted));
538
539        let is_strict_sorted = arr
540            .statistics()
541            .with_typed_stats_set(|s| s.get_as::<bool>(Stat::IsStrictSorted));
542
543        assert_eq!(is_sorted, Some(StatPrecision::Exact(true)));
544        assert_eq!(is_strict_sorted, Some(StatPrecision::Exact(true)));
545
546        Ok(())
547    }
548}