vortex_sequence/
array.rs

1use num_traits::cast::FromPrimitive;
2use vortex_array::arrays::PrimitiveArray;
3use vortex_array::stats::{ArrayStats, StatsSetRef};
4use vortex_array::vtable::{
5    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
6    VisitorVTable,
7};
8use vortex_array::{
9    ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, vtable,
10};
11use vortex_dtype::Nullability::NonNullable;
12use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype, match_each_native_ptype};
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_mask::Mask;
15use vortex_scalar::{PValue, Scalar, ScalarValue};
16
17vtable!(Sequence);
18
19#[derive(Clone, Debug)]
20/// An array representing the equation `A[i] = base + i * multiplier`.
21pub struct SequenceArray {
22    base: PValue,
23    multiplier: PValue,
24    dtype: DType,
25    length: usize,
26    stats_set: ArrayStats,
27}
28
29impl SequenceArray {
30    pub fn typed_new<T: NativePType + Into<PValue>>(
31        base: T,
32        multiplier: T,
33        length: usize,
34    ) -> VortexResult<Self> {
35        Self::new(base.into(), multiplier.into(), T::PTYPE, length)
36    }
37
38    /// Constructs a sequence array using two integer values (with the same ptype).
39    pub fn new(
40        base: PValue,
41        multiplier: PValue,
42        ptype: PType,
43        length: usize,
44    ) -> VortexResult<Self> {
45        if !ptype.is_int() {
46            vortex_bail!("only integer ptype are supported in SequenceArray currently")
47        }
48
49        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
50            e.with_context(format!(
51                "final value not expressible, base = {:?}, multiplier = {:?}, len = {} ",
52                base, multiplier, length
53            ))
54        })?;
55
56        Ok(Self::unchecked_new(base, multiplier, ptype, length))
57    }
58
59    pub(crate) fn unchecked_new(
60        base: PValue,
61        multiplier: PValue,
62        ptype: PType,
63        length: usize,
64    ) -> Self {
65        let dtype = DType::Primitive(ptype, NonNullable);
66        Self {
67            base,
68            multiplier,
69            dtype,
70            length,
71            // TODO(joe): add stats, on construct or on use?
72            stats_set: Default::default(),
73        }
74    }
75
76    pub fn ptype(&self) -> PType {
77        self.dtype.as_ptype()
78    }
79
80    pub fn base(&self) -> PValue {
81        self.base
82    }
83
84    pub fn multiplier(&self) -> PValue {
85        self.multiplier
86    }
87
88    fn try_last(
89        base: PValue,
90        multiplier: PValue,
91        ptype: PType,
92        length: usize,
93    ) -> VortexResult<PValue> {
94        match_each_integer_ptype!(ptype, |P| {
95            let len_t = <P>::from_usize(length - 1)
96                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
97
98            let base = base.as_primitive::<P>()?;
99            let multiplier = multiplier.as_primitive::<P>()?;
100
101            let last = len_t
102                .checked_mul(multiplier)
103                .and_then(|offset| offset.checked_add(base))
104                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
105            Ok(PValue::from(last))
106        })
107    }
108
109    fn index_value(&self, idx: usize) -> VortexResult<PValue> {
110        if idx > self.length {
111            vortex_bail!("out of bounds")
112        }
113        match_each_native_ptype!(self.ptype(), |P| {
114            let base = self.base.as_primitive::<P>()?;
115            let multiplier = self.multiplier.as_primitive::<P>()?;
116            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
117
118            Ok(PValue::from(value))
119        })
120    }
121
122    /// Returns the validated final value of a sequence array
123    pub fn last(&self) -> PValue {
124        Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
125            .vortex_expect("validated array")
126    }
127}
128
129impl VTable for SequenceVTable {
130    type Array = SequenceArray;
131    type Encoding = SequenceEncoding;
132
133    type ArrayVTable = Self;
134    type CanonicalVTable = Self;
135    type OperationsVTable = Self;
136    type ValidityVTable = Self;
137    type VisitorVTable = Self;
138    type ComputeVTable = NotSupported;
139    type EncodeVTable = Self;
140    type SerdeVTable = Self;
141
142    fn id(_encoding: &Self::Encoding) -> EncodingId {
143        EncodingId::new_ref("vortex.sequence")
144    }
145
146    fn encoding(_array: &Self::Array) -> EncodingRef {
147        EncodingRef::new_ref(SequenceEncoding.as_ref())
148    }
149}
150
151impl ArrayVTable<SequenceVTable> for SequenceVTable {
152    fn len(array: &SequenceArray) -> usize {
153        array.length
154    }
155
156    fn dtype(array: &SequenceArray) -> &DType {
157        &array.dtype
158    }
159
160    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
161        array.stats_set.to_ref(array.as_ref())
162    }
163}
164
165impl CanonicalVTable<SequenceVTable> for SequenceVTable {
166    fn canonicalize(array: &SequenceArray) -> VortexResult<Canonical> {
167        let prim = match_each_native_ptype!(array.ptype(), |P| {
168            let base = array.base().as_primitive::<P>()?;
169            let multiplier = array.multiplier().as_primitive::<P>()?;
170            PrimitiveArray::from_iter(
171                (0..array.len())
172                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
173            )
174        });
175
176        Ok(Canonical::Primitive(prim))
177    }
178}
179
180impl OperationsVTable<SequenceVTable> for SequenceVTable {
181    fn slice(array: &SequenceArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
182        Ok(SequenceArray::unchecked_new(
183            array.index_value(start)?,
184            array.multiplier,
185            array.ptype(),
186            stop - start,
187        )
188        .to_array())
189    }
190
191    fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
192        // Ok(Scalar::from(array.index_value(index)))
193        Ok(Scalar::new(
194            array.dtype().clone(),
195            ScalarValue::from(array.index_value(index)?),
196        ))
197    }
198}
199
200impl ValidityVTable<SequenceVTable> for SequenceVTable {
201    fn is_valid(_array: &SequenceArray, _index: usize) -> VortexResult<bool> {
202        Ok(true)
203    }
204
205    fn all_valid(_array: &SequenceArray) -> VortexResult<bool> {
206        Ok(true)
207    }
208
209    fn all_invalid(_array: &SequenceArray) -> VortexResult<bool> {
210        Ok(false)
211    }
212
213    fn validity_mask(array: &SequenceArray) -> VortexResult<Mask> {
214        Ok(Mask::AllTrue(array.len()))
215    }
216}
217
218impl VisitorVTable<SequenceVTable> for SequenceVTable {
219    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
220        // TODO(joe): expose scalar values
221    }
222
223    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
224}
225
226#[derive(Clone, Debug)]
227pub struct SequenceEncoding;
228
229#[cfg(test)]
230mod tests {
231    use vortex_array::arrays::PrimitiveArray;
232    use vortex_scalar::{Scalar, ScalarValue};
233
234    use crate::array::SequenceArray;
235
236    #[test]
237    fn test_sequence_canonical() {
238        let arr = SequenceArray::typed_new(2i64, 3, 4).unwrap();
239
240        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
241
242        assert_eq!(
243            arr.to_canonical()
244                .unwrap()
245                .into_primitive()
246                .unwrap()
247                .as_slice::<i64>(),
248            canon.as_slice::<i64>()
249        )
250    }
251
252    #[test]
253    fn test_sequence_slice_canonical() {
254        let arr = SequenceArray::typed_new(2i64, 3, 4)
255            .unwrap()
256            .slice(2, 3)
257            .unwrap();
258
259        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
260
261        assert_eq!(
262            arr.to_canonical()
263                .unwrap()
264                .into_primitive()
265                .unwrap()
266                .as_slice::<i64>(),
267            canon.as_slice::<i64>()
268        )
269    }
270
271    #[test]
272    fn test_sequence_scalar_at() {
273        let scalar = SequenceArray::typed_new(2i64, 3, 4)
274            .unwrap()
275            .scalar_at(2)
276            .unwrap();
277
278        assert_eq!(
279            scalar,
280            Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
281        )
282    }
283
284    #[test]
285    fn test_sequence_min_max() {
286        assert!(SequenceArray::typed_new(-127i8, -1i8, 2).is_ok());
287        assert!(SequenceArray::typed_new(126i8, -1i8, 2).is_ok());
288    }
289
290    #[test]
291    fn test_sequence_too_big() {
292        assert!(SequenceArray::typed_new(127i8, 1i8, 2).is_err());
293        assert!(SequenceArray::typed_new(-128i8, -1i8, 2).is_err());
294    }
295}