vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::cast::FromPrimitive;
5use vortex_array::arrays::PrimitiveArray;
6use vortex_array::stats::{ArrayStats, StatsSetRef};
7use vortex_array::vtable::{
8    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
9    VisitorVTable,
10};
11use vortex_array::{
12    ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, vtable,
13};
14use vortex_buffer::BufferMut;
15use vortex_dtype::{
16    DType, NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
17};
18use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
19use vortex_mask::Mask;
20use vortex_scalar::{PValue, Scalar, ScalarValue};
21
22vtable!(Sequence);
23
24#[derive(Clone, Debug)]
25/// An array representing the equation `A[i] = base + i * multiplier`.
26pub struct SequenceArray {
27    base: PValue,
28    multiplier: PValue,
29    dtype: DType,
30    length: usize,
31    stats_set: ArrayStats,
32}
33
34impl SequenceArray {
35    pub fn typed_new<T: NativePType + Into<PValue>>(
36        base: T,
37        multiplier: T,
38        nullability: Nullability,
39        length: usize,
40    ) -> VortexResult<Self> {
41        Self::new(
42            base.into(),
43            multiplier.into(),
44            T::PTYPE,
45            nullability,
46            length,
47        )
48    }
49
50    /// Constructs a sequence array using two integer values (with the same ptype).
51    pub fn new(
52        base: PValue,
53        multiplier: PValue,
54        ptype: PType,
55        nullability: Nullability,
56        length: usize,
57    ) -> VortexResult<Self> {
58        if !ptype.is_int() {
59            vortex_bail!("only integer ptype are supported in SequenceArray currently")
60        }
61
62        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
63            e.with_context(format!(
64                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
65            ))
66        })?;
67
68        Ok(Self::unchecked_new(
69            base,
70            multiplier,
71            ptype,
72            nullability,
73            length,
74        ))
75    }
76
77    pub(crate) fn unchecked_new(
78        base: PValue,
79        multiplier: PValue,
80        ptype: PType,
81        nullability: Nullability,
82        length: usize,
83    ) -> Self {
84        let dtype = DType::Primitive(ptype, nullability);
85        Self {
86            base,
87            multiplier,
88            dtype,
89            length,
90            // TODO(joe): add stats, on construct or on use?
91            stats_set: Default::default(),
92        }
93    }
94
95    pub fn ptype(&self) -> PType {
96        self.dtype.as_ptype()
97    }
98
99    pub fn base(&self) -> PValue {
100        self.base
101    }
102
103    pub fn multiplier(&self) -> PValue {
104        self.multiplier
105    }
106
107    pub(crate) fn try_last(
108        base: PValue,
109        multiplier: PValue,
110        ptype: PType,
111        length: usize,
112    ) -> VortexResult<PValue> {
113        match_each_integer_ptype!(ptype, |P| {
114            let len_t = <P>::from_usize(length - 1)
115                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
116
117            let base = base.as_primitive::<P>()?;
118            let multiplier = multiplier.as_primitive::<P>()?;
119
120            let last = len_t
121                .checked_mul(multiplier)
122                .and_then(|offset| offset.checked_add(base))
123                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
124            Ok(PValue::from(last))
125        })
126    }
127
128    fn index_value(&self, idx: usize) -> VortexResult<PValue> {
129        if idx > self.length {
130            vortex_bail!("out of bounds")
131        }
132        match_each_native_ptype!(self.ptype(), |P| {
133            let base = self.base.as_primitive::<P>()?;
134            let multiplier = self.multiplier.as_primitive::<P>()?;
135            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
136
137            Ok(PValue::from(value))
138        })
139    }
140
141    /// Returns the validated final value of a sequence array
142    pub fn last(&self) -> PValue {
143        Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
144            .vortex_expect("validated array")
145    }
146}
147
148impl VTable for SequenceVTable {
149    type Array = SequenceArray;
150    type Encoding = SequenceEncoding;
151
152    type ArrayVTable = Self;
153    type CanonicalVTable = Self;
154    type OperationsVTable = Self;
155    type ValidityVTable = Self;
156    type VisitorVTable = Self;
157    type ComputeVTable = NotSupported;
158    type EncodeVTable = Self;
159    type SerdeVTable = Self;
160
161    fn id(_encoding: &Self::Encoding) -> EncodingId {
162        EncodingId::new_ref("vortex.sequence")
163    }
164
165    fn encoding(_array: &Self::Array) -> EncodingRef {
166        EncodingRef::new_ref(SequenceEncoding.as_ref())
167    }
168}
169
170impl ArrayVTable<SequenceVTable> for SequenceVTable {
171    fn len(array: &SequenceArray) -> usize {
172        array.length
173    }
174
175    fn dtype(array: &SequenceArray) -> &DType {
176        &array.dtype
177    }
178
179    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
180        array.stats_set.to_ref(array.as_ref())
181    }
182}
183
184impl CanonicalVTable<SequenceVTable> for SequenceVTable {
185    fn canonicalize(array: &SequenceArray) -> VortexResult<Canonical> {
186        let prim = match_each_native_ptype!(array.ptype(), |P| {
187            let base = array.base().as_primitive::<P>()?;
188            let multiplier = array.multiplier().as_primitive::<P>()?;
189            let values = BufferMut::from_iter(
190                (0..array.len())
191                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
192            );
193            PrimitiveArray::new(values, array.dtype.nullability().into())
194        });
195
196        Ok(Canonical::Primitive(prim))
197    }
198}
199
200impl OperationsVTable<SequenceVTable> for SequenceVTable {
201    fn slice(array: &SequenceArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
202        Ok(SequenceArray::unchecked_new(
203            array.index_value(start)?,
204            array.multiplier,
205            array.ptype(),
206            array.dtype().nullability(),
207            stop - start,
208        )
209        .to_array())
210    }
211
212    fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
213        // Ok(Scalar::from(array.index_value(index)))
214        Ok(Scalar::new(
215            array.dtype().clone(),
216            ScalarValue::from(array.index_value(index)?),
217        ))
218    }
219}
220
221impl ValidityVTable<SequenceVTable> for SequenceVTable {
222    fn is_valid(_array: &SequenceArray, _index: usize) -> VortexResult<bool> {
223        Ok(true)
224    }
225
226    fn all_valid(_array: &SequenceArray) -> VortexResult<bool> {
227        Ok(true)
228    }
229
230    fn all_invalid(_array: &SequenceArray) -> VortexResult<bool> {
231        Ok(false)
232    }
233
234    fn validity_mask(array: &SequenceArray) -> VortexResult<Mask> {
235        Ok(Mask::AllTrue(array.len()))
236    }
237}
238
239impl VisitorVTable<SequenceVTable> for SequenceVTable {
240    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
241        // TODO(joe): expose scalar values
242    }
243
244    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
245}
246
247#[derive(Clone, Debug)]
248pub struct SequenceEncoding;
249
250#[cfg(test)]
251mod tests {
252    use vortex_array::arrays::PrimitiveArray;
253    use vortex_dtype::Nullability;
254    use vortex_scalar::{Scalar, ScalarValue};
255
256    use crate::array::SequenceArray;
257
258    #[test]
259    fn test_sequence_canonical() {
260        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
261
262        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
263
264        assert_eq!(
265            arr.to_canonical()
266                .unwrap()
267                .into_primitive()
268                .unwrap()
269                .as_slice::<i64>(),
270            canon.as_slice::<i64>()
271        )
272    }
273
274    #[test]
275    fn test_sequence_slice_canonical() {
276        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
277            .unwrap()
278            .slice(2, 3)
279            .unwrap();
280
281        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
282
283        assert_eq!(
284            arr.to_canonical()
285                .unwrap()
286                .into_primitive()
287                .unwrap()
288                .as_slice::<i64>(),
289            canon.as_slice::<i64>()
290        )
291    }
292
293    #[test]
294    fn test_sequence_scalar_at() {
295        let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
296            .unwrap()
297            .scalar_at(2)
298            .unwrap();
299
300        assert_eq!(
301            scalar,
302            Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
303        )
304    }
305
306    #[test]
307    fn test_sequence_min_max() {
308        assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
309        assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
310    }
311
312    #[test]
313    fn test_sequence_too_big() {
314        assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
315        assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
316    }
317}