vortex_sequence/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::register_kernel;
9use vortex_dtype::DType;
10use vortex_dtype::Nullability;
11use vortex_error::VortexResult;
12use vortex_error::vortex_err;
13use vortex_scalar::Scalar;
14use vortex_scalar::ScalarValue;
15
16use crate::SequenceArray;
17use crate::SequenceVTable;
18
19impl CastKernel for SequenceVTable {
20    fn cast(&self, array: &SequenceArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
21        // SequenceArray represents arithmetic sequences (base + i * multiplier) which
22        // only makes sense for integer types. Floating-point sequences would accumulate
23        // rounding errors, and other types don't support arithmetic operations.
24        let DType::Primitive(target_ptype, target_nullability) = dtype else {
25            return Ok(None);
26        };
27
28        if !target_ptype.is_int() {
29            return Ok(None);
30        }
31
32        // Check if this is just a nullability change
33        if array.ptype() == *target_ptype && array.dtype().nullability() != *target_nullability {
34            // For SequenceArray, we can just create a new one with the same parameters
35            // but different nullability
36            return Ok(Some(
37                SequenceArray::new(
38                    array.base(),
39                    array.multiplier(),
40                    *target_ptype,
41                    *target_nullability,
42                    array.len(),
43                )?
44                .into_array(),
45            ));
46        }
47
48        // For type changes, we need to cast the base and multiplier
49        if array.ptype() != *target_ptype {
50            // Create scalars from PValues and cast them
51            let base_scalar = Scalar::new(
52                DType::Primitive(array.ptype(), Nullability::NonNullable),
53                ScalarValue::from(array.base()),
54            );
55            let multiplier_scalar = Scalar::new(
56                DType::Primitive(array.ptype(), Nullability::NonNullable),
57                ScalarValue::from(array.multiplier()),
58            );
59
60            let new_base_scalar =
61                base_scalar.cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
62            let new_multiplier_scalar = multiplier_scalar
63                .cast(&DType::Primitive(*target_ptype, Nullability::NonNullable))?;
64
65            // Extract PValues from the casted scalars
66            let new_base = new_base_scalar
67                .as_primitive()
68                .pvalue()
69                .ok_or_else(|| vortex_err!("Cast resulted in null base value"))?;
70            let new_multiplier = new_multiplier_scalar
71                .as_primitive()
72                .pvalue()
73                .ok_or_else(|| vortex_err!("Cast resulted in null multiplier value"))?;
74
75            return Ok(Some(
76                SequenceArray::new(
77                    new_base,
78                    new_multiplier,
79                    *target_ptype,
80                    *target_nullability,
81                    array.len(),
82                )?
83                .into_array(),
84            ));
85        }
86
87        Ok(None)
88    }
89}
90
91register_kernel!(CastKernelAdapter(SequenceVTable).lift());
92
93#[cfg(test)]
94mod tests {
95    use rstest::rstest;
96    use vortex_array::ToCanonical;
97    use vortex_array::arrays::PrimitiveArray;
98    use vortex_array::assert_arrays_eq;
99    use vortex_array::compute::cast;
100    use vortex_array::compute::conformance::cast::test_cast_conformance;
101    use vortex_dtype::DType;
102    use vortex_dtype::Nullability;
103    use vortex_dtype::PType;
104
105    use crate::SequenceArray;
106
107    #[test]
108    fn test_cast_sequence_nullability() {
109        let sequence = SequenceArray::typed_new(0u32, 1u32, Nullability::NonNullable, 4).unwrap();
110
111        // Cast to nullable
112        let casted = cast(
113            sequence.as_ref(),
114            &DType::Primitive(PType::U32, Nullability::Nullable),
115        )
116        .unwrap();
117        assert_eq!(
118            casted.dtype(),
119            &DType::Primitive(PType::U32, Nullability::Nullable)
120        );
121    }
122
123    #[test]
124    fn test_cast_sequence_u32_to_i64() {
125        let sequence =
126            SequenceArray::typed_new(100u32, 10u32, Nullability::NonNullable, 4).unwrap();
127
128        let casted = cast(
129            sequence.as_ref(),
130            &DType::Primitive(PType::I64, Nullability::NonNullable),
131        )
132        .unwrap();
133        assert_eq!(
134            casted.dtype(),
135            &DType::Primitive(PType::I64, Nullability::NonNullable)
136        );
137
138        // Verify the values
139        let decoded = casted.to_primitive();
140        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([100i64, 110, 120, 130]));
141    }
142
143    #[test]
144    fn test_cast_sequence_i16_to_i32_nullable() {
145        // Test ptype change AND nullability change in one cast
146        let sequence = SequenceArray::typed_new(5i16, 3i16, Nullability::NonNullable, 3).unwrap();
147
148        let casted = cast(
149            sequence.as_ref(),
150            &DType::Primitive(PType::I32, Nullability::Nullable),
151        )
152        .unwrap();
153        assert_eq!(
154            casted.dtype(),
155            &DType::Primitive(PType::I32, Nullability::Nullable)
156        );
157
158        // Verify the values
159        let decoded = casted.to_primitive();
160        assert_arrays_eq!(
161            decoded,
162            PrimitiveArray::from_option_iter([Some(5i32), Some(8), Some(11)])
163        );
164    }
165
166    #[test]
167    fn test_cast_sequence_to_float_delegates_to_canonical() {
168        let sequence = SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap();
169
170        // Cast to float should delegate to canonical (SequenceArray doesn't support float)
171        let casted = cast(
172            sequence.as_ref(),
173            &DType::Primitive(PType::F32, Nullability::NonNullable),
174        )
175        .unwrap();
176        // Should still succeed by decoding to canonical first
177        assert_eq!(
178            casted.dtype(),
179            &DType::Primitive(PType::F32, Nullability::NonNullable)
180        );
181
182        // Verify the values were correctly converted
183        let decoded = casted.to_primitive();
184        assert_arrays_eq!(
185            decoded,
186            PrimitiveArray::from_iter([0.0f32, 1.0, 2.0, 3.0, 4.0])
187        );
188    }
189
190    #[rstest]
191    #[case::i32(SequenceArray::typed_new(0i32, 1i32, Nullability::NonNullable, 5).unwrap())]
192    #[case::u64(SequenceArray::typed_new(1000u64, 100u64, Nullability::NonNullable, 4).unwrap())]
193    // TODO(DK): SequenceArray does not actually conform. You cannot cast this array to u8 even
194    // though all its values are representable therein.
195    //
196    // #[case::negative_step(SequenceArray::typed_new(100i32, -10i32, Nullability::NonNullable,
197    // 5).unwrap())]
198    #[case::single(SequenceArray::typed_new(42i64, 0i64, Nullability::NonNullable, 1).unwrap())]
199    #[case::constant(SequenceArray::typed_new(
200        100i32,
201        0i32, // multiplier of 0 means constant array
202        Nullability::NonNullable,
203        5,
204    ).unwrap())]
205    fn test_cast_sequence_conformance(#[case] sequence: SequenceArray) {
206        test_cast_conformance(sequence.as_ref());
207    }
208}