1use num_traits::cast::FromPrimitive;
2use vortex_array::arrays::PrimitiveArray;
3use vortex_array::stats::{ArrayStats, StatsSetRef};
4use vortex_array::vtable::{
5 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
6 VisitorVTable,
7};
8use vortex_array::{
9 ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, vtable,
10};
11use vortex_dtype::Nullability::NonNullable;
12use vortex_dtype::{DType, NativePType, PType, match_each_integer_ptype, match_each_native_ptype};
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_mask::Mask;
15use vortex_scalar::{PValue, Scalar, ScalarValue};
16
17vtable!(Sequence);
18
19#[derive(Clone, Debug)]
20pub struct SequenceArray {
22 base: PValue,
23 multiplier: PValue,
24 dtype: DType,
25 length: usize,
26 stats_set: ArrayStats,
27}
28
29impl SequenceArray {
30 pub fn typed_new<T: NativePType + Into<PValue>>(
31 base: T,
32 multiplier: T,
33 length: usize,
34 ) -> VortexResult<Self> {
35 Self::new(base.into(), multiplier.into(), T::PTYPE, length)
36 }
37
38 pub fn new(
40 base: PValue,
41 multiplier: PValue,
42 ptype: PType,
43 length: usize,
44 ) -> VortexResult<Self> {
45 if !ptype.is_int() {
46 vortex_bail!("only integer ptype are supported in SequenceArray currently")
47 }
48
49 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
50 e.with_context(format!(
51 "final value not expressible, base = {:?}, multiplier = {:?}, len = {} ",
52 base, multiplier, length
53 ))
54 })?;
55
56 Ok(Self::unchecked_new(base, multiplier, ptype, length))
57 }
58
59 pub(crate) fn unchecked_new(
60 base: PValue,
61 multiplier: PValue,
62 ptype: PType,
63 length: usize,
64 ) -> Self {
65 let dtype = DType::Primitive(ptype, NonNullable);
66 Self {
67 base,
68 multiplier,
69 dtype,
70 length,
71 stats_set: Default::default(),
73 }
74 }
75
76 pub fn ptype(&self) -> PType {
77 self.dtype.as_ptype()
78 }
79
80 pub fn base(&self) -> PValue {
81 self.base
82 }
83
84 pub fn multiplier(&self) -> PValue {
85 self.multiplier
86 }
87
88 fn try_last(
89 base: PValue,
90 multiplier: PValue,
91 ptype: PType,
92 length: usize,
93 ) -> VortexResult<PValue> {
94 match_each_integer_ptype!(ptype, |P| {
95 let len_t = <P>::from_usize(length - 1)
96 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
97
98 let base = base.as_primitive::<P>()?;
99 let multiplier = multiplier.as_primitive::<P>()?;
100
101 let last = len_t
102 .checked_mul(multiplier)
103 .and_then(|offset| offset.checked_add(base))
104 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
105 Ok(PValue::from(last))
106 })
107 }
108
109 fn index_value(&self, idx: usize) -> VortexResult<PValue> {
110 if idx > self.length {
111 vortex_bail!("out of bounds")
112 }
113 match_each_native_ptype!(self.ptype(), |P| {
114 let base = self.base.as_primitive::<P>()?;
115 let multiplier = self.multiplier.as_primitive::<P>()?;
116 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
117
118 Ok(PValue::from(value))
119 })
120 }
121
122 pub fn last(&self) -> PValue {
124 Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
125 .vortex_expect("validated array")
126 }
127}
128
129impl VTable for SequenceVTable {
130 type Array = SequenceArray;
131 type Encoding = SequenceEncoding;
132
133 type ArrayVTable = Self;
134 type CanonicalVTable = Self;
135 type OperationsVTable = Self;
136 type ValidityVTable = Self;
137 type VisitorVTable = Self;
138 type ComputeVTable = NotSupported;
139 type EncodeVTable = Self;
140 type SerdeVTable = Self;
141
142 fn id(_encoding: &Self::Encoding) -> EncodingId {
143 EncodingId::new_ref("vortex.sequence")
144 }
145
146 fn encoding(_array: &Self::Array) -> EncodingRef {
147 EncodingRef::new_ref(SequenceEncoding.as_ref())
148 }
149}
150
151impl ArrayVTable<SequenceVTable> for SequenceVTable {
152 fn len(array: &SequenceArray) -> usize {
153 array.length
154 }
155
156 fn dtype(array: &SequenceArray) -> &DType {
157 &array.dtype
158 }
159
160 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
161 array.stats_set.to_ref(array.as_ref())
162 }
163}
164
165impl CanonicalVTable<SequenceVTable> for SequenceVTable {
166 fn canonicalize(array: &SequenceArray) -> VortexResult<Canonical> {
167 let prim = match_each_native_ptype!(array.ptype(), |P| {
168 let base = array.base().as_primitive::<P>()?;
169 let multiplier = array.multiplier().as_primitive::<P>()?;
170 PrimitiveArray::from_iter(
171 (0..array.len())
172 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
173 )
174 });
175
176 Ok(Canonical::Primitive(prim))
177 }
178}
179
180impl OperationsVTable<SequenceVTable> for SequenceVTable {
181 fn slice(array: &SequenceArray, start: usize, stop: usize) -> VortexResult<ArrayRef> {
182 Ok(SequenceArray::unchecked_new(
183 array.index_value(start)?,
184 array.multiplier,
185 array.ptype(),
186 stop - start,
187 )
188 .to_array())
189 }
190
191 fn scalar_at(array: &SequenceArray, index: usize) -> VortexResult<Scalar> {
192 Ok(Scalar::new(
194 array.dtype().clone(),
195 ScalarValue::from(array.index_value(index)?),
196 ))
197 }
198}
199
200impl ValidityVTable<SequenceVTable> for SequenceVTable {
201 fn is_valid(_array: &SequenceArray, _index: usize) -> VortexResult<bool> {
202 Ok(true)
203 }
204
205 fn all_valid(_array: &SequenceArray) -> VortexResult<bool> {
206 Ok(true)
207 }
208
209 fn all_invalid(_array: &SequenceArray) -> VortexResult<bool> {
210 Ok(false)
211 }
212
213 fn validity_mask(array: &SequenceArray) -> VortexResult<Mask> {
214 Ok(Mask::AllTrue(array.len()))
215 }
216}
217
218impl VisitorVTable<SequenceVTable> for SequenceVTable {
219 fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
220 }
222
223 fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
224}
225
226#[derive(Clone, Debug)]
227pub struct SequenceEncoding;
228
229#[cfg(test)]
230mod tests {
231 use vortex_array::arrays::PrimitiveArray;
232 use vortex_scalar::{Scalar, ScalarValue};
233
234 use crate::array::SequenceArray;
235
236 #[test]
237 fn test_sequence_canonical() {
238 let arr = SequenceArray::typed_new(2i64, 3, 4).unwrap();
239
240 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
241
242 assert_eq!(
243 arr.to_canonical()
244 .unwrap()
245 .into_primitive()
246 .unwrap()
247 .as_slice::<i64>(),
248 canon.as_slice::<i64>()
249 )
250 }
251
252 #[test]
253 fn test_sequence_slice_canonical() {
254 let arr = SequenceArray::typed_new(2i64, 3, 4)
255 .unwrap()
256 .slice(2, 3)
257 .unwrap();
258
259 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
260
261 assert_eq!(
262 arr.to_canonical()
263 .unwrap()
264 .into_primitive()
265 .unwrap()
266 .as_slice::<i64>(),
267 canon.as_slice::<i64>()
268 )
269 }
270
271 #[test]
272 fn test_sequence_scalar_at() {
273 let scalar = SequenceArray::typed_new(2i64, 3, 4)
274 .unwrap()
275 .scalar_at(2)
276 .unwrap();
277
278 assert_eq!(
279 scalar,
280 Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
281 )
282 }
283
284 #[test]
285 fn test_sequence_min_max() {
286 assert!(SequenceArray::typed_new(-127i8, -1i8, 2).is_ok());
287 assert!(SequenceArray::typed_new(126i8, -1i8, 2).is_ok());
288 }
289
290 #[test]
291 fn test_sequence_too_big() {
292 assert!(SequenceArray::typed_new(127i8, 1i8, 2).is_err());
293 assert!(SequenceArray::typed_new(-128i8, -1i8, 2).is_err());
294 }
295}