vortex_sequence/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5use std::ops::Range;
6
7use num_traits::cast::FromPrimitive;
8use vortex_array::arrays::PrimitiveArray;
9use vortex_array::stats::{ArrayStats, StatsSetRef};
10use vortex_array::vtable::{
11    ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
12    VisitorVTable,
13};
14use vortex_array::{
15    ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, Precision,
16    vtable,
17};
18use vortex_buffer::BufferMut;
19use vortex_dtype::{
20    DType, NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
21};
22use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
23use vortex_mask::Mask;
24use vortex_scalar::{PValue, Scalar, ScalarValue};
25
26vtable!(Sequence);
27
28#[derive(Clone, Debug)]
29/// An array representing the equation `A[i] = base + i * multiplier`.
30pub struct SequenceArray {
31    base: PValue,
32    multiplier: PValue,
33    dtype: DType,
34    pub(crate) length: usize,
35    stats_set: ArrayStats,
36}
37
38impl SequenceArray {
39    pub fn typed_new<T: NativePType + Into<PValue>>(
40        base: T,
41        multiplier: T,
42        nullability: Nullability,
43        length: usize,
44    ) -> VortexResult<Self> {
45        Self::new(
46            base.into(),
47            multiplier.into(),
48            T::PTYPE,
49            nullability,
50            length,
51        )
52    }
53
54    /// Constructs a sequence array using two integer values (with the same ptype).
55    pub fn new(
56        base: PValue,
57        multiplier: PValue,
58        ptype: PType,
59        nullability: Nullability,
60        length: usize,
61    ) -> VortexResult<Self> {
62        if !ptype.is_int() {
63            vortex_bail!("only integer ptype are supported in SequenceArray currently")
64        }
65
66        Self::try_last(base, multiplier, ptype, length).map_err(|e| {
67            e.with_context(format!(
68                "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
69            ))
70        })?;
71
72        Ok(Self::unchecked_new(
73            base,
74            multiplier,
75            ptype,
76            nullability,
77            length,
78        ))
79    }
80
81    pub(crate) fn unchecked_new(
82        base: PValue,
83        multiplier: PValue,
84        ptype: PType,
85        nullability: Nullability,
86        length: usize,
87    ) -> Self {
88        let dtype = DType::Primitive(ptype, nullability);
89        Self {
90            base,
91            multiplier,
92            dtype,
93            length,
94            // TODO(joe): add stats, on construct or on use?
95            stats_set: Default::default(),
96        }
97    }
98
99    pub fn ptype(&self) -> PType {
100        self.dtype.as_ptype()
101    }
102
103    pub fn base(&self) -> PValue {
104        self.base
105    }
106
107    pub fn multiplier(&self) -> PValue {
108        self.multiplier
109    }
110
111    pub(crate) fn try_last(
112        base: PValue,
113        multiplier: PValue,
114        ptype: PType,
115        length: usize,
116    ) -> VortexResult<PValue> {
117        match_each_integer_ptype!(ptype, |P| {
118            let len_t = <P>::from_usize(length - 1)
119                .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
120
121            let base = base.cast::<P>();
122            let multiplier = multiplier.cast::<P>();
123
124            let last = len_t
125                .checked_mul(multiplier)
126                .and_then(|offset| offset.checked_add(base))
127                .ok_or_else(|| vortex_err!("last value computation overflows"))?;
128            Ok(PValue::from(last))
129        })
130    }
131
132    pub(crate) fn index_value(&self, idx: usize) -> PValue {
133        assert!(idx < self.length, "index_value({idx}): index out of bounds");
134
135        match_each_native_ptype!(self.ptype(), |P| {
136            let base = self.base.cast::<P>();
137            let multiplier = self.multiplier.cast::<P>();
138            let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
139
140            PValue::from(value)
141        })
142    }
143
144    /// Returns the validated final value of a sequence array
145    pub fn last(&self) -> PValue {
146        Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
147            .vortex_expect("validated array")
148    }
149}
150
151impl VTable for SequenceVTable {
152    type Array = SequenceArray;
153    type Encoding = SequenceEncoding;
154
155    type ArrayVTable = Self;
156    type CanonicalVTable = Self;
157    type OperationsVTable = Self;
158    type ValidityVTable = Self;
159    type VisitorVTable = Self;
160    type ComputeVTable = NotSupported;
161    type EncodeVTable = Self;
162    type SerdeVTable = Self;
163    type OperatorVTable = Self;
164
165    fn id(_encoding: &Self::Encoding) -> EncodingId {
166        EncodingId::new_ref("vortex.sequence")
167    }
168
169    fn encoding(_array: &Self::Array) -> EncodingRef {
170        EncodingRef::new_ref(SequenceEncoding.as_ref())
171    }
172}
173
174impl ArrayVTable<SequenceVTable> for SequenceVTable {
175    fn len(array: &SequenceArray) -> usize {
176        array.length
177    }
178
179    fn dtype(array: &SequenceArray) -> &DType {
180        &array.dtype
181    }
182
183    fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
184        array.stats_set.to_ref(array.as_ref())
185    }
186
187    fn array_hash<H: std::hash::Hasher>(
188        array: &SequenceArray,
189        state: &mut H,
190        _precision: Precision,
191    ) {
192        array.base.hash(state);
193        array.multiplier.hash(state);
194        array.dtype.hash(state);
195        array.length.hash(state);
196    }
197
198    fn array_eq(array: &SequenceArray, other: &SequenceArray, _precision: Precision) -> bool {
199        array.base == other.base
200            && array.multiplier == other.multiplier
201            && array.dtype == other.dtype
202            && array.length == other.length
203    }
204}
205
206impl CanonicalVTable<SequenceVTable> for SequenceVTable {
207    fn canonicalize(array: &SequenceArray) -> Canonical {
208        let prim = match_each_native_ptype!(array.ptype(), |P| {
209            let base = array.base().cast::<P>();
210            let multiplier = array.multiplier().cast::<P>();
211            let values = BufferMut::from_iter(
212                (0..array.len())
213                    .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
214            );
215            PrimitiveArray::new(values, array.dtype.nullability().into())
216        });
217
218        Canonical::Primitive(prim)
219    }
220}
221
222impl OperationsVTable<SequenceVTable> for SequenceVTable {
223    fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
224        SequenceArray::unchecked_new(
225            array.index_value(range.start),
226            array.multiplier,
227            array.ptype(),
228            array.dtype().nullability(),
229            range.len(),
230        )
231        .to_array()
232    }
233
234    fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
235        Scalar::new(
236            array.dtype().clone(),
237            ScalarValue::from(array.index_value(index)),
238        )
239    }
240}
241
242impl ValidityVTable<SequenceVTable> for SequenceVTable {
243    fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
244        true
245    }
246
247    fn all_valid(_array: &SequenceArray) -> bool {
248        true
249    }
250
251    fn all_invalid(_array: &SequenceArray) -> bool {
252        false
253    }
254
255    fn validity_mask(array: &SequenceArray) -> Mask {
256        Mask::AllTrue(array.len())
257    }
258}
259
260impl VisitorVTable<SequenceVTable> for SequenceVTable {
261    fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
262        // TODO(joe): expose scalar values
263    }
264
265    fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
266}
267
268#[derive(Clone, Debug)]
269pub struct SequenceEncoding;
270
271#[cfg(test)]
272mod tests {
273    use vortex_array::ToCanonical;
274    use vortex_array::arrays::PrimitiveArray;
275    use vortex_dtype::Nullability;
276    use vortex_scalar::{Scalar, ScalarValue};
277
278    use crate::array::SequenceArray;
279
280    #[test]
281    fn test_sequence_canonical() {
282        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
283
284        let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
285
286        assert_eq!(
287            arr.to_primitive().as_slice::<i64>(),
288            canon.as_slice::<i64>()
289        )
290    }
291
292    #[test]
293    fn test_sequence_slice_canonical() {
294        let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
295            .unwrap()
296            .slice(2..3);
297
298        let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
299
300        assert_eq!(
301            arr.to_primitive().as_slice::<i64>(),
302            canon.as_slice::<i64>()
303        )
304    }
305
306    #[test]
307    fn test_sequence_scalar_at() {
308        let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
309            .unwrap()
310            .scalar_at(2);
311
312        assert_eq!(
313            scalar,
314            Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
315        )
316    }
317
318    #[test]
319    fn test_sequence_min_max() {
320        assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
321        assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
322    }
323
324    #[test]
325    fn test_sequence_too_big() {
326        assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
327        assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
328    }
329}