1use std::ops::Range;
5
6use num_traits::cast::FromPrimitive;
7use vortex_array::arrays::PrimitiveArray;
8use vortex_array::stats::{ArrayStats, StatsSetRef};
9use vortex_array::vtable::{
10 ArrayVTable, CanonicalVTable, NotSupported, OperationsVTable, VTable, ValidityVTable,
11 VisitorVTable,
12};
13use vortex_array::{
14 ArrayBufferVisitor, ArrayChildVisitor, ArrayRef, Canonical, EncodingId, EncodingRef, vtable,
15};
16use vortex_buffer::BufferMut;
17use vortex_dtype::{
18 DType, NativePType, Nullability, PType, match_each_integer_ptype, match_each_native_ptype,
19};
20use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
21use vortex_mask::Mask;
22use vortex_scalar::{PValue, Scalar, ScalarValue};
23
24vtable!(Sequence);
25
26#[derive(Clone, Debug)]
27pub struct SequenceArray {
29 base: PValue,
30 multiplier: PValue,
31 dtype: DType,
32 length: usize,
33 stats_set: ArrayStats,
34}
35
36impl SequenceArray {
37 pub fn typed_new<T: NativePType + Into<PValue>>(
38 base: T,
39 multiplier: T,
40 nullability: Nullability,
41 length: usize,
42 ) -> VortexResult<Self> {
43 Self::new(
44 base.into(),
45 multiplier.into(),
46 T::PTYPE,
47 nullability,
48 length,
49 )
50 }
51
52 pub fn new(
54 base: PValue,
55 multiplier: PValue,
56 ptype: PType,
57 nullability: Nullability,
58 length: usize,
59 ) -> VortexResult<Self> {
60 if !ptype.is_int() {
61 vortex_bail!("only integer ptype are supported in SequenceArray currently")
62 }
63
64 Self::try_last(base, multiplier, ptype, length).map_err(|e| {
65 e.with_context(format!(
66 "final value not expressible, base = {base:?}, multiplier = {multiplier:?}, len = {length} ",
67 ))
68 })?;
69
70 Ok(Self::unchecked_new(
71 base,
72 multiplier,
73 ptype,
74 nullability,
75 length,
76 ))
77 }
78
79 pub(crate) fn unchecked_new(
80 base: PValue,
81 multiplier: PValue,
82 ptype: PType,
83 nullability: Nullability,
84 length: usize,
85 ) -> Self {
86 let dtype = DType::Primitive(ptype, nullability);
87 Self {
88 base,
89 multiplier,
90 dtype,
91 length,
92 stats_set: Default::default(),
94 }
95 }
96
97 pub fn ptype(&self) -> PType {
98 self.dtype.as_ptype()
99 }
100
101 pub fn base(&self) -> PValue {
102 self.base
103 }
104
105 pub fn multiplier(&self) -> PValue {
106 self.multiplier
107 }
108
109 pub(crate) fn try_last(
110 base: PValue,
111 multiplier: PValue,
112 ptype: PType,
113 length: usize,
114 ) -> VortexResult<PValue> {
115 match_each_integer_ptype!(ptype, |P| {
116 let len_t = <P>::from_usize(length - 1)
117 .ok_or_else(|| vortex_err!("cannot convert length {} into {}", length, ptype))?;
118
119 let base = base.as_primitive::<P>();
120 let multiplier = multiplier.as_primitive::<P>();
121
122 let last = len_t
123 .checked_mul(multiplier)
124 .and_then(|offset| offset.checked_add(base))
125 .ok_or_else(|| vortex_err!("last value computation overflows"))?;
126 Ok(PValue::from(last))
127 })
128 }
129
130 fn index_value(&self, idx: usize) -> PValue {
131 assert!(idx < self.length, "index_value({idx}): index out of bounds");
132
133 match_each_native_ptype!(self.ptype(), |P| {
134 let base = self.base.as_primitive::<P>();
135 let multiplier = self.multiplier.as_primitive::<P>();
136 let value = base + (multiplier * <P>::from_usize(idx).vortex_expect("must fit"));
137
138 PValue::from(value)
139 })
140 }
141
142 pub fn last(&self) -> PValue {
144 Self::try_last(self.base, self.multiplier, self.ptype(), self.length)
145 .vortex_expect("validated array")
146 }
147
148 pub fn dtype(&self) -> &DType {
149 &self.dtype
150 }
151}
152
153impl VTable for SequenceVTable {
154 type Array = SequenceArray;
155 type Encoding = SequenceEncoding;
156
157 type ArrayVTable = Self;
158 type CanonicalVTable = Self;
159 type OperationsVTable = Self;
160 type ValidityVTable = Self;
161 type VisitorVTable = Self;
162 type ComputeVTable = NotSupported;
163 type EncodeVTable = Self;
164 type SerdeVTable = Self;
165 type PipelineVTable = NotSupported;
166
167 fn id(_encoding: &Self::Encoding) -> EncodingId {
168 EncodingId::new_ref("vortex.sequence")
169 }
170
171 fn encoding(_array: &Self::Array) -> EncodingRef {
172 EncodingRef::new_ref(SequenceEncoding.as_ref())
173 }
174}
175
176impl ArrayVTable<SequenceVTable> for SequenceVTable {
177 fn len(array: &SequenceArray) -> usize {
178 array.length
179 }
180
181 fn dtype(array: &SequenceArray) -> &DType {
182 &array.dtype
183 }
184
185 fn stats(array: &SequenceArray) -> StatsSetRef<'_> {
186 array.stats_set.to_ref(array.as_ref())
187 }
188}
189
190impl CanonicalVTable<SequenceVTable> for SequenceVTable {
191 fn canonicalize(array: &SequenceArray) -> VortexResult<Canonical> {
192 let prim = match_each_native_ptype!(array.ptype(), |P| {
193 let base = array.base().as_primitive::<P>();
194 let multiplier = array.multiplier().as_primitive::<P>();
195 let values = BufferMut::from_iter(
196 (0..array.len())
197 .map(|i| base + <P>::from_usize(i).vortex_expect("must fit") * multiplier),
198 );
199 PrimitiveArray::new(values, array.dtype.nullability().into())
200 });
201
202 Ok(Canonical::Primitive(prim))
203 }
204}
205
206impl OperationsVTable<SequenceVTable> for SequenceVTable {
207 fn slice(array: &SequenceArray, range: Range<usize>) -> ArrayRef {
208 SequenceArray::unchecked_new(
209 array.index_value(range.start),
210 array.multiplier,
211 array.ptype(),
212 array.dtype().nullability(),
213 range.len(),
214 )
215 .to_array()
216 }
217
218 fn scalar_at(array: &SequenceArray, index: usize) -> Scalar {
219 Scalar::new(
220 array.dtype().clone(),
221 ScalarValue::from(array.index_value(index)),
222 )
223 }
224}
225
226impl ValidityVTable<SequenceVTable> for SequenceVTable {
227 fn is_valid(_array: &SequenceArray, _index: usize) -> bool {
228 true
229 }
230
231 fn all_valid(_array: &SequenceArray) -> bool {
232 true
233 }
234
235 fn all_invalid(_array: &SequenceArray) -> bool {
236 false
237 }
238
239 fn validity_mask(array: &SequenceArray) -> Mask {
240 Mask::AllTrue(array.len())
241 }
242}
243
244impl VisitorVTable<SequenceVTable> for SequenceVTable {
245 fn visit_buffers(_array: &SequenceArray, _visitor: &mut dyn ArrayBufferVisitor) {
246 }
248
249 fn visit_children(_array: &SequenceArray, _visitor: &mut dyn ArrayChildVisitor) {}
250}
251
252#[derive(Clone, Debug)]
253pub struct SequenceEncoding;
254
255#[cfg(test)]
256mod tests {
257 use vortex_array::arrays::PrimitiveArray;
258 use vortex_dtype::Nullability;
259 use vortex_scalar::{Scalar, ScalarValue};
260
261 use crate::array::SequenceArray;
262
263 #[test]
264 fn test_sequence_canonical() {
265 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4).unwrap();
266
267 let canon = PrimitiveArray::from_iter((0..4).map(|i| 2i64 + i * 3));
268
269 assert_eq!(
270 arr.to_canonical()
271 .unwrap()
272 .into_primitive()
273 .unwrap()
274 .as_slice::<i64>(),
275 canon.as_slice::<i64>()
276 )
277 }
278
279 #[test]
280 fn test_sequence_slice_canonical() {
281 let arr = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
282 .unwrap()
283 .slice(2..3);
284
285 let canon = PrimitiveArray::from_iter((2..3).map(|i| 2i64 + i * 3));
286
287 assert_eq!(
288 arr.to_canonical()
289 .unwrap()
290 .into_primitive()
291 .unwrap()
292 .as_slice::<i64>(),
293 canon.as_slice::<i64>()
294 )
295 }
296
297 #[test]
298 fn test_sequence_scalar_at() {
299 let scalar = SequenceArray::typed_new(2i64, 3, Nullability::NonNullable, 4)
300 .unwrap()
301 .scalar_at(2);
302
303 assert_eq!(
304 scalar,
305 Scalar::new(scalar.dtype().clone(), ScalarValue::from(8i64))
306 )
307 }
308
309 #[test]
310 fn test_sequence_min_max() {
311 assert!(SequenceArray::typed_new(-127i8, -1i8, Nullability::NonNullable, 2).is_ok());
312 assert!(SequenceArray::typed_new(126i8, -1i8, Nullability::NonNullable, 2).is_ok());
313 }
314
315 #[test]
316 fn test_sequence_too_big() {
317 assert!(SequenceArray::typed_new(127i8, 1i8, Nullability::NonNullable, 2).is_err());
318 assert!(SequenceArray::typed_new(-128i8, -1i8, Nullability::NonNullable, 2).is_err());
319 }
320}