1use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{VortexResult, vortex_err};
8use vortex_scalar::{Scalar, ScalarValue};
9
10use crate::{SequenceArray, SequenceVTable};
11
12impl CastKernel for SequenceVTable {
13 fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14 let DType::Primitive(target_ptype, target_nullability) = dtype else {
18 return Ok(None);
19 };
20
21 if !target_ptype.is_int() {
22 return Ok(None);
23 }
24
25 if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
27 return Ok(Some(
30 SequenceArray::new(
31 array.base(),
32 array.multiplier(),
33 *target_ptype,
34 *target_nullability,
35 array.len(),
36 )?
37 .into_array(),
38 ));
39 }
40
41 if array.ptype() != *target_ptype {
43 let base_scalar = Scalar::new(
45 DType::Primitive(array.ptype(), Nullability::NonNullable),
46 ScalarValue::from(array.base()),
47 );
48 let multiplier_scalar = Scalar::new(
49 DType::Primitive(array.ptype(), Nullability::NonNullable),
50 ScalarValue::from(array.multiplier()),
51 );
52
53 let new_base_scalar =
54 base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
55 let new_multiplier_scalar = multiplier_scalar
56 .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
57
58 let new_base = new_base_scalar
60 .as_primitive()
61 .pvalue()
62 .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
63 let new_multiplier = new_multiplier_scalar
64 .as_primitive()
65 .pvalue()
66 .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
67
68 return Ok(Some(
69 SequenceArray::new(
70 new_base,
71 new_multiplier,
72 *target_ptype,
73 *target_nullability,
74 array.len(),
75 )?
76 .into_array(),
77 ));
78 }
79
80 Ok(None)
81 }
82}
83
84register_kernel!(CastKernelAdapter(SequenceVTable).lift());
85
86#[cfg(test)]
87mod tests {
88 use rstest::rstest;
89 use vortex_array::compute::cast;
90 use vortex_array::compute::conformance::cast::test_cast_conformance;
91 use vortex_dtype::{DType, Nullability, PType};
92
93 use crate::SequenceArray;
94
95 #[test]
96 fn test_cast_sequence_nullability() {
97 let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
98
99 let casted = cast(
101 sequence.as_ref(),
102 &DType::Primitive(PType::U32, Nullability::Nullable),
103 )
104 .unwrap();
105 assert_eq!(
106 casted.dtype(),
107 &DType::Primitive(PType::U32, Nullability::Nullable)
108 );
109 }
110
111 #[test]
112 fn test_cast_sequence_u32_to_i64() {
113 let sequence =
114 SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
115
116 let casted = cast(
117 sequence.as_ref(),
118 &DType::Primitive(PType::I64, Nullability::NonNullable),
119 )
120 .unwrap();
121 assert_eq!(
122 casted.dtype(),
123 &DType::Primitive(PType::I64, Nullability::NonNullable)
124 );
125
126 let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
128 assert_eq!(decoded.as_slice::<i64>(), &[100i64, 110, 120, 130]);
129 }
130
131 #[test]
132 fn test_cast_sequence_i16_to_i32_nullable() {
133 let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
135
136 let casted = cast(
137 sequence.as_ref(),
138 &DType::Primitive(PType::I32, Nullability::Nullable),
139 )
140 .unwrap();
141 assert_eq!(
142 casted.dtype(),
143 &DType::Primitive(PType::I32, Nullability::Nullable)
144 );
145
146 let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
148 assert_eq!(decoded.as_slice::<i32>(), &[5i32, 8, 11]);
149 }
150
151 #[test]
152 fn test_cast_sequence_to_float_delegates_to_canonical() {
153 let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
154
155 let casted = cast(
157 sequence.as_ref(),
158 &DType::Primitive(PType::F32, Nullability::NonNullable),
159 )
160 .unwrap();
161 assert_eq!(
163 casted.dtype(),
164 &DType::Primitive(PType::F32, Nullability::NonNullable)
165 );
166
167 let decoded = casted.to_canonical().unwrap().into_primitive().unwrap();
169 let float_values = decoded.as_slice::<f32>();
170 assert_eq!(float_values, &[0.0f32, 1.0, 2.0, 3.0, 4.0]);
171 }
172
173 #[rstest]
174 #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
175 #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
176 #[case::negative_step(SequenceArray::typed_new(100i32, -10i32, Nullability::NonNullable, 5).unwrap())]
177 #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
178 #[case::constant(SequenceArray::typed_new(
179 100i32,
180 0i32, Nullability::NonNullable,
182 5,
183 ).unwrap())]
184 fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
185 test_cast_conformance(sequence.as_ref());
186 }
187}