vortex_sequence/compute/
cast.rs1use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::register_kernel;
9use vortex_dtype::DType;
10use vortex_dtype::Nullability;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13use vortex_scalar::Scalar;
14use vortex_scalar::ScalarValue;
15
16use crate::SequenceArray;
17use crate::SequenceVTable;
18
19impl CastKernel for SequenceVTable {
20 fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
21 let DType::Primitive(target_ptype, target_nullability) = dtype else {
25 return Ok(None);
26 };
27
28 if !target_ptype.is_int() {
29 return Ok(None);
30 }
31
32 if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
34 return Ok(Some(
37 SequenceArray::new(
38 array.base(),
39 array.multiplier(),
40 *target_ptype,
41 *target_nullability,
42 array.len(),
43 )?
44 .into_array(),
45 ));
46 }
47
48 if array.ptype() != *target_ptype {
50 let base_scalar = Scalar::new(
52 DType::Primitive(array.ptype(), Nullability::NonNullable),
53 ScalarValue::from(array.base()),
54 );
55 let multiplier_scalar = Scalar::new(
56 DType::Primitive(array.ptype(), Nullability::NonNullable),
57 ScalarValue::from(array.multiplier()),
58 );
59
60 let new_base_scalar =
61 base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
62 let new_multiplier_scalar = multiplier_scalar
63 .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
64
65 let new_base = new_base_scalar
67 .as_primitive()
68 .pvalue()
69 .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
70 let new_multiplier = new_multiplier_scalar
71 .as_primitive()
72 .pvalue()
73 .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
74
75 return Ok(Some(
76 SequenceArray::new(
77 new_base,
78 new_multiplier,
79 *target_ptype,
80 *target_nullability,
81 array.len(),
82 )?
83 .into_array(),
84 ));
85 }
86
87 Ok(None)
88 }
89}
90
91register_kernel!(CastKernelAdapter(SequenceVTable).lift());
92
93#[cfg(test)]
94mod tests {
95 use rstest::rstest;
96 use vortex_array::ToCanonical;
97 use vortex_array::arrays::PrimitiveArray;
98 use vortex_array::assert_arrays_eq;
99 use vortex_array::compute::cast;
100 use vortex_array::compute::conformance::cast::test_cast_conformance;
101 use vortex_dtype::DType;
102 use vortex_dtype::Nullability;
103 use vortex_dtype::PType;
104
105 use crate::SequenceArray;
106
107 #[test]
108 fn test_cast_sequence_nullability() {
109 let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
110
111 let casted = cast(
113 sequence.as_ref(),
114 &DType::Primitive(PType::U32, Nullability::Nullable),
115 )
116 .unwrap();
117 assert_eq!(
118 casted.dtype(),
119 &DType::Primitive(PType::U32, Nullability::Nullable)
120 );
121 }
122
123 #[test]
124 fn test_cast_sequence_u32_to_i64() {
125 let sequence =
126 SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
127
128 let casted = cast(
129 sequence.as_ref(),
130 &DType::Primitive(PType::I64, Nullability::NonNullable),
131 )
132 .unwrap();
133 assert_eq!(
134 casted.dtype(),
135 &DType::Primitive(PType::I64, Nullability::NonNullable)
136 );
137
138 let decoded = casted.to_primitive();
140 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
141 }
142
143 #[test]
144 fn test_cast_sequence_i16_to_i32_nullable() {
145 let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
147
148 let casted = cast(
149 sequence.as_ref(),
150 &DType::Primitive(PType::I32, Nullability::Nullable),
151 )
152 .unwrap();
153 assert_eq!(
154 casted.dtype(),
155 &DType::Primitive(PType::I32, Nullability::Nullable)
156 );
157
158 let decoded = casted.to_primitive();
160 assert_arrays_eq!(
161 decoded,
162 PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
163 );
164 }
165
166 #[test]
167 fn test_cast_sequence_to_float_delegates_to_canonical() {
168 let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
169
170 let casted = cast(
172 sequence.as_ref(),
173 &DType::Primitive(PType::F32, Nullability::NonNullable),
174 )
175 .unwrap();
176 assert_eq!(
178 casted.dtype(),
179 &DType::Primitive(PType::F32, Nullability::NonNullable)
180 );
181
182 let decoded = casted.to_primitive();
184 assert_arrays_eq!(
185 decoded,
186 PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
187 );
188 }
189
190 #[rstest]
191 #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
192 #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
193 #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
199 #[case::constant(SequenceArray::typed_new(
200 100i32,
201 0i32, Nullability::NonNullable,
203 5,
204 ).unwrap())]
205 fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
206 test_cast_conformance(sequence.as_ref());
207 }
208}