1use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{VortexResult, vortex_err};
8use vortex_scalar::{Scalar, ScalarValue};
9
10use crate::{SequenceArray, SequenceVTable};
11
12impl CastKernel for SequenceVTable {
13 fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14 let DType::Primitive(target_ptype, target_nullability) = dtype else {
18 return Ok(None);
19 };
20
21 if !target_ptype.is_int() {
22 return Ok(None);
23 }
24
25 if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
27 return Ok(Some(
30 SequenceArray::new(
31 array.base(),
32 array.multiplier(),
33 *target_ptype,
34 *target_nullability,
35 array.len(),
36 )?
37 .into_array(),
38 ));
39 }
40
41 if array.ptype() != *target_ptype {
43 let base_scalar = Scalar::new(
45 DType::Primitive(array.ptype(), Nullability::NonNullable),
46 ScalarValue::from(array.base()),
47 );
48 let multiplier_scalar = Scalar::new(
49 DType::Primitive(array.ptype(), Nullability::NonNullable),
50 ScalarValue::from(array.multiplier()),
51 );
52
53 let new_base_scalar =
54 base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
55 let new_multiplier_scalar = multiplier_scalar
56 .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
57
58 let new_base = new_base_scalar
60 .as_primitive()
61 .pvalue()
62 .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
63 let new_multiplier = new_multiplier_scalar
64 .as_primitive()
65 .pvalue()
66 .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
67
68 return Ok(Some(
69 SequenceArray::new(
70 new_base,
71 new_multiplier,
72 *target_ptype,
73 *target_nullability,
74 array.len(),
75 )?
76 .into_array(),
77 ));
78 }
79
80 Ok(None)
81 }
82}
83
84register_kernel!(CastKernelAdapter(SequenceVTable).lift());
85
86#[cfg(test)]
87mod tests {
88 use rstest::rstest;
89 use vortex_array::ToCanonical;
90 use vortex_array::compute::cast;
91 use vortex_array::compute::conformance::cast::test_cast_conformance;
92 use vortex_dtype::{DType, Nullability, PType};
93
94 use crate::SequenceArray;
95
96 #[test]
97 fn test_cast_sequence_nullability() {
98 let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
99
100 let casted = cast(
102 sequence.as_ref(),
103 &DType::Primitive(PType::U32, Nullability::Nullable),
104 )
105 .unwrap();
106 assert_eq!(
107 casted.dtype(),
108 &DType::Primitive(PType::U32, Nullability::Nullable)
109 );
110 }
111
112 #[test]
113 fn test_cast_sequence_u32_to_i64() {
114 let sequence =
115 SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
116
117 let casted = cast(
118 sequence.as_ref(),
119 &DType::Primitive(PType::I64, Nullability::NonNullable),
120 )
121 .unwrap();
122 assert_eq!(
123 casted.dtype(),
124 &DType::Primitive(PType::I64, Nullability::NonNullable)
125 );
126
127 let decoded = casted.to_primitive();
129 assert_eq!(decoded.as_slice::<i64>(), &[100i64, 110, 120, 130]);
130 }
131
132 #[test]
133 fn test_cast_sequence_i16_to_i32_nullable() {
134 let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
136
137 let casted = cast(
138 sequence.as_ref(),
139 &DType::Primitive(PType::I32, Nullability::Nullable),
140 )
141 .unwrap();
142 assert_eq!(
143 casted.dtype(),
144 &DType::Primitive(PType::I32, Nullability::Nullable)
145 );
146
147 let decoded = casted.to_primitive();
149 assert_eq!(decoded.as_slice::<i32>(), &[5i32, 8, 11]);
150 }
151
152 #[test]
153 fn test_cast_sequence_to_float_delegates_to_canonical() {
154 let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
155
156 let casted = cast(
158 sequence.as_ref(),
159 &DType::Primitive(PType::F32, Nullability::NonNullable),
160 )
161 .unwrap();
162 assert_eq!(
164 casted.dtype(),
165 &DType::Primitive(PType::F32, Nullability::NonNullable)
166 );
167
168 let decoded = casted.to_primitive();
170 let float_values = decoded.as_slice::<f32>();
171 assert_eq!(float_values, &[0.0f32, 1.0, 2.0, 3.0, 4.0]);
172 }
173
174 #[rstest]
175 #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
176 #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
177 #[case::negative_step(SequenceArray::typed_new(100i32, -10i32, Nullability::NonNullable, 5).unwrap())]
178 #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
179 #[case::constant(SequenceArray::typed_new(
180 100i32,
181 0i32, Nullability::NonNullable,
183 5,
184 ).unwrap())]
185 fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
186 test_cast_conformance(sequence.as_ref());
187 }
188}