vortex_sequence/compute/
cast.rs1use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::dtype::DType;
7use vortex_array::dtype::Nullability;
8use vortex_array::scalar::Scalar;
9use vortex_array::scalar::ScalarValue;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13
14use crate::Sequence;
15use crate::SequenceArray;
16
17impl CastReduce for Sequence {
18 fn cast(array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19 let DType::Primitive(target_ptype, target_nullability) = dtype else {
23 return Ok(None);
24 };
25
26 if !target_ptype.is_int() {
27 return Ok(None);
28 }
29
30 if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
32 return Ok(Some(
35 SequenceArray::try_new(
36 array.base(),
37 array.multiplier(),
38 *target_ptype,
39 *target_nullability,
40 array.len(),
41 )?
42 .into_array(),
43 ));
44 }
45
46 if array.ptype() != *target_ptype {
48 let base_scalar = Scalar::try_new(
50 DType::Primitive(array.ptype(), Nullability::NonNullable),
51 Some(ScalarValue::Primitive(array.base())),
52 )?;
53 let multiplier_scalar = Scalar::try_new(
54 DType::Primitive(array.ptype(), Nullability::NonNullable),
55 Some(ScalarValue::Primitive(array.multiplier())),
56 )?;
57
58 let new_base_scalar =
59 base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
60 let new_multiplier_scalar = multiplier_scalar
61 .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
62
63 let new_base = new_base_scalar
65 .as_primitive()
66 .pvalue()
67 .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
68 let new_multiplier = new_multiplier_scalar
69 .as_primitive()
70 .pvalue()
71 .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
72
73 return Ok(Some(
74 SequenceArray::try_new(
75 new_base,
76 new_multiplier,
77 *target_ptype,
78 *target_nullability,
79 array.len(),
80 )?
81 .into_array(),
82 ));
83 }
84
85 Ok(None)
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use rstest::rstest;
92 use vortex_array::IntoArray;
93 use vortex_array::ToCanonical;
94 use vortex_array::arrays::PrimitiveArray;
95 use vortex_array::assert_arrays_eq;
96 use vortex_array::builtins::ArrayBuiltins;
97 use vortex_array::compute::conformance::cast::test_cast_conformance;
98 use vortex_array::dtype::DType;
99 use vortex_array::dtype::Nullability;
100 use vortex_array::dtype::PType;
101
102 use crate::SequenceArray;
103
104 #[test]
105 fn test_cast_sequence_nullability() {
106 let sequence =
107 SequenceArray::try_new_typed(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
108
109 let casted = sequence
111 .into_array()
112 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
113 .unwrap();
114 assert_eq!(
115 casted.dtype(),
116 &DType::Primitive(PType::U32, Nullability::Nullable)
117 );
118 }
119
120 #[test]
121 fn test_cast_sequence_u32_to_i64() {
122 let sequence =
123 SequenceArray::try_new_typed(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
124
125 let casted = sequence
126 .into_array()
127 .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
128 .unwrap();
129 assert_eq!(
130 casted.dtype(),
131 &DType::Primitive(PType::I64, Nullability::NonNullable)
132 );
133
134 let decoded = casted.to_primitive();
136 assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
137 }
138
139 #[test]
140 fn test_cast_sequence_i16_to_i32_nullable() {
141 let sequence =
143 SequenceArray::try_new_typed(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
144
145 let casted = sequence
146 .into_array()
147 .cast(DType::Primitive(PType::I32, Nullability::Nullable))
148 .unwrap();
149 assert_eq!(
150 casted.dtype(),
151 &DType::Primitive(PType::I32, Nullability::Nullable)
152 );
153
154 let decoded = casted.to_primitive();
156 assert_arrays_eq!(
157 decoded,
158 PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
159 );
160 }
161
162 #[test]
163 fn test_cast_sequence_to_float_delegates_to_canonical() {
164 let sequence =
165 SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
166
167 let casted = sequence
169 .into_array()
170 .cast(DType::Primitive(PType::F32, Nullability::NonNullable))
171 .unwrap();
172 assert_eq!(
174 casted.dtype(),
175 &DType::Primitive(PType::F32, Nullability::NonNullable)
176 );
177
178 let decoded = casted.to_primitive();
180 assert_arrays_eq!(
181 decoded,
182 PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
183 );
184 }
185
186 #[rstest]
187 #[case::i32(SequenceArray::try_new_typed(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
188 #[case::u64(SequenceArray::try_new_typed(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
189 #[case::single(SequenceArray::try_new_typed(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
195 #[case::constant(SequenceArray::try_new_typed(
196 100i32,
197 0i32, Nullability::NonNullable,
199 5,
200 ).unwrap())]
201 fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
202 test_cast_conformance(&sequence.into_array());
203 }
204}