vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::ops::Range;
6
7use num_traits::cast::FromPrimitive;
8use vortex_array::ArrayBufferVisitor;
9use vortex_array::ArrayChildVisitor;
10use vortex_array::ArrayRef;
11use vortex_array::Canonical;
12use vortex_array::DeserializeMetadata;
13use vortex_array::ExecutionCtx;
14use vortex_array::Precision;
15use vortex_array::ProstMetadata;
16use vortex_array::SerializeMetadata;
17use vortex_array::arrays::FilterVTable;
18use vortex_array::arrays::PrimitiveArray;
19use vortex_array::buffer::BufferHandle;
20use vortex_array::serde::ArrayChildren;
21use vortex_array::stats::ArrayStats;
22use vortex_array::stats::StatsSetRef;
23use vortex_array::validity::Validity;
24use vortex_array::vtable;
25use vortex_array::vtable::ArrayId;
26use vortex_array::vtable::ArrayVTable;
27use vortex_array::vtable::ArrayVTableExt;
28use vortex_array::vtable::BaseArrayVTable;
29use vortex_array::vtable::CanonicalVTable;
30use vortex_array::vtable::EncodeVTable;
31use vortex_array::vtable::NotSupported;
32use vortex_array::vtable::OperationsVTable;
33use vortex_array::vtable::VTable;
34use vortex_array::vtable::ValidityVTable;
35use vortex_array::vtable::VisitorVTable;
36use vortex_buffer::BufferMut;
37use vortex_dtype::DType;
38use vortex_dtype::NativePType;
39use vortex_dtype::Nullability;
40use vortex_dtype::Nullability::NonNullable;
41use vortex_dtype::PType;
42use vortex_dtype::match_each_integer_ptype;
43use vortex_dtype::match_each_native_ptype;
44use vortex_error::VortexExpect;
45use vortex_error::VortexResult;
46use vortex_error::vortex_bail;
47use vortex_error::vortex_ensure;
48use vortex_error::vortex_err;
49use vortex_mask::AllOr;
50use vortex_mask::Mask;
51use vortex_scalar::PValue;
52use vortex_scalar::Scalar;
53use vortex_scalar::ScalarValue;
54use vortex_vector::Vector;
55use vortex_vector::VectorMut;
56use vortex_vector::VectorMutOps;
57use vortex_vector::primitive::PVector;
58
59vtable!(Sequence);
60
61#[derive(Clone, prost::Message)]
62pub struct SequenceMetadata {
63    #[prost(message, tag = "1")]
64    base: Option<vortex_proto::scalar::ScalarValue>,
65    #[prost(message, tag = "2")]
66    multiplier: Option<vortex_proto::scalar::ScalarValue>,
67}
68
69#[derive(Clone, Debug)]
70/// An array representing the equation `A[i] = base + i * multiplier`.
71pub struct SequenceArray {
72    base: PValue,
73    multiplier: PValue,
74    dtype: DType,
75    pub(crate) length: usize,
76    stats_set: ArrayStats,
77}
78
79impl SequenceArray {
80    pub fn typed_new<T: NativePType + Into<PValue>>(
81        base: T,
82        multiplier: T,
83        nullability: Nullability,
84        length: usize,
85    ) -> VortexResult<Self> {
86        Self::new(
87            base.into(),
88            multiplier.into(),
89            T::PTYPE,
90            nullability,
91            length,
92        )
93    }
94
95    /// Constructs a sequence array using two integer values (with the same ptype).
96    pub fn new(
97        base: PValue,
98        multiplier: PValue,
99        ptype: PType,
100        nullability: Nullability,
101        length: usize,
102    ) -> VortexResult<Self> {
103        if !ptype.is_int() {
104            vortex_bail!("only integer ptype are supported in SequenceArray currently")
105        }
106
107        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
108            e.with_context(format!(
109                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
110            ))
111        })?;
112
113        Ok(Self::unchecked_new(
114            base,
115            multiplier,
116            ptype,
117            nullability,
118            length,
119        ))
120    }
121
122    pub(crate) fn unchecked_new(
123        base: PValue,
124        multiplier: PValue,
125        ptype: PType,
126        nullability: Nullability,
127        length: usize,
128    ) -> Self {
129        let dtype = DType::Primitive(ptype, nullability);
130        Self {
131            base,
132            multiplier,
133            dtype,
134            length,
135            // TODO(joe): add stats, on construct or on use?
136            stats_set: Default::default(),
137        }
138    }
139
140    pub fn ptype(&self) -> PType {
141        self.dtype.as_ptype()
142    }
143
144    pub fn base(&self) -> PValue {
145        self.base
146    }
147
148    pub fn multiplier(&self) -> PValue {
149        self.multiplier
150    }
151
152    pub(crate) fn try_last(
153        base: PValue,
154        multiplier: PValue,
155        ptype: PType,
156        length: usize,
157    ) -> VortexResult<PValue> {
158        match_each_integer_ptype!(ptype, |P| {
159            let len_t = <P>::from_usize(length - 1)
160                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
161
162            let base = base.cast::<P>();
163            let multiplier = multiplier.cast::<P>();
164
165            let last = len_t
166                .checked_mul(multiplier)
167                .and_then(|offset| offset.checked_add(base))
168                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
169            Ok(PValue::from(last))
170        })
171    }
172
173    pub(crate) fn index_value(&self, idx: usize) -> PValue {
174        assert!(idx < self.length, "index_value({idx}): index out of bounds");
175
176        match_each_native_ptype!(self.ptype(), |P| {
177            let base = self.base.cast::<P>();
178            let multiplier = self.multiplier.cast::<P>();
179            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
180
181            PValue::from(value)
182        })
183    }
184
185    /// Returns the validated final value of a sequence array
186    pub fn last(&self) -> PValue {
187        Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
188            .vortex_expect("validated array")
189    }
190}
191
192impl VTable for SequenceVTable {
193    type Array = SequenceArray;
194
195    type Metadata = ProstMetadata<SequenceMetadata>;
196
197    type ArrayVTable = Self;
198    type CanonicalVTable = Self;
199    type OperationsVTable = Self;
200    type ValidityVTable = Self;
201    type VisitorVTable = Self;
202    type ComputeVTable = NotSupported;
203    type EncodeVTable = Self;
204
205    fn id(&self) -> ArrayId {
206        ArrayId::new_ref("vortex.sequence")
207    }
208
209    fn encoding(_array: &Self::Array) -> ArrayVTable {
210        SequenceVTable.as_vtable()
211    }
212
213    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
214        Ok(ProstMetadata(SequenceMetadata {
215            base: Some((&array.base()).into()),
216            multiplier: Some((&array.multiplier()).into()),
217        }))
218    }
219
220    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
221        Ok(Some(metadata.serialize()))
222    }
223
224    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
225        Ok(ProstMetadata(
226            <ProstMetadata<SequenceMetadata> as DeserializeMetadata>::deserialize(buffer)?,
227        ))
228    }
229
230    fn build(
231        &self,
232        dtype: &DType,
233        len: usize,
234        metadata: &Self::Metadata,
235        _buffers: &[BufferHandle],
236        _children: &dyn ArrayChildren,
237    ) -> VortexResult<SequenceArray> {
238        let ptype = dtype.as_ptype();
239
240        // We go via scalar to cast the scalar values into the correct PType
241        let base = Scalar::new(
242            DType::Primitive(ptype, NonNullable),
243            metadata
244                .0
245                .base
246                .as_ref()
247                .ok_or_else(|| vortex_err!("base required"))?
248                .try_into()?,
249        )
250        .as_primitive()
251        .pvalue()
252        .vortex_expect("non-nullable primitive");
253
254        let multiplier = Scalar::new(
255            DType::Primitive(ptype, NonNullable),
256            metadata
257                .0
258                .multiplier
259                .as_ref()
260                .ok_or_else(|| vortex_err!("base required"))?
261                .try_into()?,
262        )
263        .as_primitive()
264        .pvalue()
265        .vortex_expect("non-nullable primitive");
266
267        Ok(SequenceArray::unchecked_new(
268            base,
269            multiplier,
270            ptype,
271            dtype.nullability(),
272            len,
273        ))
274    }
275
276    fn with_children(_array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
277        vortex_ensure!(
278            children.is_empty(),
279            "SequenceArray expects 0 children, got {}",
280            children.len()
281        );
282        Ok(())
283    }
284
285    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
286        let array = array.clone();
287
288        Ok(match_each_native_ptype!(array.ptype(), |P| {
289            let base = array.base().cast::<P>();
290            let multiplier = array.multiplier().cast::<P>();
291
292            execute_iter(base, multiplier, 0..array.len(), array.len()).into()
293        }))
294    }
295
296    fn execute_parent(
297        array: &Self::Array,
298        parent: &ArrayRef,
299        _child_idx: usize,
300        _ctx: &mut ExecutionCtx,
301    ) -> VortexResult<Option<Vector>> {
302        // Special-case filtered execution.
303        let Some(filter) = parent.as_opt::<FilterVTable>() else {
304            return Ok(None);
305        };
306
307        match filter.filter_mask().indices() {
308            AllOr::All => Ok(None),
309            AllOr::None => Ok(Some(VectorMut::with_capacity(array.dtype(), 0).freeze())),
310            AllOr::Some(indices) => Ok(Some(match_each_native_ptype!(array.ptype(), |P| {
311                let base = array.base().cast::<P>();
312                let multiplier = array.multiplier().cast::<P>();
313                execute_iter(base, multiplier, indices.iter().copied(), indices.len()).into()
314            }))),
315        }
316    }
317}
318
319fn execute_iter<P: NativePType, I: Iterator<Item = usize>>(
320    base: P,
321    multiplier: P,
322    iter: I,
323    len: usize,
324) -> PVector<P> {
325    let values = if multiplier == <P>::one() {
326        BufferMut::from_iter(iter.map(|i| base + <P>::from_usize(i).vortex_expect("must fit")))
327    } else {
328        BufferMut::from_iter(
329            iter.map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
330        )
331    };
332
333    PVector::<P>::new(values.freeze(), Mask::new_true(len))
334}
335
336impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
337    fn len(array: &SequenceArray) -> usize {
338        array.length
339    }
340
341    fn dtype(array: &SequenceArray) -> &DType {
342        &array.dtype
343    }
344
345    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
346        array.stats_set.to_ref(array.as_ref())
347    }
348
349    fn array_hash<H: std::hash::Hasher>(
350        array: &SequenceArray,
351        state: &mut H,
352        _precision: Precision,
353    ) {
354        array.base.hash(state);
355        array.multiplier.hash(state);
356        array.dtype.hash(state);
357        array.length.hash(state);
358    }
359
360    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
361        array.base == other.base
362            && array.multiplier == other.multiplier
363            && array.dtype == other.dtype
364            && array.length == other.length
365    }
366}
367
368impl CanonicalVTable<SequenceVTable> for SequenceVTable {
369    fn canonicalize(array: &SequenceArray) -> Canonical {
370        let prim = match_each_native_ptype!(array.ptype(), |P| {
371            let base = array.base().cast::<P>();
372            let multiplier = array.multiplier().cast::<P>();
373            let values = BufferMut::from_iter(
374                (0..array.len())
375                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
376            );
377            PrimitiveArray::new(values, array.dtype.nullability().into())
378        });
379
380        Canonical::Primitive(prim)
381    }
382}
383
384impl OperationsVTable<SequenceVTable> for SequenceVTable {
385    fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
386        SequenceArray::unchecked_new(
387            array.index_value(range.start),
388            array.multiplier,
389            array.ptype(),
390            array.dtype().nullability(),
391            range.len(),
392        )
393        .to_array()
394    }
395
396    fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
397        Scalar::new(
398            array.dtype().clone(),
399            ScalarValue::from(array.index_value(index)),
400        )
401    }
402}
403
404impl ValidityVTable<SequenceVTable> for SequenceVTable {
405    fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
406        true
407    }
408
409    fn all_valid(_array: &SequenceArray) -> bool {
410        true
411    }
412
413    fn all_invalid(_array: &SequenceArray) -> bool {
414        false
415    }
416
417    fn validity(_array: &SequenceArray) -> VortexResult<Validity> {
418        Ok(Validity::AllValid)
419    }
420
421    fn validity_mask(array: &SequenceArray) -> Mask {
422        Mask::AllTrue(array.len())
423    }
424}
425
426impl VisitorVTable<SequenceVTable> for SequenceVTable {
427    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
428        // TODO(joe): expose scalar values
429    }
430
431    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
432}
433
434#[derive(Debug)]
435pub struct SequenceVTable;
436
437impl EncodeVTable<SequenceVTable> for SequenceVTable {
438    fn encode(
439        _vtable: &SequenceVTable,
440        _canonical: &Canonical,
441        _like: Option<&SequenceArray>,
442    ) -> VortexResult<Option<SequenceArray>> {
443        // TODO(joe): hook up compressor
444        Ok(None)
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use vortex_array::ToCanonical;
451    use vortex_array::arrays::PrimitiveArray;
452    use vortex_dtype::Nullability;
453    use vortex_scalar::Scalar;
454    use vortex_scalar::ScalarValue;
455
456    use crate::array::SequenceArray;
457
458    #[test]
459    fn test_sequence_canonical() {
460        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
461
462        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
463
464        assert_eq!(
465            arr.to_primitive().as_slice::<i64>(),
466            canon.as_slice::<i64>()
467        )
468    }
469
470    #[test]
471    fn test_sequence_slice_canonical() {
472        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
473            .unwrap()
474            .slice(2..3);
475
476        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
477
478        assert_eq!(
479            arr.to_primitive().as_slice::<i64>(),
480            canon.as_slice::<i64>()
481        )
482    }
483
484    #[test]
485    fn test_sequence_scalar_at() {
486        let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
487            .unwrap()
488            .scalar_at(2);
489
490        assert_eq!(
491            scalar,
492            Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
493        )
494    }
495
496    #[test]
497    fn test_sequence_min_max() {
498        assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
499        assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
500    }
501
502    #[test]
503    fn test_sequence_too_big() {
504        assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
505        assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
506    }
507}