1use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::{DType, Nullability};
7use vortex_error::{VortexResult, vortex_err};
8use vortex_scalar::{Scalar, ScalarValue};
9
10use crate::{SequenceArray, SequenceVTable};
11
12impl CastKernel for SequenceVTable {
13 fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
14 let DType::Primitive(target_ptype, target_nullability) = dtype else {
18 return Ok(None);
19 };
20
21 if !target_ptype.is_int() {
22 return Ok(None);
23 }
24
25 if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
27 return Ok(Some(
30 SequenceArray::new(
31 array.base(),
32 array.multiplier(),
33 *target_ptype,
34 *target_nullability,
35 array.len(),
36 )?
37 .into_array(),
38 ));
39 }
40
41 if array.ptype() != *target_ptype {
43 let base_scalar = Scalar::new(
45 DType::Primitive(array.ptype(), Nullability::NonNullable),
46 ScalarValue::from(array.base()),
47 );
48 let multiplier_scalar = Scalar::new(
49 DType::Primitive(array.ptype(), Nullability::NonNullable),
50 ScalarValue::from(array.multiplier()),
51 );
52
53 let new_base_scalar =
54 base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
55 let new_multiplier_scalar = multiplier_scalar
56 .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
57
58 let new_base = new_base_scalar
60 .as_primitive()
61 .pvalue()
62 .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
63 let new_multiplier = new_multiplier_scalar
64 .as_primitive()
65 .pvalue()
66 .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
67
68 return Ok(Some(
69 SequenceArray::new(
70 new_base,
71 new_multiplier,
72 *target_ptype,
73 *target_nullability,
74 array.len(),
75 )?
76 .into_array(),
77 ));
78 }
79
80 Ok(None)
81 }
82}
83
84register_kernel!(CastKernelAdapter(SequenceVTable).lift());
85
86#[cfg(test)]
87mod tests {
88 use rstest::rstest;
89 use vortex_array::arrays::PrimitiveArray;
90 use vortex_array::compute::cast;
91 use vortex_array::compute::conformance::cast::test_cast_conformance;
92 use vortex_array::{ToCanonical, assert_arrays_eq};
93 use vortex_dtype::{DType, Nullability, PType};
94
95 use crate::SequenceArray;
96
97 #[test]
98 fn test_cast_sequence_nullability() {
99 let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
100
101 let casted = cast(
103 sequence.as_ref(),
104 &DType::Primitive(PType::U32, Nullability::Nullable),
105 )
106 .unwrap();
107 assert_eq!(
108 casted.dtype(),
109 &DType::Primitive(PType::U32, Nullability::Nullable)
110 );
111 }
112
113 #[test]
114 fn test_cast_sequence_u32_to_i64() {
115 let sequence =
116 SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
117
118 let casted = cast(
119 sequence.as_ref(),
120 &DType::Primitive(PType::I64, Nullability::NonNullable),
121 )
122 .unwrap();
123 assert_eq!(
124 casted.dtype(),
125 &DType::Primitive(PType::I64, Nullability::NonNullable)
126 );
127
128 let decoded = casted.to_primitive();
130 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
131 }
132
133 #[test]
134 fn test_cast_sequence_i16_to_i32_nullable() {
135 let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
137
138 let casted = cast(
139 sequence.as_ref(),
140 &DType::Primitive(PType::I32, Nullability::Nullable),
141 )
142 .unwrap();
143 assert_eq!(
144 casted.dtype(),
145 &DType::Primitive(PType::I32, Nullability::Nullable)
146 );
147
148 let decoded = casted.to_primitive();
150 assert_arrays_eq!(
151 decoded,
152 PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
153 );
154 }
155
156 #[test]
157 fn test_cast_sequence_to_float_delegates_to_canonical() {
158 let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
159
160 let casted = cast(
162 sequence.as_ref(),
163 &DType::Primitive(PType::F32, Nullability::NonNullable),
164 )
165 .unwrap();
166 assert_eq!(
168 casted.dtype(),
169 &DType::Primitive(PType::F32, Nullability::NonNullable)
170 );
171
172 let decoded = casted.to_primitive();
174 assert_arrays_eq!(
175 decoded,
176 PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
177 );
178 }
179
180 #[rstest]
181 #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
182 #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
183 #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
189 #[case::constant(SequenceArray::typed_new(
190 100i32,
191 0i32, Nullability::NonNullable,
193 5,
194 ).unwrap())]
195 fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
196 test_cast_conformance(sequence.as_ref());
197 }
198}