vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::ops::Range;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::arrays::PrimitiveArray;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{
10    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
11    VisitorVTable,
12};
13use vortex_array::{
14    ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, vtable,
15};
16use vortex_buffer::BufferMut;
17use vortex_dtype::{
18    DType, NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
19};
20use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
21use vortex_mask::Mask;
22use vortex_scalar::{PValue, Scalar, ScalarValue};
23
24vtable!(Sequence);
25
26#[derive(Clone, Debug)]
27/// An array representing the equation `A[i] = base + i * multiplier`.
28pub struct SequenceArray {
29    base: PValue,
30    multiplier: PValue,
31    dtype: DType,
32    length: usize,
33    stats_set: ArrayStats,
34}
35
36impl SequenceArray {
37    pub fn typed_new<T: NativePType + Into<PValue>>(
38        base: T,
39        multiplier: T,
40        nullability: Nullability,
41        length: usize,
42    ) -> VortexResult<Self> {
43        Self::new(
44            base.into(),
45            multiplier.into(),
46            T::PTYPE,
47            nullability,
48            length,
49        )
50    }
51
52    /// Constructs a sequence array using two integer values (with the same ptype).
53    pub fn new(
54        base: PValue,
55        multiplier: PValue,
56        ptype: PType,
57        nullability: Nullability,
58        length: usize,
59    ) -> VortexResult<Self> {
60        if !ptype.is_int() {
61            vortex_bail!("only integer ptype are supported in SequenceArray currently")
62        }
63
64        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
65            e.with_context(format!(
66                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
67            ))
68        })?;
69
70        Ok(Self::unchecked_new(
71            base,
72            multiplier,
73            ptype,
74            nullability,
75            length,
76        ))
77    }
78
79    pub(crate) fn unchecked_new(
80        base: PValue,
81        multiplier: PValue,
82        ptype: PType,
83        nullability: Nullability,
84        length: usize,
85    ) -> Self {
86        let dtype = DType::Primitive(ptype, nullability);
87        Self {
88            base,
89            multiplier,
90            dtype,
91            length,
92            // TODO(joe): add stats, on construct or on use?
93            stats_set: Default::default(),
94        }
95    }
96
97    pub fn ptype(&self) -> PType {
98        self.dtype.as_ptype()
99    }
100
101    pub fn base(&self) -> PValue {
102        self.base
103    }
104
105    pub fn multiplier(&self) -> PValue {
106        self.multiplier
107    }
108
109    pub(crate) fn try_last(
110        base: PValue,
111        multiplier: PValue,
112        ptype: PType,
113        length: usize,
114    ) -> VortexResult<PValue> {
115        match_each_integer_ptype!(ptype, |P| {
116            let len_t = <P>::from_usize(length - 1)
117                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
118
119            let base = base.as_primitive::<P>();
120            let multiplier = multiplier.as_primitive::<P>();
121
122            let last = len_t
123                .checked_mul(multiplier)
124                .and_then(|offset| offset.checked_add(base))
125                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
126            Ok(PValue::from(last))
127        })
128    }
129
130    fn index_value(&self, idx: usize) -> PValue {
131        assert!(idx < self.length, "index_value({idx}): index out of bounds");
132
133        match_each_native_ptype!(self.ptype(), |P| {
134            let base = self.base.as_primitive::<P>();
135            let multiplier = self.multiplier.as_primitive::<P>();
136            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
137
138            PValue::from(value)
139        })
140    }
141
142    /// Returns the validated final value of a sequence array
143    pub fn last(&self) -> PValue {
144        Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
145            .vortex_expect("validated array")
146    }
147
148    pub fn dtype(&self) -> &DType {
149        &self.dtype
150    }
151}
152
153impl VTable for SequenceVTable {
154    type Array = SequenceArray;
155    type Encoding = SequenceEncoding;
156
157    type ArrayVTable = Self;
158    type CanonicalVTable = Self;
159    type OperationsVTable = Self;
160    type ValidityVTable = Self;
161    type VisitorVTable = Self;
162    type ComputeVTable = NotSupported;
163    type EncodeVTable = Self;
164    type SerdeVTable = Self;
165    type PipelineVTable = NotSupported;
166
167    fn id(_encoding: &Self::Encoding) -> EncodingId {
168        EncodingId::new_ref("vortex.sequence")
169    }
170
171    fn encoding(_array: &Self::Array) -> EncodingRef {
172        EncodingRef::new_ref(SequenceEncoding.as_ref())
173    }
174}
175
176impl ArrayVTable<SequenceVTable> for SequenceVTable {
177    fn len(array: &SequenceArray) -> usize {
178        array.length
179    }
180
181    fn dtype(array: &SequenceArray) -> &DType {
182        &array.dtype
183    }
184
185    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
186        array.stats_set.to_ref(array.as_ref())
187    }
188}
189
190impl CanonicalVTable<SequenceVTable> for SequenceVTable {
191    fn canonicalize(array: &SequenceArray) -> VortexResult<Canonical> {
192        let prim = match_each_native_ptype!(array.ptype(), |P| {
193            let base = array.base().as_primitive::<P>();
194            let multiplier = array.multiplier().as_primitive::<P>();
195            let values = BufferMut::from_iter(
196                (0..array.len())
197                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
198            );
199            PrimitiveArray::new(values, array.dtype.nullability().into())
200        });
201
202        Ok(Canonical::Primitive(prim))
203    }
204}
205
206impl OperationsVTable<SequenceVTable> for SequenceVTable {
207    fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
208        SequenceArray::unchecked_new(
209            array.index_value(range.start),
210            array.multiplier,
211            array.ptype(),
212            array.dtype().nullability(),
213            range.len(),
214        )
215        .to_array()
216    }
217
218    fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
219        Scalar::new(
220            array.dtype().clone(),
221            ScalarValue::from(array.index_value(index)),
222        )
223    }
224}
225
226impl ValidityVTable<SequenceVTable> for SequenceVTable {
227    fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
228        true
229    }
230
231    fn all_valid(_array: &SequenceArray) -> bool {
232        true
233    }
234
235    fn all_invalid(_array: &SequenceArray) -> bool {
236        false
237    }
238
239    fn validity_mask(array: &SequenceArray) -> Mask {
240        Mask::AllTrue(array.len())
241    }
242}
243
244impl VisitorVTable<SequenceVTable> for SequenceVTable {
245    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
246        // TODO(joe): expose scalar values
247    }
248
249    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
250}
251
252#[derive(Clone, Debug)]
253pub struct SequenceEncoding;
254
255#[cfg(test)]
256mod tests {
257    use vortex_array::arrays::PrimitiveArray;
258    use vortex_dtype::Nullability;
259    use vortex_scalar::{Scalar, ScalarValue};
260
261    use crate::array::SequenceArray;
262
263    #[test]
264    fn test_sequence_canonical() {
265        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
266
267        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
268
269        assert_eq!(
270            arr.to_canonical()
271                .unwrap()
272                .into_primitive()
273                .unwrap()
274                .as_slice::<i64>(),
275            canon.as_slice::<i64>()
276        )
277    }
278
279    #[test]
280    fn test_sequence_slice_canonical() {
281        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
282            .unwrap()
283            .slice(2..3);
284
285        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
286
287        assert_eq!(
288            arr.to_canonical()
289                .unwrap()
290                .into_primitive()
291                .unwrap()
292                .as_slice::<i64>(),
293            canon.as_slice::<i64>()
294        )
295    }
296
297    #[test]
298    fn test_sequence_scalar_at() {
299        let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
300            .unwrap()
301            .scalar_at(2);
302
303        assert_eq!(
304            scalar,
305            Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
306        )
307    }
308
309    #[test]
310    fn test_sequence_min_max() {
311        assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
312        assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
313    }
314
315    #[test]
316    fn test_sequence_too_big() {
317        assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
318        assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
319    }
320}