vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::ops::Range;
6
7use num_traits::One;
8use num_traits::cast::FromPrimitive;
9use vortex_array::ArrayBufferVisitor;
10use vortex_array::ArrayChildVisitor;
11use vortex_array::ArrayRef;
12use vortex_array::Canonical;
13use vortex_array::DeserializeMetadata;
14use vortex_array::Precision;
15use vortex_array::ProstMetadata;
16use vortex_array::SerializeMetadata;
17use vortex_array::arrays::PrimitiveArray;
18use vortex_array::execution::ExecutionCtx;
19use vortex_array::serde::ArrayChildren;
20use vortex_array::stats::ArrayStats;
21use vortex_array::stats::StatsSetRef;
22use vortex_array::vtable;
23use vortex_array::vtable::ArrayId;
24use vortex_array::vtable::ArrayVTable;
25use vortex_array::vtable::ArrayVTableExt;
26use vortex_array::vtable::BaseArrayVTable;
27use vortex_array::vtable::CanonicalVTable;
28use vortex_array::vtable::EncodeVTable;
29use vortex_array::vtable::NotSupported;
30use vortex_array::vtable::OperationsVTable;
31use vortex_array::vtable::VTable;
32use vortex_array::vtable::ValidityVTable;
33use vortex_array::vtable::VisitorVTable;
34use vortex_buffer::BufferHandle;
35use vortex_buffer::BufferMut;
36use vortex_dtype::DType;
37use vortex_dtype::NativePType;
38use vortex_dtype::Nullability;
39use vortex_dtype::Nullability::NonNullable;
40use vortex_dtype::PType;
41use vortex_dtype::match_each_integer_ptype;
42use vortex_dtype::match_each_native_ptype;
43use vortex_error::VortexExpect;
44use vortex_error::VortexResult;
45use vortex_error::vortex_bail;
46use vortex_error::vortex_err;
47use vortex_mask::Mask;
48use vortex_scalar::PValue;
49use vortex_scalar::Scalar;
50use vortex_scalar::ScalarValue;
51use vortex_vector::Vector;
52use vortex_vector::primitive::PVector;
53
54vtable!(Sequence);
55
56#[derive(Clone, prost::Message)]
57pub struct SequenceMetadata {
58    #[prost(message, tag = "1")]
59    base: Option<vortex_proto::scalar::ScalarValue>,
60    #[prost(message, tag = "2")]
61    multiplier: Option<vortex_proto::scalar::ScalarValue>,
62}
63
64#[derive(Clone, Debug)]
65/// An array representing the equation `A[i] = base + i * multiplier`.
66pub struct SequenceArray {
67    base: PValue,
68    multiplier: PValue,
69    dtype: DType,
70    pub(crate) length: usize,
71    stats_set: ArrayStats,
72}
73
74impl SequenceArray {
75    pub fn typed_new<T: NativePType + Into<PValue>>(
76        base: T,
77        multiplier: T,
78        nullability: Nullability,
79        length: usize,
80    ) -> VortexResult<Self> {
81        Self::new(
82            base.into(),
83            multiplier.into(),
84            T::PTYPE,
85            nullability,
86            length,
87        )
88    }
89
90    /// Constructs a sequence array using two integer values (with the same ptype).
91    pub fn new(
92        base: PValue,
93        multiplier: PValue,
94        ptype: PType,
95        nullability: Nullability,
96        length: usize,
97    ) -> VortexResult<Self> {
98        if !ptype.is_int() {
99            vortex_bail!("only integer ptype are supported in SequenceArray currently")
100        }
101
102        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
103            e.with_context(format!(
104                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
105            ))
106        })?;
107
108        Ok(Self::unchecked_new(
109            base,
110            multiplier,
111            ptype,
112            nullability,
113            length,
114        ))
115    }
116
117    pub(crate) fn unchecked_new(
118        base: PValue,
119        multiplier: PValue,
120        ptype: PType,
121        nullability: Nullability,
122        length: usize,
123    ) -> Self {
124        let dtype = DType::Primitive(ptype, nullability);
125        Self {
126            base,
127            multiplier,
128            dtype,
129            length,
130            // TODO(joe): add stats, on construct or on use?
131            stats_set: Default::default(),
132        }
133    }
134
135    pub fn ptype(&self) -> PType {
136        self.dtype.as_ptype()
137    }
138
139    pub fn base(&self) -> PValue {
140        self.base
141    }
142
143    pub fn multiplier(&self) -> PValue {
144        self.multiplier
145    }
146
147    pub(crate) fn try_last(
148        base: PValue,
149        multiplier: PValue,
150        ptype: PType,
151        length: usize,
152    ) -> VortexResult<PValue> {
153        match_each_integer_ptype!(ptype, |P| {
154            let len_t = <P>::from_usize(length - 1)
155                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
156
157            let base = base.cast::<P>();
158            let multiplier = multiplier.cast::<P>();
159
160            let last = len_t
161                .checked_mul(multiplier)
162                .and_then(|offset| offset.checked_add(base))
163                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
164            Ok(PValue::from(last))
165        })
166    }
167
168    pub(crate) fn index_value(&self, idx: usize) -> PValue {
169        assert!(idx < self.length, "index_value({idx}): index out of bounds");
170
171        match_each_native_ptype!(self.ptype(), |P| {
172            let base = self.base.cast::<P>();
173            let multiplier = self.multiplier.cast::<P>();
174            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
175
176            PValue::from(value)
177        })
178    }
179
180    /// Returns the validated final value of a sequence array
181    pub fn last(&self) -> PValue {
182        Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
183            .vortex_expect("validated array")
184    }
185}
186
187impl VTable for SequenceVTable {
188    type Array = SequenceArray;
189
190    type Metadata = ProstMetadata<SequenceMetadata>;
191
192    type ArrayVTable = Self;
193    type CanonicalVTable = Self;
194    type OperationsVTable = Self;
195    type ValidityVTable = Self;
196    type VisitorVTable = Self;
197    type ComputeVTable = NotSupported;
198    type EncodeVTable = Self;
199
200    fn id(&self) -> ArrayId {
201        ArrayId::new_ref("vortex.sequence")
202    }
203
204    fn encoding(_array: &Self::Array) -> ArrayVTable {
205        SequenceVTable.as_vtable()
206    }
207
208    fn metadata(array: &SequenceArray) -> VortexResult<Self::Metadata> {
209        Ok(ProstMetadata(SequenceMetadata {
210            base: Some((&array.base()).into()),
211            multiplier: Some((&array.multiplier()).into()),
212        }))
213    }
214
215    fn serialize(metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
216        Ok(Some(metadata.serialize()))
217    }
218
219    fn deserialize(buffer: &[u8]) -> VortexResult<Self::Metadata> {
220        Ok(ProstMetadata(
221            <ProstMetadata<SequenceMetadata> as DeserializeMetadata>::deserialize(buffer)?,
222        ))
223    }
224
225    fn build(
226        &self,
227        dtype: &DType,
228        len: usize,
229        metadata: &Self::Metadata,
230        _buffers: &[BufferHandle],
231        _children: &dyn ArrayChildren,
232    ) -> VortexResult<SequenceArray> {
233        let ptype = dtype.as_ptype();
234
235        // We go via scalar to cast the scalar values into the correct PType
236        let base = Scalar::new(
237            DType::Primitive(ptype, NonNullable),
238            metadata
239                .0
240                .base
241                .as_ref()
242                .ok_or_else(|| vortex_err!("base required"))?
243                .try_into()?,
244        )
245        .as_primitive()
246        .pvalue()
247        .vortex_expect("non-nullable primitive");
248
249        let multiplier = Scalar::new(
250            DType::Primitive(ptype, NonNullable),
251            metadata
252                .0
253                .multiplier
254                .as_ref()
255                .ok_or_else(|| vortex_err!("base required"))?
256                .try_into()?,
257        )
258        .as_primitive()
259        .pvalue()
260        .vortex_expect("non-nullable primitive");
261
262        Ok(SequenceArray::unchecked_new(
263            base,
264            multiplier,
265            ptype,
266            dtype.nullability(),
267            len,
268        ))
269    }
270
271    fn batch_execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<Vector> {
272        Ok(match_each_native_ptype!(array.ptype(), |P| {
273            let base = array.base().cast::<P>();
274            let multiplier = array.multiplier().cast::<P>();
275
276            let values = if multiplier == <P>::one() {
277                BufferMut::from_iter(
278                    (0..array.len()).map(|i| base + <P>::from_usize(i).vortex_expect("must fit")),
279                )
280            } else {
281                BufferMut::from_iter(
282                    (0..array.len())
283                        .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
284                )
285            };
286
287            PVector::<P>::new(values.freeze(), Mask::new_true(array.len())).into()
288        }))
289    }
290}
291
292impl BaseArrayVTable<SequenceVTable> for SequenceVTable {
293    fn len(array: &SequenceArray) -> usize {
294        array.length
295    }
296
297    fn dtype(array: &SequenceArray) -> &DType {
298        &array.dtype
299    }
300
301    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
302        array.stats_set.to_ref(array.as_ref())
303    }
304
305    fn array_hash<H: std::hash::Hasher>(
306        array: &SequenceArray,
307        state: &mut H,
308        _precision: Precision,
309    ) {
310        array.base.hash(state);
311        array.multiplier.hash(state);
312        array.dtype.hash(state);
313        array.length.hash(state);
314    }
315
316    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
317        array.base == other.base
318            && array.multiplier == other.multiplier
319            && array.dtype == other.dtype
320            && array.length == other.length
321    }
322}
323
324impl CanonicalVTable<SequenceVTable> for SequenceVTable {
325    fn canonicalize(array: &SequenceArray) -> Canonical {
326        let prim = match_each_native_ptype!(array.ptype(), |P| {
327            let base = array.base().cast::<P>();
328            let multiplier = array.multiplier().cast::<P>();
329            let values = BufferMut::from_iter(
330                (0..array.len())
331                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
332            );
333            PrimitiveArray::new(values, array.dtype.nullability().into())
334        });
335
336        Canonical::Primitive(prim)
337    }
338}
339
340impl OperationsVTable<SequenceVTable> for SequenceVTable {
341    fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
342        SequenceArray::unchecked_new(
343            array.index_value(range.start),
344            array.multiplier,
345            array.ptype(),
346            array.dtype().nullability(),
347            range.len(),
348        )
349        .to_array()
350    }
351
352    fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
353        Scalar::new(
354            array.dtype().clone(),
355            ScalarValue::from(array.index_value(index)),
356        )
357    }
358}
359
360impl ValidityVTable<SequenceVTable> for SequenceVTable {
361    fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
362        true
363    }
364
365    fn all_valid(_array: &SequenceArray) -> bool {
366        true
367    }
368
369    fn all_invalid(_array: &SequenceArray) -> bool {
370        false
371    }
372
373    fn validity_mask(array: &SequenceArray) -> Mask {
374        Mask::AllTrue(array.len())
375    }
376}
377
378impl VisitorVTable<SequenceVTable> for SequenceVTable {
379    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
380        // TODO(joe): expose scalar values
381    }
382
383    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
384}
385
386#[derive(Debug)]
387pub struct SequenceVTable;
388
389impl EncodeVTable<SequenceVTable> for SequenceVTable {
390    fn encode(
391        _vtable: &SequenceVTable,
392        _canonical: &Canonical,
393        _like: Option<&SequenceArray>,
394    ) -> VortexResult<Option<SequenceArray>> {
395        // TODO(joe): hook up compressor
396        Ok(None)
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use vortex_array::ToCanonical;
403    use vortex_array::arrays::PrimitiveArray;
404    use vortex_dtype::Nullability;
405    use vortex_scalar::Scalar;
406    use vortex_scalar::ScalarValue;
407
408    use crate::array::SequenceArray;
409
410    #[test]
411    fn test_sequence_canonical() {
412        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
413
414        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
415
416        assert_eq!(
417            arr.to_primitive().as_slice::<i64>(),
418            canon.as_slice::<i64>()
419        )
420    }
421
422    #[test]
423    fn test_sequence_slice_canonical() {
424        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
425            .unwrap()
426            .slice(2..3);
427
428        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
429
430        assert_eq!(
431            arr.to_primitive().as_slice::<i64>(),
432            canon.as_slice::<i64>()
433        )
434    }
435
436    #[test]
437    fn test_sequence_scalar_at() {
438        let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
439            .unwrap()
440            .scalar_at(2);
441
442        assert_eq!(
443            scalar,
444            Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
445        )
446    }
447
448    #[test]
449    fn test_sequence_min_max() {
450        assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
451        assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
452    }
453
454    #[test]
455    fn test_sequence_too_big() {
456        assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
457        assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
458    }
459}